diff --git a/run_squad.py b/run_squad.py index edd4c3ed9..20d862fe3 100644 --- a/run_squad.py +++ b/run_squad.py @@ -903,8 +903,11 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, all_predictions[example.qas_id] = nbest_json[0]["text"] else: # predict "" iff the null score - the score of best non-null > threshold - score_diff = score_null - best_non_null_entry.start_logit - ( - best_non_null_entry.end_logit) + try: + score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit) + except: + score_diff = score_null + scores_diff_json[example.qas_id] = score_diff if score_diff > FLAGS.null_score_diff_threshold: all_predictions[example.qas_id] = ""