diff --git a/python/training_infra/builder_actions_data_loader.py b/python/training_infra/builder_actions_data_loader.py index 513afba2..b896fe88 100644 --- a/python/training_infra/builder_actions_data_loader.py +++ b/python/training_infra/builder_actions_data_loader.py @@ -1,5 +1,6 @@ import argparse, pprint as pp, torch, sys, numpy as np - +from collections import defaultdict +import copy sys.path.append('..') from utils import * @@ -10,7 +11,7 @@ class BuilderActionsDataset(CwCDataset): def __init__(self, itemize_args, **kwargs): if not kwargs["items_only"]: # samples-only and jsons-only mode - extra_args = ["items_only", "load_items", "dump_items", "development_mode"] + extra_args = ["items_only", "load_items", "dump_items", "development_mode", "split_questions"] kwargs_super = {i: kwargs[i] for i in kwargs if i not in extra_args} super(BuilderActionsDataset, self).__init__(**kwargs_super) else: @@ -50,7 +51,7 @@ def __init__(self, itemize_args, **kwargs): self.item_batches = self.item_batches[:5] else: # load samples - extra_args = ["items_only", "load_items", "dump_items", "development_mode"] + extra_args = ["items_only", "load_items", "dump_items", "development_mode", "split_questions"] kwargs_super = {i: kwargs[i] for i in kwargs if i not in extra_args} super(BuilderActionsDataset, self).__init__(**kwargs_super) @@ -67,6 +68,7 @@ def __init__(self, itemize_args, **kwargs): self.action_history_weighting_scheme = itemize_args.action_history_weighting_scheme self.concatenate_action_history_weight = itemize_args.concatenate_action_history_weight self.two_dim_posterior = itemize_args.two_dim_posterior + self.split_questions = kwargs['split_questions'] # generate items if kwargs["split"] == "train": # items @@ -84,6 +86,10 @@ def __init__(self, itemize_args, **kwargs): else: # item batches print("Generating item batches...") + if kwargs['split_questions']: + print("Splitting on question data") + self.samples = self.split_qs() + loader = self.get_data_loader( batch_size=itemize_args.batch_size, shuffle=False, num_workers=itemize_args.num_workers ) @@ -109,12 +115,100 @@ def __init__(self, itemize_args, **kwargs): print("\nSaving git commit hashes ...\n") write_commit_hashes("../..", kwargs["saved_dataset_dir"], filepath_modifier="_items_" + kwargs["split"]) + def is_sorted(self, lst): + """ + Check if the provided list is sorted numberically (recursively) + """ + if len(lst) == 1: + return True + return lst[0] <= lst[1] and self.is_sorted(lst[1:]) + + def create_question_sample(self, sample, utterances): + """ + Create a new sample object based on another, but with a question label + """ + new_sample = copy.copy(sample) + new_sample['prev_utterances'] = copy.copy(utterances) + new_sample['question'] = True + # If we remove the next builder actions we destroy the __getitem__ + # interface.. However, we should probably do it to make sure we don't + # provide the agent with too much information? + # new_sample['next_builder_actions'] = [] + + return new_sample + + def check_zeroth_sample(self, sample): + """ + Check if the first step in an episode contains questions from the + builder and create additional samples if there are. + """ + intermediate_samples = [] + recreate_utterances = [] + for utterance in sample['prev_utterances']: + if utterance['speaker'] == 'Builder' and '?' in utterance['utterance']: + new_sample = self.create_question_sample(sample, recreate_utterances) + intermediate_samples.append(new_sample) + recreate_utterances.append(utterance) + sample['question'] = False + intermediate_samples.append(sample) + return intermediate_samples + + def check_step_diff(self, sample_prev, sample_new): + """ + Check if there were questions asked in the chat difference between two + steps (_prev and _new), and create additional samples if there are. + """ + chat_prev_len = len(sample_prev['prev_utterances']) + chat_diff = sample_new['prev_utterances'][chat_prev_len:] + intermediate_samples = [] + recreate_utterances = sample_prev['prev_utterances'] + for utterance in chat_diff: + if utterance['speaker'] == 'Builder' and '?' in utterance['utterance']: + new_sample = self.create_question_sample(sample_new, recreate_utterances) + intermediate_samples.append(new_sample) + recreate_utterances.append(utterance) + sample_new['question'] = False + intermediate_samples.append(sample_new) + return intermediate_samples + + def extract_questions_in_episode(self, samples_in_episode): + """ + Extract samples from episode steps where the builder asked a question + """ + sample_ids = [x['sample_id'] for x in samples_in_episode] + assert self.is_sorted(sample_ids) # episode should be ordered + + new_samples = [] + zeroth_samples = self.check_zeroth_sample(samples_in_episode[0]) + new_samples.extend(zeroth_samples) + for i in range(len(samples_in_episode)-1): + intermediate_samples = self.check_step_diff(samples_in_episode[i], samples_in_episode[i+1]) + new_samples.extend(intermediate_samples) + return new_samples + + def split_qs(self): + """ + Split the dataset on questions in the chat asked by the builder agent + + Creates new (empty action) samples for question actions, inflating the + dataset. + """ + print(f"Length before splitting: {len(self.samples)}") + episode_data = defaultdict(lambda: []) + for sample in self.samples: + episode_data[sample['json_id']].append(sample) + new_samples = [] + for _, episode in episode_data.items(): + new_episode_samples = self.extract_questions_in_episode(episode) + new_samples.extend(new_episode_samples) + print(f"Length after splitting: {len(new_samples)}") + return new_samples + def __getitem__(self, idx): """ Computes the tensor representations of a sample """ orig_sample = self.samples[idx] all_actions = orig_sample["next_builder_actions"] - # print(all_actions) perspective_coords = orig_sample["perspective_coordinates"] initial_prev_config_raw = all_actions[0].prev_config @@ -262,18 +356,30 @@ def __getitem__(self, idx): prev_utterances.append(start_token) prev_utterances.extend(self.encoder_vocab(token) for token in utterance) prev_utterances.append(end_token) - - return ( - torch.Tensor(prev_utterances), - torch.stack(dec_inputs_1), - torch.stack(dec_inputs_2), - torch.Tensor(dec_outputs), - RawInputs(initial_prev_config_raw, initial_action_history_raw, end_built_config_raw, perspective_coords) - ) + if self.split_questions: + return ( + torch.Tensor(prev_utterances), + torch.stack(dec_inputs_1), + torch.stack(dec_inputs_2), + torch.Tensor(dec_outputs), + RawInputs(initial_prev_config_raw, initial_action_history_raw, end_built_config_raw, perspective_coords), + orig_sample['question'] + ) + else: + return ( + torch.Tensor(prev_utterances), + torch.stack(dec_inputs_1), + torch.stack(dec_inputs_2), + torch.Tensor(dec_outputs), + RawInputs(initial_prev_config_raw, initial_action_history_raw, end_built_config_raw, perspective_coords), + ) def collate_fn(self, data): # NOTE: assumes batch size = 1 - prev_utterances, dec_inputs_1, dec_inputs_2, dec_outputs, raw_inputs = zip(*data) + if self.split_questions: + prev_utterances, dec_inputs_1, dec_inputs_2, dec_outputs, raw_inputs, question_label = zip(*data) + else: + prev_utterances, dec_inputs_1, dec_inputs_2, dec_outputs, raw_inputs = zip(*data) def merge_text(sequences): lengths = [len(seq) for seq in sequences] @@ -524,6 +630,7 @@ def __init__(self, initial_prev_config_raw, initial_action_history_raw, end_buil parser.add_argument('--use_builder_actions', default=False, action='store_true', help='include builder action tokens in the dialogue history') parser.add_argument('--add_perspective_coords', default=False, action='store_true', help='whether or not to include perspective coords in world state repr') + parser.add_argument('--split_questions', default=False, action='store_true', help='split the dataset also on question actions') args = parser.parse_args() @@ -554,5 +661,18 @@ def __init__(self, initial_prev_config_raw, initial_action_history_raw, end_buil items_only=True, load_items=args.load_items, dump_items=args.dump_items, - development_mode=args.development_mode + development_mode=args.development_mode, + split_questions=args.split_questions ) + + # Format of item in dataset + # torch.Tensor(prev_utterances), + # torch.stack(dec_inputs_1), + # torch.stack(dec_inputs_2), + # torch.Tensor(dec_outputs), + # RawInputs(initial_prev_config_raw, initial_action_history_raw, end_built_config_raw, perspective_coords) + # + print(dataset[0][0]) + print(dataset[0][-1]) + print(dataset[0][-2].initial_prev_config_raw) + print(dataset[0][-2].end_built_config_raw) \ No newline at end of file