diff --git a/examples/run_sft.py b/examples/run_sft.py index 4e80414f8d..69f3128231 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -17,7 +17,6 @@ import pprint from functools import partial -from datasets import concatenate_datasets from omegaconf import OmegaConf from transformers import AutoTokenizer @@ -29,6 +28,7 @@ load_response_dataset, update_single_dataset_config, ) +from nemo_rl.data.utils import merge_datasets from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.utils.config import ( load_config, @@ -89,7 +89,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): if hasattr(data, "preprocessor") and data.preprocessor is not None: task_data_preprocessors[data.task_name] = data.preprocessor - merged_data = concatenate_datasets([data.dataset for data in data_list]) + merged_data = merge_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( merged_data, tokenizer, @@ -144,7 +144,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): val_dataset = None if len(val_data_list) > 0: - merged_val_data = concatenate_datasets(val_data_list) + merged_val_data = merge_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( merged_val_data, tokenizer, diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 2819e27582..cbfc50cf52 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -14,7 +14,7 @@ from typing import Any, Optional, Union -from datasets import concatenate_datasets +from datasets import Dataset, concatenate_datasets from transformers import AutoProcessor, AutoTokenizer from nemo_rl.data import DataConfig @@ -25,11 +25,27 @@ load_response_dataset, update_single_dataset_config, ) +from nemo_rl.data.datasets.response_datasets.oai_format_dataset import ( + PreservingDataset, +) from nemo_rl.data.processors import preference_preprocessor from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.environments.utils import create_env +def merge_datasets(datasets: list) -> Union[Dataset, "PreservingDataset"]: + """Merge a list of datasets, handling both HuggingFace Dataset and PreservingDataset. + + HuggingFace's ``concatenate_datasets`` does not accept ``PreservingDataset`` objects. + This helper detects the dataset types and merges them appropriately. + """ + if all(isinstance(d, PreservingDataset) for d in datasets): + merged_data = [item for d in datasets for item in d.data] + return PreservingDataset(merged_data) + + return concatenate_datasets(datasets) + + # TODO: @yukih: unify to setup_data after dataset refactored def setup_response_data( tokenizer: AutoProcessor | AutoTokenizer, @@ -134,7 +150,7 @@ def setup_response_data( } else: # merge datasets into a single dataset - merged_data = concatenate_datasets([data.dataset for data in data_list]) + merged_data = merge_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( merged_data, tokenizer, @@ -199,7 +215,7 @@ def setup_response_data( # merge datasets val_dataset = None if len(val_data_list) > 0: - merged_val_data = concatenate_datasets(val_data_list) + merged_val_data = merge_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( merged_val_data, tokenizer, diff --git a/tests/unit/data/datasets/test_preserving_dataset.py b/tests/unit/data/datasets/test_preserving_dataset.py index 9c16a6ffeb..9c06ef5b00 100644 --- a/tests/unit/data/datasets/test_preserving_dataset.py +++ b/tests/unit/data/datasets/test_preserving_dataset.py @@ -313,3 +313,59 @@ def test_comparison_with_standard_dataset(self): preserving_dataset = PreservingDataset(data) assert preserving_dataset[0]["tool_id"] == "123" assert "tool_id" not in preserving_dataset[1] # Key doesn't exist + + +class TestMergeDatasets: + """Test merge_datasets helper that handles both HF Dataset and PreservingDataset.""" + + def test_merge_preserving_datasets(self): + """Test merging multiple PreservingDatasets.""" + from nemo_rl.data.utils import merge_datasets + + ds1 = PreservingDataset([{"a": 1}, {"b": 2}]) + ds2 = PreservingDataset([{"c": 3}]) + + merged = merge_datasets([ds1, ds2]) + + assert isinstance(merged, PreservingDataset) + assert len(merged) == 3 + assert merged[0] == {"a": 1} + assert merged[1] == {"b": 2} + assert merged[2] == {"c": 3} + + def test_merge_hf_datasets(self): + """Test merging standard HuggingFace Datasets still works.""" + from nemo_rl.data.utils import merge_datasets + + ds1 = Dataset.from_list([{"x": 1}, {"x": 2}]) + ds2 = Dataset.from_list([{"x": 3}]) + + merged = merge_datasets([ds1, ds2]) + + assert isinstance(merged, Dataset) + assert len(merged) == 3 + assert merged[0]["x"] == 1 + assert merged[2]["x"] == 3 + + def test_merge_single_preserving_dataset(self): + """Test merging a single PreservingDataset.""" + from nemo_rl.data.utils import merge_datasets + + ds = PreservingDataset([{"a": 1, "b": 2}, {"c": 3}]) + + merged = merge_datasets([ds]) + + assert isinstance(merged, PreservingDataset) + assert len(merged) == 2 + + def test_merge_preserving_datasets_preserves_heterogeneous_structure(self): + """Test that merging PreservingDatasets doesn't introduce None-filling.""" + from nemo_rl.data.utils import merge_datasets + + ds1 = PreservingDataset([{"role": "user", "content": "hi", "tool_id": "1"}]) + ds2 = PreservingDataset([{"role": "assistant", "content": "hello"}]) + + merged = merge_datasets([ds1, ds2]) + + assert "tool_id" in merged[0] + assert "tool_id" not in merged[1] # No None-filling