Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 134 additions & 14 deletions python/training_infra/builder_actions_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse, pprint as pp, torch, sys, numpy as np

from collections import defaultdict
import copy
sys.path.append('..')

from utils import *
Expand All @@ -10,7 +11,7 @@ class BuilderActionsDataset(CwCDataset):
def __init__(self, itemize_args, **kwargs):
if not kwargs["items_only"]:
# samples-only and jsons-only mode
extra_args = ["items_only", "load_items", "dump_items", "development_mode"]
extra_args = ["items_only", "load_items", "dump_items", "development_mode", "split_questions"]
kwargs_super = {i: kwargs[i] for i in kwargs if i not in extra_args}
super(BuilderActionsDataset, self).__init__(**kwargs_super)
else:
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(self, itemize_args, **kwargs):
self.item_batches = self.item_batches[:5]
else:
# load samples
extra_args = ["items_only", "load_items", "dump_items", "development_mode"]
extra_args = ["items_only", "load_items", "dump_items", "development_mode", "split_questions"]
kwargs_super = {i: kwargs[i] for i in kwargs if i not in extra_args}
super(BuilderActionsDataset, self).__init__(**kwargs_super)

Expand All @@ -67,6 +68,7 @@ def __init__(self, itemize_args, **kwargs):
self.action_history_weighting_scheme = itemize_args.action_history_weighting_scheme
self.concatenate_action_history_weight = itemize_args.concatenate_action_history_weight
self.two_dim_posterior = itemize_args.two_dim_posterior
self.split_questions = kwargs['split_questions']

# generate items
if kwargs["split"] == "train": # items
Expand All @@ -84,6 +86,10 @@ def __init__(self, itemize_args, **kwargs):
else: # item batches
print("Generating item batches...")

if kwargs['split_questions']:
print("Splitting on question data")
self.samples = self.split_qs()

loader = self.get_data_loader(
batch_size=itemize_args.batch_size, shuffle=False, num_workers=itemize_args.num_workers
)
Expand All @@ -109,12 +115,100 @@ def __init__(self, itemize_args, **kwargs):
print("\nSaving git commit hashes ...\n")
write_commit_hashes("../..", kwargs["saved_dataset_dir"], filepath_modifier="_items_" + kwargs["split"])

def is_sorted(self, lst):
"""
Check if the provided list is sorted numberically (recursively)
"""
if len(lst) == 1:
return True
return lst[0] <= lst[1] and self.is_sorted(lst[1:])

def create_question_sample(self, sample, utterances):
"""
Create a new sample object based on another, but with a question label
"""
new_sample = copy.copy(sample)
new_sample['prev_utterances'] = copy.copy(utterances)
new_sample['question'] = True
# If we remove the next builder actions we destroy the __getitem__
# interface.. However, we should probably do it to make sure we don't
# provide the agent with too much information?
# new_sample['next_builder_actions'] = []

return new_sample

def check_zeroth_sample(self, sample):
"""
Check if the first step in an episode contains questions from the
builder and create additional samples if there are.
"""
intermediate_samples = []
recreate_utterances = []
for utterance in sample['prev_utterances']:
if utterance['speaker'] == 'Builder' and '?' in utterance['utterance']:
new_sample = self.create_question_sample(sample, recreate_utterances)
intermediate_samples.append(new_sample)
recreate_utterances.append(utterance)
sample['question'] = False
intermediate_samples.append(sample)
return intermediate_samples

def check_step_diff(self, sample_prev, sample_new):
"""
Check if there were questions asked in the chat difference between two
steps (_prev and _new), and create additional samples if there are.
"""
chat_prev_len = len(sample_prev['prev_utterances'])
chat_diff = sample_new['prev_utterances'][chat_prev_len:]
intermediate_samples = []
recreate_utterances = sample_prev['prev_utterances']
for utterance in chat_diff:
if utterance['speaker'] == 'Builder' and '?' in utterance['utterance']:
new_sample = self.create_question_sample(sample_new, recreate_utterances)
intermediate_samples.append(new_sample)
recreate_utterances.append(utterance)
sample_new['question'] = False
intermediate_samples.append(sample_new)
return intermediate_samples

def extract_questions_in_episode(self, samples_in_episode):
"""
Extract samples from episode steps where the builder asked a question
"""
sample_ids = [x['sample_id'] for x in samples_in_episode]
assert self.is_sorted(sample_ids) # episode should be ordered

