diff --git a/trainer/distillation.py b/trainer/distillation.py index f556da7..a76b114 100644 --- a/trainer/distillation.py +++ b/trainer/distillation.py @@ -347,14 +347,9 @@ def __init__(self, config): self.name_to_trainable_params[renamed_n] = p ema_weight = config.ema_weight self.generator_ema = None - if (ema_weight is not None) and (ema_weight > 0.0): - if self.is_lora_enabled: - if self.is_main_process: - print(f"EMA disabled in LoRA mode (LoRA provides efficient parameter updates without EMA)") - self.generator_ema = None - else: - print(f"Setting up EMA with weight {ema_weight}") - self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight) + if (ema_weight is not None) and (ema_weight > 0.0) and self.is_lora_enabled: + if self.is_main_process: + print(f"EMA disabled in LoRA mode (LoRA provides efficient parameter updates without EMA)") if self.one_logger is not None: @@ -483,7 +478,17 @@ def __init__(self, config): if self.is_main_process: print(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") - + + if (ema_weight is not None) and (ema_weight > 0.0) and not self.is_lora_enabled and "generator_ema" in checkpoint: + if self.is_main_process: + print(f"Loading pretrained EMA from {checkpoint_path}") + self.model.generator.load_state_dict(checkpoint["generator_ema"], strict=True) + if self.is_main_process: + print(f"Setting up EMA with weight {ema_weight}") + self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight) + elif (ema_weight is not None) and (ema_weight > 0.0) and (not self.is_lora_enabled) and self.is_main_process: + print("Warning: EMA checkpoint not found or EMA not initialized.") + # Load generator if "generator" in checkpoint: if self.is_main_process: @@ -496,7 +501,12 @@ def __init__(self, config): else: if self.is_main_process: print("Warning: Generator checkpoint not found.") - + + if (ema_weight is not None) and (ema_weight > 0.0) and (not self.is_lora_enabled) and self.generator_ema is None: + if self.is_main_process: + print(f"Setting up EMA with weight {ema_weight}") + self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight) + # Load critic if "critic" in checkpoint: if self.is_main_process: @@ -505,16 +515,7 @@ def __init__(self, config): else: if self.is_main_process: print("Warning: Critic checkpoint not found.") - - # Load EMA - if "generator_ema" in checkpoint and self.generator_ema is not None: - if self.is_main_process: - print(f"Loading pretrained EMA from {checkpoint_path}") - self.generator_ema.load_state_dict(checkpoint["generator_ema"]) - else: - if self.is_main_process: - print("Warning: EMA checkpoint not found or EMA not initialized.") - + # For auto resume, always resume full training state # Load optimizers if "generator_optimizer" in checkpoint: @@ -529,7 +530,7 @@ def __init__(self, config): else: if self.is_main_process: print("Warning: Generator optimizer checkpoint not found.") - + if "critic_optimizer" in checkpoint: if self.is_main_process: print("Resuming critic optimizer...") @@ -542,7 +543,7 @@ def __init__(self, config): else: if self.is_main_process: print("Warning: Critic optimizer checkpoint not found.") - + # Load training step if "step" in checkpoint: self.step = checkpoint["step"] @@ -551,6 +552,10 @@ def __init__(self, config): else: if self.is_main_process: print("Warning: Step not found in checkpoint, starting from step 0.") + elif (ema_weight is not None) and (ema_weight > 0.0) and not self.is_lora_enabled: + if self.is_main_process: + print(f"Setting up EMA with weight {ema_weight}") + self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight) if self.one_logger is not None: self.one_logger.on_load_checkpoint_end() @@ -779,7 +784,7 @@ def save(self): state_dict = { "generator": generator_state_dict, "critic": critic_state_dict, - "generator_ema": self.generator_ema.state_dict(), + "generator_ema": self.generator_ema.full_state_dict(self.model.generator), "generator_optimizer": generator_opim_state_dict, "critic_optimizer": critic_opim_state_dict, "step": self.step, diff --git a/utils/distributed.py b/utils/distributed.py index 4367ded..c2e8913 100644 --- a/utils/distributed.py +++ b/utils/distributed.py @@ -96,18 +96,14 @@ def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999): @torch.no_grad() def _init_shadow(self, fsdp_module): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params(fsdp_module, writeback=False): - for n, p in fsdp_module.module.named_parameters(): - self.shadow[n] = p.detach().clone().float().cpu() + for n, p in fsdp_module.module.named_parameters(): + self.shadow[n] = p.detach().clone().float().cpu() @torch.no_grad() def update(self, fsdp_module): d = self.decay - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params(fsdp_module, writeback=False): - for n, p in fsdp_module.module.named_parameters(): - self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) + for n, p in fsdp_module.module.named_parameters(): + self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) # Optional helpers --------------------------------------------------- def state_dict(self): @@ -117,9 +113,22 @@ def load_state_dict(self, sd): self.shadow = {k: v.clone() for k, v in sd.items()} def copy_to(self, fsdp_module): - # load EMA weights into an (unwrapped) copy of the generator - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params(fsdp_module, writeback=True): - for n, p in fsdp_module.module.named_parameters(): - if n in self.shadow: - p.data.copy_(self.shadow[n].to(p.dtype, device=p.device)) + for n, p in fsdp_module.module.named_parameters(): + if n in self.shadow: + p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device)) + + @torch.no_grad() + def full_state_dict(self, fsdp_module): + live_state = {} + for n, p in fsdp_module.module.named_parameters(): + live_state[n] = p.detach().clone() + for n, p in fsdp_module.module.named_parameters(): + if n in self.shadow: + p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device)) + + checkpoint = fsdp_state_dict(fsdp_module) + for n, p in fsdp_module.module.named_parameters(): + if n in live_state: + p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device)) + + return checkpoint