Skip to content
7 changes: 7 additions & 0 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,13 @@ def log_rollout_data(
"rollout_routed_experts",
"max_seq_lens",
"dynamic_global_batch_size",
# rollout-source aggregates emitted by RolloutManager._log_rollout_data
# (pre-filter); skip here to keep one wandb writer per key.
"raw_reward",
"rewards",
"truncated",
"response_lengths",
"total_lengths",
]:
continue
# Upload per sample mean for each rollout value
Expand Down
96 changes: 75 additions & 21 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,22 @@ def generate(self, rollout_id):
self._try_ci_fault_injection()
data, metrics = self._get_rollout_data(rollout_id=rollout_id)
self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=False)
_log_rollout_data(rollout_id, self.args, data, metrics, time.time() - start_time)

raw_rewards, rewards = self._post_process_rewards(data)
_log_rollout_data(rollout_id, self.args, data, raw_rewards, rewards, metrics, time.time() - start_time)
if self.args.debug_rollout_only:
# if debug rollout only, we don't convert samples to train data and directly return
return
data = self._convert_samples_to_train_data(data)

if self.args.filter_zero_advantage_samples:
data, raw_rewards, rewards = self._filter_zero_advantage_samples(data, raw_rewards, rewards)
self._dynamic_global_batch_size = self._compute_dynamic_global_batch_size(len(data))

if self.custom_convert_samples_to_train_data_func is not None:
data = self.custom_convert_samples_to_train_data_func(self.args, data, raw_rewards, rewards)
else:
data = self._convert_samples_to_train_data(data, raw_rewards, rewards)

return self._split_train_data_by_dp(data, self.train_parallel_config["dp_size"])

def eval(self, rollout_id):
Expand Down Expand Up @@ -679,15 +690,51 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):

return raw_rewards, raw_rewards

def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sample]]):
"""
Convert inference generated samples to training data.
"""
if self.custom_convert_samples_to_train_data_func is not None:
return self.custom_convert_samples_to_train_data_func(self.args, samples)

raw_rewards, rewards = self._post_process_rewards(samples)

def _filter_zero_advantage_samples(self, samples, raw_rewards, rewards):
"""Drop zero-advantage samples; pad to dp_size with dropped samples when survivors fall short."""
dp_size = self.train_parallel_config["dp_size"]
total = len(samples)
nonzero_idx = [i for i, r in enumerate(rewards) if r != 0.0]
zero_idx = [i for i, r in enumerate(rewards) if r == 0.0]

if len(nonzero_idx) >= dp_size:
keep_count = (len(nonzero_idx) // dp_size) * dp_size
kept_idx = nonzero_idx[:keep_count]
padding_count = 0
else:
padding_count = dp_size - len(nonzero_idx)
assert len(zero_idx) >= padding_count, (
f"Not enough samples to pad to dp_size={dp_size}: "
f"total={total} nonzero={len(nonzero_idx)} zero={len(zero_idx)}"
)
pad_idx = zero_idx[:padding_count]
for i in pad_idx:
samples[i].remove_sample = True
kept_idx = nonzero_idx + pad_idx

kept_samples = [samples[i] for i in kept_idx]
kept_raw = [raw_rewards[i] for i in kept_idx]
kept_rewards = [rewards[i] for i in kept_idx]

log_dict = {
"rollout/filter/total": total,
"rollout/filter/kept": len(kept_samples),
"rollout/filter/dropped_ratio": (total - len(kept_samples)) / total,
"rollout/filter/zero_advantage_ratio": len(zero_idx) / total,
"rollout/filter/padding_count": padding_count,
"rollout/step": compute_rollout_step(self.args, self.rollout_id),
}
logger.info(f"filter {self.rollout_id}: {log_dict}")
logging_utils.log(self.args, log_dict, step_key="rollout/step")
return kept_samples, kept_raw, kept_rewards

def _convert_samples_to_train_data(
self,
samples: list[Sample] | list[list[Sample]],
raw_rewards: list[float],
rewards: list[float],
):
"""Convert inference generated samples to training data."""
assert len(raw_rewards) == len(samples)
assert len(rewards) == len(samples)

Expand All @@ -697,7 +744,7 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
# some reward model, e.g. remote rm, may return multiple rewards,
# we could use key to select the reward.
"rewards": rewards,
"raw_reward": raw_rewards,
"raw_reward": _resolve_raw_rewards(samples, raw_rewards),
"truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples],
"sample_indices": [sample.index for sample in samples],
}
Expand All @@ -718,14 +765,6 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
loss_masks.append(sample.loss_mask)
train_data["loss_masks"] = loss_masks

# Overwrite raw_reward when available. Mixed-source batches may only
# populate this field for a subset of samples (e.g. SWE but not code).
if any(sample.metadata and "raw_reward" in sample.metadata for sample in samples):
train_data["raw_reward"] = [
sample.metadata["raw_reward"] if sample.metadata and "raw_reward" in sample.metadata else sample.reward
for sample in samples
]

# For rollout buffer
if samples[0].metadata and "round_number" in samples[0].metadata:
train_data["round_number"] = [sample.metadata["round_number"] for sample in samples]
Expand Down Expand Up @@ -1172,7 +1211,7 @@ def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any]
return log_dict


