diff --git a/lm_eval/api/metrics.py b/lm_eval/api/metrics.py index 0f142f750b4..0387d452402 100644 --- a/lm_eval/api/metrics.py +++ b/lm_eval/api/metrics.py @@ -411,12 +411,13 @@ def acc_all_stderr(items): docs = list(zip(*items))[1] for doc, pred in zip(docs, preds): + paragraph_id = doc["idx"]["paragraph"] question_id = doc["idx"]["question"] - if question_id not in question_scoring_dict: - question_scoring_dict[question_id] = [] + if (paragraph_id, question_id) not in question_scoring_dict: + question_scoring_dict[(paragraph_id, question_id)] = [] gold_label = doc["label"] == 1 - question_scoring_dict[question_id].append(gold_label == pred) + question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()]) return acc