diff --git a/rust/src/collector/ppo.rs b/rust/src/collector/ppo.rs index 24dc9df..f042408 100644 --- a/rust/src/collector/ppo.rs +++ b/rust/src/collector/ppo.rs @@ -38,17 +38,16 @@ impl PPOCollector { } impl PPOCollector { - fn get_step_data( + fn get_action_data( &self, env: &dyn Env, policy: &Policy, - ) -> (Vec, Vec, usize, f32, f32, Option) { + ) -> (Vec, Vec, usize, f32, Option) { let obs = env.observe(); // Vec or whatever your Env returns let masks = env.masks(); - let reward = env.reward(); let (logits, value, perm_idx) = policy.forward_with_perm(obs.clone(), masks); let action = sample_from_logits(&logits); - (obs, logits, action, value, reward, perm_idx) + (obs, logits, action, value, perm_idx) } fn single_collect( @@ -67,20 +66,31 @@ impl PPOCollector { let mut perms = Vec::new(); loop { - let (obs, log_prob, act, val, rew, perm_idx) = self.get_step_data(&*env, policy); + if env.is_final() { break; } + let (obs, log_prob, act, val, perm_idx) = self.get_action_data(&*env, policy); + // Step first, then read reward so collected transitions follow env step semantics. + env.step(act); + let rew = env.reward(); obss.push(obs); log_probs.push(log_prob); vals.push(val); rews.push(rew); acts.push(act); perms.push(perm_idx); - - if env.is_final() { break; } - env.step(act); } // compute GAE advs/rets let n = rews.len(); + if n == 0 { + return CollectedData::new( + obss, + log_probs, + perms, + vals, + rews, + acts, + ); + } let mut advs = vec![0.0; n]; let mut rets = vec![0.0; n]; advs[n-1] = rews[n-1] - vals[n-1]; @@ -177,7 +187,7 @@ mod tests { let collector = PPOCollector::new(1, 0.9, 0.95, 1); let data = collector.collect(&env, &policy).unwrap(); - assert_eq!(data.obs.len(), 2); + assert_eq!(data.obs.len(), 1); assert!(data.additional_data.contains_key("rets")); } } diff --git a/rust/src/nn/policy.rs b/rust/src/nn/policy.rs index c473989..34c0f4f 100644 --- a/rust/src/nn/policy.rs +++ b/rust/src/nn/policy.rs @@ -16,6 +16,23 @@ use rand::{prelude::Distribution, Rng}; use crate::nn::modules::Sequential; use crate::nn::layers::EmbeddingBag; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ActionMode { + // One logit per discrete action. + Categorical, + // One logit per binary factor; expanded to categorical logits on demand. + FactorizedBernoulli, +} + +impl ActionMode { + fn from_name(name: &str) -> Self { + match name.trim().to_ascii_lowercase().as_str() { + "factorized_bernoulli" => Self::FactorizedBernoulli, + _ => Self::Categorical, + } + } +} + #[derive(Clone)] pub struct Policy { embeddings: Box, @@ -23,12 +40,90 @@ pub struct Policy { action_net: Box, value_net: Box, obs_perms: Vec>, - act_perms: Vec> + act_perms: Vec>, + action_mode: ActionMode, + num_action_factors: usize, + num_actions: usize, } impl Policy { pub fn new(embeddings: Box, common: Box, action_net: Box, value_net: Box, obs_perms: Vec>, act_perms: Vec>) -> Self { - Self { embeddings: embeddings, common, action_net, value_net, obs_perms, act_perms } + let inferred_num_actions = act_perms.first().map(|p| p.len()).unwrap_or(0); + Self { + embeddings, + common, + action_net, + value_net, + obs_perms, + act_perms, + action_mode: ActionMode::Categorical, + num_action_factors: 0, + num_actions: inferred_num_actions, + } + } + + pub fn new_with_action_mode( + embeddings: Box, + common: Box, + action_net: Box, + value_net: Box, + obs_perms: Vec>, + act_perms: Vec>, + action_mode: String, + num_action_factors: usize, + num_actions: usize, + ) -> Self { + let mut out = Self::new(embeddings, common, action_net, value_net, obs_perms, act_perms); + out.action_mode = ActionMode::from_name(&action_mode); + out.num_action_factors = num_action_factors; + out.num_actions = if num_actions > 0 { + num_actions + } else { + out.num_actions + }; + out + } + + fn effective_num_actions(&self) -> usize { + if self.num_actions > 0 { + return self.num_actions; + } + if let Some(first_perm) = self.act_perms.first() { + if !first_perm.is_empty() { + return first_perm.len(); + } + } + if self.action_mode == ActionMode::FactorizedBernoulli && self.num_action_factors > 0 { + return 1usize.checked_shl(self.num_action_factors as u32).unwrap_or(0); + } + 0 + } + + fn expand_factorized_logits(&self, factor_logits: &[f32]) -> Vec { + // Convert per-factor logits into per-action logits by summing logits of active bits. + let num_factors = if self.num_action_factors > 0 { + self.num_action_factors + } else { + factor_logits.len() + }; + if num_factors == 0 || factor_logits.len() < num_factors { + return factor_logits.to_vec(); + } + let num_actions = self.effective_num_actions(); + if num_actions == 0 { + return factor_logits.to_vec(); + } + let mut expanded = vec![0.0f32; num_actions]; + for action in 0..num_actions { + let mut logit = 0.0f32; + for bit in 0..num_factors { + if ((action >> bit) & 1usize) == 1usize { + logit += factor_logits[bit]; + } + } + expanded[action] = logit; + } + expanded } pub fn predict(&self, obs: Vec, masks: Vec) -> (Vec, f32) { @@ -89,11 +184,19 @@ impl Policy { let value = self.value_net.forward(common_out.clone()).sum(); // This only has one element // Forward of the action net - let mut action_logits = self.action_net.forward(common_out).data.as_vec().to_owned(); + let raw_action_logits = self.action_net.forward(common_out).data.as_vec().to_owned(); + let mut action_logits = match self.action_mode { + ActionMode::Categorical => raw_action_logits, + ActionMode::FactorizedBernoulli => self.expand_factorized_logits(&raw_action_logits), + }; // Permute logits according to the corresponding act_perm if let Some(pi) = n_perm { - action_logits = self.act_perms[pi].iter().map(|&v| action_logits[v]).collect(); + if let Some(act_perm) = self.act_perms.get(pi) { + if act_perm.len() == action_logits.len() { + action_logits = act_perm.iter().map(|&v| action_logits[v]).collect(); + } + } } (action_logits, value) @@ -103,7 +206,7 @@ impl Policy { if self.obs_perms.len() == 0 {return self.predict(obs, masks);}; // Forward of the action net for each perm - let mut action_logits = vec![0.0f32; self.act_perms[0].len()]; + let mut action_logits = vec![0.0f32; self.effective_num_actions()]; let mut value = 0.0f32; for pi in 0..self.obs_perms.len() { diff --git a/rust/src/python_interface/policy.rs b/rust/src/python_interface/policy.rs index 46b3abe..192d7c5 100644 --- a/rust/src/python_interface/policy.rs +++ b/rust/src/python_interface/policy.rs @@ -25,8 +25,39 @@ pub struct PyPolicy { #[pymethods] impl PyPolicy { #[new] - pub fn new(embeddings: PyEmbeddingBag, common: PySequential, action_net: PySequential, value_net: PySequential, obs_perms: Vec>, act_perms: Vec>) -> Self { - let policy = Box::new(Policy::new(embeddings.embedding, common.seq, action_net.seq, value_net.seq, obs_perms, act_perms)); + #[pyo3(signature = ( + embeddings, + common, + action_net, + value_net, + obs_perms, + act_perms, + action_mode = "categorical", + num_action_factors = 0, + num_actions = 0 + ))] + pub fn new( + embeddings: PyEmbeddingBag, + common: PySequential, + action_net: PySequential, + value_net: PySequential, + obs_perms: Vec>, + act_perms: Vec>, + action_mode: &str, + num_action_factors: usize, + num_actions: usize, + ) -> Self { + let policy = Box::new(Policy::new_with_action_mode( + embeddings.embedding, + common.seq, + action_net.seq, + value_net.seq, + obs_perms, + act_perms, + action_mode.to_string(), + num_action_factors, + num_actions, + )); PyPolicy { policy } } diff --git a/src/twisterl/defaults.py b/src/twisterl/defaults.py index ebe9a77..1c24dfd 100644 --- a/src/twisterl/defaults.py +++ b/src/twisterl/defaults.py @@ -101,7 +101,17 @@ # Learning -LEARNING_CONFIG = {"diff_threshold": 0.85, "diff_metric": "ppo_deterministic"} +LEARNING_CONFIG = { + "diff_threshold": 0.85, + # Lower hysteresis threshold used to avoid rapid on/off difficulty toggling. + "threshold_min": 0.85, + "diff_max": 256, + "diff_step": 1, + # While difficulty <= warmup, keep +1 increments regardless of diff_step. + "warmup": 0, + "final_diff_is_none": False, + "diff_metric": "ppo_deterministic", +} # Logging and checkpoints diff --git a/src/twisterl/nn/policy.py b/src/twisterl/nn/policy.py index 120a172..ae57b4e 100644 --- a/src/twisterl/nn/policy.py +++ b/src/twisterl/nn/policy.py @@ -33,12 +33,48 @@ def __init__( value_layers=tuple(), obs_perms=tuple(), act_perms=tuple(), + action_mode: str = "categorical", + num_action_factors: Optional[int] = None, device="cuda", ): super().__init__() self.obs_shape = obs_shape self.obs_size = np.prod(obs_shape) self.num_actions = num_actions + self.action_mode = str(action_mode).strip().lower() + if self.action_mode not in ("categorical", "factorized_bernoulli"): + raise ValueError( + f"Unsupported action_mode='{action_mode}'. " + "Expected 'categorical' or 'factorized_bernoulli'." + ) + + if self.action_mode == "factorized_bernoulli": + # Routing-like MultiBinary spaces can be interpreted as 2^N discrete actions. + # Keep the policy head size at N factors and expand to categorical logits at runtime. + inferred = ( + int(num_action_factors) + if num_action_factors is not None + else _infer_num_action_factors(self.num_actions) + ) + if inferred < 1: + raise ValueError( + "num_action_factors must be >= 1 for factorized_bernoulli." + ) + if (1 << inferred) != self.num_actions: + raise ValueError( + "factorized_bernoulli requires num_actions == 2 ** num_action_factors " + f"(got num_actions={self.num_actions}, num_action_factors={inferred})." + ) + self.num_action_factors = inferred + action_out_size = self.num_action_factors + action_index_bits = _build_action_index_bits( + self.num_actions, self.num_action_factors + ) + else: + self.num_action_factors = 0 + action_out_size = self.num_actions + action_index_bits = torch.empty((0, 0), dtype=torch.float32) + self.embeddings = torch.nn.Linear(self.obs_size, embedding_size) self.device = device self._expects_conv_input = False @@ -51,11 +87,12 @@ def __init__( self.common = torch.nn.Sequential() self.action = make_sequential( - in_size, tuple(policy_layers) + (num_actions,), final_relu=False + in_size, tuple(policy_layers) + (action_out_size,), final_relu=False ) self.value = make_sequential( in_size, tuple(value_layers) + (1,), final_relu=False ) + self.register_buffer("_action_index_bits", action_index_bits, persistent=False) self.register_buffer( "_obs_perm_tensor", torch.empty((0, 0), dtype=torch.long), persistent=False ) @@ -117,7 +154,14 @@ def _forward_core(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x = x.reshape((-1, *self.obs_shape)) common_in = torch.nn.functional.relu(self.embeddings(x)) common = self.common(common_in) - return self.action(common), self.value(common) + action_logits = self.action(common) + if self.action_mode == "factorized_bernoulli": + # Project per-factor logits to full action logits using the precomputed bit matrix. + bits = self._action_index_bits.to( + device=action_logits.device, dtype=action_logits.dtype + ) + action_logits = action_logits @ bits.t() + return action_logits, self.value(common) def _forward_with_indices( self, x: torch.Tensor, perm_indices: torch.Tensor @@ -196,6 +240,9 @@ def to_rust(self): sequential_to_rust(self.value), self.obs_perms, self.act_perms, + self.action_mode, + self.num_action_factors, + self.num_actions, ) @@ -216,6 +263,8 @@ def __init__( value_layers=tuple(), obs_perms=tuple(), act_perms=tuple(), + action_mode: str = "categorical", + num_action_factors: Optional[int] = None, ): super().__init__( obs_shape, @@ -226,6 +275,8 @@ def __init__( value_layers, obs_perms, act_perms, + action_mode=action_mode, + num_action_factors=num_action_factors, ) self.conv_dim = conv_dim self._expects_conv_input = True @@ -263,4 +314,27 @@ def to_rust(self): sequential_to_rust(self.value), self.obs_perms, self.act_perms, + self.action_mode, + self.num_action_factors, + self.num_actions, + ) + + +def _infer_num_action_factors(num_actions: int) -> int: + if num_actions < 1: + return 0 + factors = int(round(np.log2(num_actions))) + if (1 << factors) != num_actions: + raise ValueError( + f"Cannot infer num_action_factors from non-power-of-two num_actions={num_actions}." ) + return factors + + +def _build_action_index_bits(num_actions: int, num_factors: int) -> torch.Tensor: + # Row i encodes the binary representation of action i. + bits = np.zeros((num_actions, num_factors), dtype=np.float32) + for action in range(num_actions): + for bit in range(num_factors): + bits[action, bit] = float((action >> bit) & 1) + return torch.tensor(bits, dtype=torch.float32) diff --git a/src/twisterl/rl/algorithm.py b/src/twisterl/rl/algorithm.py index 292c65d..5304584 100644 --- a/src/twisterl/rl/algorithm.py +++ b/src/twisterl/rl/algorithm.py @@ -141,6 +141,18 @@ def learn_step(self): return times_dict, bench_dict, train_dict def learn(self, num_steps, best_metrics=None): + learning_cfg = self.config.get("learning", {}) + diff_threshold = float(learning_cfg.get("diff_threshold", 1.0)) + threshold_min_raw = learning_cfg.get("threshold_min", diff_threshold) + threshold_min = ( + diff_threshold if threshold_min_raw is None else float(threshold_min_raw) + ) + diff_max = int(learning_cfg.get("diff_max", 1)) + diff_step = max(1, int(learning_cfg.get("diff_step", 1))) + warmup = int(learning_cfg.get("warmup", 0)) + final_diff_is_none = bool(learning_cfg.get("final_diff_is_none", False)) + increasing = False + # Init best metrics with a benchmark if best_metrics is None: (success, reward), _ = self.evaluate( @@ -162,13 +174,33 @@ def learn(self, num_steps, best_metrics=None): best_metrics = current_metrics # Maybe increase difficulty - if ( - bench_dict["success"] >= self.config["learning"]["diff_threshold"] - ) and self.env.difficulty < self.config["learning"]["diff_max"]: - self.env.difficulty += 1 - logger.info( - f"({self.env.difficulty}/{iteration}) Diff increased to {self.env.difficulty}, {current_metrics}" - ) + current_difficulty = self.env.difficulty + if current_difficulty is not None: + last_success = bench_dict["success"] + # Hysteresis: start increasing above diff_threshold and stop only below threshold_min. + if (not increasing) and (last_success >= diff_threshold): + increasing = True + elif increasing and (last_success < threshold_min): + increasing = False + + if increasing and current_difficulty < diff_max: + # Keep +1 increments during warmup; then switch to configured step size. + increment = diff_step if current_difficulty > warmup else 1 + next_difficulty = current_difficulty + increment + if next_difficulty > diff_max: + next_difficulty = None if final_diff_is_none else diff_max + + try: + self.env.difficulty = next_difficulty + except Exception: + # Some environments expose difficulty as integer-only. + if next_difficulty is None: + self.env.difficulty = diff_max + else: + raise + logger.info( + f"({self.env.difficulty}/{iteration}) Diff increased to {self.env.difficulty}, {current_metrics}" + ) # Pring logs if (self.config["logging"]["log_freq"] > 0) and ( diff --git a/src/twisterl/rl/az.py b/src/twisterl/rl/az.py index 4d73c1c..7e967f9 100644 --- a/src/twisterl/rl/az.py +++ b/src/twisterl/rl/az.py @@ -35,7 +35,14 @@ def data_to_torch(self, data): np_obs = np.zeros((len(obs), self.obs_size), dtype=float) for i, obs_i in enumerate(obs): - np_obs[i, obs_i] = 1.0 + if len(obs_i) == 0: + continue + obs_idx = np.asarray(obs_i, dtype=np.int64) + valid = (obs_idx >= 0) & (obs_idx < self.obs_size) + if not np.any(valid): + continue + # Keep sparse multiplicities (duplicates) instead of collapsing to binary. + np.add.at(np_obs[i], obs_idx[valid], 1.0) pt_obs = torch.tensor(np_obs, dtype=torch.float, device=self.config["device"]) pt_probs = torch.tensor(probs, dtype=torch.float, device=self.config["device"]) diff --git a/src/twisterl/rl/ppo.py b/src/twisterl/rl/ppo.py index dd24cbe..7935c8b 100644 --- a/src/twisterl/rl/ppo.py +++ b/src/twisterl/rl/ppo.py @@ -36,7 +36,14 @@ def data_to_torch(self, data): ) np_obs = np.zeros((len(obs), self.obs_size), dtype=float) for i, obs_i in enumerate(obs): - np_obs[i, obs_i] = 1.0 + if len(obs_i) == 0: + continue + obs_idx = np.asarray(obs_i, dtype=np.int64) + valid = (obs_idx >= 0) & (obs_idx < self.obs_size) + if not np.any(valid): + continue + # Keep sparse multiplicities (duplicates) instead of collapsing to binary. + np.add.at(np_obs[i], obs_idx[valid], 1.0) pt_obs = torch.tensor(np_obs, dtype=torch.float, device=self.config["device"]) pt_logits = torch.tensor( diff --git a/src/twisterl/utils.py b/src/twisterl/utils.py index 9bf8d4c..a9c5bc5 100644 --- a/src/twisterl/utils.py +++ b/src/twisterl/utils.py @@ -11,8 +11,10 @@ # that they have been altered from the originals. import importlib +import inspect import json import torch +import numpy as np from huggingface_hub import HfApi, snapshot_download import fnmatch from loguru import logger @@ -191,11 +193,51 @@ def prepare_algorithm(config, run_path=None, load_checkpoint_path=None): # Import policy class and make policy policy_cls = dynamic_import(config["policy_cls"]) + policy_kwargs = dict(config["policy"]) + action_space = getattr(env, "action_space", None) + action_mode = str(policy_kwargs.get("action_mode", "categorical")).strip().lower() + num_action_factors = policy_kwargs.get("num_action_factors", None) + should_auto_factorize = action_mode == "categorical" and ( + num_action_factors is None or int(num_action_factors) <= 0 + ) + if ( + should_auto_factorize + and action_space is not None + and action_space.__class__.__name__ == "MultiBinary" + ): + # Most configs default to categorical; auto-upgrade for MultiBinary envs + # to avoid building an impractically large categorical head. + n_bits = getattr(action_space, "n", None) + if n_bits is None: + shape = getattr(action_space, "shape", None) + n_bits = int(np.prod(shape)) if shape is not None else 0 + else: + n_bits = int(np.prod(np.asarray(n_bits))) + if n_bits > 0: + policy_kwargs["action_mode"] = "factorized_bernoulli" + policy_kwargs["num_action_factors"] = n_bits + + # Backward compatibility: older policy classes may not accept the newest + # config kwargs (e.g., action_mode / num_action_factors). + sig = inspect.signature(policy_cls.__init__) + has_var_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + if not has_var_kwargs: + accepted = {name for name in sig.parameters if name != "self"} + dropped = [k for k in list(policy_kwargs.keys()) if k not in accepted] + for key in dropped: + policy_kwargs.pop(key, None) + if dropped: + logger.warning( + f"Dropping unsupported policy kwargs for {policy_cls.__name__}: {dropped}" + ) + obs_perms, act_perms = env.twists() policy = policy_cls( env.obs_shape(), env.num_actions(), - **config["policy"], + **policy_kwargs, obs_perms=obs_perms, act_perms=act_perms, )