new_samples = []
zeroth_samples = self.check_zeroth_sample(samples_in_episode[0])
new_samples.extend(zeroth_samples)
for i in range(len(samples_in_episode)-1):
intermediate_samples = self.check_step_diff(samples_in_episode[i], samples_in_episode[i+1])
new_samples.extend(intermediate_samples)
return new_samples

def split_qs(self):
"""
Split the dataset on questions in the chat asked by the builder agent

Creates new (empty action) samples for question actions, inflating the
dataset.
"""
print(f"Length before splitting: {len(self.samples)}")
episode_data = defaultdict(lambda: [])
for sample in self.samples:
episode_data[sample['json_id']].append(sample)
new_samples = []
for _, episode in episode_data.items():
new_episode_samples = self.extract_questions_in_episode(episode)
new_samples.extend(new_episode_samples)
print(f"Length after splitting: {len(new_samples)}")
return new_samples

def __getitem__(self, idx):
""" Computes the tensor representations of a sample """
orig_sample = self.samples[idx]

all_actions = orig_sample["next_builder_actions"]
# print(all_actions)
perspective_coords = orig_sample["perspective_coordinates"]

initial_prev_config_raw = all_actions[0].prev_config
Expand Down Expand Up @@ -262,18 +356,30 @@ def __getitem__(self, idx):
prev_utterances.append(start_token)
prev_utterances.extend(self.encoder_vocab(token) for token in utterance)
prev_utterances.append(end_token)

return (
torch.Tensor(prev_utterances),
torch.stack(dec_inputs_1),
torch.stack(dec_inputs_2),
torch.Tensor(dec_outputs),
RawInputs(initial_prev_config_raw, initial_action_history_raw, end_built_config_raw, perspective_coords)
)
if self.split_questions:
return (
torch.Tensor(prev_utterances),
torch.stack(dec_inputs_1),
torch.stack(dec_inputs_2),
torch.Tensor(dec_outputs),
RawInputs(initial_prev_config_raw, initial_action_history_raw, end_built_config_raw, perspective_coords),
orig_sample['question']
)
else:
return (
torch.Tensor(prev_utterances),
torch.stack(dec_inputs_1),
torch.stack(dec_inputs_2),
torch.Tensor(dec_outputs),
RawInputs(initial_prev_config_raw, initial_action_history_raw, end_built_config_raw, perspective_coords),
)

def collate_fn(self, data):
# NOTE: assumes batch size = 1
prev_utterances, dec_inputs_1, dec_inputs_2, dec_outputs, raw_inputs = zip(*data)
if self.split_questions:
prev_utterances, dec_inputs_1, dec_inputs_2, dec_outputs, raw_inputs, question_label = zip(*data)
else:
prev_utterances, dec_inputs_1, dec_inputs_2, dec_outputs, raw_inputs = zip(*data)

def merge_text(sequences):
lengths = [len(seq) for seq in sequences]
Expand Down Expand Up @@ -524,6 +630,7 @@ def __init__(self, initial_prev_config_raw, initial_action_history_raw, end_buil
parser.add_argument('--use_builder_actions', default=False, action='store_true', help='include builder action tokens in the dialogue history')

parser.add_argument('--add_perspective_coords', default=False, action='store_true', help='whether or not to include perspective coords in world state repr')
parser.add_argument('--split_questions', default=False, action='store_true', help='split the dataset also on question actions')

args = parser.parse_args()

Expand Down Expand Up @@ -554,5 +661,18 @@ def __init__(self, initial_prev_config_raw, initial_action_history_raw, end_buil
items_only=True,
load_items=args.load_items,
dump_items=args.dump_items,
development_mode=args.development_mode
development_mode=args.development_mode,
split_questions=args.split_questions
)

# Format of item in dataset
# torch.Tensor(prev_utterances),
# torch.stack(dec_inputs_1),
# torch.stack(dec_inputs_2),
# torch.Tensor(dec_outputs),
# RawInputs(initial_prev_config_raw, initial_action_history_raw, end_built_config_raw, perspective_coords)
# <question label>
print(dataset[0][0])
print(dataset[0][-1])
print(dataset[0][-2].initial_prev_config_raw)
print(dataset[0][-2].end_built_config_raw)