Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lm_eval/tasks/coqa/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ dataset_path: EleutherAI/coqa
output_type: generate_until
training_split: train
validation_split: validation
process_docs: !function utils.process_docs
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "{{story}} {{question.input_text|join('\n')}}"
doc_to_decontamination_query: "{{story}} {{questions|join('\n')}}"
generation_kwargs:
until:
- "\nQ:"
Expand All @@ -19,4 +20,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 3.0
version: 4.0
94 changes: 77 additions & 17 deletions lm_eval/tasks/coqa/utils.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,96 @@
from itertools import zip_longest
from datasets import Dataset

import transformers.data.metrics.squad_metrics as squad_metrics


def process_docs(dataset):
"""
Expand each CoQA conversation into multiple instances, one per turn.
Each instance contains the story and conversation history up to that turn.
"""

def _expand_turns(doc):
"""Expand a single document into multiple turns."""
story = doc["story"]
questions = doc["questions"]["input_text"]
answers = doc["answers"]["input_text"]
additional_answers = doc.get("additional_answers", {})

# Create lists to store all turns
expanded = {
"story": [],
"questions": [],
"answers": [],
"additional_answers": [],
"turn_id": [],
}

# Create one instance per turn
for turn_idx in range(len(questions)):
expanded["story"].append(story)
# Store questions and answers up to and including this turn
expanded["questions"].append(questions[: turn_idx + 1])
expanded["answers"].append(answers[: turn_idx + 1])
expanded["turn_id"].append(turn_idx)

# Handle additional answers for this turn
turn_additional = {}
if additional_answers:
for key, value in additional_answers.items():
if "input_text" in value:
turn_additional[key] = value["input_text"][turn_idx]
expanded["additional_answers"].append(turn_additional)

return expanded

# Apply the expansion
dataset = dataset.map(
_expand_turns,
remove_columns=[
key for key in dataset.features.keys() if key not in ["story"]
],
)

# Flatten the lists
new_dataset = {}
for key in dataset.features.keys():
new_dataset[key] = [x for row in dataset[key] for x in row]

return Dataset.from_dict(new_dataset)


def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n"
for q, a in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer

questions = doc["questions"]
answers = doc["answers"]

# Add conversation history (all Q&A pairs except the last answer)
for i in range(len(questions) - 1):
doc_text += f"Q: {questions[i]}\n\n"
doc_text += f"A: {answers[i]}\n\n"

# Add the current question without its answer
doc_text += f"Q: {questions[-1]}\n\nA:"

return doc_text


def doc_to_target(doc):
turn_id = len(doc["questions"]["input_text"])
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn)
# The target is the last answer in this turn's history
answer_for_turn = doc["answers"][-1]
answers.append(answer_for_turn)

additional_answers = doc.get("additional_answers")
additional_answers = doc.get("additional_answers", {})
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
for key, value in additional_answers.items():
if value and value.lower() not in map(str.lower, answers):
answers.append(value)

return answers


Expand Down