def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_time):
def _log_rollout_data(rollout_id, args, samples, raw_rewards, rewards, rollout_extra_metrics, rollout_time):
if args.custom_rollout_log_function_path is not None:
custom_log_func = load_function(args.custom_rollout_log_function_path)
if custom_log_func(rollout_id, args, samples, rollout_extra_metrics, rollout_time):
Expand All @@ -1181,15 +1220,30 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_
if args.load_debug_rollout_data:
return

num_samples = len(samples)
log_dict = {**(rollout_extra_metrics or {})}
log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), "rollout/")
log_dict |= dict_add_prefix(compute_perf_metrics_from_samples(args, samples, rollout_time), "perf/")
log_dict["rollout/raw_reward"] = sum(_resolve_raw_rewards(samples, raw_rewards)) / num_samples
log_dict["rollout/rewards"] = sum(rewards) / num_samples
log_dict["rollout/response_lengths"] = sum(s.response_length for s in samples) / num_samples
log_dict["rollout/total_lengths"] = sum(len(s.tokens) for s in samples) / num_samples
logger.info(f"perf {rollout_id}: {log_dict}")
step = compute_rollout_step(args, rollout_id)
log_dict["rollout/step"] = step
logging_utils.log(args, log_dict, step_key="rollout/step")


def _resolve_raw_rewards(samples, raw_rewards):
"""Apply the sample.metadata['raw_reward'] override used by train_data."""
if not any(sample.metadata and "raw_reward" in sample.metadata for sample in samples):
return raw_rewards
return [
sample.metadata["raw_reward"] if sample.metadata and "raw_reward" in sample.metadata else sample.reward
for sample in samples
]


def compute_metrics_from_samples(args, samples):
response_lengths = [sample.effective_response_length for sample in samples]

Expand Down
20 changes: 19 additions & 1 deletion slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,8 @@ def add_reward_model_arguments(parser):
help=(
"Path to a custom function that converts samples to training data. "
"If set, this function will replace the default _convert_samples_to_train_data. "
"The function should have the signature `def convert_samples_to_train_data(args, samples) -> dict`."
"The function should have the signature "
"`def convert_samples_to_train_data(args, samples, raw_rewards, rewards) -> dict`."
),
)
return parser
Expand Down Expand Up @@ -1307,6 +1308,18 @@ def add_rollout_buffer_arguments(parser):
default=False,
help="enable dynamic global batch size, disable trim samples in rollout buffer when converting samples to train data",
)
parser.add_argument(
"--filter-zero-advantage-samples",
action="store_true",
default=False,
help=(
"Drop samples whose post-processed reward (i.e. advantage in GRPO/GSPO with "
"normalization) is 0 before training, since they contribute zero gradient. "
"When fewer than dp_size non-zero samples survive, pad back to dp_size with "
"dropped zero-advantage samples whose loss mask is zeroed via remove_sample. "
"Requires --use-dynamic-global-batch-size."
),
)
return parser

def add_custom_megatron_plugins_arguments(parser):
Expand Down Expand Up @@ -1699,6 +1712,11 @@ def slime_validate_args(args):
if args.log_probs_max_tokens_per_gpu is None:
args.log_probs_max_tokens_per_gpu = args.max_tokens_per_gpu

if args.filter_zero_advantage_samples:
assert args.use_dynamic_global_batch_size, (
"--filter-zero-advantage-samples requires --use-dynamic-global-batch-size so the global "
"batch size adapts to the post-filter sample count."
)
if args.eps_clip_high is None:
args.eps_clip_high = args.eps_clip

Expand Down
13 changes: 7 additions & 6 deletions tests/plugin_contracts/test_plugin_runtime_hook_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def reference_reward_post_process(args, samples):
return raw_rewards, rewards


def reference_convert_samples_to_train_data(args, samples):
def reference_convert_samples_to_train_data(args, samples, raw_rewards, rewards):
return {
"tokens": [sample.tokens for sample in samples],
"response_lengths": [sample.response_length for sample in samples],
"rewards": [sample.reward for sample in samples],
"raw_reward": [sample.reward for sample in samples],
"rewards": rewards,
"raw_reward": raw_rewards,
"truncated": [0 for _ in samples],
"sample_indices": [sample.index for sample in samples],
"loss_masks": [sample.loss_mask for sample in samples],
Expand Down Expand Up @@ -116,7 +116,8 @@ def invoke_reward_post_process(fn):


def invoke_convert_samples_to_train_data(fn):
train_data = fn(type("Args", (), {})(), [make_sample(0, 0.5), make_sample(1, 1.5)])
samples = [make_sample(0, 0.5), make_sample(1, 1.5)]
train_data = fn(type("Args", (), {})(), samples, [0.5, 1.5], [0.5, 1.5])
assert {"tokens", "response_lengths", "rewards", "raw_reward", "truncated", "sample_indices", "loss_masks"} <= set(
train_data
)
Expand Down Expand Up @@ -161,8 +162,8 @@ def invoke_rollout_data_postprocess(fn):
"CUSTOM_CONVERT_SAMPLES_TO_TRAIN_DATA_PATH",
"plugin_contracts.test_plugin_runtime_hook_contracts.reference_convert_samples_to_train_data",
"slime/ray/rollout.py",
"self.custom_convert_samples_to_train_data_func(self.args, samples)",
("args", "samples"),
"self.custom_convert_samples_to_train_data_func(self.args, data, raw_rewards, rewards)",
("args", "samples", "raw_rewards", "rewards"),
invoke_convert_samples_to_train_data,
),
HookCase(
Expand Down