diff --git a/lm_eval/tasks/coqa/default.yaml b/lm_eval/tasks/coqa/default.yaml index c1ed84f7d19..9bf7fa09147 100644 --- a/lm_eval/tasks/coqa/default.yaml +++ b/lm_eval/tasks/coqa/default.yaml @@ -3,11 +3,12 @@ dataset_path: EleutherAI/coqa output_type: generate_until training_split: train validation_split: validation +process_docs: !function utils.process_docs doc_to_text: !function utils.doc_to_text doc_to_target: !function utils.doc_to_target process_results: !function utils.process_results should_decontaminate: true -doc_to_decontamination_query: "{{story}} {{question.input_text|join('\n')}}" +doc_to_decontamination_query: "{{story}} {{questions|join('\n')}}" generation_kwargs: until: - "\nQ:" @@ -19,4 +20,4 @@ metric_list: aggregation: mean higher_is_better: true metadata: - version: 3.0 + version: 4.0 diff --git a/lm_eval/tasks/coqa/utils.py b/lm_eval/tasks/coqa/utils.py index 29911cfec5c..70112fea5c8 100644 --- a/lm_eval/tasks/coqa/utils.py +++ b/lm_eval/tasks/coqa/utils.py @@ -1,36 +1,96 @@ -from itertools import zip_longest +from datasets import Dataset import transformers.data.metrics.squad_metrics as squad_metrics +def process_docs(dataset): + """ + Expand each CoQA conversation into multiple instances, one per turn. + Each instance contains the story and conversation history up to that turn. + """ + + def _expand_turns(doc): + """Expand a single document into multiple turns.""" + story = doc["story"] + questions = doc["questions"]["input_text"] + answers = doc["answers"]["input_text"] + additional_answers = doc.get("additional_answers", {}) + + # Create lists to store all turns + expanded = { + "story": [], + "questions": [], + "answers": [], + "additional_answers": [], + "turn_id": [], + } + + # Create one instance per turn + for turn_idx in range(len(questions)): + expanded["story"].append(story) + # Store questions and answers up to and including this turn + expanded["questions"].append(questions[: turn_idx + 1]) + expanded["answers"].append(answers[: turn_idx + 1]) + expanded["turn_id"].append(turn_idx) + + # Handle additional answers for this turn + turn_additional = {} + if additional_answers: + for key, value in additional_answers.items(): + if "input_text" in value: + turn_additional[key] = value["input_text"][turn_idx] + expanded["additional_answers"].append(turn_additional) + + return expanded + + # Apply the expansion + dataset = dataset.map( + _expand_turns, + remove_columns=[ + key for key in dataset.features.keys() if key not in ["story"] + ], + ) + + # Flatten the lists + new_dataset = {} + for key in dataset.features.keys(): + new_dataset[key] = [x for row in dataset[key] for x in row] + + return Dataset.from_dict(new_dataset) + + def doc_to_text(doc): # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # and a question qi, the task is to predict the answer ai doc_text = doc["story"] + "\n\n" - for q, a in zip_longest( - doc["questions"]["input_text"], doc["answers"]["input_text"][:-1] - ): # omit target answer ai - question = f"Q: {q}\n\n" - answer = f"A: {a}\n\n" if a is not None else "A:" - doc_text += question + answer + + questions = doc["questions"] + answers = doc["answers"] + + # Add conversation history (all Q&A pairs except the last answer) + for i in range(len(questions) - 1): + doc_text += f"Q: {questions[i]}\n\n" + doc_text += f"A: {answers[i]}\n\n" + + # Add the current question without its answer + doc_text += f"Q: {questions[-1]}\n\nA:" + return doc_text def doc_to_target(doc): - turn_id = len(doc["questions"]["input_text"]) # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). answers = [] - answer_forturn = doc["answers"]["input_text"][turn_id - 1] - answers.append(answer_forturn) + # The target is the last answer in this turn's history + answer_for_turn = doc["answers"][-1] + answers.append(answer_for_turn) - additional_answers = doc.get("additional_answers") + additional_answers = doc.get("additional_answers", {}) if additional_answers: - for key in additional_answers: - additional_answer_for_turn = additional_answers[key]["input_text"][ - turn_id - 1 - ] - if additional_answer_for_turn.lower() not in map(str.lower, answers): - answers.append(additional_answer_for_turn) + for key, value in additional_answers.items(): + if value and value.lower() not in map(str.lower, answers): + answers.append(value) + return answers