diff --git a/.claude/tsc-cache/d2ffffd0-e71d-465f-b644-2132f4e62ff5/affected-repos.txt b/.claude/tsc-cache/d2ffffd0-e71d-465f-b644-2132f4e62ff5/affected-repos.txt new file mode 100644 index 0000000..d8649da --- /dev/null +++ b/.claude/tsc-cache/d2ffffd0-e71d-465f-b644-2132f4e62ff5/affected-repos.txt @@ -0,0 +1 @@ +root diff --git a/.claude/tsc-cache/d2ffffd0-e71d-465f-b644-2132f4e62ff5/edited-files.log b/.claude/tsc-cache/d2ffffd0-e71d-465f-b644-2132f4e62ff5/edited-files.log new file mode 100644 index 0000000..64e83eb --- /dev/null +++ b/.claude/tsc-cache/d2ffffd0-e71d-465f-b644-2132f4e62ff5/edited-files.log @@ -0,0 +1,6 @@ +1772731018:/scratch3/f007yzf/repos/Step1X-Edit-clean/inference_v2.py:root +1772731024:/scratch3/f007yzf/repos/Step1X-Edit-clean/inference_v2.py:root +1772731034:/scratch3/f007yzf/repos/Step1X-Edit-clean/inference_v2.py:root +1772731040:/scratch3/f007yzf/repos/Step1X-Edit-clean/inference_v2.py:root +1772731844:/scratch3/f007yzf/repos/Step1X-Edit-clean/inference_v2.py:root +1772731852:/scratch3/f007yzf/repos/Step1X-Edit-clean/inference_v2.py:root diff --git a/.gitignore b/.gitignore index cd5feba..b4cb691 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,13 @@ */.DS_store __pycache__ */__pycache__/ -test* \ No newline at end of file +test* +tmp_* +training_data/source_img +training_data/reference_img +training_data/target_img +training_data/reference_img/left +training_data/reference_img/right +training_6k/source_img +training_6k/reference_img +training_6k/target_img \ No newline at end of file diff --git a/examples/reference/0000_I.png b/examples/reference/0000_I.png new file mode 100644 index 0000000..45ce5aa Binary files /dev/null and b/examples/reference/0000_I.png differ diff --git a/examples/source/0000_l.png b/examples/source/0000_l.png new file mode 100644 index 0000000..80fd610 Binary files /dev/null and b/examples/source/0000_l.png differ diff --git a/finetuning.py b/finetuning.py index a41ec4e..37d2293 100644 --- a/finetuning.py +++ b/finetuning.py @@ -143,7 +143,7 @@ def get_tokenize_strategy(self, args): """ 获取分词策略。 """ - return strategy_step1x.Step1xEditTokenizeStrategy(tokenizer_cache_dir=args.qwen2p5vl) + return strategy_step1x.Step1xEditTokenizeStrategy(max_length=1280, tokenizer_cache_dir=args.qwen2p5vl) def get_tokenizers(self, tokenize_strategy): return [tokenize_strategy.processor] @@ -157,7 +157,7 @@ def get_latents_caching_strategy(self, args): return latents_caching_strategy def get_text_encoding_strategy(self, args): - return strategy_step1x.Step1XEditEncodingStrategy() + return strategy_step1x.Step1XEditEncodingStrategy(max_length=1280) def post_process_network(self, args, accelerator, network, text_encoders, unet): pass @@ -556,4 +556,4 @@ def setup_parser() -> argparse.ArgumentParser: args = train_util.read_config_from_file(args, parser) trainer = Step1XEditNetworkTrainer() - trainer.train(args) \ No newline at end of file + trainer.train(args) diff --git a/finetuning_v1.py b/finetuning_v1.py new file mode 100644 index 0000000..a41ec4e --- /dev/null +++ b/finetuning_v1.py @@ -0,0 +1,559 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional, Union + +import torch +from accelerate import Accelerator + +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +from library import kohya_trainer +from library import ( + step1x_edit_train_utils, + step1x_utils, + strategy_step1x, + train_util, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class Step1XEditNetworkTrainer(kohya_trainer.NetworkTrainer): + def __init__(self): + """ + 初始化 Step1XEditNetworkTrainer 类。 + """ + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): + """ + 断言额外的参数是否有效。 + + Args: + args: 命令行参数。 + train_dataset_group: 训练数据集组。 + val_dataset_group: 验证数据集组。 + """ + super().assert_extra_args(args, train_dataset_group, val_dataset_group) + + if args.fp8_base_unet: + args.fp8_base = True # 如果启用了 fp8_base_unet,则 fp8_base 也为 base model 启用 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_disk 已启用,因此 cache_text_encoder_outputs 也将启用" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / 缓存文本编码器输出时,caption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rate 不能使用" + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in base model training / max_token_length 在基模训练中未使用") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swap 与 cpu_offload_checkpointing 不兼容" + + train_dataset_group.verify_bucket_reso_steps(32) # TODO 检查这个 + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO 检查这个 + + def load_target_model(self, args, weight_dtype, accelerator): + """ + 加载目标模型(base模型、文本编码器、AE)。 + + Args: + args: 命令行参数。 + weight_dtype: 权重的数据类型。 + accelerator: Accelerator 对象。 + + Returns: + Tuple: 包含模型版本、文本编码器列表、AE 模型和 base 模型的元组。 + """ + # 当前将某些模型卸载到 CPU + + # 如果文件是 fp8 并且我们正在使用 fp8_base,我们可以按原样加载它 (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + model = step1x_utils.load_models( + dit_path=args.pretrained_model_name_or_path, + device='cpu', + dtype=loading_dtype + ) + if args.fp8_base: + # 检查模型的 dtype + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 model") + else: + logger.info( + "Cast model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / 正在将模型转换为 fp8。这可能需要一些时间。您可以使用 fp8 检查点来缩短时间。" + ) + model.to(torch.float8_e4m3fn) + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # 在前向和后向传递中,在 CPU 和 GPU 之间交换块以减少内存使用。 + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + # 如果文件是 fp8 并且我们正在使用 fp8_base (而不是 unet),我们可以按原样加载它 (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # 按原样 + else: + loading_dtype = weight_dtype + + qwen2p5vl = step1x_utils.load_qwen2p5vl( + args.qwen2p5vl, dtype=weight_dtype, device="cpu" + ) + qwen2p5vl.eval() + if args.fp8_base and not args.fp8_base_unet: + # 检查模型的 dtype + if qwen2p5vl.dtype == torch.float8_e4m3fnuz or qwen2p5vl.dtype == torch.float8_e5m2 or qwen2p5vl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {qwen2p5vl.dtype}") + elif qwen2p5vl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 qwen2p5vl model") + + ae = step1x_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + return "Step1X-Edit", [qwen2p5vl], ae, model + + def get_tokenize_strategy(self, args): + """ + 获取分词策略。 + """ + return strategy_step1x.Step1xEditTokenizeStrategy(tokenizer_cache_dir=args.qwen2p5vl) + + def get_tokenizers(self, tokenize_strategy): + return [tokenize_strategy.processor] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_step1x.Step1XEditLatentsCachingStrategy( + cache_to_disk=args.cache_latents_to_disk, + batch_size=args.vae_batch_size, + skip_disk_cache_validity_check=False, + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_step1x.Step1XEditEncodingStrategy() + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + 获取用于文本编码的模型。 + + Args: + args: 命令行参数。 + accelerator: Accelerator 对象。 + text_encoders: 文本编码器列表。 + + Returns: + Optional[List[torch.nn.Module]]: 用于文本编码的模型列表,如果不需要则返回 None。 + """ + if args.cache_text_encoder_outputs: + return None # 不需要文本编码器进行编码,因为两者都已缓存 + else: + return text_encoders + + def get_text_encoders_train_flags(self, args, text_encoders): + return [False] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_step1x.Step1xEditEncoderOutputsCachingStrategy( + cache_to_disk=args.cache_text_encoder_outputs_to_disk, + batch_size=args.text_encoder_batch_size, + skip_disk_cache_validity_check=args.skip_cache_check, + is_partial=False, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + """ + 如果需要,缓存文本编码器的输出。 + + Args: + args: 命令行参数。 + accelerator: Accelerator 对象。 + unet: U-Net 模型。 + vae: VAE 模型。 + text_encoders: 文本编码器列表。 + dataset: 数据集组。 + weight_dtype: 权重的数据类型。 + """ + if args.cache_text_encoder_outputs: + if not args.lowram: + # 减少内存消耗 + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # 当 TE 未被训练时,它不会被准备,所以我们需要使用显式的 autocast + logger.info("move text encoders to gpu") + # text_encoders[0].to(accelerator.device, dtype=weight_dtype) # 始终不是 fp8 + [text_encoder.to(accelerator.device) for text_encoder in text_encoders] + + if text_encoders[0].dtype == torch.float8_e4m3fn: + # 如果我们加载 fp8 权重,模型已经是 fp8,所以我们按原样使用它 + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # 否则,我们需要将其转换为目标 dtype + text_encoders[0].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # 缓存样本提示 + if args.sample_prompts is not None: + raise ValueError('not converted') + + accelerator.wait_for_everyone() + + # 移回 CPU + if not self.is_train_text_encoder(args): + text_encoders[0].to("cpu") + # text_encoders[0].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # 每次都从文本编码器获取输出,因此将其放在 GPU 上 + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + # text_encoders[1].to(accelerator.device) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + """ + 获取噪声调度器。 + + Args: + args: 命令行参数。 + device: 设备 (CPU 或 GPU)。 + + Returns: + Any: 噪声调度器对象。 + """ + noise_scheduler = step1x_edit_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, vae, images): + """ + 将图像编码为潜变量。 + + Args: + args: 命令行参数。 + vae: VAE 模型。 + images: 图像张量。 + + Returns: + torch.Tensor: 潜变量张量。 + """ + import pdb;pdb.set_trace() + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + """ + 对潜变量进行移位和缩放。 + + Args: + args: 命令行参数。 + latents: 潜变量张量。 + + Returns: + torch.Tensor: 经过移位和缩放的潜变量张量。 + """ + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + ref_latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=True, + ): + """ + 获取噪声预测和目标。 + 这里之所以有了batch还要有latents和text_encoder_conds是因为有可能没有采取cache策略 + 这部分的处理是在外部完成的 + + Args: + args: 命令行参数。 + accelerator: Accelerator 对象。 + noise_scheduler: 噪声调度器。 + latents: 潜变量。 + batch: 当前批次的数据。 + text_encoder_conds: 文本编码器的条件。 + unet: 基模 + network: 训练的网络。 + weight_dtype: 权重的数据类型。 + train_unet: 是否训练 U-Net。 + is_train: 是否处于训练模式。 + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + 模型预测、目标、时间步长和权重 (如果适用)。 + """ + # 采样我们将添加到潜变量中的噪声 + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # 获取带噪模型输入和时间步长 + noisy_model_input, timesteps, sigmas = step1x_edit_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # 打包潜变量并获取 img_ids + packed_noisy_model_input = step1x_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = step1x_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # 处理ref_latents + packed_ref_model_input = step1x_utils.pack_latents(ref_latents) + + # concate latents + packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_ref_model_input], dim=1) + img_ids = torch.cat([img_ids, img_ids], dim=1) + + # 确保隐藏状态需要梯度 + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + + # 预测噪声残差 + embeds, masks = text_encoder_conds + masks = masks.to(torch.long) + txt_ids = torch.zeros(bsz, embeds.shape[1], 3).to(packed_noisy_model_input.device) + with torch.set_grad_enabled(is_train), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + # 原版有个标记,YiYi 注意:暂时将其除以 1000,因为我们在 Transformer 模型中将其缩放了 1000(我们不应该保留它,但我想保持模型的输入相同以进行测试) + packed_noisy_model_input = packed_noisy_model_input.to(weight_dtype) + masks = masks.to(device=accelerator.device) + model_pred = unet( + img=packed_noisy_model_input, + img_ids=img_ids, + txt_ids=txt_ids, + timesteps=timesteps / 1000, + llm_embedding=embeds, + t_vec=timesteps, + mask=masks, + ) + + def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + import einops + x = einops.rearrange(x, "b (p h w) (c ph pw) -> b p c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2, p=2) + return x[:, 0] + + # 解包潜变量 + model_pred = unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + weighting = None + + # 流匹配损失:这与 SD3 不同 + target = noise - latents + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + """ + 后处理损失。 + + Args: + loss: 计算得到的损失。 + args: 命令行参数。 + timesteps: 时间步长。 + noise_scheduler: 噪声调度器。 + + Returns: + torch.Tensor: 后处理后的损失。 + """ + return loss + + def get_sai_model_spec(self, args): + """ + 获取 SAI 模型规范。 + + Args: + args: 命令行参数。 + + Returns: + Dict: SAI 模型规范字典。 + """ + return { + # === Must === + "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + "modelspec.architecture": "Step1X-Edit", + "modelspec.implementation": "https://github.com/stepfun-ai/Step1X-Edit", + "modelspec.title": "Lora", + "modelspec.resolution": "1024", + # === Should === + "modelspec.description": "Lora for Step1X-Edit", + "modelspec.author": "Step1X-Edit Team", + "modelspec.date": "2025", + } + + def update_metadata(self, metadata, args): + """ + 更新元数据。 + + Args: + metadata: 要更新的元数据字典。 + args: 命令行参数。 + """ + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + """ + 判断在训练过程中是否不需要文本编码器。 + + Args: + args: 命令行参数。 + + Returns: + bool: 如果不需要文本编码器则返回 True,否则返回 False。 + """ + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + """ + 为文本编码器的梯度检查点准备解决方法。 + """ + if index == 0: + text_encoder.model.model.embed_tokens.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + """ + 为 fp8 准备文本编码器。 + + Args: + index: 文本编码器的索引 + text_encoder: 文本编码器模型。 + te_weight_dtype: 文本编码器的权重数据类型。 + weight_dtype: 整体权重的数据类型。 + """ + raise ValueError('qwen still not tested for fp8') + if step1x_utils.get_qwen_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"Qwen already prepared for fp8") + else: + logger.info(f"prepare Qwen for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + """ + 在验证步骤结束时调用。 + + Args: + args: 命令行参数。 + accelerator: Accelerator 对象。 + network: 网络模型。 + text_encoders: 文本编码器列表。 + unet: U-Net 模型。 + batch: 当前批次的数据。 + weight_dtype: 权重的数据类型。 + """ + if self.is_swapping_blocks: + # 为下一次前向传播做准备:因为没有调用后向传播,所以我们需要在这里准备 + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + """ + 使用 Accelerator 准备 U-Net 模型。 + + Args: + args: 命令行参数。 + accelerator: Accelerator 对象。 + unet: U-Net 模型。 + + Returns: + torch.nn.Module: 准备好的 U-Net 模型。 + """ + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # 如果我们不交换块,我们可以将模型移动到设备 + new_unet = unet + new_unet = accelerator.prepare(new_unet, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(new_unet).move_to_device_except_swap_blocks(accelerator.device) # 减少峰值内存使用 + accelerator.unwrap_model(new_unet).prepare_block_swap_before_forward() + + return new_unet + + +def setup_parser() -> argparse.ArgumentParser: + """ + 设置命令行参数解析器。 + + Returns: + argparse.ArgumentParser: 参数解析器对象。 + """ + parser = kohya_trainer.setup_parser() + train_util.add_dit_training_arguments(parser) + step1x_edit_train_utils.add_step1x_edit_train_arguments(parser) + parser.add_argument('--qwen2p5vl', type=str, help='Path to Qwen2.5VL model / Qwen2.5VL模型的路径') + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = Step1XEditNetworkTrainer() + trainer.train(args) \ No newline at end of file diff --git a/inference_v1.py b/inference_v1.py new file mode 100644 index 0000000..cdc9a59 --- /dev/null +++ b/inference_v1.py @@ -0,0 +1,760 @@ +import argparse +import datetime +import itertools +import math +import os +import time +import functools +from pathlib import Path + +import numpy as np +import torch +from einops import rearrange, repeat +from PIL import Image, ImageOps +from safetensors.torch import load_file +from torchvision.transforms import functional as F +from tqdm import tqdm + +import sampling +from modules.autoencoder import AutoEncoder +from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder +from modules.model_edit import Step1XParams, Step1XEdit +from modules.multigpu import parallel_transformer, teacache_transformer, parallel_teacache_transformer + +from torch import Tensor +import torch.distributed as dist +from xfuser.core.distributed import ( + get_world_group, + initialize_model_parallel, +) +from qwen_vl_utils import process_vision_info + +DUAL_IMAGE_CAPTION_PROMPT = """You are analyzing facial expressions for a controlled editing task. +Given: +- Image 1: source face to be edited +- Image 2: target expression reference + +Output a structured editing plan +""" + +def cfg_usp_level_setting(ring_degree: int = 1, ulysses_degree: int = 1, cfg_degree: int = 1): + # restriction: dist.get_world_size() == x x + initialize_model_parallel( + ring_degree=ring_degree, + ulysses_degree=ulysses_degree, + classifier_free_guidance_degree=cfg_degree, + ) + +def teacache_init(pipe, args): + pipe.dit.__class__.enable_teacache = True + pipe.dit.__class__.cnt = 0 + pipe.dit.__class__.num_steps = args.num_steps + pipe.dit.__class__.rel_l1_thresh = args.teacache_threshold + pipe.dit.__class__.accumulated_rel_l1_distance = 0 + pipe.dit.__class__.previous_modulated_input = None + pipe.dit.__class__.previous_residual = None + + +def cudagc(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + +def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True): + if Path(ckpt_path).suffix == ".safetensors": + state_dict = load_file(ckpt_path, device) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + missing, unexpected = model.load_state_dict( + state_dict, strict=strict, assign=assign + ) + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + return model + + +def load_models( + dit_path=None, + ae_path=None, + qwen2vl_model_path=None, + mode="flash", + device="cuda", + max_length=256, + dtype=torch.bfloat16, + version='v1.0' +): + qwen2vl_encoder = Qwen2VLEmbedder( + qwen2vl_model_path, + device=device, + max_length=max_length, + dtype=dtype, + ) + + with torch.device("meta"): + ae = AutoEncoder( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) + + step1x_params = Step1XParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + mode=mode, + version=version, + ) + dit = Step1XEdit(step1x_params) + + ae = load_state_dict(ae, ae_path, 'cpu') + dit = load_state_dict( + dit, dit_path, 'cpu' + ) + + ae = ae.to(dtype=torch.float32) + + return ae, dit, qwen2vl_encoder + +def equip_dit_with_lora_sd_scripts(ae, text_encoders, dit, lora, device='cuda'): + from safetensors.torch import load_file + weights_sd = load_file(lora) + is_lora = True + from library import lora_module + module = lora_module + lora_model, _ = module.create_network_from_weights(1.0, None, ae, text_encoders, dit, weights_sd, True) + lora_model.merge_to(text_encoders, dit, weights_sd) + + lora_model.set_multiplier(1.0) + return lora_model + +class ImageGenerator: + def __init__( + self, + dit_path=None, + ae_path=None, + qwen2vl_model_path=None, + device="cuda", + max_length=1280, + dtype=torch.bfloat16, + quantized=False, + offload=False, + lora=None, + mode="flash", + version='v1.0' + ) -> None: + self.version = version + if os.getenv("TORCHELASTIC_RUN_ID") is not None: + local_rank = get_world_group().local_rank + torch.cuda.set_device(local_rank) + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device(device) + + self.ae, self.dit, self.llm_encoder = load_models( + dit_path=dit_path, + ae_path=ae_path, + qwen2vl_model_path=qwen2vl_model_path, + max_length=max_length, + dtype=dtype, + device=self.device, + mode=mode, + version=version, + ) + + if not quantized: + self.dit = self.dit.to(dtype=torch.bfloat16) + else: + self.dit = self.dit.to(dtype=torch.float8_e4m3fn) + if not offload: + self.dit = self.dit.to(device=self.device) + self.ae = self.ae.to(device=self.device) + self.quantized = quantized + self.offload = offload + if lora is not None: + self.lora_module = equip_dit_with_lora_sd_scripts( + self.ae, + [self.llm_encoder], + self.dit, + lora, + device=self.dit.device, + ) + else: + self.lora_module = None + self.mode = mode + + + def prepare(self, prompt, img, ref_image, ref_image_raw): + bs, _, h, w = img.shape + bs, _, ref_h, ref_w = ref_image.shape + + assert h == ref_h and w == ref_w + + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + elif bs >= 1 and isinstance(prompt, str): + prompt = [prompt] * bs + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if self.version == 'v1.0': + ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) + else: + ref_img_ids = torch.ones(ref_h // 2, ref_w // 2, 3) + + ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + if self.offload: + self.llm_encoder = self.llm_encoder.to(self.device) + txt, mask = self.llm_encoder(prompt, ref_image_raw) + if self.offload: + self.llm_encoder = self.llm_encoder.cpu() + cudagc() + + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2) + img_ids = torch.cat([img_ids, ref_img_ids], dim=-2) + + + return { + "img": img, + "mask": mask, + "img_ids": img_ids.to(img.device), + "llm_embedding": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + } + + + def prepare_t2i(self, prompt, img, ref_image_raw): + bs, _, h, w = img.shape + + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + elif bs >= 1 and isinstance(prompt, str): + prompt = [prompt] * bs + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + + img_ids = torch.zeros(h // 2, w // 2, 3) + + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + + if isinstance(prompt, str): + prompt = [prompt] + if self.offload: + self.llm_encoder = self.llm_encoder.to(self.device) + txt, mask = self.llm_encoder(prompt, ref_image_raw) + if self.offload: + self.llm_encoder = self.llm_encoder.cpu() + cudagc() + + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + + return { + "img": img, + "mask": mask, + "img_ids": img_ids.to(img.device), + "llm_embedding": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + } + @staticmethod + def process_diff_norm(diff_norm, k): + pow_result = torch.pow(diff_norm, k) + + result = torch.where( + diff_norm > 1.0, + pow_result, + torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm), + ) + return result + + def denoise_t2i( + self, + img: torch.Tensor, + img_ids: torch.Tensor, + llm_embedding: torch.Tensor, + txt_ids: torch.Tensor, + timesteps: list[float], + cfg_guidance: float = 4.5, + mask=None, + show_progress=False, + timesteps_truncate=0.93, + ): + if self.offload: + self.dit = self.dit.to(self.device) + if show_progress: + pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...') + else: + pbar = itertools.pairwise(timesteps) + for idx, (t_curr, t_prev) in enumerate(pbar): + if img.shape[0] == 1 and cfg_guidance != -1: + img = torch.cat([img, img], dim=0) + t_vec = torch.full( + (img.shape[0],), t_curr, dtype=img.dtype, device=img.device + ) + pred = self.dit( + img=img, + img_ids=img_ids, + txt_ids=txt_ids, + timesteps=t_vec, + llm_embedding=llm_embedding, + t_vec=t_vec, + mask=mask, + ) + + if cfg_guidance != -1: + cond, uncond = ( + pred[0 : pred.shape[0] // 2, :], + pred[pred.shape[0] // 2 :, :], + ) + if t_curr > timesteps_truncate: + diff = cond - uncond + diff_norm = torch.norm(diff, dim=(2), keepdim=True) + pred = uncond + cfg_guidance * ( + cond - uncond + ) / self.process_diff_norm(diff_norm, k=0.4) + else: + pred = uncond + cfg_guidance * (cond - uncond) + img = img[0 : img.shape[0] // 2] + (t_prev - t_curr) * pred + if self.offload: + self.dit = self.dit.cpu() + cudagc() + + return img + + def denoise( + self, + img: torch.Tensor, + img_ids: torch.Tensor, + llm_embedding: torch.Tensor, + txt_ids: torch.Tensor, + timesteps: list[float], + cfg_guidance: float = 4.5, + mask=None, + show_progress=False, + timesteps_truncate=0.93, + ): + ref_img_tensor = img[0, img.shape[1] // 2:].clone() + if self.offload: + self.dit = self.dit.to(self.device) + if show_progress: + pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...') + else: + pbar = itertools.pairwise(timesteps) + for idx, (t_curr, t_prev) in enumerate(pbar): + if img.shape[0] == 1 and cfg_guidance != -1: + img = torch.cat([img, img], dim=0) + t_vec = torch.full( + (img.shape[0],), t_curr, dtype=img.dtype, device=img.device + ) + pred = self.dit( + img=img, + img_ids=img_ids, + txt_ids=txt_ids, + timesteps=t_vec, + llm_embedding=llm_embedding, + t_vec=t_vec, + mask=mask, + ) + pred = pred[:, :pred.shape[1] // 2] + + if cfg_guidance != -1: + cond, uncond = ( + pred[0 : pred.shape[0] // 2, :], + pred[pred.shape[0] // 2 :, :], + ) + if t_curr > timesteps_truncate: + diff = cond - uncond + diff_norm = torch.norm(diff, dim=(2), keepdim=True) + pred = uncond + cfg_guidance * ( + cond - uncond + ) / self.process_diff_norm(diff_norm, k=0.4) + else: + pred = uncond + cfg_guidance * (cond - uncond) + tem_img = img[0 : img.shape[0] // 2, : img.shape[1] // 2] + (t_prev - t_curr) * pred + img = torch.cat( + [ + tem_img, + ref_img_tensor.unsqueeze(0), + ], dim=1 + ) + if self.offload: + self.dit = self.dit.cpu() + cudagc() + + return img[:, :img.shape[1] // 2] + + @staticmethod + def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + + @staticmethod + def load_image(image): + from PIL import Image + + if isinstance(image, np.ndarray): + image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 + image = image.unsqueeze(0) + return image + elif isinstance(image, Image.Image): + image = F.to_tensor(image.convert("RGB")) + image = image.unsqueeze(0) + return image + elif isinstance(image, torch.Tensor): + return image + elif isinstance(image, str): + image = F.to_tensor(Image.open(image).convert("RGB")) + image = image.unsqueeze(0) + return image + else: + raise ValueError(f"Unsupported image type: {type(image)}") + + def output_process_image(self, resize_img, image_size): + res_image = resize_img.resize(image_size) + return res_image + return resize_img + + def input_process_image(self, img, img_size=512): + # 1. 打开图片 + w, h = img.size + r = w / h + + if w > h: + w_new = math.ceil(math.sqrt(img_size * img_size * r)) + h_new = math.ceil(w_new / r) + else: + h_new = math.ceil(math.sqrt(img_size * img_size / r)) + w_new = math.ceil(h_new * r) + h_new = h_new // 16 * 16 + w_new = w_new // 16 * 16 + + img_resized = img.resize((w_new, h_new), Image.LANCZOS) + return img_resized, img.size + + def build_caption_from_dual_images(self, source_image, reference_image, user_prompt="", max_new_tokens=150): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": DUAL_IMAGE_CAPTION_PROMPT}, + {"type": "text", "text": "\n[Source Image (Structure/Identity)]:"}, + {"type": "image", "image": source_image}, + {"type": "text", "text": "\n[Reference Image (Style/Expression)]:"}, + {"type": "image", "image": reference_image}, + {"type": "text", "text": f"\nUser prompt: {user_prompt}"}, + {"type": "text", "text": "\nPlease generate the structured editing plan now."}, + ], + } + ] + + processor = self.llm_encoder.processor + model = self.llm_encoder.model + + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + image_inputs, _ = process_vision_info(messages) + inputs = processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + ) + + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + generated_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True + )[0].strip() + return generated_text if generated_text else user_prompt + + @torch.inference_mode() + def generate_image( + self, + prompt, + negative_prompt, + ref_images, + num_steps, + cfg_guidance, + seed, + caption_ref_images=None, + num_samples=1, + init_image=None, + image2image_strength=0.0, + show_progress=False, + size_level=512, + height=None, + width=None, + caption_max_new_tokens=150, + ): + assert num_samples == 1, "num_samples > 1 is not supported yet." + if ref_images == None: + self.task_type='t2i' + ref_images = Image.new('RGB', (1024, 1024)) + ref_images_raw = ref_images + img_info = (width, height) if width is not None and height is not None else (1024, 1024) + else: + self.task_type = 'edit' + ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level) + if caption_ref_images is not None: + caption_ref_images_raw, _ = self.input_process_image(caption_ref_images, img_size=size_level) + if caption_ref_images_raw.size != ref_images_raw.size: + caption_ref_images_raw = caption_ref_images_raw.resize(ref_images_raw.size, Image.LANCZOS) + if self.offload: + self.llm_encoder = self.llm_encoder.to(self.device) + prompt = self.build_caption_from_dual_images( + source_image=ref_images_raw, + reference_image=caption_ref_images_raw, + user_prompt=prompt, + max_new_tokens=caption_max_new_tokens, + ) + if os.getenv("TORCHELASTIC_RUN_ID") is None or dist.get_rank() == 0: + print(f"[Dual-image generated prompt] {prompt}") + + if self.task_type == 'edit': + width, height = ref_images_raw.width, ref_images_raw.height + + ref_images_raw = self.load_image(ref_images_raw) + ref_images_raw = ref_images_raw.to(self.device) + if self.offload: + self.ae = self.ae.to(self.device) + with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16): + ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1) + if self.offload: + self.ae = self.ae.cpu() + cudagc() + else: + width, height = img_info + ref_images_raw = self.load_image(ref_images_raw) + ref_images_raw = ref_images_raw.to(self.device) + ref_images = None + + seed = int(seed) + seed = torch.Generator(device="cpu").seed() if seed < 0 else seed + + t0 = time.perf_counter() + + if init_image is not None: + init_image = self.load_image(init_image) + init_image = init_image.to(self.device) + init_image = torch.nn.functional.interpolate(init_image, (height, width)) + if self.offload: + self.ae = self.ae.to(self.device) + init_image = self.ae.encode(init_image.to() * 2 - 1) + if self.offload: + self.ae = self.ae.cpu() + cudagc() + + x = torch.randn( + num_samples, + 16, + height // 8, + width // 8, + device=self.device, + dtype=torch.bfloat16, + generator=torch.Generator(device=self.device).manual_seed(seed), + ) + timesteps = sampling.get_schedule( + num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True + ) + + if init_image is not None: + t_idx = int((1 - image2image_strength) * num_steps) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + x = t * x + (1.0 - t) * init_image.to(x.dtype) + + x = torch.cat([x, x], dim=0) + if self.task_type == 'edit': + ref_images = torch.cat([ref_images, ref_images], dim=0) + ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0) + inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw) + else: + ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0) + inputs = self.prepare_t2i([prompt, negative_prompt], x, ref_images_raw) + + + + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): + if self.task_type == 'edit': + x = self.denoise( + **inputs, + cfg_guidance=cfg_guidance, + timesteps=timesteps, + show_progress=show_progress, + timesteps_truncate=0.93, + ) + else: + x = self.denoise_t2i( + **inputs, + cfg_guidance=cfg_guidance, + timesteps=timesteps, + show_progress=show_progress, + timesteps_truncate=0.93, + ) + x = self.unpack(x.float(), height, width) + if self.offload: + self.ae = self.ae.to(self.device) + x = self.ae.decode(x) + if self.offload: + self.ae = self.ae.cpu() + cudagc() + x = x.clamp(-1, 1) + x = x.mul(0.5).add(0.5) + + t1 = time.perf_counter() + if os.getenv("TORCHELASTIC_RUN_ID") is None or dist.get_rank() == 0: + print(f"Done in {t1 - t0:.1f}s.") + images_list = [] + for img in x.float(): + images_list.append(self.output_process_image(F.to_pil_image(img), img_info)) + return images_list + + +def main(): + torch.backends.cudnn.deterministic = True + + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint') + parser.add_argument('--source_path', type=str, required=True, help='Path to source image') + parser.add_argument('--reference_path', type=str, required=True, help='Path to reference image for dual-image caption generation') + parser.add_argument('--output_path', type=str, required=True, help='Path to output image') + parser.add_argument('--prompt', type=str, default='', help='Optional user prompt. If empty, caption is generated only from dual images.') + parser.add_argument('--seed', type=int, default=42, help='Random seed for generation') + parser.add_argument('--num_steps', type=int, default=28, help='Number of diffusion steps') + parser.add_argument('--cfg_guidance', type=float, default=6.0, help='CFG guidance strength') + parser.add_argument('--size_level', default=512, type=int) + parser.add_argument('--offload', action='store_true', help='Use offload for large models') + parser.add_argument('--quantized', action='store_true', help='Use fp8 model weights') + parser.add_argument('--lora', type=str, default=None) + parser.add_argument('--ring_degree', type=int, default=1) + parser.add_argument('--ulysses_degree', type=int, default=1) + parser.add_argument('--cfg_degree', type=int, default=1) + parser.add_argument('--teacache', action='store_true') + parser.add_argument('--teacache_threshold', type=float, default=0.2, help='Used to control the acceleration ratio of teacache') + parser.add_argument('--version', type=str, default='v1.1', choices=['v1.0', 'v1.1']) + parser.add_argument('--task_type', type=str, default='edit', choices=['edit', 't2i'], help='Task type: edit or t2i') + parser.add_argument('--height', type=int, default=1024, help='Size of the output image (for t2i task)') + parser.add_argument('--width', type=int, default=1024, help='Size of the output image (for t2i task)') + + args = parser.parse_args() + + assert os.path.exists(args.source_path), f"Source image {args.source_path} does not exist." + assert os.path.exists(args.reference_path), f"Reference image {args.reference_path} does not exist." + os.makedirs(os.path.dirname(args.output_path) or ".", exist_ok=True) + + mode = "flash" if args.ring_degree * args.ulysses_degree * args.cfg_degree == 1 else "xdit" + + if args.version == 'v1.0': + ckpt_name = 'step1x-edit-i1258.safetensors' + elif args.version == 'v1.1': + ckpt_name = 'step1x-edit-v1p1-official.safetensors' + + image_edit = ImageGenerator( + ae_path=os.path.join(args.model_path, 'vae.safetensors'), + dit_path=os.path.join(args.model_path, ckpt_name), + qwen2vl_model_path=os.path.join(args.model_path, 'Qwen2.5-VL-7B-Instruct'), + max_length=1280, + quantized=args.quantized, + offload=args.offload, + lora=args.lora, + mode=mode, + version=args.version, + ) + + if args.teacache: + teacache_init(image_edit, args) + if args.ring_degree * args.ulysses_degree * args.cfg_degree != 1: + cfg_usp_level_setting(args.ring_degree, args.ulysses_degree, args.cfg_degree) + parallel_teacache_transformer(image_edit) + else: + teacache_transformer(image_edit) + else: + if args.ring_degree * args.ulysses_degree * args.cfg_degree != 1: + cfg_usp_level_setting(args.ring_degree, args.ulysses_degree, args.cfg_degree) + parallel_transformer(image_edit) + + start_time = time.time() + source_image = Image.open(args.source_path).convert("RGB") + reference_image = Image.open(args.reference_path).convert("RGB") + image = image_edit.generate_image( + args.prompt, + negative_prompt="" if args.task_type == 'edit' else "worst quality, wrong limbs, unreasonable limbs, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting", + ref_images=source_image if args.task_type == 'edit' else None, + caption_ref_images=reference_image if args.task_type == 'edit' else None, + num_samples=1, + num_steps=args.num_steps, + cfg_guidance=args.cfg_guidance, + seed=args.seed, + show_progress=True, + size_level=args.size_level, + height=args.height, + width=args.width, + )[0] + + if os.getenv("TORCHELASTIC_RUN_ID") is None or dist.get_rank() == 0: + print(f"Time taken: {time.time() - start_time:.2f} seconds") + image.save(args.output_path, lossless=True) + print(f"Saved result to: {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/inference_v2.py b/inference_v2.py new file mode 100644 index 0000000..bac139e --- /dev/null +++ b/inference_v2.py @@ -0,0 +1,768 @@ +import argparse +import datetime +import itertools +import math +import os +import time +import functools +from pathlib import Path + +import numpy as np +import torch +from einops import rearrange, repeat +from PIL import Image, ImageOps +from safetensors.torch import load_file +from torchvision.transforms import functional as F +from tqdm import tqdm + +import sampling +from modules.autoencoder import AutoEncoder +from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder +from modules.model_edit import Step1XParams, Step1XEdit +from modules.multigpu import parallel_transformer, teacache_transformer, parallel_teacache_transformer + +from torch import Tensor +import torch.distributed as dist +from xfuser.core.distributed import ( + get_world_group, + initialize_model_parallel, +) +from qwen_vl_utils import process_vision_info + +DUAL_IMAGE_CAPTION_PROMPT = """You are analyzing facial expressions for a controlled editing task. +Given: +- Image 1: source face to be edited +- Image 2: target expression reference + +Output a structured editing plan +""" + +def cfg_usp_level_setting(ring_degree: int = 1, ulysses_degree: int = 1, cfg_degree: int = 1): + # restriction: dist.get_world_size() == x x + initialize_model_parallel( + ring_degree=ring_degree, + ulysses_degree=ulysses_degree, + classifier_free_guidance_degree=cfg_degree, + ) + +def teacache_init(pipe, args): + pipe.dit.__class__.enable_teacache = True + pipe.dit.__class__.cnt = 0 + pipe.dit.__class__.num_steps = args.num_steps + pipe.dit.__class__.rel_l1_thresh = args.teacache_threshold + pipe.dit.__class__.accumulated_rel_l1_distance = 0 + pipe.dit.__class__.previous_modulated_input = None + pipe.dit.__class__.previous_residual = None + + +def cudagc(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + +def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True): + if Path(ckpt_path).suffix == ".safetensors": + state_dict = load_file(ckpt_path, device) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + missing, unexpected = model.load_state_dict( + state_dict, strict=strict, assign=assign + ) + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + return model + + +def load_models( + dit_path=None, + ae_path=None, + qwen2vl_model_path=None, + mode="flash", + device="cuda", + max_length=256, + dtype=torch.bfloat16, + version='v1.0' +): + qwen2vl_encoder = Qwen2VLEmbedder( + qwen2vl_model_path, + device=device, + max_length=max_length, + dtype=dtype, + ) + + with torch.device("meta"): + ae = AutoEncoder( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) + + step1x_params = Step1XParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + mode=mode, + version=version, + ) + dit = Step1XEdit(step1x_params) + + ae = load_state_dict(ae, ae_path, 'cpu') + dit = load_state_dict( + dit, dit_path, 'cpu' + ) + + ae = ae.to(dtype=torch.float32) + + return ae, dit, qwen2vl_encoder + +def equip_dit_with_lora_sd_scripts(ae, text_encoders, dit, lora, device='cuda'): + from safetensors.torch import load_file + weights_sd = load_file(lora) + keys = list(weights_sd.keys()) + is_lora = any(("lora_down" in k or "lora_up" in k or k.startswith("lora_")) for k in keys) + is_connector = any(k.startswith("connector.") for k in keys) or not is_lora + + if is_lora and not is_connector: + from library import lora_module as module + print(f"[Adapter loader] Detected LoRA weights: {lora}") + else: + from library import connector_module as module + print(f"[Adapter loader] Detected connector weights: {lora}") + + lora_model, _ = module.create_network_from_weights(1.0, None, ae, text_encoders, dit, weights_sd, True) + lora_model.merge_to(text_encoders, dit, weights_sd) + + lora_model.set_multiplier(1.0) + return lora_model + +class ImageGenerator: + def __init__( + self, + dit_path=None, + ae_path=None, + qwen2vl_model_path=None, + device="cuda", + max_length=640, + dtype=torch.bfloat16, + quantized=False, + offload=False, + lora=None, + mode="flash", + version='v1.0' + ) -> None: + self.version = version + if os.getenv("TORCHELASTIC_RUN_ID") is not None: + local_rank = get_world_group().local_rank + torch.cuda.set_device(local_rank) + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device(device) + + self.ae, self.dit, self.llm_encoder = load_models( + dit_path=dit_path, + ae_path=ae_path, + qwen2vl_model_path=qwen2vl_model_path, + max_length=max_length, + dtype=dtype, + device=self.device, + mode=mode, + version=version, + ) + + if not quantized: + self.dit = self.dit.to(dtype=torch.bfloat16) + else: + self.dit = self.dit.to(dtype=torch.float8_e4m3fn) + if not offload: + self.dit = self.dit.to(device=self.device) + self.ae = self.ae.to(device=self.device) + self.quantized = quantized + self.offload = offload + if lora is not None: + self.lora_module = equip_dit_with_lora_sd_scripts( + self.ae, + [self.llm_encoder], + self.dit, + lora, + device=self.dit.device, + ) + else: + self.lora_module = None + self.mode = mode + + + def prepare(self, prompt, img, ref_image, ref_image_raw): + bs, _, h, w = img.shape + bs, _, ref_h, ref_w = ref_image.shape + + assert h == ref_h and w == ref_w + + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + elif bs >= 1 and isinstance(prompt, str): + prompt = [prompt] * bs + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if self.version == 'v1.0': + ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) + else: + ref_img_ids = torch.ones(ref_h // 2, ref_w // 2, 3) + + ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + if self.offload: + self.llm_encoder = self.llm_encoder.to(self.device) + txt, mask = self.llm_encoder(prompt, ref_image_raw) + if self.offload: + self.llm_encoder = self.llm_encoder.cpu() + cudagc() + + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2) + img_ids = torch.cat([img_ids, ref_img_ids], dim=-2) + + + return { + "img": img, + "mask": mask, + "img_ids": img_ids.to(img.device), + "llm_embedding": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + } + + + def prepare_t2i(self, prompt, img, ref_image_raw): + bs, _, h, w = img.shape + + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + elif bs >= 1 and isinstance(prompt, str): + prompt = [prompt] * bs + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + + img_ids = torch.zeros(h // 2, w // 2, 3) + + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + + if isinstance(prompt, str): + prompt = [prompt] + if self.offload: + self.llm_encoder = self.llm_encoder.to(self.device) + txt, mask = self.llm_encoder(prompt, ref_image_raw) + if self.offload: + self.llm_encoder = self.llm_encoder.cpu() + cudagc() + + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + + return { + "img": img, + "mask": mask, + "img_ids": img_ids.to(img.device), + "llm_embedding": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + } + @staticmethod + def process_diff_norm(diff_norm, k): + pow_result = torch.pow(diff_norm, k) + + result = torch.where( + diff_norm > 1.0, + pow_result, + torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm), + ) + return result + + def denoise_t2i( + self, + img: torch.Tensor, + img_ids: torch.Tensor, + llm_embedding: torch.Tensor, + txt_ids: torch.Tensor, + timesteps: list[float], + cfg_guidance: float = 4.5, + mask=None, + show_progress=False, + timesteps_truncate=0.93, + ): + if self.offload: + self.dit = self.dit.to(self.device) + if show_progress: + pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...') + else: + pbar = itertools.pairwise(timesteps) + for idx, (t_curr, t_prev) in enumerate(pbar): + if img.shape[0] == 1 and cfg_guidance != -1: + img = torch.cat([img, img], dim=0) + t_vec = torch.full( + (img.shape[0],), t_curr, dtype=img.dtype, device=img.device + ) + pred = self.dit( + img=img, + img_ids=img_ids, + txt_ids=txt_ids, + timesteps=t_vec, + llm_embedding=llm_embedding, + t_vec=t_vec, + mask=mask, + ) + + if cfg_guidance != -1: + cond, uncond = ( + pred[0 : pred.shape[0] // 2, :], + pred[pred.shape[0] // 2 :, :], + ) + if t_curr > timesteps_truncate: + diff = cond - uncond + diff_norm = torch.norm(diff, dim=(2), keepdim=True) + pred = uncond + cfg_guidance * ( + cond - uncond + ) / self.process_diff_norm(diff_norm, k=0.4) + else: + pred = uncond + cfg_guidance * (cond - uncond) + img = img[0 : img.shape[0] // 2] + (t_prev - t_curr) * pred + if self.offload: + self.dit = self.dit.cpu() + cudagc() + + return img + + def denoise( + self, + img: torch.Tensor, + img_ids: torch.Tensor, + llm_embedding: torch.Tensor, + txt_ids: torch.Tensor, + timesteps: list[float], + cfg_guidance: float = 4.5, + mask=None, + show_progress=False, + timesteps_truncate=0.93, + ): + ref_img_tensor = img[0, img.shape[1] // 2:].clone() + if self.offload: + self.dit = self.dit.to(self.device) + if show_progress: + pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...') + else: + pbar = itertools.pairwise(timesteps) + for idx, (t_curr, t_prev) in enumerate(pbar): + if img.shape[0] == 1 and cfg_guidance != -1: + img = torch.cat([img, img], dim=0) + t_vec = torch.full( + (img.shape[0],), t_curr, dtype=img.dtype, device=img.device + ) + pred = self.dit( + img=img, + img_ids=img_ids, + txt_ids=txt_ids, + timesteps=t_vec, + llm_embedding=llm_embedding, + t_vec=t_vec, + mask=mask, + ) + pred = pred[:, :pred.shape[1] // 2] + + if cfg_guidance != -1: + cond, uncond = ( + pred[0 : pred.shape[0] // 2, :], + pred[pred.shape[0] // 2 :, :], + ) + if t_curr > timesteps_truncate: + diff = cond - uncond + diff_norm = torch.norm(diff, dim=(2), keepdim=True) + pred = uncond + cfg_guidance * ( + cond - uncond + ) / self.process_diff_norm(diff_norm, k=0.4) + else: + pred = uncond + cfg_guidance * (cond - uncond) + tem_img = img[0 : img.shape[0] // 2, : img.shape[1] // 2] + (t_prev - t_curr) * pred + img = torch.cat( + [ + tem_img, + ref_img_tensor.unsqueeze(0), + ], dim=1 + ) + if self.offload: + self.dit = self.dit.cpu() + cudagc() + + return img[:, :img.shape[1] // 2] + + @staticmethod + def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + + @staticmethod + def load_image(image): + from PIL import Image + + if isinstance(image, np.ndarray): + image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 + image = image.unsqueeze(0) + return image + elif isinstance(image, Image.Image): + image = F.to_tensor(image.convert("RGB")) + image = image.unsqueeze(0) + return image + elif isinstance(image, torch.Tensor): + return image + elif isinstance(image, str): + image = F.to_tensor(Image.open(image).convert("RGB")) + image = image.unsqueeze(0) + return image + else: + raise ValueError(f"Unsupported image type: {type(image)}") + + def output_process_image(self, resize_img, image_size): + res_image = resize_img.resize(image_size) + return res_image + return resize_img + + def input_process_image(self, img, img_size=512): + # 1. 打开图片 + w, h = img.size + r = w / h + + if w > h: + w_new = math.ceil(math.sqrt(img_size * img_size * r)) + h_new = math.ceil(w_new / r) + else: + h_new = math.ceil(math.sqrt(img_size * img_size / r)) + w_new = math.ceil(h_new * r) + h_new = h_new // 16 * 16 + w_new = w_new // 16 * 16 + + img_resized = img.resize((w_new, h_new), Image.LANCZOS) + return img_resized, img.size + + def build_caption_from_dual_images(self, source_image, reference_image, user_prompt="", max_new_tokens=150): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": DUAL_IMAGE_CAPTION_PROMPT}, + {"type": "text", "text": "\n[Source Image (Structure/Identity)]:"}, + {"type": "image", "image": source_image}, + {"type": "text", "text": "\n[Reference Image (Style/Expression)]:"}, + {"type": "image", "image": reference_image}, + {"type": "text", "text": f"\nUser prompt: {user_prompt}"}, + {"type": "text", "text": "\nPlease generate the structured editing plan now."}, + ], + } + ] + + processor = self.llm_encoder.processor + model = self.llm_encoder.model + + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + image_inputs, _ = process_vision_info(messages) + inputs = processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + ) + + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + generated_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True + )[0].strip() + return generated_text if generated_text else user_prompt + + @torch.inference_mode() + def generate_image( + self, + prompt, + negative_prompt, + ref_images, + num_steps, + cfg_guidance, + seed, + caption_ref_images=None, + num_samples=1, + init_image=None, + image2image_strength=0.0, + show_progress=False, + size_level=512, + height=None, + width=None, + caption_max_new_tokens=150, + ): + assert num_samples == 1, "num_samples > 1 is not supported yet." + if ref_images == None: + self.task_type='t2i' + ref_images = Image.new('RGB', (1024, 1024)) + ref_images_raw = ref_images + img_info = (width, height) if width is not None and height is not None else (1024, 1024) + else: + self.task_type = 'edit' + ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level) + if caption_ref_images is not None: + caption_ref_images_raw, _ = self.input_process_image(caption_ref_images, img_size=size_level) + if caption_ref_images_raw.size != ref_images_raw.size: + caption_ref_images_raw = caption_ref_images_raw.resize(ref_images_raw.size, Image.LANCZOS) + if self.offload: + self.llm_encoder = self.llm_encoder.to(self.device) + prompt = self.build_caption_from_dual_images( + source_image=ref_images_raw, + reference_image=caption_ref_images_raw, + user_prompt=prompt, + max_new_tokens=caption_max_new_tokens, + ) + if os.getenv("TORCHELASTIC_RUN_ID") is None or dist.get_rank() == 0: + print(f"[Dual-image generated prompt] {prompt}") + + if self.task_type == 'edit': + width, height = ref_images_raw.width, ref_images_raw.height + + ref_images_raw = self.load_image(ref_images_raw) + ref_images_raw = ref_images_raw.to(self.device) + if self.offload: + self.ae = self.ae.to(self.device) + with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16): + ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1) + if self.offload: + self.ae = self.ae.cpu() + cudagc() + else: + width, height = img_info + ref_images_raw = self.load_image(ref_images_raw) + ref_images_raw = ref_images_raw.to(self.device) + ref_images = None + + seed = int(seed) + seed = torch.Generator(device="cpu").seed() if seed < 0 else seed + + t0 = time.perf_counter() + + if init_image is not None: + init_image = self.load_image(init_image) + init_image = init_image.to(self.device) + init_image = torch.nn.functional.interpolate(init_image, (height, width)) + if self.offload: + self.ae = self.ae.to(self.device) + init_image = self.ae.encode(init_image.to() * 2 - 1) + if self.offload: + self.ae = self.ae.cpu() + cudagc() + + x = torch.randn( + num_samples, + 16, + height // 8, + width // 8, + device=self.device, + dtype=torch.bfloat16, + generator=torch.Generator(device=self.device).manual_seed(seed), + ) + timesteps = sampling.get_schedule( + num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True + ) + + if init_image is not None: + t_idx = int((1 - image2image_strength) * num_steps) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + x = t * x + (1.0 - t) * init_image.to(x.dtype) + + x = torch.cat([x, x], dim=0) + if self.task_type == 'edit': + ref_images = torch.cat([ref_images, ref_images], dim=0) + ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0) + inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw) + else: + ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0) + inputs = self.prepare_t2i([prompt, negative_prompt], x, ref_images_raw) + + + + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): + if self.task_type == 'edit': + x = self.denoise( + **inputs, + cfg_guidance=cfg_guidance, + timesteps=timesteps, + show_progress=show_progress, + timesteps_truncate=0.93, + ) + else: + x = self.denoise_t2i( + **inputs, + cfg_guidance=cfg_guidance, + timesteps=timesteps, + show_progress=show_progress, + timesteps_truncate=0.93, + ) + x = self.unpack(x.float(), height, width) + if self.offload: + self.ae = self.ae.to(self.device) + x = self.ae.decode(x) + if self.offload: + self.ae = self.ae.cpu() + cudagc() + x = x.clamp(-1, 1) + x = x.mul(0.5).add(0.5) + + t1 = time.perf_counter() + if os.getenv("TORCHELASTIC_RUN_ID") is None or dist.get_rank() == 0: + print(f"Done in {t1 - t0:.1f}s.") + images_list = [] + for img in x.float(): + images_list.append(self.output_process_image(F.to_pil_image(img), img_info)) + return images_list + + +def main(): + torch.backends.cudnn.deterministic = True + + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint') + parser.add_argument('--source_path', type=str, required=True, help='Path to source image') + parser.add_argument('--reference_path', type=str, required=True, help='Path to reference image for dual-image caption generation') + parser.add_argument('--output_path', type=str, required=True, help='Path to output image') + parser.add_argument('--prompt', type=str, default='', help='Optional user prompt. If empty, caption is generated only from dual images.') + parser.add_argument('--seed', type=int, default=42, help='Random seed for generation') + parser.add_argument('--num_steps', type=int, default=28, help='Number of diffusion steps') + parser.add_argument('--cfg_guidance', type=float, default=6.0, help='CFG guidance strength') + parser.add_argument('--size_level', default=512, type=int) + parser.add_argument('--offload', action='store_true', help='Use offload for large models') + parser.add_argument('--quantized', action='store_true', help='Use fp8 model weights') + parser.add_argument('--lora', type=str, default=None) + parser.add_argument('--ring_degree', type=int, default=1) + parser.add_argument('--ulysses_degree', type=int, default=1) + parser.add_argument('--cfg_degree', type=int, default=1) + parser.add_argument('--teacache', action='store_true') + parser.add_argument('--teacache_threshold', type=float, default=0.2, help='Used to control the acceleration ratio of teacache') + parser.add_argument('--version', type=str, default='v1.1', choices=['v1.0', 'v1.1']) + parser.add_argument('--task_type', type=str, default='edit', choices=['edit', 't2i'], help='Task type: edit or t2i') + parser.add_argument('--height', type=int, default=1024, help='Size of the output image (for t2i task)') + parser.add_argument('--width', type=int, default=1024, help='Size of the output image (for t2i task)') + + args = parser.parse_args() + + assert os.path.exists(args.source_path), f"Source image {args.source_path} does not exist." + assert os.path.exists(args.reference_path), f"Reference image {args.reference_path} does not exist." + os.makedirs(os.path.dirname(args.output_path) or ".", exist_ok=True) + + mode = "flash" if args.ring_degree * args.ulysses_degree * args.cfg_degree == 1 else "xdit" + + if args.version == 'v1.0': + ckpt_name = 'step1x-edit-i1258.safetensors' + elif args.version == 'v1.1': + ckpt_name = 'step1x-edit-v1p1-official.safetensors' + + image_edit = ImageGenerator( + ae_path=os.path.join(args.model_path, 'vae.safetensors'), + dit_path=os.path.join(args.model_path, ckpt_name), + qwen2vl_model_path=os.path.join(args.model_path, 'Qwen2.5-VL-7B-Instruct'), + max_length=640, + quantized=args.quantized, + offload=args.offload, + lora=args.lora, + mode=mode, + version=args.version, + ) + + if args.teacache: + teacache_init(image_edit, args) + if args.ring_degree * args.ulysses_degree * args.cfg_degree != 1: + cfg_usp_level_setting(args.ring_degree, args.ulysses_degree, args.cfg_degree) + parallel_teacache_transformer(image_edit) + else: + teacache_transformer(image_edit) + else: + if args.ring_degree * args.ulysses_degree * args.cfg_degree != 1: + cfg_usp_level_setting(args.ring_degree, args.ulysses_degree, args.cfg_degree) + parallel_transformer(image_edit) + + start_time = time.time() + source_image = Image.open(args.source_path).convert("RGB") + reference_image = Image.open(args.reference_path).convert("RGB") + image = image_edit.generate_image( + args.prompt, + negative_prompt="" if args.task_type == 'edit' else "worst quality, wrong limbs, unreasonable limbs, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting", + ref_images=source_image if args.task_type == 'edit' else None, + caption_ref_images=reference_image if args.task_type == 'edit' else None, + num_samples=1, + num_steps=args.num_steps, + cfg_guidance=args.cfg_guidance, + seed=args.seed, + show_progress=True, + size_level=args.size_level, + height=args.height, + width=args.width, + )[0] + + if os.getenv("TORCHELASTIC_RUN_ID") is None or dist.get_rank() == 0: + print(f"Time taken: {time.time() - start_time:.2f} seconds") + image.save(args.output_path, lossless=True) + print(f"Saved result to: {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/library/data_configs/step1x_edit.toml b/library/data_configs/step1x_edit.toml index 1ccd472..ea259d7 100644 --- a/library/data_configs/step1x_edit.toml +++ b/library/data_configs/step1x_edit.toml @@ -10,5 +10,5 @@ batch_size = 1 edit_dataset = true # necessary for editing tasks [[datasets.subsets]] - image_dir = - metadata_file = \ No newline at end of file + image_dir = "/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/target_img" + metadata_file = "/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata.json" \ No newline at end of file diff --git a/library/kohya_trainer.py b/library/kohya_trainer.py index 0f1fe53..1242f8e 100644 --- a/library/kohya_trainer.py +++ b/library/kohya_trainer.py @@ -825,7 +825,9 @@ def train(self, args): t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 - if t_enc.device.type != "cpu": + t_enc_device = t_enc.device + t_enc_device_type = t_enc_device.type if hasattr(t_enc_device, "type") else str(t_enc_device) + if t_enc_device_type != "cpu": t_enc.to(dtype=te_weight_dtype) # nn.Embedding not support FP8 diff --git a/library/qwen_connector_module_v1.py b/library/qwen_connector_module_v1.py new file mode 100644 index 0000000..9ef570b --- /dev/null +++ b/library/qwen_connector_module_v1.py @@ -0,0 +1,164 @@ +import os +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def _load_weights_file(file: str) -> Dict[str, torch.Tensor]: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + return load_file(file) + return torch.load(file, map_location="cpu") + + +def _is_fp8_dtype(dtype: torch.dtype) -> bool: + return dtype in { + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + } + + +class QwenConnectorNetwork(nn.Module): + def __init__(self, multiplier: float = 1.0) -> None: + super().__init__() + self.multiplier = multiplier + self.connector_ref: Optional[list[nn.Module]] = None + + @property + def connector(self) -> nn.Module: + if self.connector_ref is None or len(self.connector_ref) == 0 or self.connector_ref[0] is None: + raise RuntimeError("connector is not attached. call apply_to() before using the network") + return self.connector_ref[0] + + def train(self, mode: bool = True): + super().train(mode) + if self.connector_ref is not None and len(self.connector_ref) > 0 and self.connector_ref[0] is not None: + self.connector_ref[0].train(mode) + return self + + def parameters(self, recurse: bool = True) -> Iterable[torch.nn.Parameter]: + if self.connector_ref is None: + return iter(()) + return self.connector.parameters(recurse=recurse) + + def named_parameters(self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True): + if self.connector_ref is None: + return iter(()) + return self.connector.named_parameters(prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + + def apply_to(self, text_encoders, dit, apply_text_encoder: bool = True, apply_unet: bool = True): + if not apply_unet: + raise ValueError("QwenConnectorNetwork requires apply_unet=True because connector belongs to the base DiT") + self.connector_ref = [dit.connector] + logger.info("enable connector-only training for U-Net connector") + + def is_mergeable(self): + return False + + def set_multiplier(self, multiplier: float): + self.multiplier = multiplier + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + lr = unet_lr if unet_lr is not None else default_lr + if lr is None or lr == 0: + raise ValueError("unet_lr or learning_rate must be specified for connector training") + params = list(self.connector.parameters()) + if len(params) == 0: + raise ValueError("no connector parameters found") + return [{"params": params, "lr": lr}], ["unet_connector"] + + def enable_gradient_checkpointing(self): + pass + + def prepare_grad_etc(self, text_encoder, unet): + connector = self.connector + for param in connector.parameters(): + if _is_fp8_dtype(param.dtype): + raise ValueError("QwenConnectorNetwork v1 does not support fp8_base/fp8_base_unet") + connector.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.connector.parameters() + + def save_weights(self, file, dtype, metadata): + from library import train_util + + state_dict = self.connector.state_dict() + if dtype is not None: + for key in list(state_dict.keys()): + state_dict[key] = state_dict[key].detach().clone().to("cpu").to(dtype) + else: + for key in list(state_dict.keys()): + state_dict[key] = state_dict[key].detach().clone().to("cpu") + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + metadata = {} if metadata is None else dict(metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def load_weights(self, file): + weights_sd = _load_weights_file(file) + info = self.connector.load_state_dict(weights_sd, strict=True) + return info + + def merge_to(self, text_encoders, dit, weights_sd, dtype=None, device=None): + self.apply_to(text_encoders, dit, apply_text_encoder=False, apply_unet=True) + target_connector = self.connector + if dtype is not None: + converted = {} + for key, value in weights_sd.items(): + converted[key] = value.to(device=device or value.device, dtype=dtype) + weights_sd = converted + target_connector.load_state_dict(weights_sd, strict=True) + logger.info("connector weights are loaded into the base DiT") + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: Any, + text_encoders, + base_dit, + neuron_dropout: Optional[float] = None, + **kwargs, +): + return QwenConnectorNetwork(multiplier=multiplier) + + +def create_network_from_weights( + multiplier, + file, + ae, + text_encoders, + base_dit, + weights_sd=None, + for_inference=False, + **kwargs, +) -> Tuple[QwenConnectorNetwork, Dict[str, torch.Tensor]]: + if weights_sd is None: + if file is None: + raise ValueError("file must be specified when weights_sd is None") + weights_sd = _load_weights_file(file) + network = QwenConnectorNetwork(multiplier=multiplier) + return network, weights_sd diff --git a/library/step1x_utils.py b/library/step1x_utils.py index e52b884..7c7b00b 100644 --- a/library/step1x_utils.py +++ b/library/step1x_utils.py @@ -56,6 +56,8 @@ def load_models( axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, + mode="flash", + version="v1.0", ) dit = Step1XEdit(step1x_params) @@ -69,7 +71,7 @@ def load_models( def load_qwen2p5vl( qwen2vl_model_path=None, device="cuda", - max_length=640, + max_length=1280, dtype=torch.bfloat16, ): qwen2vl_encoder = Qwen2VLEmbedder( @@ -122,4 +124,4 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi def get_qwen_actual_dtype(input_model) -> torch.dtype: # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 - return input_model.model.model.layers[0].mlp.gate_proj.weight.dtype \ No newline at end of file + return input_model.model.model.layers[0].mlp.gate_proj.weight.dtype diff --git a/library/step1x_utils_v1.py b/library/step1x_utils_v1.py new file mode 100644 index 0000000..6e96a3e --- /dev/null +++ b/library/step1x_utils_v1.py @@ -0,0 +1,127 @@ +import einops +import torch +from modules.autoencoder import AutoEncoder +from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder +from modules.model_edit import Step1XParams, Step1XEdit +from pathlib import Path +from safetensors.torch import load_file + +from modules import autoencoder + +from library.utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library.utils import load_safetensors + +def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True): + if Path(ckpt_path).suffix == ".safetensors": + state_dict = load_file(ckpt_path, device) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + missing, unexpected = model.load_state_dict( + state_dict, strict=strict, assign=assign + ) + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + return model + +def load_models( + dit_path=None, + device="cpu", + dtype=torch.bfloat16, +): + + with torch.device("meta"): + + step1x_params = Step1XParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + mode="flash", + version="v1.0", + ) + dit = Step1XEdit(step1x_params) + + dit = load_state_dict( + dit, dit_path, device + ) + dit = dit.to(dtype=dtype, device=device) + + return dit + +def load_qwen2p5vl( + qwen2vl_model_path=None, + device="cuda", + max_length=640, + dtype=torch.bfloat16, +): + qwen2vl_encoder = Qwen2VLEmbedder( + qwen2vl_model_path, + device=device, + max_length=max_length, + dtype=dtype, + ) + return qwen2vl_encoder + + +def load_ae( + ckpt_path: str, dtype: torch.dtype, device, disable_mmap: bool = False +) -> autoencoder.AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = AutoEncoder( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + +def get_qwen_actual_dtype(input_model) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return input_model.model.model.layers[0].mlp.gate_proj.weight.dtype diff --git a/library/strategy_step1x.py b/library/strategy_step1x.py index 6392b2e..5d40062 100644 --- a/library/strategy_step1x.py +++ b/library/strategy_step1x.py @@ -56,7 +56,7 @@ def split_string(s): return result class Step1xEditTokenizeStrategy(TokenizeStrategy): - def __init__(self, max_length: int = 640, tokenizer_cache_dir: Optional[str] = None) -> None: + def __init__(self, max_length: int = 1280, tokenizer_cache_dir: Optional[str] = None) -> None: self.max_length = max_length self.processor = AutoProcessor.from_pretrained( tokenizer_cache_dir, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28 @@ -130,7 +130,7 @@ def tokenize(self, text: Union[str, List[str]], ref_images: Union[PIL.Image.Imag return res_list class Step1XEditEncodingStrategy(TextEncodingStrategy): - def __init__(self, max_length=640, hidden_size=None) -> None: + def __init__(self, max_length=1280, hidden_size=None) -> None: self.max_length = max_length self.hidden_size=hidden_size self.dtype = None @@ -485,4 +485,4 @@ def save_latents_to_disk( if alpha_mask is not None: kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() kwargs["ref_alpha_mask" + key_reso_suffix] = ref_alpha_mask.float().cpu().numpy() - np.savez(npz_path, **kwargs) \ No newline at end of file + np.savez(npz_path, **kwargs) diff --git a/library/strategy_step1x_v1.py b/library/strategy_step1x_v1.py new file mode 100644 index 0000000..6392b2e --- /dev/null +++ b/library/strategy_step1x_v1.py @@ -0,0 +1,488 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import PIL +import torch +import numpy as np +from transformers import AutoProcessor + +from library import train_util, step1x_utils +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from qwen_vl_utils import process_vision_info + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:''' + +def split_string(s): + s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes + result = [] + in_quotes = False + temp = "" + + for idx,char in enumerate(s): + if char == '"' and idx>155: + temp += char + if not in_quotes: + result.append(temp) + temp = "" + + in_quotes = not in_quotes + continue + if in_quotes: + if char.isspace(): + pass # have space token + + result.append("“" + char + "”") + else: + temp += char + + if temp: + result.append(temp) + + return result + +class Step1xEditTokenizeStrategy(TokenizeStrategy): + def __init__(self, max_length: int = 640, tokenizer_cache_dir: Optional[str] = None) -> None: + self.max_length = max_length + self.processor = AutoProcessor.from_pretrained( + tokenizer_cache_dir, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28 + ) + self.prefix = Qwen25VL_7b_PREFIX + + def tokenize(self, text: Union[str, List[str]], ref_images: Union[PIL.Image.Image, List[PIL.Image.Image]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + ref_images = [ref_images] if isinstance(ref_images, PIL.Image.Image) else ref_images + + res_list = [] + + for idx, (txt, imgs) in enumerate(zip(text, ref_images)): + messages = [{"role": "user", "content": []}] + + messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) + + messages[0]["content"].append({"type": "image", "image": imgs}) + + # 再添加 text + messages[0]["content"].append({"type": "text", "text": f"{txt}"}) + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + + image_inputs, video_inputs = process_vision_info(messages) + + inputs = self.processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + old_inputs_ids = inputs.input_ids + text_split_list = split_string(text) + + token_list = [] + for text_each in text_split_list: + txt_inputs = self.processor( + text=text_each, + images=None, + videos=None, + padding=True, + return_tensors="pt", + ) + token_each = txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:, 1:-1] + token_list.append(token_each) + else: + token_list.append(token_each) + + new_txt_ids = torch.cat(token_list, dim=1).to("cuda") + + new_txt_ids = new_txt_ids.to(old_inputs_ids.device) + + idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + inputs.input_ids = ( + torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) + .unsqueeze(0) + .to("cuda") + ) + inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") + + res_list.append(inputs) + + return res_list + +class Step1XEditEncodingStrategy(TextEncodingStrategy): + def __init__(self, max_length=640, hidden_size=None) -> None: + self.max_length = max_length + self.hidden_size=hidden_size + self.dtype = None + + def encode_tokens(self, tokenize_strategy, models, tokens): + qwen2p5vl = models[0] + if self.dtype is None: + self.dtype = qwen2p5vl.model.lm_head.weight.dtype + if self.hidden_size is None: + self.hidden_size = qwen2p5vl.model.config.hidden_size + embs = torch.zeros( + len(tokens), + self.max_length, + self.hidden_size, + dtype=self.dtype, + device=torch.cuda.current_device() + ) + masks = torch.zeros( + len(tokens), + self.max_length, + dtype=torch.long, + device=torch.cuda.current_device(), + ) + self.device = torch.device(torch.cuda.current_device()) + for idx, inputs in enumerate(tokens): + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + if value.dtype == torch.float32: + inputs[key] = value.to(self.dtype) + # with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): + outputs = qwen2p5vl.model( + input_ids = inputs.input_ids.to(torch.cuda.current_device()), + attention_mask=inputs.attention_mask, + pixel_values=inputs.pixel_values.to(torch.cuda.current_device()), + image_grid_thw=inputs.image_grid_thw.to(torch.cuda.current_device()), + output_hidden_states=True, + ) + emb = outputs['hidden_states'][-1] + embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][ + : self.max_length + ] + + masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( + (min(self.max_length, emb.shape[1] - 217)), + dtype=torch.long, + device=torch.cuda.current_device(), + ) + return embs, masks + +class Step1xEditEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + Step1XEdit_ENCODER_OUTPUTS_NPZ_SUFFIX = "_step1x_te.npz" + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + self.warn_fp8_weights = False + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + Step1xEditEncoderOutputsCachingStrategy.Step1XEdit_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + try: + npz = np.load(npz_path) + if 'embeds' not in npz: + return False + if 'masks' not in npz: + return False + except Exception as e: + logger.error(f'Error loading file: {npz_path}') + raise e + return True + + def load_outputs_npz(self, npz_path): + data = np.load(npz_path) + embeds = data['embeds'] + masks = data['masks'] + return [embeds, masks] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, + models: List[Any], + text_encoding_strategy: TextEncodingStrategy, + infos: List + ): + if not self.warn_fp8_weights: + if models[0].model.lm_head.weight.dtype == torch.float8_e4m3fn: + logger.warning( + "Qwen2VL model is using fp8 weights for caching. This may affect the quality of the cached outputs." + ) + self.warn_fp8_weights = True + step1x_edit_encoding_strategy: Step1XEditEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + images = [PIL.Image.open(info.ref_absolute_path).convert('RGB') for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions, images) + with torch.no_grad(): + embs, masks = step1x_edit_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks + ) + if embs.dtype == torch.bfloat16: + embs = embs.float() + if masks.dtype == torch.bfloat16: + masks = masks.float() + + for i, info in enumerate(infos): + emb_i = embs[i] + mask_i = masks[i] + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + embeds=emb_i.cpu().numpy(), + masks=mask_i.cpu().numpy() + ) + else: + info.text_encoder_outputs = (emb_i, mask_i) + + + +from library.train_util import load_image, trim_and_resize_if_required, ImageInfo, IMAGE_TRANSFORMS +# for new_cache_latents +def step1x_load_images_and_masks_for_caching( + image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: + r""" + requires image_infos to have: [absolute_path or image], bucket_reso, resized_size + + returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs + + image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] + alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + original_sizes: List[Tuple[int, int]] = [(W, H), ...] + crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] + """ + images: List[torch.Tensor] = [] + alpha_masks: List[np.ndarray] = [] + original_sizes: List[Tuple[int, int]] = [] + crop_ltrbs: List[Tuple[int, int, int, int]] = [] + + ref_images: List[torch.Tensor] = [] + ref_alpha_masks: List[np.ndarray] = [] + + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + ref_image = load_image(info.ref_absolute_path, use_alpha_mask) if info.ref_image is None else np.array(info.ref_image, np.uint8) + # thanks for the authors not introducing randomness into cropping. + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + ref_image, ref_original_size, ref_crop_ltrb = trim_and_resize_if_required(random_crop, ref_image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + + if ref_image.shape[2] == 4: + ref_alpha_mask = ref_image[:, :, 3] # [H,W] + ref_alpha_mask = ref_alpha_mask.astype(np.float32) / 255.0 + ref_alpha_mask = torch.FloatTensor(ref_alpha_mask) # [H,W] + else: + ref_alpha_mask = torch.ones_like(ref_image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + ref_alpha_mask = None + alpha_masks.append(alpha_mask) + ref_alpha_masks.append(ref_alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + ref_image = ref_image[:, :, :3] + ref_image = IMAGE_TRANSFORMS(ref_image) + ref_images.append(ref_image) + + img_tensor = torch.stack(images, dim=0) + ref_img_tensor = torch.stack(ref_images, dim=0) + return img_tensor, alpha_masks, ref_img_tensor, ref_alpha_masks, original_sizes, crop_ltrbs + +class Step1XEditLatentsCachingStrategy(LatentsCachingStrategy): + Step1XEdit_LATENTS_NPZ_SUFFIX = "_step1x_latents.npz" + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return Step1XEditLatentsCachingStrategy.Step1XEdit_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Step1XEditLatentsCachingStrategy.Step1XEdit_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + + def _default_load_latents_from_disk( + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + if latents_stride is None: + key_reso_suffix = "" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents = npz["latents" + key_reso_suffix] + ref_latents = npz["ref_latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + ref_flipped_latents = npz["ref_latents_flipped" + key_reso_suffix] if "ref_latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + ref_alpha_mask = npz["ref_alpha_mask" + key_reso_suffix] if "ref_alpha_mask" + key_reso_suffix in npz else None + return latents, ref_latents, original_size, crop_ltrb, flipped_latents, ref_flipped_latents, alpha_mask, ref_alpha_mask + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._custom_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + def _custom_cache_batch_latents( + self, + encode_by_vae, + vae_device, + vae_dtype, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + multi_resolution: bool = False, + ): + """ + Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + """ + from library import train_util # import here to avoid circular import + + img_tensor, alpha_masks, ref_img_tensor, ref_alpha_masks, original_sizes, crop_ltrbs = step1x_load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) + ref_img_tensor = ref_img_tensor.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + latents_tensors = encode_by_vae(img_tensor).to("cpu") + ref_latents_tensors = encode_by_vae(ref_img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + ref_img_tensor = torch.flip(ref_img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = encode_by_vae(img_tensor).to("cpu") + ref_flipped_latents = encode_by_vae(ref_img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + ref_flipped_latents = [None] * len(ref_latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + ref_latents = ref_latents_tensors[i] + flipped_latent = flipped_latents[i] + ref_flipped_latent = ref_flipped_latents[i] + alpha_mask = alpha_masks[i] + ref_alpha_mask = ref_alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + latents_size = latents.shape[1:3] # H, W + + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW + + if self.cache_to_disk: + self.save_latents_to_disk( + info.latents_npz, latents, ref_latents, + original_size, crop_ltrb, + flipped_latent, ref_flipped_latent, + alpha_mask, ref_alpha_mask, + key_reso_suffix + ) + else: + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + info.latents = latents + info.ref_latents = ref_latents + if flip_aug: + info.latents_flipped = flipped_latent + info.ref_latents_flipped = ref_flipped_latent + info.alpha_mask = alpha_mask + + def save_latents_to_disk( + self, + npz_path, + latents_tensor, + ref_latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + ref_flipped_latents_tensor=None, + alpha_mask=None, + ref_alpha_mask=None, + key_reso_suffix="", + ): + kwargs = {} + + if os.path.exists(npz_path): + # load existing npz and update it + npz = np.load(npz_path) + for key in npz.files: + kwargs[key] = npz[key] + + kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() + kwargs["ref_latents" + key_reso_suffix] = ref_latents_tensor.float().cpu().numpy() + kwargs["original_size" + key_reso_suffix] = np.array(original_size) + kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) + if flipped_latents_tensor is not None: + kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() + kwargs["ref_latents_flipped" + key_reso_suffix] = ref_flipped_latents_tensor.float().cpy().numpy() + if alpha_mask is not None: + kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() + kwargs["ref_alpha_mask" + key_reso_suffix] = ref_alpha_mask.float().cpu().numpy() + np.savez(npz_path, **kwargs) \ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index f0f6e28..97c1d92 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2080,10 +2080,15 @@ def __init__( if os.path.exists(image_key): abs_path = image_key else: - # わりといい加減だがいい方法が思いつかん - paths = glob_images(subset.image_dir, image_key) - if len(paths) > 0: - abs_path = paths[0] + # try direct join with image_dir (handles keys with extension) + direct_path = os.path.join(subset.image_dir, image_key) + if os.path.exists(direct_path): + abs_path = direct_path + else: + # わりといい加減だがいい方法が思いつかん + paths = glob_images(subset.image_dir, os.path.splitext(image_key)[0]) + if len(paths) > 0: + abs_path = paths[0] # なければnpzを探す if abs_path is None: @@ -2310,10 +2315,15 @@ def __init__( if os.path.exists(image_key): abs_path = image_key else: - # わりといい加減だがいい方法が思いつかん - paths = glob_images(subset.image_dir, image_key) - if len(paths) > 0: - abs_path = paths[0] + # try direct join with image_dir (handles keys with extension) + direct_path = os.path.join(subset.image_dir, image_key) + if os.path.exists(direct_path): + abs_path = direct_path + else: + # わりといい加減だがいい方法が思いつかん + paths = glob_images(subset.image_dir, os.path.splitext(image_key)[0]) + if len(paths) > 0: + abs_path = paths[0] # なければnpzを探す if abs_path is None: diff --git a/modules/conditioner.py b/modules/conditioner.py index cd1db96..9f153c7 100644 --- a/modules/conditioner.py +++ b/modules/conditioner.py @@ -63,7 +63,7 @@ def split_string(s): return result class Qwen25VL_7b_Embedder(torch.nn.Module): - def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device='cuda'): + def __init__(self, model_path, max_length=1280, dtype=torch.bfloat16, device='cuda'): super(Qwen25VL_7b_Embedder, self).__init__() self.max_length = max_length self.dtype = dtype @@ -200,7 +200,7 @@ def split_string(s): class Qwen25VL_7b_Embedder_backup(torch.nn.Module): - def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"): + def __init__(self, model_path, max_length=1280, dtype=torch.bfloat16, device="cuda"): super(Qwen25VL_7b_Embedder_backup, self).__init__() self.max_length = max_length @@ -355,4 +355,4 @@ def split_string(s): device=torch.cuda.current_device(), ) - return embs, masks \ No newline at end of file + return embs, masks diff --git a/modules/conditioner_v1.py b/modules/conditioner_v1.py new file mode 100644 index 0000000..cd1db96 --- /dev/null +++ b/modules/conditioner_v1.py @@ -0,0 +1,358 @@ +import torch +from qwen_vl_utils import process_vision_info +from transformers import ( + AutoProcessor, + Qwen2VLForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, +) +from torchvision.transforms import ToPILImage + +to_pil = ToPILImage() + +Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:''' + + +def split_string(s): + # 将中文引号替换为英文引号 + s = s.replace("“", '"').replace("”", '"') # use english quotes + result = [] + # 标记是否在引号内 + in_quotes = False + temp = "" + + # 遍历字符串中的每个字符及其索引 + for idx, char in enumerate(s): + # 如果字符是引号且索引大于 155 + if char == '"' and idx > 155: + # 将引号添加到临时字符串 + temp += char + # 如果不在引号内 + if not in_quotes: + # 将临时字符串添加到结果列表 + result.append(temp) + # 清空临时字符串 + temp = "" + + # 切换引号状态 + in_quotes = not in_quotes + continue + # 如果在引号内 + if in_quotes: + # 如果字符是空格 + if char.isspace(): + pass # have space token + + # 将字符用中文引号包裹后添加到结果列表 + result.append("“" + char + "”") + else: + # 将字符添加到临时字符串 + temp += char + + # 如果临时字符串不为空 + if temp: + # 将临时字符串添加到结果列表 + result.append(temp) + + return result + +class Qwen25VL_7b_Embedder(torch.nn.Module): + def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device='cuda'): + super(Qwen25VL_7b_Embedder, self).__init__() + self.max_length = max_length + self.dtype = dtype + self.device = device + + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=dtype, + attn_implementation="sdpa",#目前还没装好 + ).to(torch.cuda.current_device()) + + self.model.requires_grad_(False) + self.processor = AutoProcessor.from_pretrained(model_path,min_pixels = 256 * 28 * 28, max_pixels = 324 * 28 * 28) + + self.prefix = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:''' + + def forward(self, caption, ref_images): + text_list=caption + embs = torch.zeros(len(text_list),self.max_length, self.model.config.hidden_size, dtype=torch.bfloat16, device=torch.cuda.current_device()) + hidden_states = torch.zeros(len(text_list),self.max_length, self.model.config.hidden_size, dtype=torch.bfloat16, device=torch.cuda.current_device()) + masks = torch.zeros(len(text_list),self.max_length, dtype=torch.long, device=torch.cuda.current_device()) + input_ids_list = [] + attention_mask_list = [] + emb_list = [] + + def split_string(s): + s = s.replace("'", '"').replace("“", '"').replace("”", '"') # use english quotes + result = [] + in_quotes = False + temp = "" + + for idx,char in enumerate(s): + if char == '"' and idx>155: + temp += char + if not in_quotes: + result.append(temp) + temp = "" + + in_quotes = not in_quotes + continue + if in_quotes: + if char.isspace(): + pass # have space token + + result.append("“" + char + "”") + else: + temp += char + + if temp: + result.append(temp) + + return result + + for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)): + # image_list = [] + # for idx_img in imgs: + # image_list.append(to_pil(idx_img)) + + messages = [ + { + "role": "user", + "content": [] + } + ] + + messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) + # 先添加所有的 image + # messages[0]["content"].extend([{"type": "image", "image": img} for img in image_list]) + messages[0]['content'].append({"type": "image", "image": to_pil(imgs)}) + + # 再添加 text + messages[0]["content"].append({"type": "text", "text": f"{txt}"}) + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + + image_inputs, video_inputs = process_vision_info(messages) + + inputs = self.processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + old_inputs_ids = inputs.input_ids + text_split_list = split_string(text) + + token_list = [] + for text_each in text_split_list: + txt_inputs = self.processor( + text=text_each, + images=None, + videos=None, + padding=True, + return_tensors="pt", + ) + token_each=txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:,1:-1] + token_list.append(token_each) + else: + token_list.append(token_each) + + new_txt_ids=torch.cat(token_list,dim=1).to("cuda") + + new_txt_ids = new_txt_ids.to(old_inputs_ids.device) + idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + inputs.input_ids = torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]],dim=0).unsqueeze(0).to("cuda") + inputs.attention_mask= (inputs.input_ids>0).long().to("cuda") + outputs = self.model(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask, pixel_values = inputs.pixel_values.to("cuda"), image_grid_thw = inputs.image_grid_thw.to("cuda"), output_hidden_states=True) + # outputs = self.model.base_model(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask, pixel_values = inputs.pixel_values.to("cuda"), image_grid_thw = inputs.image_grid_thw.to("cuda"), output_hidden_states=True) + + emb = outputs['hidden_states'][-1] + # hidden_state = output['hidden_states'][8] + + embs[idx,:min(self.max_length,emb.shape[1]-217)] = emb[0,217:][:self.max_length] + # hidden_states[idx,:min(self.max_length,hidden_state.shape[1]-217)] = hidden_state[0,217:][:self.max_length] + + masks[idx,:min(self.max_length,emb.shape[1]-217)]=torch.ones((min(self.max_length,emb.shape[1]-217)), dtype=torch.long, device=torch.cuda.current_device()) + + return embs,masks + + + +class Qwen25VL_7b_Embedder_backup(torch.nn.Module): + def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"): + super(Qwen25VL_7b_Embedder_backup, self).__init__() + self.max_length = max_length + + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=dtype, + attn_implementation="flash_attention_2", + ).to(torch.cuda.current_device()) + + self.model.requires_grad_(False) + self.processor = AutoProcessor.from_pretrained( + model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28 + ) + + self.prefix = Qwen25VL_7b_PREFIX + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def forward(self, caption, ref_images): + text_list = caption + embs = torch.zeros( + len(text_list), + self.max_length, + self.model.config.hidden_size, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + ) + hidden_states = torch.zeros( + len(text_list), + self.max_length, + self.model.config.hidden_size, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + ) + masks = torch.zeros( + len(text_list), + self.max_length, + dtype=torch.long, + device=torch.cuda.current_device(), + ) + input_ids_list = [] + attention_mask_list = [] + emb_list = [] + + def split_string(s): + s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes + result = [] + in_quotes = False + temp = "" + + for idx,char in enumerate(s): + if char == '"' and idx>155: + temp += char + if not in_quotes: + result.append(temp) + temp = "" + + in_quotes = not in_quotes + continue + if in_quotes: + if char.isspace(): + pass # have space token + + result.append("“" + char + "”") + else: + temp += char + + if temp: + result.append(temp) + + return result + + for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)): + + messages = [{"role": "user", "content": []}] + + messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) + + messages[0]["content"].append({"type": "image", "image": to_pil(imgs)}) + + # 再添加 text + messages[0]["content"].append({"type": "text", "text": f"{txt}"}) + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + + image_inputs, video_inputs = process_vision_info(messages) + + inputs = self.processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + old_inputs_ids = inputs.input_ids + text_split_list = split_string(text) + + token_list = [] + for text_each in text_split_list: + txt_inputs = self.processor( + text=text_each, + images=None, + videos=None, + padding=True, + return_tensors="pt", + ) + token_each = txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:, 1:-1] + token_list.append(token_each) + else: + token_list.append(token_each) + + new_txt_ids = torch.cat(token_list, dim=1).to("cuda") + + new_txt_ids = new_txt_ids.to(old_inputs_ids.device) + + idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + inputs.input_ids = ( + torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) + .unsqueeze(0) + .to("cuda") + ) + inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") + outputs = self.model( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + pixel_values=inputs.pixel_values.to("cuda"), + image_grid_thw=inputs.image_grid_thw.to("cuda"), + output_hidden_states=True, + ) + + emb = outputs["hidden_states"][-1] + + embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][ + : self.max_length + ] + + masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( + (min(self.max_length, emb.shape[1] - 217)), + dtype=torch.long, + device=torch.cuda.current_device(), + ) + + return embs, masks \ No newline at end of file diff --git a/scripts/build_step1x_prompt_cache.py b/scripts/build_step1x_prompt_cache.py new file mode 100644 index 0000000..d345f5f --- /dev/null +++ b/scripts/build_step1x_prompt_cache.py @@ -0,0 +1,576 @@ +from __future__ import annotations + +import argparse +import json +import os +import random +import re +import shutil +from collections import defaultdict +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +import numpy as np +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + +from qwen_vl_utils import process_vision_info + + +REPO_ROOT = Path("/scratch3/f007yzf/repos/Step1X-Edit-clean") +DATA_ROOT = REPO_ROOT / "training_data" +EXPERIMENT_ROOT = REPO_ROOT / "training_6k" +MODEL_PATH = Path("/scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct") +OUTPUT_METADATA_DIR = EXPERIMENT_ROOT / "metadata" +EXPERIMENT_ROOT.mkdir(parents=True, exist_ok=True) +OUTPUT_METADATA_DIR.mkdir(parents=True, exist_ok=True) + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 +MAX_NEW_TOKENS = 150 +MAX_LENGTH = 640 +MASTER_SEED = 20260307 +SET1_SEED = MASTER_SEED +SET2_SEED = MASTER_SEED + 1 +PAIRS = [f"{i:02d}" for i in range(5)] +PREFIXES = ["data", "data_seed_2", "data_seed_3", "data_seed_4", "data_seed_5"] + +SINGLE_IMG1_PATH = DATA_ROOT / "source_img" / "data__p0000_pair00_s0__source.png" +SINGLE_IMG2_PATH = DATA_ROOT / "reference_img" / "left" / "data__p0000_pair00_s0.png" +SINGLE_PREVIEW_OUTPUT = OUTPUT_METADATA_DIR / "llm_prompt_preview_v1.json" + +DUAL_IMAGE_CAPTION_PROMPT = """You are analyzing facial expressions for a controlled editing task. +Given: +- Image 1: source face to be edited +- Image 2: target expression reference + +Output a structured expression editing plan +""" + +EMBEDDER_PREFIX = """Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating. + +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers. + +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:""" + +IMAGE_RE = re.compile(r"^(data(?:_seed_\d+)?)__p(\d+)_pair(\d+)_s(\d+)(?:__(source|target))?\.png$") + +ROOTS = { + "flux_neg": DATA_ROOT / "source_img", + "flux_pos": DATA_ROOT / "target_img", + "iy_pos": DATA_ROOT / "reference_img" / "left", + "iy_neg": DATA_ROOT / "reference_img" / "right", +} + + +print("DEVICE =", DEVICE) +print("MODEL_PATH =", MODEL_PATH) +print("OUTPUT_METADATA_DIR =", OUTPUT_METADATA_DIR) +print("EXPERIMENT_ROOT =", EXPERIMENT_ROOT) + +print("Loading processor...") +processor = AutoProcessor.from_pretrained( + MODEL_PATH, + min_pixels=256 * 28 * 28, + max_pixels=324 * 28 * 28, +) +print("Loading Qwen2.5-VL model...") +model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + MODEL_PATH, + torch_dtype=DTYPE, + device_map=DEVICE, +) +model.requires_grad_(False) +model.eval() +HIDDEN_SIZE = model.config.hidden_size +print("Model loaded. hidden_size =", HIDDEN_SIZE) + + +def load_pil(path: Path) -> Image.Image: + return Image.open(path).convert("RGB") + + +def split_string_for_embedder(text: str) -> List[str]: + text = text.replace("'", '"').replace("“", '"').replace("”", '"') + result = [] + in_quotes = False + temp = "" + for idx, char in enumerate(text): + if char == '"' and idx > 155: + temp += char + if not in_quotes: + result.append(temp) + temp = "" + in_quotes = not in_quotes + continue + if in_quotes: + result.append("“" + char + "”") + else: + temp += char + if temp: + result.append(temp) + return result + + +@torch.inference_mode() +def build_dual_image_prompt( + img1_path: Path, img2_path: Path, user_prompt: str = "", max_new_tokens: int = MAX_NEW_TOKENS +) -> str: + img1 = load_pil(img1_path) + img2 = load_pil(img2_path) + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": DUAL_IMAGE_CAPTION_PROMPT}, + {"type": "text", "text": "\n[Source Image (Structure/Identity)]:"}, + {"type": "image", "image": img1}, + {"type": "text", "text": "\n[Reference Image (Style/Expression)]:"}, + {"type": "image", "image": img2}, + {"type": "text", "text": f"\nUser prompt: {user_prompt}"}, + {"type": "text", "text": "\nPlease generate the structured editing plan now."}, + ], + }] + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True) + image_inputs, _ = process_vision_info(messages) + inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(DEVICE) + generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + generated_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)[0].strip() + return generated_text or user_prompt + + +@torch.inference_mode() +def build_prompt_embedding(img1_path: Path, prompt: str, max_length: int = MAX_LENGTH) -> Tuple[np.ndarray, np.ndarray]: + img1 = load_pil(img1_path) + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": EMBEDDER_PREFIX}, + {"type": "image", "image": img1}, + {"type": "text", "text": prompt}, + ], + }] + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True) + image_inputs, _ = process_vision_info(messages) + inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt") + old_input_ids = inputs.input_ids + token_list = [] + for text_each in split_string_for_embedder(text): + txt_inputs = processor(text=text_each, images=None, videos=None, padding=True, return_tensors="pt") + token_each = txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:, 1:-1] + token_list.append(token_each) + new_txt_ids = torch.cat(token_list, dim=1).to(old_input_ids.device) + idx1 = (old_input_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + input_ids = torch.cat([old_input_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0).unsqueeze(0) + attention_mask = (input_ids > 0).long() + outputs = model( + input_ids=input_ids.to(DEVICE), + attention_mask=attention_mask.to(DEVICE), + pixel_values=inputs.pixel_values.to(DEVICE), + image_grid_thw=inputs.image_grid_thw.to(DEVICE), + output_hidden_states=True, + ) + emb = outputs.hidden_states[-1] + embeds = torch.zeros((max_length, HIDDEN_SIZE), dtype=torch.bfloat16, device=DEVICE) + masks = torch.zeros((max_length,), dtype=torch.long, device=DEVICE) + usable = max(0, emb.shape[1] - 217) + length = min(max_length, usable) + if length > 0: + embeds[:length] = emb[0, 217:217 + length] + masks[:length] = 1 + return embeds.to(torch.float32).cpu().numpy(), masks.cpu().numpy() + + +@torch.inference_mode() +def run_single_preview( + img1_path: Path = SINGLE_IMG1_PATH, + img2_path: Path = SINGLE_IMG2_PATH, + output_json: Path = SINGLE_PREVIEW_OUTPUT, +): + prompt = build_dual_image_prompt(img1_path, img2_path) + embeds, masks = build_prompt_embedding(img1_path, prompt) + payload = { + "img1_path": str(img1_path.resolve()), + "img2_path": str(img2_path.resolve()), + "prompt": prompt, + "embedding_shape": list(embeds.shape), + "mask_sum": int(masks.sum()), + } + output_json.parent.mkdir(parents=True, exist_ok=True) + output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2)) + return payload + + +def candidate_filename(prefix: str, person_id: str, pair_id: str, kind: str) -> str: + stem = f"{prefix}__p{person_id}_pair{pair_id}_s0" + if kind == "flux_neg": + return stem + "__source.png" + if kind == "flux_pos": + return stem + "__target.png" + if kind in {"iy_pos", "iy_neg"}: + return stem + ".png" + raise KeyError(kind) + + +def resolve_image_paths(prefix: str, person_id: str, pair_id: str) -> Dict[str, Path]: + return {kind: ROOTS[kind] / candidate_filename(prefix, person_id, pair_id, kind) for kind in ROOTS} + + +def mirror_path(original_path: Path) -> Path: + rel = original_path.relative_to(DATA_ROOT) + return EXPERIMENT_ROOT / rel + + +def ensure_mirror_file(original_path: Path) -> Path: + mirrored_path = mirror_path(original_path) + mirrored_path.parent.mkdir(parents=True, exist_ok=True) + if mirrored_path.exists(): + return mirrored_path + try: + mirrored_path.symlink_to(original_path) + except OSError: + shutil.copy2(original_path, mirrored_path) + return mirrored_path + + +def person_complete(prefix: str, person_id: str) -> bool: + for pair_id in PAIRS: + paths = resolve_image_paths(prefix, person_id, pair_id) + if not all(path.exists() for path in paths.values()): + return False + return True + + +def scan_candidates() -> Dict[str, List[str]]: + seen = defaultdict(set) + for root in ROOTS.values(): + for path in root.glob("*.png"): + m = IMAGE_RE.match(path.name) + if not m: + continue + prefix, person_id, pair_id, s, _ = m.groups() + if prefix not in PREFIXES or s != "0": + continue + seen[prefix].add(person_id) + candidates = {} + for prefix in PREFIXES: + valid = sorted(pid for pid in seen[prefix] if person_complete(prefix, pid)) + candidates[prefix] = valid + return candidates + + +def build_person_to_prefixes(candidates_by_prefix: Dict[str, List[str]]) -> Dict[str, List[str]]: + person_to_prefixes = defaultdict(list) + for prefix, persons in candidates_by_prefix.items(): + for pid in persons: + person_to_prefixes[pid].append(prefix) + return {pid: sorted(prefixes) for pid, prefixes in person_to_prefixes.items()} + + +def hopcroft_karp(graph: Dict[str, List[str]], left_nodes: List[str], right_nodes: List[str]): + inf = 10**9 + pair_u = {u: None for u in left_nodes} + pair_v = {v: None for v in right_nodes} + dist = {} + + from collections import deque + + def bfs(): + queue = deque() + for u in left_nodes: + if pair_u[u] is None: + dist[u] = 0 + queue.append(u) + else: + dist[u] = inf + found = False + while queue: + u = queue.popleft() + for v in graph[u]: + pu = pair_v[v] + if pu is None: + found = True + elif dist[pu] == inf: + dist[pu] = dist[u] + 1 + queue.append(pu) + return found + + def dfs(u): + for v in graph[u]: + pu = pair_v[v] + if pu is None or (dist[pu] == dist[u] + 1 and dfs(pu)): + pair_u[u] = v + pair_v[v] = u + return True + dist[u] = inf + return False + + matching = 0 + while bfs(): + for u in left_nodes: + if pair_u[u] is None and dfs(u): + matching += 1 + return matching, pair_u + + +def assign_people_to_prefixes(candidates_by_prefix: Dict[str, List[str]], seed: int) -> Dict[str, List[str]]: + person_to_prefixes = build_person_to_prefixes(candidates_by_prefix) + left_nodes = sorted(person_to_prefixes) + if len(left_nodes) != 600: + raise RuntimeError(f"Expected 600 unique person ids, got {len(left_nodes)}") + rng = random.Random(seed) + slot_map = {} + right_nodes = [] + for prefix in PREFIXES: + slot_names = [f"{prefix}#{idx:03d}" for idx in range(120)] + rng.shuffle(slot_names) + slot_map[prefix] = slot_names + right_nodes.extend(slot_names) + graph = {} + shuffled_left = left_nodes[:] + rng.shuffle(shuffled_left) + for pid in shuffled_left: + neighbors = [] + prefixes = person_to_prefixes[pid][:] + rng.shuffle(prefixes) + for prefix in prefixes: + neighbors.extend(slot_map[prefix]) + graph[pid] = neighbors + matching, pair_u = hopcroft_karp(graph, shuffled_left, right_nodes) + if matching != 600: + raise RuntimeError(f"Failed to assign 600 people to prefix slots, only matched {matching}") + result = defaultdict(list) + for pid, slot in pair_u.items(): + prefix = slot.split("#", 1)[0] + result[prefix].append(pid) + for prefix in PREFIXES: + result[prefix] = sorted(result[prefix]) + if len(result[prefix]) != 120: + raise RuntimeError(f"Prefix {prefix} expected 120 people, got {len(result[prefix])}") + return dict(result) + + +def split_prefix_people( + prefix_people: Dict[str, List[str]], seed: int, first_name: str, second_name: str +) -> Dict[str, Dict[str, List[str]]]: + rng = random.Random(seed) + result = {} + for prefix, people in prefix_people.items(): + shuffled = people[:] + rng.shuffle(shuffled) + result[prefix] = { + first_name: sorted(shuffled[:60]), + second_name: sorted(shuffled[60:120]), + } + return result + + +def save_step1x_npz( + target_image_path: Path, embeds: np.ndarray, masks: np.ndarray, force_overwrite: bool = False +) -> Path: + npz_path = target_image_path.with_suffix("") + npz_path = npz_path.parent / f"{npz_path.name}_step1x_te.npz" + if npz_path.exists() and not force_overwrite: + existing = np.load(npz_path) + if "embeds" in existing and "masks" in existing: + return npz_path + np.savez(npz_path, embeds=embeds, masks=masks) + return npz_path + + +def build_record(set_id: str, direction: str, prefix: str, person_id: str, pair_id: str) -> Dict[str, str]: + paths = resolve_image_paths(prefix, person_id, pair_id) + if set_id == "set1" and direction == "neg_to_pos": + img1 = paths["flux_neg"] + img2 = paths["iy_pos"] + ref_image_path = paths["flux_neg"] + target_image_path = paths["flux_pos"] + elif set_id == "set1" and direction == "pos_to_neg": + img1 = paths["flux_pos"] + img2 = paths["iy_neg"] + ref_image_path = paths["flux_pos"] + target_image_path = paths["flux_neg"] + elif set_id == "set2" and direction == "pos_to_neg": + img1 = paths["iy_pos"] + img2 = paths["flux_neg"] + ref_image_path = paths["iy_pos"] + target_image_path = paths["iy_neg"] + elif set_id == "set2" and direction == "neg_to_pos": + img1 = paths["iy_neg"] + img2 = paths["flux_pos"] + ref_image_path = paths["iy_neg"] + target_image_path = paths["iy_pos"] + else: + raise ValueError((set_id, direction)) + return { + "set_id": set_id, + "direction": direction, + "prefix_family": prefix, + "person_id": person_id, + "pair_id": pair_id, + "img1_path": str(img1.resolve()), + "img2_path": str(img2.resolve()), + "ref_image_path": str(ref_image_path.resolve()), + "target_image_path": str(target_image_path.resolve()), + } + + +def expand_split_to_records(split: Dict[str, Dict[str, List[str]]], set_id: str) -> List[Dict[str, str]]: + records = [] + for prefix in PREFIXES: + for direction, people in split[prefix].items(): + for person_id in people: + for pair_id in PAIRS: + records.append(build_record(set_id, direction, prefix, person_id, pair_id)) + return records + + +def write_json(path: Path, payload) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2)) + + +def write_jsonl(path: Path, rows: Iterable[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + +def generate_dataset(records: List[Dict[str, str]], force_overwrite_npz: bool = False) -> Tuple[dict, List[dict]]: + metadata = {} + audit_rows = [] + for record in tqdm(records): + img1 = Path(record["img1_path"]) + img2 = Path(record["img2_path"]) + ref_image_path = Path(record["ref_image_path"]) + target_image_path = Path(record["target_image_path"]) + + mirrored_img1_path = ensure_mirror_file(img1) + mirrored_img2_path = ensure_mirror_file(img2) + mirrored_ref_image_path = ensure_mirror_file(ref_image_path) + mirrored_target_image_path = ensure_mirror_file(target_image_path) + + prompt = build_dual_image_prompt(img1, img2) + embeds, masks = build_prompt_embedding(img1, prompt) + npz_path = save_step1x_npz(mirrored_target_image_path, embeds, masks, force_overwrite=force_overwrite_npz) + key = str(mirrored_target_image_path.resolve()) + if key in metadata: + raise RuntimeError(f"Duplicate target image key: {key}") + metadata[key] = { + "ref_image_path": str(mirrored_ref_image_path.resolve()), + "caption": prompt, + } + audit_row = dict(record) + audit_row["caption"] = prompt + audit_row["original_img1_path"] = str(img1.resolve()) + audit_row["original_img2_path"] = str(img2.resolve()) + audit_row["original_ref_image_path"] = str(ref_image_path.resolve()) + audit_row["original_target_image_path"] = str(target_image_path.resolve()) + audit_row["mirrored_img1_path"] = str(mirrored_img1_path.resolve()) + audit_row["mirrored_img2_path"] = str(mirrored_img2_path.resolve()) + audit_row["mirrored_ref_image_path"] = str(mirrored_ref_image_path.resolve()) + audit_row["mirrored_target_image_path"] = str(mirrored_target_image_path.resolve()) + audit_row["embedding_npz_path"] = str(npz_path.resolve()) + audit_rows.append(audit_row) + return metadata, audit_rows + + +def build_splits(): + candidates_by_prefix = scan_candidates() + for prefix in PREFIXES: + print(prefix, len(candidates_by_prefix[prefix])) + set1_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET1_SEED) + set2_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET2_SEED) + set1_split = split_prefix_people(set1_prefix_people, SET1_SEED + 100, "neg_to_pos", "pos_to_neg") + set2_split = split_prefix_people(set2_prefix_people, SET2_SEED + 100, "pos_to_neg", "neg_to_pos") + return set1_split, set2_split + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--run-set1", action="store_true") + parser.add_argument("--run-set2", action="store_true") + parser.add_argument("--run-all", action="store_true") + parser.add_argument("--run-single-preview", action="store_true") + parser.add_argument("--force-overwrite-npz", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + + if args.run_single_preview: + preview = run_single_preview() + print(json.dumps(preview, ensure_ascii=False, indent=2)) + return + + set1_split, set2_split = build_splits() + set1_records = expand_split_to_records(set1_split, "set1") + set2_records = expand_split_to_records(set2_split, "set2") + assert len(set1_records) == 3000 + assert len(set2_records) == 3000 + all_records = set1_records + set2_records + assert len(all_records) == 6000 + + split_output = OUTPUT_METADATA_DIR / "connector_6000_split_v1.json" + set1_output = OUTPUT_METADATA_DIR / "connector_set1_flux2_3000_v1.json" + set2_output = OUTPUT_METADATA_DIR / "connector_set2_infiniteyou_3000_v1.json" + combined_output = OUTPUT_METADATA_DIR / "connector_6000_train_v1.json" + audit_output = OUTPUT_METADATA_DIR / "connector_6000_audit_v1.jsonl" + + split_payload = { + "master_seed": MASTER_SEED, + "set1_seed": SET1_SEED, + "set2_seed": SET2_SEED, + "prefixes": PREFIXES, + "pairs": PAIRS, + "set1": set1_split, + "set2": set2_split, + } + write_json(split_output, split_payload) + print("Prepared split payload at", split_output) + + run_all = args.run_all or (not args.run_set1 and not args.run_set2) + + if run_all or args.run_set1: + set1_metadata, set1_audit = generate_dataset(set1_records, force_overwrite_npz=args.force_overwrite_npz) + write_json(set1_output, set1_metadata) + print("set1 samples =", len(set1_metadata)) + print("Saved:", set1_output) + else: + set1_metadata, set1_audit = {}, [] + + if run_all or args.run_set2: + set2_metadata, set2_audit = generate_dataset(set2_records, force_overwrite_npz=args.force_overwrite_npz) + write_json(set2_output, set2_metadata) + print("set2 samples =", len(set2_metadata)) + print("Saved:", set2_output) + else: + set2_metadata, set2_audit = {}, [] + + if run_all: + combined_metadata = dict(set1_metadata) + overlap = set(combined_metadata).intersection(set2_metadata) + if overlap: + raise RuntimeError(f"Unexpected overlap between set1 and set2 target keys: {len(overlap)}") + combined_metadata.update(set2_metadata) + write_json(combined_output, combined_metadata) + write_jsonl(audit_output, set1_audit + set2_audit) + print("combined samples =", len(combined_metadata)) + print("Saved:", combined_output) + print("Saved:", audit_output) + + +if __name__ == "__main__": + main() diff --git a/scripts/finetuning_connector_v1.sh b/scripts/finetuning_connector_v1.sh new file mode 100644 index 0000000..5e529f3 --- /dev/null +++ b/scripts/finetuning_connector_v1.sh @@ -0,0 +1,48 @@ +set -euo pipefail + +echo "GPU status:" +nvidia-smi --query-gpu=index,name,memory.used,memory.total,utilization.gpu --format=csv,noheader,nounits + +GPU_ID=$(nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits | sort -t',' -k2 -nr | head -n1 | cut -d',' -f1 | tr -d ' +') +echo "Auto-selected GPU_ID=${GPU_ID}" + +GPU_ID="${GPU_ID:-1}" + +CUDA_VISIBLE_DEVICES="${GPU_ID}" accelerate launch \ + --main_process_port 29502 \ + --mixed_precision bf16 \ + --num_cpu_threads_per_process 1 \ + --num_processes 1 \ + --config_file ./library/accelerate_config.yaml \ + finetuning.py \ + --pretrained_model_name_or_path /scratch3/f007yzf/models/step1x_v11/step1x-edit-i1258.safetensors \ + --qwen2p5vl /scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct \ + --ae /scratch3/f007yzf/models/step1x_v11/vae.safetensors \ + --cache_latents_to_disk \ + --save_model_as safetensors \ + --sdpa \ + --persistent_data_loader_workers \ + --max_data_loader_n_workers 2 \ + --seed 20260307 \ + --gradient_checkpointing \ + --mixed_precision bf16 \ + --save_precision bf16 \ + --network_module library.qwen_connector_module_v1 \ + --network_train_unet_only \ + --optimizer_type adamw8bit \ + --learning_rate 1e-4 \ + --cache_text_encoder_outputs \ + --cache_text_encoder_outputs_to_disk \ + --highvram \ + --max_train_epochs 1000 \ + --save_every_n_epochs 100 \ + --dataset_config library/data_configs/step1x_edit.toml \ + --output_dir /scratch3/f007yzf/repos/Step1X-Edit-clean/output \ + --output_name step1x-edit-qwen-connector-v1 \ + --timestep_sampling shift \ + --discrete_flow_shift 3.1582 \ + --model_prediction_type raw \ + --guidance_scale 1.0 \ + --train_batch_size 1 \ + --gradient_accumulation_steps 4 diff --git a/scripts/llm_prompt.ipynb b/scripts/llm_prompt.ipynb new file mode 100644 index 0000000..e69de29 diff --git a/scripts/llm_prompt_v1.ipynb b/scripts/llm_prompt_v1.ipynb new file mode 100644 index 0000000..d400d43 --- /dev/null +++ b/scripts/llm_prompt_v1.ipynb @@ -0,0 +1,706 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Qwen dual-image direct embedding cache for connector v2\n", + "\n", + "这个 notebook 现在做两件事:\n", + "1. 输入两张图片和固定 prompt,直接生成双图 embedding cache\n", + "2. 按 6000 样本规则写出 metadata 与 `_step1x_te.npz`\n", + "\n", + "当前设计:\n", + "- 不再走 `双图 -> prompt 文本 -> embedding` 两段式\n", + "- 直接走 `双图 + 固定 prompt -> embedding`\n", + "- metadata 里的 `caption` 仅保留固定字符串占位\n", + "- 最终 cache 仍然兼容 Step1X 训练读取\n", + "\n", + "运行前注意:\n", + "- 这个 notebook 默认通过 `CUDA_VISIBLE_DEVICES=1` 使用物理 1 号卡\n", + "- 必须先 Restart Kernel,再从第一格开始运行,否则环境变量不会生效\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA_VISIBLE_DEVICES = 1\n", + "torch.cuda.is_available() = True\n", + "torch.cuda.device_count() = 1\n", + "DEVICE = cuda:0\n", + "MODEL_PATH = /scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct\n", + "OUTPUT_METADATA_DIR = /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata\n" + ] + } + ], + "source": [ + "from __future__ import annotations\n", + "\n", + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", + "\n", + "import json\n", + "import math\n", + "import random\n", + "import re\n", + "from collections import defaultdict\n", + "from pathlib import Path\n", + "from typing import Dict, Iterable, List, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from PIL import Image\n", + "from tqdm.auto import tqdm\n", + "from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration\n", + "from qwen_vl_utils import process_vision_info\n", + "\n", + "REPO_ROOT = Path('/scratch3/f007yzf/repos/Step1X-Edit-clean')\n", + "DATA_ROOT = REPO_ROOT / 'training_data'\n", + "MODEL_PATH = Path('/scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct')\n", + "OUTPUT_METADATA_DIR = DATA_ROOT / 'metadata'\n", + "OUTPUT_METADATA_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", + "DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32\n", + "MAX_LENGTH = 1280\n", + "MASTER_SEED = 20260307\n", + "SET1_SEED = MASTER_SEED\n", + "SET2_SEED = MASTER_SEED + 1\n", + "PAIRS = [f'{i:02d}' for i in range(5)]\n", + "PREFIXES = ['data', 'data_seed_2', 'data_seed_3', 'data_seed_4', 'data_seed_5']\n", + "\n", + "SINGLE_IMG1_PATH = DATA_ROOT / 'source_img' / 'data__p0000_pair00_s0__source.png'\n", + "SINGLE_IMG2_PATH = DATA_ROOT / 'reference_img' / 'left' / 'data__p0000_pair00_s0.png'\n", + "SINGLE_PREVIEW_OUTPUT = OUTPUT_METADATA_DIR / 'llm_dual_embedding_preview_v1.json'\n", + "\n", + "RUN_SINGLE_PREVIEW = False\n", + "RUN_FULL_GENERATION = False\n", + "FORCE_OVERWRITE_NPZ = False\n", + "\n", + "FIXED_EMBED_PROMPT = \"\"\"You are analyzing facial expressions for a controlled editing task.\n", + "Given:\n", + "- Image 1: source face to be edited\n", + "- Image 2: target expression reference\n", + "\n", + "Output a structured expression editing plan.\n", + "Focus only on expression change. Keep identity, hairstyle, clothing, background, lighting, and camera unchanged.\n", + "Describe the target expression change clearly and concretely.\n", + "\"\"\"\n", + "\n", + "FIXED_METADATA_CAPTION = '编辑图片表情'\n", + "\n", + "IMAGE_RE = re.compile(r'^(data(?:_seed_\\d+)?)__p(\\d+)_pair(\\d+)_s(\\d+)(?:__(source|target))?\\.png$')\n", + "\n", + "print('CUDA_VISIBLE_DEVICES =', os.environ.get('CUDA_VISIBLE_DEVICES'))\n", + "print('torch.cuda.is_available() =', torch.cuda.is_available())\n", + "print('torch.cuda.device_count() =', torch.cuda.device_count())\n", + "print('DEVICE =', DEVICE)\n", + "print('MODEL_PATH =', MODEL_PATH)\n", + "print('OUTPUT_METADATA_DIR =', OUTPUT_METADATA_DIR)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading processor...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading Qwen2.5-VL model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "63dd799659e148749a96f3f0b62c2d46", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/5 [00:00 Image.Image:\n", + " return Image.open(path).convert('RGB')\n", + "\n", + "def get_vision_token_ids():\n", + " tokenizer = processor.tokenizer\n", + " v_start_id = tokenizer.convert_tokens_to_ids('<|vision_start|>')\n", + " v_end_id = tokenizer.convert_tokens_to_ids('<|vision_end|>')\n", + " if v_start_id in [None, tokenizer.unk_token_id]:\n", + " v_start_id = 151652\n", + " if v_end_id in [None, tokenizer.unk_token_id]:\n", + " v_end_id = 151653\n", + " return v_start_id, v_end_id\n", + "\n", + "@torch.inference_mode()\n", + "def build_dual_image_prompt(img1_path: Path, img2_path: Path, user_prompt: str = '') -> str:\n", + " return user_prompt or FIXED_METADATA_CAPTION\n", + "\n", + "@torch.inference_mode()\n", + "def build_dual_image_embedding(\n", + " img1_path: Path,\n", + " img2_path: Path,\n", + " fixed_prompt: str = FIXED_EMBED_PROMPT,\n", + " max_length: int = MAX_LENGTH,\n", + " return_debug: bool = False,\n", + ") -> Tuple[np.ndarray, np.ndarray]:\n", + " img1 = load_pil(img1_path)\n", + " img2 = load_pil(img2_path)\n", + " messages = [{\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': '\\n[Source Image]:'},\n", + " {'type': 'image', 'image': img1},\n", + " {'type': 'text', 'text': '\\n[Reference Image]:'},\n", + " {'type': 'image', 'image': img2},\n", + " {'type': 'text', 'text': f'\\nTask:\\n{fixed_prompt}'},\n", + " ],\n", + " }]\n", + " text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n", + " image_inputs, _ = process_vision_info(messages)\n", + " inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors='pt')\n", + " outputs = model(\n", + " input_ids=inputs.input_ids.to(DEVICE),\n", + " attention_mask=inputs.attention_mask.to(DEVICE),\n", + " pixel_values=inputs.pixel_values.to(DEVICE),\n", + " image_grid_thw=inputs.image_grid_thw.to(DEVICE),\n", + " output_hidden_states=True,\n", + " )\n", + " emb = outputs.hidden_states[-1]\n", + " v_start_id, v_end_id = get_vision_token_ids()\n", + " input_ids_0 = inputs.input_ids[0]\n", + " v_ends = (input_ids_0 == v_end_id).nonzero(as_tuple=True)[0]\n", + " if len(v_ends) < 2:\n", + " raise RuntimeError(f'Expected at least two vision_end tokens, got {len(v_ends)}')\n", + " text_start = int(v_ends[-1].item()) + 1\n", + " usable = max(0, emb.shape[1] - text_start)\n", + " length = min(max_length, usable)\n", + "\n", + " embeds = torch.zeros((max_length, HIDDEN_SIZE), dtype=torch.float32, device=DEVICE)\n", + " masks = torch.zeros((max_length,), dtype=torch.long, device=DEVICE)\n", + " if length > 0:\n", + " embeds[:length] = emb[0, text_start:text_start + length].to(torch.float32)\n", + " masks[:length] = 1\n", + "\n", + " if return_debug:\n", + " debug = {\n", + " 'full_seq_len': int(emb.shape[1]),\n", + " 'text_start_index': text_start,\n", + " 'usable_len_before_truncation': int(usable),\n", + " 'final_len': int(length),\n", + " }\n", + " return embeds.cpu().numpy(), masks.cpu().numpy(), debug\n", + " return embeds.cpu().numpy(), masks.cpu().numpy()\n", + "\n", + "@torch.inference_mode()\n", + "def run_single_preview(img1_path: Path = SINGLE_IMG1_PATH, img2_path: Path = SINGLE_IMG2_PATH, output_json: Path = SINGLE_PREVIEW_OUTPUT):\n", + " embeds, masks, debug = build_dual_image_embedding(img1_path, img2_path, return_debug=True)\n", + " payload = {\n", + " 'img1_path': str(img1_path.resolve()),\n", + " 'img2_path': str(img2_path.resolve()),\n", + " 'fixed_prompt': FIXED_EMBED_PROMPT,\n", + " 'fixed_metadata_caption': FIXED_METADATA_CAPTION,\n", + " 'embedding_shape': list(embeds.shape),\n", + " 'mask_sum': int(masks.sum()),\n", + " **debug,\n", + " }\n", + " output_json.parent.mkdir(parents=True, exist_ok=True)\n", + " output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + " return payload\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data 579\n", + "data_seed_2 600\n", + "data_seed_3 571\n", + "data_seed_4 583\n", + "data_seed_5 584\n", + "set1 unique persons = 600\n", + "set2 unique persons = 600\n" + ] + } + ], + "source": [ + "ROOTS = {\n", + " 'flux_neg': DATA_ROOT / 'source_img',\n", + " 'flux_pos': DATA_ROOT / 'target_img',\n", + " 'iy_pos': DATA_ROOT / 'reference_img' / 'left',\n", + " 'iy_neg': DATA_ROOT / 'reference_img' / 'right',\n", + "}\n", + "\n", + "def candidate_filename(prefix: str, person_id: str, pair_id: str, kind: str) -> str:\n", + " stem = f'{prefix}__p{person_id}_pair{pair_id}_s0'\n", + " if kind == 'flux_neg':\n", + " return stem + '__source.png'\n", + " if kind == 'flux_pos':\n", + " return stem + '__target.png'\n", + " if kind in {'iy_pos', 'iy_neg'}:\n", + " return stem + '.png'\n", + " raise KeyError(kind)\n", + "\n", + "def resolve_image_paths(prefix: str, person_id: str, pair_id: str) -> Dict[str, Path]:\n", + " return {kind: ROOTS[kind] / candidate_filename(prefix, person_id, pair_id, kind) for kind in ROOTS}\n", + "\n", + "def person_complete(prefix: str, person_id: str) -> bool:\n", + " for pair_id in PAIRS:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if not all(path.exists() for path in paths.values()):\n", + " return False\n", + " return True\n", + "\n", + "def scan_candidates() -> Dict[str, List[str]]:\n", + " seen = defaultdict(set)\n", + " for root in ROOTS.values():\n", + " for path in root.glob('*.png'):\n", + " m = IMAGE_RE.match(path.name)\n", + " if not m:\n", + " continue\n", + " prefix, person_id, pair_id, s, _ = m.groups()\n", + " if prefix not in PREFIXES or s != '0':\n", + " continue\n", + " seen[prefix].add(person_id)\n", + " candidates = {}\n", + " for prefix in PREFIXES:\n", + " valid = sorted(pid for pid in seen[prefix] if person_complete(prefix, pid))\n", + " candidates[prefix] = valid\n", + " return candidates\n", + "\n", + "def build_person_to_prefixes(candidates_by_prefix: Dict[str, List[str]]) -> Dict[str, List[str]]:\n", + " person_to_prefixes = defaultdict(list)\n", + " for prefix, persons in candidates_by_prefix.items():\n", + " for pid in persons:\n", + " person_to_prefixes[pid].append(prefix)\n", + " return {pid: sorted(prefixes) for pid, prefixes in person_to_prefixes.items()}\n", + "\n", + "def hopcroft_karp(graph: Dict[str, List[str]], left_nodes: List[str], right_nodes: List[str]):\n", + " INF = 10 ** 9\n", + " pair_u = {u: None for u in left_nodes}\n", + " pair_v = {v: None for v in right_nodes}\n", + " dist = {}\n", + "\n", + " from collections import deque\n", + "\n", + " def bfs():\n", + " queue = deque()\n", + " for u in left_nodes:\n", + " if pair_u[u] is None:\n", + " dist[u] = 0\n", + " queue.append(u)\n", + " else:\n", + " dist[u] = INF\n", + " found = False\n", + " while queue:\n", + " u = queue.popleft()\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None:\n", + " found = True\n", + " elif dist[pu] == INF:\n", + " dist[pu] = dist[u] + 1\n", + " queue.append(pu)\n", + " return found\n", + "\n", + " def dfs(u):\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None or (dist[pu] == dist[u] + 1 and dfs(pu)):\n", + " pair_u[u] = v\n", + " pair_v[v] = u\n", + " return True\n", + " dist[u] = INF\n", + " return False\n", + "\n", + " matching = 0\n", + " while bfs():\n", + " for u in left_nodes:\n", + " if pair_u[u] is None and dfs(u):\n", + " matching += 1\n", + " return matching, pair_u\n", + "\n", + "def assign_people_to_prefixes(candidates_by_prefix: Dict[str, List[str]], seed: int) -> Dict[str, List[str]]:\n", + " person_to_prefixes = build_person_to_prefixes(candidates_by_prefix)\n", + " left_nodes = sorted(person_to_prefixes)\n", + " if len(left_nodes) != 600:\n", + " raise RuntimeError(f'Expected 600 unique person ids, got {len(left_nodes)}')\n", + " rng = random.Random(seed)\n", + " slot_map = {}\n", + " right_nodes = []\n", + " for prefix in PREFIXES:\n", + " slot_names = [f'{prefix}#{idx:03d}' for idx in range(120)]\n", + " rng.shuffle(slot_names)\n", + " slot_map[prefix] = slot_names\n", + " right_nodes.extend(slot_names)\n", + " graph = {}\n", + " shuffled_left = left_nodes[:]\n", + " rng.shuffle(shuffled_left)\n", + " for pid in shuffled_left:\n", + " neighbors = []\n", + " prefixes = person_to_prefixes[pid][:]\n", + " rng.shuffle(prefixes)\n", + " for prefix in prefixes:\n", + " neighbors.extend(slot_map[prefix])\n", + " graph[pid] = neighbors\n", + " matching, pair_u = hopcroft_karp(graph, shuffled_left, right_nodes)\n", + " if matching != 600:\n", + " raise RuntimeError(f'Failed to assign 600 people to prefix slots, only matched {matching}')\n", + " result = defaultdict(list)\n", + " for pid, slot in pair_u.items():\n", + " prefix = slot.split('#', 1)[0]\n", + " result[prefix].append(pid)\n", + " for prefix in PREFIXES:\n", + " result[prefix] = sorted(result[prefix])\n", + " if len(result[prefix]) != 120:\n", + " raise RuntimeError(f'Prefix {prefix} expected 120 people, got {len(result[prefix])}')\n", + " return dict(result)\n", + "\n", + "def split_prefix_people(prefix_people: Dict[str, List[str]], seed: int, first_name: str, second_name: str) -> Dict[str, Dict[str, List[str]]]:\n", + " rng = random.Random(seed)\n", + " result = {}\n", + " for prefix, people in prefix_people.items():\n", + " shuffled = people[:]\n", + " rng.shuffle(shuffled)\n", + " result[prefix] = {\n", + " first_name: sorted(shuffled[:60]),\n", + " second_name: sorted(shuffled[60:120]),\n", + " }\n", + " return result\n", + "\n", + "candidates_by_prefix = scan_candidates()\n", + "for prefix in PREFIXES:\n", + " print(prefix, len(candidates_by_prefix[prefix]))\n", + "set1_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET1_SEED)\n", + "set2_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET2_SEED)\n", + "set1_split = split_prefix_people(set1_prefix_people, SET1_SEED + 100, 'neg_to_pos', 'pos_to_neg')\n", + "set2_split = split_prefix_people(set2_prefix_people, SET2_SEED + 100, 'pos_to_neg', 'neg_to_pos')\n", + "print('set1 unique persons =', len(set().union(*[set(v) for v in set1_prefix_people.values()])))\n", + "print('set2 unique persons =', len(set().union(*[set(v) for v in set2_prefix_people.values()])))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prepared split payload at /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_6000_split_v1.json\n" + ] + } + ], + "source": [ + "def save_step1x_npz(target_image_path: Path, embeds: np.ndarray, masks: np.ndarray, force_overwrite: bool = False) -> Path:\n", + " npz_path = target_image_path.with_suffix('')\n", + " npz_path = npz_path.parent / f'{npz_path.name}_step1x_te.npz'\n", + " if npz_path.exists() and not force_overwrite:\n", + " existing = np.load(npz_path)\n", + " if 'embeds' in existing and 'masks' in existing:\n", + " return npz_path\n", + " np.savez(npz_path, embeds=embeds, masks=masks)\n", + " return npz_path\n", + "\n", + "def build_record(set_id: str, direction: str, prefix: str, person_id: str, pair_id: str) -> Dict[str, str]:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if set_id == 'set1' and direction == 'neg_to_pos':\n", + " img1 = paths['flux_neg']\n", + " img2 = paths['iy_pos']\n", + " ref_image_path = paths['flux_neg']\n", + " target_image_path = paths['flux_pos']\n", + " elif set_id == 'set1' and direction == 'pos_to_neg':\n", + " img1 = paths['flux_pos']\n", + " img2 = paths['iy_neg']\n", + " ref_image_path = paths['flux_pos']\n", + " target_image_path = paths['flux_neg']\n", + " elif set_id == 'set2' and direction == 'pos_to_neg':\n", + " img1 = paths['iy_pos']\n", + " img2 = paths['flux_neg']\n", + " ref_image_path = paths['iy_pos']\n", + " target_image_path = paths['iy_neg']\n", + " elif set_id == 'set2' and direction == 'neg_to_pos':\n", + " img1 = paths['iy_neg']\n", + " img2 = paths['flux_pos']\n", + " ref_image_path = paths['iy_neg']\n", + " target_image_path = paths['iy_pos']\n", + " else:\n", + " raise ValueError((set_id, direction))\n", + " return {\n", + " 'set_id': set_id,\n", + " 'direction': direction,\n", + " 'prefix_family': prefix,\n", + " 'person_id': person_id,\n", + " 'pair_id': pair_id,\n", + " 'img1_path': str(img1.resolve()),\n", + " 'img2_path': str(img2.resolve()),\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'target_image_path': str(target_image_path.resolve()),\n", + " }\n", + "\n", + "def expand_split_to_records(split: Dict[str, Dict[str, List[str]]], set_id: str) -> List[Dict[str, str]]:\n", + " records = []\n", + " for prefix in PREFIXES:\n", + " for direction, people in split[prefix].items():\n", + " for person_id in people:\n", + " for pair_id in PAIRS:\n", + " records.append(build_record(set_id, direction, prefix, person_id, pair_id))\n", + " return records\n", + "\n", + "def write_json(path: Path, payload) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + "\n", + "def write_jsonl(path: Path, rows: Iterable[dict]) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " with path.open('w', encoding='utf-8') as f:\n", + " for row in rows:\n", + " f.write(json.dumps(row, ensure_ascii=False) + '\\n')\n", + "\n", + "def generate_dataset(records: List[Dict[str, str]], force_overwrite_npz: bool = False) -> Tuple[dict, List[dict]]:\n", + " metadata = {}\n", + " audit_rows = []\n", + " for record in tqdm(records):\n", + " img1 = Path(record['img1_path'])\n", + " img2 = Path(record['img2_path'])\n", + " target_image_path = Path(record['target_image_path'])\n", + " ref_image_path = Path(record['ref_image_path'])\n", + " embeds, masks, debug = build_dual_image_embedding(img1, img2, return_debug=True)\n", + " npz_path = save_step1x_npz(target_image_path, embeds, masks, force_overwrite=force_overwrite_npz)\n", + " key = str(target_image_path.resolve())\n", + " if key in metadata:\n", + " raise RuntimeError(f'Duplicate target image key: {key}')\n", + " metadata[key] = {\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'caption': FIXED_METADATA_CAPTION,\n", + " }\n", + " audit_row = dict(record)\n", + " audit_row['caption'] = FIXED_METADATA_CAPTION\n", + " audit_row['fixed_prompt'] = FIXED_EMBED_PROMPT\n", + " audit_row['embedding_npz_path'] = str(npz_path.resolve())\n", + " audit_row.update(debug)\n", + " audit_rows.append(audit_row)\n", + " return metadata, audit_rows\n", + "\n", + "set1_records = expand_split_to_records(set1_split, 'set1')\n", + "set2_records = expand_split_to_records(set2_split, 'set2')\n", + "assert len(set1_records) == 3000\n", + "assert len(set2_records) == 3000\n", + "all_records = set1_records + set2_records\n", + "assert len(all_records) == 6000\n", + "\n", + "SPLIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_split_v1.json'\n", + "SET1_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set1_flux2_3000_v1.json'\n", + "SET2_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set2_infiniteyou_3000_v1.json'\n", + "COMBINED_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_train_v1.json'\n", + "AUDIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_audit_v1.jsonl'\n", + "\n", + "split_payload = {\n", + " 'master_seed': MASTER_SEED,\n", + " 'set1_seed': SET1_SEED,\n", + " 'set2_seed': SET2_SEED,\n", + " 'prefixes': PREFIXES,\n", + " 'pairs': PAIRS,\n", + " 'fixed_embed_prompt': FIXED_EMBED_PROMPT,\n", + " 'fixed_metadata_caption': FIXED_METADATA_CAPTION,\n", + " 'set1': set1_split,\n", + " 'set2': set2_split,\n", + "}\n", + "write_json(SPLIT_OUTPUT, split_payload)\n", + "print('Prepared split payload at', SPLIT_OUTPUT)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"img1_path\": \"/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/source_img/data__p0000_pair00_s0__source.png\",\n", + " \"img2_path\": \"/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/reference_img/left/data__p0000_pair00_s0.png\",\n", + " \"prompt\": \"Editing instruction:\\n\\n- Raise the left corner of the mouth upwards and outward.\\n- Raise the right corner of the mouth upwards and outward.\\n- Lower the left side of the eyebrows slightly.\\n- Lower the right side of the eyebrows slightly.\\n- Tighten the lips into a smile.\\n- Relax the forehead muscles to reduce any frown lines.\",\n", + " \"embedding_shape\": [\n", + " 640,\n", + " 3584\n", + " ],\n", + " \"mask_sum\": 381\n", + "}\n", + "Set RUN_FULL_GENERATION = True to generate all 6000 samples.\n" + ] + } + ], + "source": [ + "RUN_SINGLE_PREVIEW = False\n", + "if RUN_SINGLE_PREVIEW:\n", + " preview = run_single_preview()\n", + " print(json.dumps(preview, ensure_ascii=False, indent=2))\n", + "else:\n", + " print('Set RUN_SINGLE_PREVIEW = True to generate one dual-image embedding preview.')\n", + "\n", + "if RUN_FULL_GENERATION:\n", + " set1_metadata, set1_audit = generate_dataset(set1_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " set2_metadata, set2_audit = generate_dataset(set2_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " combined_metadata = dict(set1_metadata)\n", + " overlap = set(combined_metadata).intersection(set2_metadata)\n", + " if overlap:\n", + " raise RuntimeError(f'Unexpected overlap between set1 and set2 target keys: {len(overlap)}')\n", + " combined_metadata.update(set2_metadata)\n", + " write_json(SET1_OUTPUT, set1_metadata)\n", + " write_json(SET2_OUTPUT, set2_metadata)\n", + " write_json(COMBINED_OUTPUT, combined_metadata)\n", + " write_jsonl(AUDIT_OUTPUT, set1_audit + set2_audit)\n", + " print('set1 samples =', len(set1_metadata))\n", + " print('set2 samples =', len(set2_metadata))\n", + " print('combined samples =', len(combined_metadata))\n", + " print('Saved:', SET1_OUTPUT)\n", + " print('Saved:', SET2_OUTPUT)\n", + " print('Saved:', COMBINED_OUTPUT)\n", + " print('Saved:', AUDIT_OUTPUT)\n", + "else:\n", + " print('Set RUN_FULL_GENERATION = True to generate all 6000 samples.')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "89ce9ff1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_set1_flux2_3000_v1.json')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SET1_OUTPUT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6586cdaf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_6000_audit_v1.jsonl')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3a2f3f3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "step1x_v11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/llm_prompt_v1_orig.ipynb b/scripts/llm_prompt_v1_orig.ipynb new file mode 100644 index 0000000..ff0b7a6 --- /dev/null +++ b/scripts/llm_prompt_v1_orig.ipynb @@ -0,0 +1,555 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Qwen dual-image prompt + embedding generation for connector v1\n", + "\n", + "这个 notebook 做三件事:\n", + "1. 输入两张图片,生成编辑 prompt\n", + "2. 用 `img1 + prompt` 生成当前 connector 兼容的 embedding/mask\n", + "3. 按 6000 样本规则写出 metadata 与 `_step1x_te.npz`\n", + "\n", + "实现约束:\n", + "- 不替换当前 llm encoder\n", + "- 双图只用于生成 prompt\n", + "- 最终 embedding 只使用 `img1 + prompt`\n", + "- 每个 set 独立采样\n", + "- 每个 set 内 `data/data_seed_2/3/4/5` 各 120 人\n", + "- 每个 prefix 的 120 人拆成 60 人 `neg_to_pos` 和 60 人 `pos_to_neg`,人不重叠\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "import json\n", + "import math\n", + "import os\n", + "import random\n", + "import re\n", + "from collections import defaultdict\n", + "from pathlib import Path\n", + "from typing import Dict, Iterable, List, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from PIL import Image\n", + "from tqdm.auto import tqdm\n", + "from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration\n", + "from qwen_vl_utils import process_vision_info\n", + "\n", + "REPO_ROOT = Path('/scratch3/f007yzf/repos/Step1X-Edit-clean')\n", + "DATA_ROOT = REPO_ROOT / 'training_data'\n", + "MODEL_PATH = Path('/scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct')\n", + "OUTPUT_METADATA_DIR = DATA_ROOT / 'metadata'\n", + "OUTPUT_METADATA_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32\n", + "MAX_NEW_TOKENS = 150\n", + "MAX_LENGTH = 640\n", + "MASTER_SEED = 20260307\n", + "SET1_SEED = MASTER_SEED\n", + "SET2_SEED = MASTER_SEED + 1\n", + "PAIRS = [f'{i:02d}' for i in range(5)]\n", + "PREFIXES = ['data', 'data_seed_2', 'data_seed_3', 'data_seed_4', 'data_seed_5']\n", + "\n", + "SINGLE_IMG1_PATH = DATA_ROOT / 'source_img' / 'data__p0000_pair00_s0__source.png'\n", + "SINGLE_IMG2_PATH = DATA_ROOT / 'reference_img' / 'left' / 'data__p0000_pair00_s0.png'\n", + "SINGLE_PREVIEW_OUTPUT = OUTPUT_METADATA_DIR / 'llm_prompt_preview_v1.json'\n", + "\n", + "RUN_SINGLE_PREVIEW = False\n", + "RUN_FULL_GENERATION = False\n", + "FORCE_OVERWRITE_NPZ = False\n", + "\n", + "DUAL_IMAGE_CAPTION_PROMPT = '''You are analyzing facial expressions for a controlled editing task.\n", + "Given:\n", + "- Image 1: source face to be edited\n", + "- Image 2: target expression reference\n", + "\n", + "Output a structured expression editing plan\n", + "'''\n", + "\n", + "EMBEDDER_PREFIX = '''Given a user prompt, generate an \"Enhanced prompt\" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:\n", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.\n", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n", + "\n", + "Here are examples of how to transform or refine prompts:\n", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n", + "\n", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:\n", + "User Prompt:'''\n", + "\n", + "IMAGE_RE = re.compile(r'^(data(?:_seed_\\d+)?)__p(\\d+)_pair(\\d+)_s(\\d+)(?:__(source|target))?\\.png$')\n", + "\n", + "print('DEVICE =', DEVICE)\n", + "print('MODEL_PATH =', MODEL_PATH)\n", + "print('OUTPUT_METADATA_DIR =', OUTPUT_METADATA_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('Loading processor...')\n", + "processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28)\n", + "print('Loading Qwen2.5-VL model...')\n", + "model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n", + " MODEL_PATH,\n", + " torch_dtype=DTYPE,\n", + " device_map=DEVICE,\n", + ")\n", + "model.requires_grad_(False)\n", + "model.eval()\n", + "HIDDEN_SIZE = model.config.hidden_size\n", + "print('Model loaded. hidden_size =', HIDDEN_SIZE)\n", + "\n", + "def load_pil(path: Path) -> Image.Image:\n", + " return Image.open(path).convert('RGB')\n", + "\n", + "def split_string_for_embedder(text: str) -> List[str]:\n", + " text = text.replace(\"'\", '\"').replace('“', '\"').replace('”', '\"')\n", + " result = []\n", + " in_quotes = False\n", + " temp = ''\n", + " for idx, char in enumerate(text):\n", + " if char == '\"' and idx > 155:\n", + " temp += char\n", + " if not in_quotes:\n", + " result.append(temp)\n", + " temp = ''\n", + " in_quotes = not in_quotes\n", + " continue\n", + " if in_quotes:\n", + " result.append('“' + char + '”')\n", + " else:\n", + " temp += char\n", + " if temp:\n", + " result.append(temp)\n", + " return result\n", + "\n", + "@torch.inference_mode()\n", + "def build_dual_image_prompt(img1_path: Path, img2_path: Path, user_prompt: str = '', max_new_tokens: int = MAX_NEW_TOKENS) -> str:\n", + " img1 = load_pil(img1_path)\n", + " img2 = load_pil(img2_path)\n", + " messages = [{\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': DUAL_IMAGE_CAPTION_PROMPT},\n", + " {'type': 'text', 'text': '\\n[Source Image (Structure/Identity)]:'},\n", + " {'type': 'image', 'image': img1},\n", + " {'type': 'text', 'text': '\\n[Reference Image (Style/Expression)]:'},\n", + " {'type': 'image', 'image': img2},\n", + " {'type': 'text', 'text': f'\\nUser prompt: {user_prompt}'},\n", + " {'type': 'text', 'text': '\\nPlease generate the structured editing plan now.'},\n", + " ],\n", + " }]\n", + " text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n", + " image_inputs, _ = process_vision_info(messages)\n", + " inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors='pt').to(DEVICE)\n", + " generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)\n", + " generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]\n", + " generated_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)[0].strip()\n", + " return generated_text or user_prompt\n", + "\n", + "@torch.inference_mode()\n", + "def build_prompt_embedding(img1_path: Path, prompt: str, max_length: int = MAX_LENGTH) -> Tuple[np.ndarray, np.ndarray]:\n", + " img1 = load_pil(img1_path)\n", + " messages = [{\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': EMBEDDER_PREFIX},\n", + " {'type': 'image', 'image': img1},\n", + " {'type': 'text', 'text': prompt},\n", + " ],\n", + " }]\n", + " text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n", + " image_inputs, _ = process_vision_info(messages)\n", + " inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors='pt')\n", + " old_input_ids = inputs.input_ids\n", + " token_list = []\n", + " for text_each in split_string_for_embedder(text):\n", + " txt_inputs = processor(text=text_each, images=None, videos=None, padding=True, return_tensors='pt')\n", + " token_each = txt_inputs.input_ids\n", + " if token_each[0][0] == 2073 and token_each[0][-1] == 854:\n", + " token_each = token_each[:, 1:-1]\n", + " token_list.append(token_each)\n", + " new_txt_ids = torch.cat(token_list, dim=1).to(old_input_ids.device)\n", + " idx1 = (old_input_ids == 151653).nonzero(as_tuple=True)[1][0]\n", + " idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]\n", + " input_ids = torch.cat([old_input_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0).unsqueeze(0)\n", + " attention_mask = (input_ids > 0).long()\n", + " outputs = model(\n", + " input_ids=input_ids.to(DEVICE),\n", + " attention_mask=attention_mask.to(DEVICE),\n", + " pixel_values=inputs.pixel_values.to(DEVICE),\n", + " image_grid_thw=inputs.image_grid_thw.to(DEVICE),\n", + " output_hidden_states=True,\n", + " )\n", + " emb = outputs.hidden_states[-1]\n", + " embeds = torch.zeros((max_length, HIDDEN_SIZE), dtype=torch.bfloat16, device=DEVICE)\n", + " masks = torch.zeros((max_length,), dtype=torch.long, device=DEVICE)\n", + " usable = max(0, emb.shape[1] - 217)\n", + " length = min(max_length, usable)\n", + " if length > 0:\n", + " embeds[:length] = emb[0, 217:217 + length]\n", + " masks[:length] = 1\n", + " return embeds.cpu().numpy(), masks.cpu().numpy()\n", + "\n", + "@torch.inference_mode()\n", + "def run_single_preview(img1_path: Path = SINGLE_IMG1_PATH, img2_path: Path = SINGLE_IMG2_PATH, output_json: Path = SINGLE_PREVIEW_OUTPUT):\n", + " prompt = build_dual_image_prompt(img1_path, img2_path)\n", + " embeds, masks = build_prompt_embedding(img1_path, prompt)\n", + " payload = {\n", + " 'img1_path': str(img1_path.resolve()),\n", + " 'img2_path': str(img2_path.resolve()),\n", + " 'prompt': prompt,\n", + " 'embedding_shape': list(embeds.shape),\n", + " 'mask_sum': int(masks.sum()),\n", + " }\n", + " output_json.parent.mkdir(parents=True, exist_ok=True)\n", + " output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + " return payload\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ROOTS = {\n", + " 'flux_neg': DATA_ROOT / 'source_img',\n", + " 'flux_pos': DATA_ROOT / 'target_img',\n", + " 'iy_pos': DATA_ROOT / 'reference_img' / 'left',\n", + " 'iy_neg': DATA_ROOT / 'reference_img' / 'right',\n", + "}\n", + "\n", + "def candidate_filename(prefix: str, person_id: str, pair_id: str, kind: str) -> str:\n", + " stem = f'{prefix}__p{person_id}_pair{pair_id}_s0'\n", + " if kind == 'flux_neg':\n", + " return stem + '__source.png'\n", + " if kind == 'flux_pos':\n", + " return stem + '__target.png'\n", + " if kind in {'iy_pos', 'iy_neg'}:\n", + " return stem + '.png'\n", + " raise KeyError(kind)\n", + "\n", + "def resolve_image_paths(prefix: str, person_id: str, pair_id: str) -> Dict[str, Path]:\n", + " return {kind: ROOTS[kind] / candidate_filename(prefix, person_id, pair_id, kind) for kind in ROOTS}\n", + "\n", + "def person_complete(prefix: str, person_id: str) -> bool:\n", + " for pair_id in PAIRS:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if not all(path.exists() for path in paths.values()):\n", + " return False\n", + " return True\n", + "\n", + "def scan_candidates() -> Dict[str, List[str]]:\n", + " seen = defaultdict(set)\n", + " for root in ROOTS.values():\n", + " for path in root.glob('*.png'):\n", + " m = IMAGE_RE.match(path.name)\n", + " if not m:\n", + " continue\n", + " prefix, person_id, pair_id, s, _ = m.groups()\n", + " if prefix not in PREFIXES or s != '0':\n", + " continue\n", + " seen[prefix].add(person_id)\n", + " candidates = {}\n", + " for prefix in PREFIXES:\n", + " valid = sorted(pid for pid in seen[prefix] if person_complete(prefix, pid))\n", + " candidates[prefix] = valid\n", + " return candidates\n", + "\n", + "def build_person_to_prefixes(candidates_by_prefix: Dict[str, List[str]]) -> Dict[str, List[str]]:\n", + " person_to_prefixes = defaultdict(list)\n", + " for prefix, persons in candidates_by_prefix.items():\n", + " for pid in persons:\n", + " person_to_prefixes[pid].append(prefix)\n", + " return {pid: sorted(prefixes) for pid, prefixes in person_to_prefixes.items()}\n", + "\n", + "def hopcroft_karp(graph: Dict[str, List[str]], left_nodes: List[str], right_nodes: List[str]):\n", + " INF = 10 ** 9\n", + " pair_u = {u: None for u in left_nodes}\n", + " pair_v = {v: None for v in right_nodes}\n", + " dist = {}\n", + "\n", + " from collections import deque\n", + "\n", + " def bfs():\n", + " queue = deque()\n", + " for u in left_nodes:\n", + " if pair_u[u] is None:\n", + " dist[u] = 0\n", + " queue.append(u)\n", + " else:\n", + " dist[u] = INF\n", + " found = False\n", + " while queue:\n", + " u = queue.popleft()\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None:\n", + " found = True\n", + " elif dist[pu] == INF:\n", + " dist[pu] = dist[u] + 1\n", + " queue.append(pu)\n", + " return found\n", + "\n", + " def dfs(u):\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None or (dist[pu] == dist[u] + 1 and dfs(pu)):\n", + " pair_u[u] = v\n", + " pair_v[v] = u\n", + " return True\n", + " dist[u] = INF\n", + " return False\n", + "\n", + " matching = 0\n", + " while bfs():\n", + " for u in left_nodes:\n", + " if pair_u[u] is None and dfs(u):\n", + " matching += 1\n", + " return matching, pair_u\n", + "\n", + "def assign_people_to_prefixes(candidates_by_prefix: Dict[str, List[str]], seed: int) -> Dict[str, List[str]]:\n", + " person_to_prefixes = build_person_to_prefixes(candidates_by_prefix)\n", + " left_nodes = sorted(person_to_prefixes)\n", + " if len(left_nodes) != 600:\n", + " raise RuntimeError(f'Expected 600 unique person ids, got {len(left_nodes)}')\n", + " rng = random.Random(seed)\n", + " slot_map = {}\n", + " right_nodes = []\n", + " for prefix in PREFIXES:\n", + " slot_names = [f'{prefix}#{idx:03d}' for idx in range(120)]\n", + " rng.shuffle(slot_names)\n", + " slot_map[prefix] = slot_names\n", + " right_nodes.extend(slot_names)\n", + " graph = {}\n", + " shuffled_left = left_nodes[:]\n", + " rng.shuffle(shuffled_left)\n", + " for pid in shuffled_left:\n", + " neighbors = []\n", + " prefixes = person_to_prefixes[pid][:]\n", + " rng.shuffle(prefixes)\n", + " for prefix in prefixes:\n", + " neighbors.extend(slot_map[prefix])\n", + " graph[pid] = neighbors\n", + " matching, pair_u = hopcroft_karp(graph, shuffled_left, right_nodes)\n", + " if matching != 600:\n", + " raise RuntimeError(f'Failed to assign 600 people to prefix slots, only matched {matching}')\n", + " result = defaultdict(list)\n", + " for pid, slot in pair_u.items():\n", + " prefix = slot.split('#', 1)[0]\n", + " result[prefix].append(pid)\n", + " for prefix in PREFIXES:\n", + " result[prefix] = sorted(result[prefix])\n", + " if len(result[prefix]) != 120:\n", + " raise RuntimeError(f'Prefix {prefix} expected 120 people, got {len(result[prefix])}')\n", + " return dict(result)\n", + "\n", + "def split_prefix_people(prefix_people: Dict[str, List[str]], seed: int, first_name: str, second_name: str) -> Dict[str, Dict[str, List[str]]]:\n", + " rng = random.Random(seed)\n", + " result = {}\n", + " for prefix, people in prefix_people.items():\n", + " shuffled = people[:]\n", + " rng.shuffle(shuffled)\n", + " result[prefix] = {\n", + " first_name: sorted(shuffled[:60]),\n", + " second_name: sorted(shuffled[60:120]),\n", + " }\n", + " return result\n", + "\n", + "candidates_by_prefix = scan_candidates()\n", + "for prefix in PREFIXES:\n", + " print(prefix, len(candidates_by_prefix[prefix]))\n", + "set1_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET1_SEED)\n", + "set2_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET2_SEED)\n", + "set1_split = split_prefix_people(set1_prefix_people, SET1_SEED + 100, 'neg_to_pos', 'pos_to_neg')\n", + "set2_split = split_prefix_people(set2_prefix_people, SET2_SEED + 100, 'pos_to_neg', 'neg_to_pos')\n", + "print('set1 unique persons =', len(set().union(*[set(v) for v in set1_prefix_people.values()])))\n", + "print('set2 unique persons =', len(set().union(*[set(v) for v in set2_prefix_people.values()])))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def save_step1x_npz(target_image_path: Path, embeds: np.ndarray, masks: np.ndarray, force_overwrite: bool = False) -> Path:\n", + " npz_path = target_image_path.with_suffix('')\n", + " npz_path = npz_path.parent / f'{npz_path.name}_step1x_te.npz'\n", + " if npz_path.exists() and not force_overwrite:\n", + " existing = np.load(npz_path)\n", + " if 'embeds' in existing and 'masks' in existing:\n", + " return npz_path\n", + " np.savez(npz_path, embeds=embeds, masks=masks)\n", + " return npz_path\n", + "\n", + "def build_record(set_id: str, direction: str, prefix: str, person_id: str, pair_id: str) -> Dict[str, str]:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if set_id == 'set1' and direction == 'neg_to_pos':\n", + " img1 = paths['flux_neg']\n", + " img2 = paths['iy_pos']\n", + " ref_image_path = paths['flux_neg']\n", + " target_image_path = paths['flux_pos']\n", + " elif set_id == 'set1' and direction == 'pos_to_neg':\n", + " img1 = paths['flux_pos']\n", + " img2 = paths['iy_neg']\n", + " ref_image_path = paths['flux_pos']\n", + " target_image_path = paths['flux_neg']\n", + " elif set_id == 'set2' and direction == 'pos_to_neg':\n", + " img1 = paths['iy_pos']\n", + " img2 = paths['flux_neg']\n", + " ref_image_path = paths['iy_pos']\n", + " target_image_path = paths['iy_neg']\n", + " elif set_id == 'set2' and direction == 'neg_to_pos':\n", + " img1 = paths['iy_neg']\n", + " img2 = paths['flux_pos']\n", + " ref_image_path = paths['iy_neg']\n", + " target_image_path = paths['iy_pos']\n", + " else:\n", + " raise ValueError((set_id, direction))\n", + " return {\n", + " 'set_id': set_id,\n", + " 'direction': direction,\n", + " 'prefix_family': prefix,\n", + " 'person_id': person_id,\n", + " 'pair_id': pair_id,\n", + " 'img1_path': str(img1.resolve()),\n", + " 'img2_path': str(img2.resolve()),\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'target_image_path': str(target_image_path.resolve()),\n", + " }\n", + "\n", + "def expand_split_to_records(split: Dict[str, Dict[str, List[str]]], set_id: str) -> List[Dict[str, str]]:\n", + " records = []\n", + " for prefix in PREFIXES:\n", + " for direction, people in split[prefix].items():\n", + " for person_id in people:\n", + " for pair_id in PAIRS:\n", + " records.append(build_record(set_id, direction, prefix, person_id, pair_id))\n", + " return records\n", + "\n", + "def write_json(path: Path, payload) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + "\n", + "def write_jsonl(path: Path, rows: Iterable[dict]) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " with path.open('w', encoding='utf-8') as f:\n", + " for row in rows:\n", + " f.write(json.dumps(row, ensure_ascii=False) + '\\n')\n", + "\n", + "def generate_dataset(records: List[Dict[str, str]], force_overwrite_npz: bool = False) -> Tuple[dict, List[dict]]:\n", + " metadata = {}\n", + " audit_rows = []\n", + " for record in tqdm(records):\n", + " img1 = Path(record['img1_path'])\n", + " img2 = Path(record['img2_path'])\n", + " target_image_path = Path(record['target_image_path'])\n", + " ref_image_path = Path(record['ref_image_path'])\n", + " prompt = build_dual_image_prompt(img1, img2)\n", + " embeds, masks = build_prompt_embedding(img1, prompt)\n", + " npz_path = save_step1x_npz(target_image_path, embeds, masks, force_overwrite=force_overwrite_npz)\n", + " key = str(target_image_path.resolve())\n", + " if key in metadata:\n", + " raise RuntimeError(f'Duplicate target image key: {key}')\n", + " metadata[key] = {\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'caption': prompt,\n", + " }\n", + " audit_row = dict(record)\n", + " audit_row['caption'] = prompt\n", + " audit_row['embedding_npz_path'] = str(npz_path.resolve())\n", + " audit_rows.append(audit_row)\n", + " return metadata, audit_rows\n", + "\n", + "set1_records = expand_split_to_records(set1_split, 'set1')\n", + "set2_records = expand_split_to_records(set2_split, 'set2')\n", + "assert len(set1_records) == 3000\n", + "assert len(set2_records) == 3000\n", + "all_records = set1_records + set2_records\n", + "assert len(all_records) == 6000\n", + "\n", + "SPLIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_split_v1.json'\n", + "SET1_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set1_flux2_3000_v1.json'\n", + "SET2_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set2_infiniteyou_3000_v1.json'\n", + "COMBINED_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_train_v1.json'\n", + "AUDIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_audit_v1.jsonl'\n", + "\n", + "split_payload = {\n", + " 'master_seed': MASTER_SEED,\n", + " 'set1_seed': SET1_SEED,\n", + " 'set2_seed': SET2_SEED,\n", + " 'prefixes': PREFIXES,\n", + " 'pairs': PAIRS,\n", + " 'set1': set1_split,\n", + " 'set2': set2_split,\n", + "}\n", + "write_json(SPLIT_OUTPUT, split_payload)\n", + "print('Prepared split payload at', SPLIT_OUTPUT)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if RUN_SINGLE_PREVIEW:\n", + " preview = run_single_preview()\n", + " print(json.dumps(preview, ensure_ascii=False, indent=2))\n", + "else:\n", + " print('Set RUN_SINGLE_PREVIEW = True to generate one prompt preview.')\n", + "\n", + "if RUN_FULL_GENERATION:\n", + " set1_metadata, set1_audit = generate_dataset(set1_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " set2_metadata, set2_audit = generate_dataset(set2_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " combined_metadata = dict(set1_metadata)\n", + " overlap = set(combined_metadata).intersection(set2_metadata)\n", + " if overlap:\n", + " raise RuntimeError(f'Unexpected overlap between set1 and set2 target keys: {len(overlap)}')\n", + " combined_metadata.update(set2_metadata)\n", + " write_json(SET1_OUTPUT, set1_metadata)\n", + " write_json(SET2_OUTPUT, set2_metadata)\n", + " write_json(COMBINED_OUTPUT, combined_metadata)\n", + " write_jsonl(AUDIT_OUTPUT, set1_audit + set2_audit)\n", + " print('set1 samples =', len(set1_metadata))\n", + " print('set2 samples =', len(set2_metadata))\n", + " print('combined samples =', len(combined_metadata))\n", + " print('Saved:', SET1_OUTPUT)\n", + " print('Saved:', SET2_OUTPUT)\n", + " print('Saved:', COMBINED_OUTPUT)\n", + " print('Saved:', AUDIT_OUTPUT)\n", + "else:\n", + " print('Set RUN_FULL_GENERATION = True to generate all 6000 samples.')\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "step1x_v11", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/llm_prompt_v1_v1.ipynb b/scripts/llm_prompt_v1_v1.ipynb new file mode 100644 index 0000000..7d0f8d7 --- /dev/null +++ b/scripts/llm_prompt_v1_v1.ipynb @@ -0,0 +1,734 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Qwen dual-image prompt + embedding generation for connector v1\n", + "\n", + "这个 notebook 做三件事:\n", + "1. 输入两张图片,生成编辑 prompt\n", + "2. 用 `img1 + prompt` 生成当前 connector 兼容的 embedding/mask\n", + "3. 按 6000 样本规则写出 metadata 与 `_step1x_te.npz`\n", + "\n", + "实现约束:\n", + "- 不替换当前 llm encoder\n", + "- 双图只用于生成 prompt\n", + "- 最终 embedding 只使用 `img1 + prompt`\n", + "- 每个 set 独立采样\n", + "- 每个 set 内 `data/data_seed_2/3/4/5` 各 120 人\n", + "- 每个 prefix 的 120 人拆成 60 人 `neg_to_pos` 和 60 人 `pos_to_neg`,人不重叠\n", + "\n", + "运行前注意:\n", + "- 这个 notebook 默认通过 `CUDA_VISIBLE_DEVICES=1` 使用物理 1 号卡\n", + "- 必须先 Restart Kernel,再从第一格开始运行,否则环境变量不会生效\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA_VISIBLE_DEVICES = 1\n", + "torch.cuda.is_available() = True\n", + "torch.cuda.device_count() = 1\n", + "DEVICE = cuda:0\n", + "MODEL_PATH = /scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct\n", + "OUTPUT_METADATA_DIR = /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata\n" + ] + } + ], + "source": [ + "from __future__ import annotations\n", + "\n", + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", + "\n", + "import json\n", + "import math\n", + "import random\n", + "import re\n", + "from collections import defaultdict\n", + "from pathlib import Path\n", + "from typing import Dict, Iterable, List, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from PIL import Image\n", + "from tqdm.auto import tqdm\n", + "from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration\n", + "from qwen_vl_utils import process_vision_info\n", + "\n", + "REPO_ROOT = Path('/scratch3/f007yzf/repos/Step1X-Edit-clean')\n", + "DATA_ROOT = REPO_ROOT / 'training_data'\n", + "MODEL_PATH = Path('/scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct')\n", + "OUTPUT_METADATA_DIR = DATA_ROOT / 'metadata'\n", + "OUTPUT_METADATA_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", + "DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32\n", + "MAX_NEW_TOKENS = 150\n", + "MAX_LENGTH = 640\n", + "MASTER_SEED = 20260307\n", + "SET1_SEED = MASTER_SEED\n", + "SET2_SEED = MASTER_SEED + 1\n", + "PAIRS = [f'{i:02d}' for i in range(5)]\n", + "PREFIXES = ['data', 'data_seed_2', 'data_seed_3', 'data_seed_4', 'data_seed_5']\n", + "\n", + "SINGLE_IMG1_PATH = DATA_ROOT / 'source_img' / 'data__p0000_pair00_s0__source.png'\n", + "SINGLE_IMG2_PATH = DATA_ROOT / 'reference_img' / 'left' / 'data__p0000_pair00_s0.png'\n", + "SINGLE_PREVIEW_OUTPUT = OUTPUT_METADATA_DIR / 'llm_prompt_preview_v1.json'\n", + "\n", + "RUN_SINGLE_PREVIEW = False\n", + "RUN_FULL_GENERATION = False\n", + "FORCE_OVERWRITE_NPZ = False\n", + "\n", + "DUAL_IMAGE_CAPTION_PROMPT = '''You are analyzing facial expressions for a controlled editing task.\n", + "Given:\n", + "- Image 1: source face to be edited\n", + "- Image 2: target expression reference\n", + "\n", + "Output a structured expression editing plan\n", + "'''\n", + "\n", + "EMBEDDER_PREFIX = '''Given a user prompt, generate an \"Enhanced prompt\" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:\n", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.\n", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n", + "\n", + "Here are examples of how to transform or refine prompts:\n", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n", + "\n", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:\n", + "User Prompt:'''\n", + "\n", + "IMAGE_RE = re.compile(r'^(data(?:_seed_\\d+)?)__p(\\d+)_pair(\\d+)_s(\\d+)(?:__(source|target))?\\.png$')\n", + "\n", + "print('CUDA_VISIBLE_DEVICES =', os.environ.get('CUDA_VISIBLE_DEVICES'))\n", + "print('torch.cuda.is_available() =', torch.cuda.is_available())\n", + "print('torch.cuda.device_count() =', torch.cuda.device_count())\n", + "print('DEVICE =', DEVICE)\n", + "print('MODEL_PATH =', MODEL_PATH)\n", + "print('OUTPUT_METADATA_DIR =', OUTPUT_METADATA_DIR)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading processor...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading Qwen2.5-VL model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "63dd799659e148749a96f3f0b62c2d46", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/5 [00:00 Image.Image:\n", + " return Image.open(path).convert('RGB')\n", + "\n", + "def split_string_for_embedder(text: str) -> List[str]:\n", + " text = text.replace(\"'\", '\"').replace('“', '\"').replace('”', '\"')\n", + " result = []\n", + " in_quotes = False\n", + " temp = ''\n", + " for idx, char in enumerate(text):\n", + " if char == '\"' and idx > 155:\n", + " temp += char\n", + " if not in_quotes:\n", + " result.append(temp)\n", + " temp = ''\n", + " in_quotes = not in_quotes\n", + " continue\n", + " if in_quotes:\n", + " result.append('“' + char + '”')\n", + " else:\n", + " temp += char\n", + " if temp:\n", + " result.append(temp)\n", + " return result\n", + "\n", + "@torch.inference_mode()\n", + "def build_dual_image_prompt(img1_path: Path, img2_path: Path, user_prompt: str = '', max_new_tokens: int = MAX_NEW_TOKENS) -> str:\n", + " img1 = load_pil(img1_path)\n", + " img2 = load_pil(img2_path)\n", + " messages = [{\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': DUAL_IMAGE_CAPTION_PROMPT},\n", + " {'type': 'text', 'text': '\\n[Source Image (Structure/Identity)]:'},\n", + " {'type': 'image', 'image': img1},\n", + " {'type': 'text', 'text': '\\n[Reference Image (Style/Expression)]:'},\n", + " {'type': 'image', 'image': img2},\n", + " {'type': 'text', 'text': f'\\nUser prompt: {user_prompt}'},\n", + " {'type': 'text', 'text': '\\nPlease generate the structured editing plan now.'},\n", + " ],\n", + " }]\n", + " text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n", + " image_inputs, _ = process_vision_info(messages)\n", + " inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors='pt').to(DEVICE)\n", + " generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)\n", + " generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]\n", + " generated_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)[0].strip()\n", + " return generated_text or user_prompt\n", + "\n", + "@torch.inference_mode()\n", + "def build_prompt_embedding(img1_path: Path, prompt: str, max_length: int = MAX_LENGTH) -> Tuple[np.ndarray, np.ndarray]:\n", + " img1 = load_pil(img1_path)\n", + " messages = [{\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': EMBEDDER_PREFIX},\n", + " {'type': 'image', 'image': img1},\n", + " {'type': 'text', 'text': prompt},\n", + " ],\n", + " }]\n", + " text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n", + " image_inputs, _ = process_vision_info(messages)\n", + " inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors='pt')\n", + " old_input_ids = inputs.input_ids\n", + " token_list = []\n", + " for text_each in split_string_for_embedder(text):\n", + " txt_inputs = processor(text=text_each, images=None, videos=None, padding=True, return_tensors='pt')\n", + " token_each = txt_inputs.input_ids\n", + " if token_each[0][0] == 2073 and token_each[0][-1] == 854:\n", + " token_each = token_each[:, 1:-1]\n", + " token_list.append(token_each)\n", + " new_txt_ids = torch.cat(token_list, dim=1).to(old_input_ids.device)\n", + " idx1 = (old_input_ids == 151653).nonzero(as_tuple=True)[1][0]\n", + " idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]\n", + " input_ids = torch.cat([old_input_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0).unsqueeze(0)\n", + " attention_mask = (input_ids > 0).long()\n", + " outputs = model(\n", + " input_ids=input_ids.to(DEVICE),\n", + " attention_mask=attention_mask.to(DEVICE),\n", + " pixel_values=inputs.pixel_values.to(DEVICE),\n", + " image_grid_thw=inputs.image_grid_thw.to(DEVICE),\n", + " output_hidden_states=True,\n", + " )\n", + " emb = outputs.hidden_states[-1]\n", + " embeds = torch.zeros((max_length, HIDDEN_SIZE), dtype=torch.bfloat16, device=DEVICE)\n", + " masks = torch.zeros((max_length,), dtype=torch.long, device=DEVICE)\n", + " usable = max(0, emb.shape[1] - 217)\n", + " length = min(max_length, usable)\n", + " if length > 0:\n", + " embeds[:length] = emb[0, 217:217 + length]\n", + " masks[:length] = 1\n", + " # return embeds.cpu().numpy(), masks.cpu().numpy()\n", + " return embeds.to(torch.float32).cpu().numpy(), masks.cpu().numpy()\n", + "\n", + "@torch.inference_mode()\n", + "def run_single_preview(img1_path: Path = SINGLE_IMG1_PATH, img2_path: Path = SINGLE_IMG2_PATH, output_json: Path = SINGLE_PREVIEW_OUTPUT):\n", + " prompt = build_dual_image_prompt(img1_path, img2_path)\n", + " embeds, masks = build_prompt_embedding(img1_path, prompt)\n", + " payload = {\n", + " 'img1_path': str(img1_path.resolve()),\n", + " 'img2_path': str(img2_path.resolve()),\n", + " 'prompt': prompt,\n", + " 'embedding_shape': list(embeds.shape),\n", + " 'mask_sum': int(masks.sum()),\n", + " }\n", + " output_json.parent.mkdir(parents=True, exist_ok=True)\n", + " output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + " return payload\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data 579\n", + "data_seed_2 600\n", + "data_seed_3 571\n", + "data_seed_4 583\n", + "data_seed_5 584\n", + "set1 unique persons = 600\n", + "set2 unique persons = 600\n" + ] + } + ], + "source": [ + "ROOTS = {\n", + " 'flux_neg': DATA_ROOT / 'source_img',\n", + " 'flux_pos': DATA_ROOT / 'target_img',\n", + " 'iy_pos': DATA_ROOT / 'reference_img' / 'left',\n", + " 'iy_neg': DATA_ROOT / 'reference_img' / 'right',\n", + "}\n", + "\n", + "def candidate_filename(prefix: str, person_id: str, pair_id: str, kind: str) -> str:\n", + " stem = f'{prefix}__p{person_id}_pair{pair_id}_s0'\n", + " if kind == 'flux_neg':\n", + " return stem + '__source.png'\n", + " if kind == 'flux_pos':\n", + " return stem + '__target.png'\n", + " if kind in {'iy_pos', 'iy_neg'}:\n", + " return stem + '.png'\n", + " raise KeyError(kind)\n", + "\n", + "def resolve_image_paths(prefix: str, person_id: str, pair_id: str) -> Dict[str, Path]:\n", + " return {kind: ROOTS[kind] / candidate_filename(prefix, person_id, pair_id, kind) for kind in ROOTS}\n", + "\n", + "def person_complete(prefix: str, person_id: str) -> bool:\n", + " for pair_id in PAIRS:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if not all(path.exists() for path in paths.values()):\n", + " return False\n", + " return True\n", + "\n", + "def scan_candidates() -> Dict[str, List[str]]:\n", + " seen = defaultdict(set)\n", + " for root in ROOTS.values():\n", + " for path in root.glob('*.png'):\n", + " m = IMAGE_RE.match(path.name)\n", + " if not m:\n", + " continue\n", + " prefix, person_id, pair_id, s, _ = m.groups()\n", + " if prefix not in PREFIXES or s != '0':\n", + " continue\n", + " seen[prefix].add(person_id)\n", + " candidates = {}\n", + " for prefix in PREFIXES:\n", + " valid = sorted(pid for pid in seen[prefix] if person_complete(prefix, pid))\n", + " candidates[prefix] = valid\n", + " return candidates\n", + "\n", + "def build_person_to_prefixes(candidates_by_prefix: Dict[str, List[str]]) -> Dict[str, List[str]]:\n", + " person_to_prefixes = defaultdict(list)\n", + " for prefix, persons in candidates_by_prefix.items():\n", + " for pid in persons:\n", + " person_to_prefixes[pid].append(prefix)\n", + " return {pid: sorted(prefixes) for pid, prefixes in person_to_prefixes.items()}\n", + "\n", + "def hopcroft_karp(graph: Dict[str, List[str]], left_nodes: List[str], right_nodes: List[str]):\n", + " INF = 10 ** 9\n", + " pair_u = {u: None for u in left_nodes}\n", + " pair_v = {v: None for v in right_nodes}\n", + " dist = {}\n", + "\n", + " from collections import deque\n", + "\n", + " def bfs():\n", + " queue = deque()\n", + " for u in left_nodes:\n", + " if pair_u[u] is None:\n", + " dist[u] = 0\n", + " queue.append(u)\n", + " else:\n", + " dist[u] = INF\n", + " found = False\n", + " while queue:\n", + " u = queue.popleft()\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None:\n", + " found = True\n", + " elif dist[pu] == INF:\n", + " dist[pu] = dist[u] + 1\n", + " queue.append(pu)\n", + " return found\n", + "\n", + " def dfs(u):\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None or (dist[pu] == dist[u] + 1 and dfs(pu)):\n", + " pair_u[u] = v\n", + " pair_v[v] = u\n", + " return True\n", + " dist[u] = INF\n", + " return False\n", + "\n", + " matching = 0\n", + " while bfs():\n", + " for u in left_nodes:\n", + " if pair_u[u] is None and dfs(u):\n", + " matching += 1\n", + " return matching, pair_u\n", + "\n", + "def assign_people_to_prefixes(candidates_by_prefix: Dict[str, List[str]], seed: int) -> Dict[str, List[str]]:\n", + " person_to_prefixes = build_person_to_prefixes(candidates_by_prefix)\n", + " left_nodes = sorted(person_to_prefixes)\n", + " if len(left_nodes) != 600:\n", + " raise RuntimeError(f'Expected 600 unique person ids, got {len(left_nodes)}')\n", + " rng = random.Random(seed)\n", + " slot_map = {}\n", + " right_nodes = []\n", + " for prefix in PREFIXES:\n", + " slot_names = [f'{prefix}#{idx:03d}' for idx in range(120)]\n", + " rng.shuffle(slot_names)\n", + " slot_map[prefix] = slot_names\n", + " right_nodes.extend(slot_names)\n", + " graph = {}\n", + " shuffled_left = left_nodes[:]\n", + " rng.shuffle(shuffled_left)\n", + " for pid in shuffled_left:\n", + " neighbors = []\n", + " prefixes = person_to_prefixes[pid][:]\n", + " rng.shuffle(prefixes)\n", + " for prefix in prefixes:\n", + " neighbors.extend(slot_map[prefix])\n", + " graph[pid] = neighbors\n", + " matching, pair_u = hopcroft_karp(graph, shuffled_left, right_nodes)\n", + " if matching != 600:\n", + " raise RuntimeError(f'Failed to assign 600 people to prefix slots, only matched {matching}')\n", + " result = defaultdict(list)\n", + " for pid, slot in pair_u.items():\n", + " prefix = slot.split('#', 1)[0]\n", + " result[prefix].append(pid)\n", + " for prefix in PREFIXES:\n", + " result[prefix] = sorted(result[prefix])\n", + " if len(result[prefix]) != 120:\n", + " raise RuntimeError(f'Prefix {prefix} expected 120 people, got {len(result[prefix])}')\n", + " return dict(result)\n", + "\n", + "def split_prefix_people(prefix_people: Dict[str, List[str]], seed: int, first_name: str, second_name: str) -> Dict[str, Dict[str, List[str]]]:\n", + " rng = random.Random(seed)\n", + " result = {}\n", + " for prefix, people in prefix_people.items():\n", + " shuffled = people[:]\n", + " rng.shuffle(shuffled)\n", + " result[prefix] = {\n", + " first_name: sorted(shuffled[:60]),\n", + " second_name: sorted(shuffled[60:120]),\n", + " }\n", + " return result\n", + "\n", + "candidates_by_prefix = scan_candidates()\n", + "for prefix in PREFIXES:\n", + " print(prefix, len(candidates_by_prefix[prefix]))\n", + "set1_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET1_SEED)\n", + "set2_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET2_SEED)\n", + "set1_split = split_prefix_people(set1_prefix_people, SET1_SEED + 100, 'neg_to_pos', 'pos_to_neg')\n", + "set2_split = split_prefix_people(set2_prefix_people, SET2_SEED + 100, 'pos_to_neg', 'neg_to_pos')\n", + "print('set1 unique persons =', len(set().union(*[set(v) for v in set1_prefix_people.values()])))\n", + "print('set2 unique persons =', len(set().union(*[set(v) for v in set2_prefix_people.values()])))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prepared split payload at /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_6000_split_v1.json\n" + ] + } + ], + "source": [ + "def save_step1x_npz(target_image_path: Path, embeds: np.ndarray, masks: np.ndarray, force_overwrite: bool = False) -> Path:\n", + " npz_path = target_image_path.with_suffix('')\n", + " npz_path = npz_path.parent / f'{npz_path.name}_step1x_te.npz'\n", + " if npz_path.exists() and not force_overwrite:\n", + " existing = np.load(npz_path)\n", + " if 'embeds' in existing and 'masks' in existing:\n", + " return npz_path\n", + " np.savez(npz_path, embeds=embeds, masks=masks)\n", + " return npz_path\n", + "\n", + "def build_record(set_id: str, direction: str, prefix: str, person_id: str, pair_id: str) -> Dict[str, str]:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if set_id == 'set1' and direction == 'neg_to_pos':\n", + " img1 = paths['flux_neg']\n", + " img2 = paths['iy_pos']\n", + " ref_image_path = paths['flux_neg']\n", + " target_image_path = paths['flux_pos']\n", + " elif set_id == 'set1' and direction == 'pos_to_neg':\n", + " img1 = paths['flux_pos']\n", + " img2 = paths['iy_neg']\n", + " ref_image_path = paths['flux_pos']\n", + " target_image_path = paths['flux_neg']\n", + " elif set_id == 'set2' and direction == 'pos_to_neg':\n", + " img1 = paths['iy_pos']\n", + " img2 = paths['flux_neg']\n", + " ref_image_path = paths['iy_pos']\n", + " target_image_path = paths['iy_neg']\n", + " elif set_id == 'set2' and direction == 'neg_to_pos':\n", + " img1 = paths['iy_neg']\n", + " img2 = paths['flux_pos']\n", + " ref_image_path = paths['iy_neg']\n", + " target_image_path = paths['iy_pos']\n", + " else:\n", + " raise ValueError((set_id, direction))\n", + " return {\n", + " 'set_id': set_id,\n", + " 'direction': direction,\n", + " 'prefix_family': prefix,\n", + " 'person_id': person_id,\n", + " 'pair_id': pair_id,\n", + " 'img1_path': str(img1.resolve()),\n", + " 'img2_path': str(img2.resolve()),\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'target_image_path': str(target_image_path.resolve()),\n", + " }\n", + "\n", + "def expand_split_to_records(split: Dict[str, Dict[str, List[str]]], set_id: str) -> List[Dict[str, str]]:\n", + " records = []\n", + " for prefix in PREFIXES:\n", + " for direction, people in split[prefix].items():\n", + " for person_id in people:\n", + " for pair_id in PAIRS:\n", + " records.append(build_record(set_id, direction, prefix, person_id, pair_id))\n", + " return records\n", + "\n", + "def write_json(path: Path, payload) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + "\n", + "def write_jsonl(path: Path, rows: Iterable[dict]) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " with path.open('w', encoding='utf-8') as f:\n", + " for row in rows:\n", + " f.write(json.dumps(row, ensure_ascii=False) + '\\n')\n", + "\n", + "def generate_dataset(records: List[Dict[str, str]], force_overwrite_npz: bool = False) -> Tuple[dict, List[dict]]:\n", + " metadata = {}\n", + " audit_rows = []\n", + " for record in tqdm(records):\n", + " img1 = Path(record['img1_path'])\n", + " img2 = Path(record['img2_path'])\n", + " target_image_path = Path(record['target_image_path'])\n", + " ref_image_path = Path(record['ref_image_path'])\n", + " prompt = build_dual_image_prompt(img1, img2)\n", + " embeds, masks = build_prompt_embedding(img1, prompt)\n", + " npz_path = save_step1x_npz(target_image_path, embeds, masks, force_overwrite=force_overwrite_npz)\n", + " key = str(target_image_path.resolve())\n", + " if key in metadata:\n", + " raise RuntimeError(f'Duplicate target image key: {key}')\n", + " metadata[key] = {\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'caption': prompt,\n", + " }\n", + " audit_row = dict(record)\n", + " audit_row['caption'] = prompt\n", + " audit_row['embedding_npz_path'] = str(npz_path.resolve())\n", + " audit_rows.append(audit_row)\n", + " return metadata, audit_rows\n", + "\n", + "set1_records = expand_split_to_records(set1_split, 'set1')\n", + "set2_records = expand_split_to_records(set2_split, 'set2')\n", + "assert len(set1_records) == 3000\n", + "assert len(set2_records) == 3000\n", + "all_records = set1_records + set2_records\n", + "assert len(all_records) == 6000\n", + "\n", + "SPLIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_split_v1.json'\n", + "SET1_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set1_flux2_3000_v1.json'\n", + "SET2_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set2_infiniteyou_3000_v1.json'\n", + "COMBINED_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_train_v1.json'\n", + "AUDIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_audit_v1.jsonl'\n", + "\n", + "split_payload = {\n", + " 'master_seed': MASTER_SEED,\n", + " 'set1_seed': SET1_SEED,\n", + " 'set2_seed': SET2_SEED,\n", + " 'prefixes': PREFIXES,\n", + " 'pairs': PAIRS,\n", + " 'set1': set1_split,\n", + " 'set2': set2_split,\n", + "}\n", + "write_json(SPLIT_OUTPUT, split_payload)\n", + "print('Prepared split payload at', SPLIT_OUTPUT)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"img1_path\": \"/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/source_img/data__p0000_pair00_s0__source.png\",\n", + " \"img2_path\": \"/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/reference_img/left/data__p0000_pair00_s0.png\",\n", + " \"prompt\": \"Editing instruction:\\n\\n- Raise the left corner of the mouth upwards and outward.\\n- Raise the right corner of the mouth upwards and outward.\\n- Lower the left side of the eyebrows slightly.\\n- Lower the right side of the eyebrows slightly.\\n- Tighten the lips into a smile.\\n- Relax the forehead muscles to reduce any frown lines.\",\n", + " \"embedding_shape\": [\n", + " 640,\n", + " 3584\n", + " ],\n", + " \"mask_sum\": 381\n", + "}\n", + "Set RUN_FULL_GENERATION = True to generate all 6000 samples.\n" + ] + } + ], + "source": [ + "RUN_SINGLE_PREVIEW = True\n", + "if RUN_SINGLE_PREVIEW:\n", + " preview = run_single_preview()\n", + " print(json.dumps(preview, ensure_ascii=False, indent=2))\n", + "else:\n", + " print('Set RUN_SINGLE_PREVIEW = True to generate one prompt preview.')\n", + "\n", + "if RUN_FULL_GENERATION:\n", + " set1_metadata, set1_audit = generate_dataset(set1_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " set2_metadata, set2_audit = generate_dataset(set2_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " combined_metadata = dict(set1_metadata)\n", + " overlap = set(combined_metadata).intersection(set2_metadata)\n", + " if overlap:\n", + " raise RuntimeError(f'Unexpected overlap between set1 and set2 target keys: {len(overlap)}')\n", + " combined_metadata.update(set2_metadata)\n", + " write_json(SET1_OUTPUT, set1_metadata)\n", + " write_json(SET2_OUTPUT, set2_metadata)\n", + " write_json(COMBINED_OUTPUT, combined_metadata)\n", + " write_jsonl(AUDIT_OUTPUT, set1_audit + set2_audit)\n", + " print('set1 samples =', len(set1_metadata))\n", + " print('set2 samples =', len(set2_metadata))\n", + " print('combined samples =', len(combined_metadata))\n", + " print('Saved:', SET1_OUTPUT)\n", + " print('Saved:', SET2_OUTPUT)\n", + " print('Saved:', COMBINED_OUTPUT)\n", + " print('Saved:', AUDIT_OUTPUT)\n", + "else:\n", + " print('Set RUN_FULL_GENERATION = True to generate all 6000 samples.')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "89ce9ff1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_set1_flux2_3000_v1.json')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SET1_OUTPUT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6586cdaf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_6000_audit_v1.jsonl')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3a2f3f3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "step1x_v11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/llm_prompt_v1_v2.ipynb b/scripts/llm_prompt_v1_v2.ipynb new file mode 100644 index 0000000..b5a4171 --- /dev/null +++ b/scripts/llm_prompt_v1_v2.ipynb @@ -0,0 +1,711 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Qwen dual-image direct embedding cache for connector v2\n", + "\n", + "这个 notebook 现在做两件事:\n", + "1. 输入两张图片和固定 prompt,直接生成双图 embedding cache\n", + "2. 按 6000 样本规则写出 metadata 与 `_step1x_te.npz`\n", + "\n", + "当前设计:\n", + "- 不再走 `双图 -> prompt 文本 -> embedding` 两段式\n", + "- 直接走 `双图 + 固定 prompt -> embedding`\n", + "- metadata 里的 `caption` 仅保留固定字符串占位\n", + "- 最终 cache 仍然兼容 Step1X 训练读取\n", + "\n", + "运行前注意:\n", + "- 这个 notebook 默认通过 `CUDA_VISIBLE_DEVICES=1` 使用物理 1 号卡\n", + "- 必须先 Restart Kernel,再从第一格开始运行,否则环境变量不会生效\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA_VISIBLE_DEVICES = 1\n", + "torch.cuda.is_available() = True\n", + "torch.cuda.device_count() = 1\n", + "DEVICE = cuda:0\n", + "MODEL_PATH = /scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct\n", + "OUTPUT_METADATA_DIR = /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata\n" + ] + } + ], + "source": [ + "from __future__ import annotations\n", + "\n", + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", + "\n", + "import json\n", + "import math\n", + "import random\n", + "import re\n", + "from collections import defaultdict\n", + "from pathlib import Path\n", + "from typing import Dict, Iterable, List, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from PIL import Image\n", + "from tqdm.auto import tqdm\n", + "from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration\n", + "from qwen_vl_utils import process_vision_info\n", + "\n", + "REPO_ROOT = Path('/scratch3/f007yzf/repos/Step1X-Edit-clean')\n", + "DATA_ROOT = REPO_ROOT / 'training_data'\n", + "MODEL_PATH = Path('/scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct')\n", + "OUTPUT_METADATA_DIR = DATA_ROOT / 'metadata'\n", + "OUTPUT_METADATA_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", + "DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32\n", + "MAX_LENGTH = 1280\n", + "MASTER_SEED = 20260307\n", + "SET1_SEED = MASTER_SEED\n", + "SET2_SEED = MASTER_SEED + 1\n", + "PAIRS = [f'{i:02d}' for i in range(5)]\n", + "PREFIXES = ['data', 'data_seed_2', 'data_seed_3', 'data_seed_4', 'data_seed_5']\n", + "\n", + "SINGLE_IMG1_PATH = DATA_ROOT / 'source_img' / 'data__p0000_pair00_s0__source.png'\n", + "SINGLE_IMG2_PATH = DATA_ROOT / 'reference_img' / 'left' / 'data__p0000_pair00_s0.png'\n", + "SINGLE_PREVIEW_OUTPUT = OUTPUT_METADATA_DIR / 'llm_dual_embedding_preview_v2.json'\n", + "\n", + "RUN_SINGLE_PREVIEW = False\n", + "RUN_FULL_GENERATION = False\n", + "FORCE_OVERWRITE_NPZ = False\n", + "\n", + "FIXED_EMBED_PROMPT = \"\"\"You are analyzing facial expressions for a controlled editing task.\n", + "Given:\n", + "- Image 1: source face to be edited\n", + "- Image 2: target expression reference\n", + "\n", + "Output a structured expression editing plan.\n", + "Focus only on expression change. Keep identity, hairstyle, clothing, background, lighting, and camera unchanged.\n", + "Describe the target expression change clearly and concretely.\n", + "\"\"\"\n", + "\n", + "FIXED_METADATA_CAPTION = 'dual-image connector cache placeholder'\n", + "\n", + "IMAGE_RE = re.compile(r'^(data(?:_seed_\\d+)?)__p(\\d+)_pair(\\d+)_s(\\d+)(?:__(source|target))?\\.png$')\n", + "\n", + "print('CUDA_VISIBLE_DEVICES =', os.environ.get('CUDA_VISIBLE_DEVICES'))\n", + "print('torch.cuda.is_available() =', torch.cuda.is_available())\n", + "print('torch.cuda.device_count() =', torch.cuda.device_count())\n", + "print('DEVICE =', DEVICE)\n", + "print('MODEL_PATH =', MODEL_PATH)\n", + "print('OUTPUT_METADATA_DIR =', OUTPUT_METADATA_DIR)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading processor...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading Qwen2.5-VL model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "63dd799659e148749a96f3f0b62c2d46", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/5 [00:00 Image.Image:\n", + " return Image.open(path).convert('RGB')\n", + "\n", + "def get_vision_token_ids():\n", + " tokenizer = processor.tokenizer\n", + " v_start_id = tokenizer.convert_tokens_to_ids('<|vision_start|>')\n", + " v_end_id = tokenizer.convert_tokens_to_ids('<|vision_end|>')\n", + " if v_start_id in [None, tokenizer.unk_token_id]:\n", + " v_start_id = 151652\n", + " if v_end_id in [None, tokenizer.unk_token_id]:\n", + " v_end_id = 151653\n", + " return v_start_id, v_end_id\n", + "\n", + "@torch.inference_mode()\n", + "def build_dual_image_prompt(img1_path: Path, img2_path: Path, user_prompt: str = '') -> str:\n", + " return user_prompt or FIXED_METADATA_CAPTION\n", + "\n", + "@torch.inference_mode()\n", + "def build_dual_image_embedding(\n", + " img1_path: Path,\n", + " img2_path: Path,\n", + " fixed_prompt: str = FIXED_EMBED_PROMPT,\n", + " max_length: int = MAX_LENGTH,\n", + " return_debug: bool = False,\n", + ") -> Tuple[np.ndarray, np.ndarray]:\n", + " img1 = load_pil(img1_path)\n", + " img2 = load_pil(img2_path)\n", + " messages = [{\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': '\n", + "[Source Image]:'},\n", + " {'type': 'image', 'image': img1},\n", + " {'type': 'text', 'text': '\n", + "[Reference Image]:'},\n", + " {'type': 'image', 'image': img2},\n", + " {'type': 'text', 'text': f'\n", + "Task:\n", + "{fixed_prompt}'},\n", + " ],\n", + " }]\n", + " text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n", + " image_inputs, _ = process_vision_info(messages)\n", + " inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors='pt')\n", + " outputs = model(\n", + " input_ids=inputs.input_ids.to(DEVICE),\n", + " attention_mask=inputs.attention_mask.to(DEVICE),\n", + " pixel_values=inputs.pixel_values.to(DEVICE),\n", + " image_grid_thw=inputs.image_grid_thw.to(DEVICE),\n", + " output_hidden_states=True,\n", + " )\n", + " emb = outputs.hidden_states[-1]\n", + " v_start_id, v_end_id = get_vision_token_ids()\n", + " input_ids_0 = inputs.input_ids[0]\n", + " v_ends = (input_ids_0 == v_end_id).nonzero(as_tuple=True)[0]\n", + " if len(v_ends) < 2:\n", + " raise RuntimeError(f'Expected at least two vision_end tokens, got {len(v_ends)}')\n", + " text_start = int(v_ends[-1].item()) + 1\n", + " usable = max(0, emb.shape[1] - text_start)\n", + " length = min(max_length, usable)\n", + "\n", + " embeds = torch.zeros((max_length, HIDDEN_SIZE), dtype=torch.float32, device=DEVICE)\n", + " masks = torch.zeros((max_length,), dtype=torch.long, device=DEVICE)\n", + " if length > 0:\n", + " embeds[:length] = emb[0, text_start:text_start + length].to(torch.float32)\n", + " masks[:length] = 1\n", + "\n", + " if return_debug:\n", + " debug = {\n", + " 'full_seq_len': int(emb.shape[1]),\n", + " 'text_start_index': text_start,\n", + " 'usable_len_before_truncation': int(usable),\n", + " 'final_len': int(length),\n", + " }\n", + " return embeds.cpu().numpy(), masks.cpu().numpy(), debug\n", + " return embeds.cpu().numpy(), masks.cpu().numpy()\n", + "\n", + "@torch.inference_mode()\n", + "def run_single_preview(img1_path: Path = SINGLE_IMG1_PATH, img2_path: Path = SINGLE_IMG2_PATH, output_json: Path = SINGLE_PREVIEW_OUTPUT):\n", + " embeds, masks, debug = build_dual_image_embedding(img1_path, img2_path, return_debug=True)\n", + " payload = {\n", + " 'img1_path': str(img1_path.resolve()),\n", + " 'img2_path': str(img2_path.resolve()),\n", + " 'fixed_prompt': FIXED_EMBED_PROMPT,\n", + " 'fixed_metadata_caption': FIXED_METADATA_CAPTION,\n", + " 'embedding_shape': list(embeds.shape),\n", + " 'mask_sum': int(masks.sum()),\n", + " **debug,\n", + " }\n", + " output_json.parent.mkdir(parents=True, exist_ok=True)\n", + " output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + " return payload\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data 579\n", + "data_seed_2 600\n", + "data_seed_3 571\n", + "data_seed_4 583\n", + "data_seed_5 584\n", + "set1 unique persons = 600\n", + "set2 unique persons = 600\n" + ] + } + ], + "source": [ + "ROOTS = {\n", + " 'flux_neg': DATA_ROOT / 'source_img',\n", + " 'flux_pos': DATA_ROOT / 'target_img',\n", + " 'iy_pos': DATA_ROOT / 'reference_img' / 'left',\n", + " 'iy_neg': DATA_ROOT / 'reference_img' / 'right',\n", + "}\n", + "\n", + "def candidate_filename(prefix: str, person_id: str, pair_id: str, kind: str) -> str:\n", + " stem = f'{prefix}__p{person_id}_pair{pair_id}_s0'\n", + " if kind == 'flux_neg':\n", + " return stem + '__source.png'\n", + " if kind == 'flux_pos':\n", + " return stem + '__target.png'\n", + " if kind in {'iy_pos', 'iy_neg'}:\n", + " return stem + '.png'\n", + " raise KeyError(kind)\n", + "\n", + "def resolve_image_paths(prefix: str, person_id: str, pair_id: str) -> Dict[str, Path]:\n", + " return {kind: ROOTS[kind] / candidate_filename(prefix, person_id, pair_id, kind) for kind in ROOTS}\n", + "\n", + "def person_complete(prefix: str, person_id: str) -> bool:\n", + " for pair_id in PAIRS:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if not all(path.exists() for path in paths.values()):\n", + " return False\n", + " return True\n", + "\n", + "def scan_candidates() -> Dict[str, List[str]]:\n", + " seen = defaultdict(set)\n", + " for root in ROOTS.values():\n", + " for path in root.glob('*.png'):\n", + " m = IMAGE_RE.match(path.name)\n", + " if not m:\n", + " continue\n", + " prefix, person_id, pair_id, s, _ = m.groups()\n", + " if prefix not in PREFIXES or s != '0':\n", + " continue\n", + " seen[prefix].add(person_id)\n", + " candidates = {}\n", + " for prefix in PREFIXES:\n", + " valid = sorted(pid for pid in seen[prefix] if person_complete(prefix, pid))\n", + " candidates[prefix] = valid\n", + " return candidates\n", + "\n", + "def build_person_to_prefixes(candidates_by_prefix: Dict[str, List[str]]) -> Dict[str, List[str]]:\n", + " person_to_prefixes = defaultdict(list)\n", + " for prefix, persons in candidates_by_prefix.items():\n", + " for pid in persons:\n", + " person_to_prefixes[pid].append(prefix)\n", + " return {pid: sorted(prefixes) for pid, prefixes in person_to_prefixes.items()}\n", + "\n", + "def hopcroft_karp(graph: Dict[str, List[str]], left_nodes: List[str], right_nodes: List[str]):\n", + " INF = 10 ** 9\n", + " pair_u = {u: None for u in left_nodes}\n", + " pair_v = {v: None for v in right_nodes}\n", + " dist = {}\n", + "\n", + " from collections import deque\n", + "\n", + " def bfs():\n", + " queue = deque()\n", + " for u in left_nodes:\n", + " if pair_u[u] is None:\n", + " dist[u] = 0\n", + " queue.append(u)\n", + " else:\n", + " dist[u] = INF\n", + " found = False\n", + " while queue:\n", + " u = queue.popleft()\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None:\n", + " found = True\n", + " elif dist[pu] == INF:\n", + " dist[pu] = dist[u] + 1\n", + " queue.append(pu)\n", + " return found\n", + "\n", + " def dfs(u):\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None or (dist[pu] == dist[u] + 1 and dfs(pu)):\n", + " pair_u[u] = v\n", + " pair_v[v] = u\n", + " return True\n", + " dist[u] = INF\n", + " return False\n", + "\n", + " matching = 0\n", + " while bfs():\n", + " for u in left_nodes:\n", + " if pair_u[u] is None and dfs(u):\n", + " matching += 1\n", + " return matching, pair_u\n", + "\n", + "def assign_people_to_prefixes(candidates_by_prefix: Dict[str, List[str]], seed: int) -> Dict[str, List[str]]:\n", + " person_to_prefixes = build_person_to_prefixes(candidates_by_prefix)\n", + " left_nodes = sorted(person_to_prefixes)\n", + " if len(left_nodes) != 600:\n", + " raise RuntimeError(f'Expected 600 unique person ids, got {len(left_nodes)}')\n", + " rng = random.Random(seed)\n", + " slot_map = {}\n", + " right_nodes = []\n", + " for prefix in PREFIXES:\n", + " slot_names = [f'{prefix}#{idx:03d}' for idx in range(120)]\n", + " rng.shuffle(slot_names)\n", + " slot_map[prefix] = slot_names\n", + " right_nodes.extend(slot_names)\n", + " graph = {}\n", + " shuffled_left = left_nodes[:]\n", + " rng.shuffle(shuffled_left)\n", + " for pid in shuffled_left:\n", + " neighbors = []\n", + " prefixes = person_to_prefixes[pid][:]\n", + " rng.shuffle(prefixes)\n", + " for prefix in prefixes:\n", + " neighbors.extend(slot_map[prefix])\n", + " graph[pid] = neighbors\n", + " matching, pair_u = hopcroft_karp(graph, shuffled_left, right_nodes)\n", + " if matching != 600:\n", + " raise RuntimeError(f'Failed to assign 600 people to prefix slots, only matched {matching}')\n", + " result = defaultdict(list)\n", + " for pid, slot in pair_u.items():\n", + " prefix = slot.split('#', 1)[0]\n", + " result[prefix].append(pid)\n", + " for prefix in PREFIXES:\n", + " result[prefix] = sorted(result[prefix])\n", + " if len(result[prefix]) != 120:\n", + " raise RuntimeError(f'Prefix {prefix} expected 120 people, got {len(result[prefix])}')\n", + " return dict(result)\n", + "\n", + "def split_prefix_people(prefix_people: Dict[str, List[str]], seed: int, first_name: str, second_name: str) -> Dict[str, Dict[str, List[str]]]:\n", + " rng = random.Random(seed)\n", + " result = {}\n", + " for prefix, people in prefix_people.items():\n", + " shuffled = people[:]\n", + " rng.shuffle(shuffled)\n", + " result[prefix] = {\n", + " first_name: sorted(shuffled[:60]),\n", + " second_name: sorted(shuffled[60:120]),\n", + " }\n", + " return result\n", + "\n", + "candidates_by_prefix = scan_candidates()\n", + "for prefix in PREFIXES:\n", + " print(prefix, len(candidates_by_prefix[prefix]))\n", + "set1_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET1_SEED)\n", + "set2_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET2_SEED)\n", + "set1_split = split_prefix_people(set1_prefix_people, SET1_SEED + 100, 'neg_to_pos', 'pos_to_neg')\n", + "set2_split = split_prefix_people(set2_prefix_people, SET2_SEED + 100, 'pos_to_neg', 'neg_to_pos')\n", + "print('set1 unique persons =', len(set().union(*[set(v) for v in set1_prefix_people.values()])))\n", + "print('set2 unique persons =', len(set().union(*[set(v) for v in set2_prefix_people.values()])))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prepared split payload at /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_6000_split_v1.json\n" + ] + } + ], + "source": [ + "def save_step1x_npz(target_image_path: Path, embeds: np.ndarray, masks: np.ndarray, force_overwrite: bool = False) -> Path:\n", + " npz_path = target_image_path.with_suffix('')\n", + " npz_path = npz_path.parent / f'{npz_path.name}_step1x_te.npz'\n", + " if npz_path.exists() and not force_overwrite:\n", + " existing = np.load(npz_path)\n", + " if 'embeds' in existing and 'masks' in existing:\n", + " return npz_path\n", + " np.savez(npz_path, embeds=embeds, masks=masks)\n", + " return npz_path\n", + "\n", + "def build_record(set_id: str, direction: str, prefix: str, person_id: str, pair_id: str) -> Dict[str, str]:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if set_id == 'set1' and direction == 'neg_to_pos':\n", + " img1 = paths['flux_neg']\n", + " img2 = paths['iy_pos']\n", + " ref_image_path = paths['flux_neg']\n", + " target_image_path = paths['flux_pos']\n", + " elif set_id == 'set1' and direction == 'pos_to_neg':\n", + " img1 = paths['flux_pos']\n", + " img2 = paths['iy_neg']\n", + " ref_image_path = paths['flux_pos']\n", + " target_image_path = paths['flux_neg']\n", + " elif set_id == 'set2' and direction == 'pos_to_neg':\n", + " img1 = paths['iy_pos']\n", + " img2 = paths['flux_neg']\n", + " ref_image_path = paths['iy_pos']\n", + " target_image_path = paths['iy_neg']\n", + " elif set_id == 'set2' and direction == 'neg_to_pos':\n", + " img1 = paths['iy_neg']\n", + " img2 = paths['flux_pos']\n", + " ref_image_path = paths['iy_neg']\n", + " target_image_path = paths['iy_pos']\n", + " else:\n", + " raise ValueError((set_id, direction))\n", + " return {\n", + " 'set_id': set_id,\n", + " 'direction': direction,\n", + " 'prefix_family': prefix,\n", + " 'person_id': person_id,\n", + " 'pair_id': pair_id,\n", + " 'img1_path': str(img1.resolve()),\n", + " 'img2_path': str(img2.resolve()),\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'target_image_path': str(target_image_path.resolve()),\n", + " }\n", + "\n", + "def expand_split_to_records(split: Dict[str, Dict[str, List[str]]], set_id: str) -> List[Dict[str, str]]:\n", + " records = []\n", + " for prefix in PREFIXES:\n", + " for direction, people in split[prefix].items():\n", + " for person_id in people:\n", + " for pair_id in PAIRS:\n", + " records.append(build_record(set_id, direction, prefix, person_id, pair_id))\n", + " return records\n", + "\n", + "def write_json(path: Path, payload) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + "\n", + "def write_jsonl(path: Path, rows: Iterable[dict]) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " with path.open('w', encoding='utf-8') as f:\n", + " for row in rows:\n", + " f.write(json.dumps(row, ensure_ascii=False) + '\n", + "')\n", + "\n", + "def generate_dataset(records: List[Dict[str, str]], force_overwrite_npz: bool = False) -> Tuple[dict, List[dict]]:\n", + " metadata = {}\n", + " audit_rows = []\n", + " for record in tqdm(records):\n", + " img1 = Path(record['img1_path'])\n", + " img2 = Path(record['img2_path'])\n", + " target_image_path = Path(record['target_image_path'])\n", + " ref_image_path = Path(record['ref_image_path'])\n", + " embeds, masks, debug = build_dual_image_embedding(img1, img2, return_debug=True)\n", + " npz_path = save_step1x_npz(target_image_path, embeds, masks, force_overwrite=force_overwrite_npz)\n", + " key = str(target_image_path.resolve())\n", + " if key in metadata:\n", + " raise RuntimeError(f'Duplicate target image key: {key}')\n", + " metadata[key] = {\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'caption': FIXED_METADATA_CAPTION,\n", + " }\n", + " audit_row = dict(record)\n", + " audit_row['caption'] = FIXED_METADATA_CAPTION\n", + " audit_row['fixed_prompt'] = FIXED_EMBED_PROMPT\n", + " audit_row['embedding_npz_path'] = str(npz_path.resolve())\n", + " audit_row.update(debug)\n", + " audit_rows.append(audit_row)\n", + " return metadata, audit_rows\n", + "\n", + "set1_records = expand_split_to_records(set1_split, 'set1')\n", + "set2_records = expand_split_to_records(set2_split, 'set2')\n", + "assert len(set1_records) == 3000\n", + "assert len(set2_records) == 3000\n", + "all_records = set1_records + set2_records\n", + "assert len(all_records) == 6000\n", + "\n", + "SPLIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_split_v2.json'\n", + "SET1_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set1_flux2_3000_v2.json'\n", + "SET2_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set2_infiniteyou_3000_v2.json'\n", + "COMBINED_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_train_v2.json'\n", + "AUDIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_audit_v2.jsonl'\n", + "\n", + "split_payload = {\n", + " 'master_seed': MASTER_SEED,\n", + " 'set1_seed': SET1_SEED,\n", + " 'set2_seed': SET2_SEED,\n", + " 'prefixes': PREFIXES,\n", + " 'pairs': PAIRS,\n", + " 'fixed_embed_prompt': FIXED_EMBED_PROMPT,\n", + " 'fixed_metadata_caption': FIXED_METADATA_CAPTION,\n", + " 'set1': set1_split,\n", + " 'set2': set2_split,\n", + "}\n", + "write_json(SPLIT_OUTPUT, split_payload)\n", + "print('Prepared split payload at', SPLIT_OUTPUT)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"img1_path\": \"/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/source_img/data__p0000_pair00_s0__source.png\",\n", + " \"img2_path\": \"/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/reference_img/left/data__p0000_pair00_s0.png\",\n", + " \"prompt\": \"Editing instruction:\\n\\n- Raise the left corner of the mouth upwards and outward.\\n- Raise the right corner of the mouth upwards and outward.\\n- Lower the left side of the eyebrows slightly.\\n- Lower the right side of the eyebrows slightly.\\n- Tighten the lips into a smile.\\n- Relax the forehead muscles to reduce any frown lines.\",\n", + " \"embedding_shape\": [\n", + " 640,\n", + " 3584\n", + " ],\n", + " \"mask_sum\": 381\n", + "}\n", + "Set RUN_FULL_GENERATION = True to generate all 6000 samples.\n" + ] + } + ], + "source": [ + "RUN_SINGLE_PREVIEW = False\n", + "if RUN_SINGLE_PREVIEW:\n", + " preview = run_single_preview()\n", + " print(json.dumps(preview, ensure_ascii=False, indent=2))\n", + "else:\n", + " print('Set RUN_SINGLE_PREVIEW = True to generate one dual-image embedding preview.')\n", + "\n", + "if RUN_FULL_GENERATION:\n", + " set1_metadata, set1_audit = generate_dataset(set1_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " set2_metadata, set2_audit = generate_dataset(set2_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " combined_metadata = dict(set1_metadata)\n", + " overlap = set(combined_metadata).intersection(set2_metadata)\n", + " if overlap:\n", + " raise RuntimeError(f'Unexpected overlap between set1 and set2 target keys: {len(overlap)}')\n", + " combined_metadata.update(set2_metadata)\n", + " write_json(SET1_OUTPUT, set1_metadata)\n", + " write_json(SET2_OUTPUT, set2_metadata)\n", + " write_json(COMBINED_OUTPUT, combined_metadata)\n", + " write_jsonl(AUDIT_OUTPUT, set1_audit + set2_audit)\n", + " print('set1 samples =', len(set1_metadata))\n", + " print('set2 samples =', len(set2_metadata))\n", + " print('combined samples =', len(combined_metadata))\n", + " print('Saved:', SET1_OUTPUT)\n", + " print('Saved:', SET2_OUTPUT)\n", + " print('Saved:', COMBINED_OUTPUT)\n", + " print('Saved:', AUDIT_OUTPUT)\n", + "else:\n", + " print('Set RUN_FULL_GENERATION = True to generate all 6000 samples.')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "89ce9ff1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_set1_flux2_3000_v1.json')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SET1_OUTPUT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6586cdaf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_6000_audit_v1.jsonl')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3a2f3f3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "step1x_v11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/scripts/llm_prompt_v1_v3.ipynb b/scripts/llm_prompt_v1_v3.ipynb new file mode 100644 index 0000000..a387d30 --- /dev/null +++ b/scripts/llm_prompt_v1_v3.ipynb @@ -0,0 +1,768 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Qwen dual-image direct embedding cache for connector v3\n", + "\n", + "这个 notebook 做两件事:\n", + "1. 输入两张图片和固定 prompt,直接生成新的 embedding/mask cache\n", + "2. 按 6000 样本规则写出 metadata 与 `_step1x_te.npz`\n", + "\n", + "实现约束:\n", + "- 不修改当前 `modules/conditioner.py`\n", + "- 仅替换离线缓存生成逻辑\n", + "- 输出 `embeds/masks` 的 shape、dtype、字段名与原缓存兼容\n", + "- metadata 里的 `caption` 只保留固定字符串占位\n", + "- 每个 set 独立采样\n", + "- 每个 set 内 `data/data_seed_2/3/4/5` 各 120 人\n", + "- 每个 prefix 的 120 人拆成 60 人 `neg_to_pos` 和 60 人 `pos_to_neg`,人不重叠\n", + "\n", + "运行前注意:\n", + "- 这个 notebook 默认通过 `CUDA_VISIBLE_DEVICES=1` 使用物理 1 号卡\n", + "- 必须先 Restart Kernel,再从第一格开始运行,否则环境变量不会生效\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA_VISIBLE_DEVICES = 1\n", + "torch.cuda.is_available() = True\n", + "torch.cuda.device_count() = 1\n", + "DEVICE = cuda:0\n", + "MODEL_PATH = /scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct\n", + "OUTPUT_METADATA_DIR = /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata\n" + ] + } + ], + "source": [ + "from __future__ import annotations\n", + "\n", + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", + "\n", + "import json\n", + "import math\n", + "import random\n", + "import re\n", + "import shutil\n", + "from collections import defaultdict\n", + "from pathlib import Path\n", + "from typing import Dict, Iterable, List, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from PIL import Image\n", + "from tqdm.auto import tqdm\n", + "from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration\n", + "from qwen_vl_utils import process_vision_info\n", + "\n", + "REPO_ROOT = Path('/scratch3/f007yzf/repos/Step1X-Edit-clean')\n", + "DATA_ROOT = REPO_ROOT / 'training_data'\n", + "CACHE_ROOT = REPO_ROOT / 'cache' / '0307_dual_test'\n", + "MODEL_PATH = Path('/scratch3/f007yzf/models/step1x_v11/Qwen2.5-VL-7B-Instruct')\n", + "OUTPUT_METADATA_DIR = DATA_ROOT / 'metadata'\n", + "OUTPUT_METADATA_DIR.mkdir(parents=True, exist_ok=True)\n", + "CACHE_ROOT.mkdir(parents=True, exist_ok=True)\n", + "\n", + "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", + "DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32\n", + "PREVIEW_MAX_NEW_TOKENS = 150\n", + "MAX_LENGTH = 640\n", + "MASTER_SEED = 20260307\n", + "SET1_SEED = MASTER_SEED\n", + "SET2_SEED = MASTER_SEED + 1\n", + "PAIRS = [f'{i:02d}' for i in range(5)]\n", + "PREFIXES = ['data', 'data_seed_2', 'data_seed_3', 'data_seed_4', 'data_seed_5']\n", + "\n", + "SINGLE_IMG1_PATH = DATA_ROOT / 'source_img' / 'data__p0000_pair00_s0__source.png'\n", + "SINGLE_IMG2_PATH = DATA_ROOT / 'reference_img' / 'left' / 'data__p0000_pair00_s0.png'\n", + "SINGLE_PREVIEW_OUTPUT = OUTPUT_METADATA_DIR / 'llm_dual_embedding_preview_v3.json'\n", + "\n", + "RUN_SINGLE_PREVIEW = False\n", + "RUN_SMOKE_GENERATION = False\n", + "RUN_FULL_GENERATION = False\n", + "FORCE_OVERWRITE_NPZ = False\n", + "SMOKE_PERSON_COUNT = 10\n", + "\n", + "FIXED_DUAL_IMAGE_PROMPT = \"\"\"You are analyzing facial expressions for a controlled editing task.\n", + "Given:\n", + "- Image 1: source face to be edited\n", + "- Image 2: target expression reference\n", + "\n", + "Describe only the target facial-expression edit.\n", + "Focus only on expression change.\n", + "Keep identity, hairstyle, clothing, background, lighting, and camera unchanged.\"\"\"\n", + "\n", + "FIXED_METADATA_CAPTION = 'edit facial expression'\n", + "\n", + "IMAGE_RE = re.compile(r'^(data(?:_seed_\\d+)?)__p(\\d+)_pair(\\d+)_s(\\d+)(?:__(source|target))?\\.png$')\n", + "\n", + "print('CUDA_VISIBLE_DEVICES =', os.environ.get('CUDA_VISIBLE_DEVICES'))\n", + "print('torch.cuda.is_available() =', torch.cuda.is_available())\n", + "print('torch.cuda.device_count() =', torch.cuda.device_count())\n", + "print('DEVICE =', DEVICE)\n", + "print('MODEL_PATH =', MODEL_PATH)\n", + "print('OUTPUT_METADATA_DIR =', OUTPUT_METADATA_DIR)\n", + "print('CACHE_ROOT =', CACHE_ROOT)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading processor...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading Qwen2.5-VL model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "63dd799659e148749a96f3f0b62c2d46", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/5 [00:00 Image.Image:\n", + " return Image.open(path).convert('RGB')\n", + "\n", + "def build_dual_image_messages(img1: Image.Image, img2: Image.Image, fixed_prompt: str = FIXED_DUAL_IMAGE_PROMPT):\n", + " return [{\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': '\\n[Source Image]:'},\n", + " {'type': 'image', 'image': img1},\n", + " {'type': 'text', 'text': '\\n[Reference Image]:'},\n", + " {'type': 'image', 'image': img2},\n", + " {'type': 'text', 'text': f'\\nTask:\\n{fixed_prompt}'},\n", + " ],\n", + " }]\n", + "\n", + "def get_vision_token_ids() -> Tuple[int, int]:\n", + " tokenizer = processor.tokenizer\n", + " v_start_id = tokenizer.convert_tokens_to_ids('<|vision_start|>')\n", + " v_end_id = tokenizer.convert_tokens_to_ids('<|vision_end|>')\n", + " if v_start_id in [None, tokenizer.unk_token_id]:\n", + " v_start_id = 151652\n", + " if v_end_id in [None, tokenizer.unk_token_id]:\n", + " v_end_id = 151653\n", + " return v_start_id, v_end_id\n", + "\n", + "@torch.inference_mode()\n", + "def build_dual_image_embedding(\n", + " img1_path: Path,\n", + " img2_path: Path,\n", + " fixed_prompt: str = FIXED_DUAL_IMAGE_PROMPT,\n", + " max_length: int = MAX_LENGTH,\n", + " return_debug: bool = False,\n", + "):\n", + " img1 = load_pil(img1_path)\n", + " img2 = load_pil(img2_path)\n", + " messages = build_dual_image_messages(img1, img2, fixed_prompt=fixed_prompt)\n", + " text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n", + " image_inputs, _ = process_vision_info(messages)\n", + " inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors='pt')\n", + " outputs = model(\n", + " input_ids=inputs.input_ids.to(DEVICE),\n", + " attention_mask=inputs.attention_mask.to(DEVICE),\n", + " pixel_values=inputs.pixel_values.to(DEVICE),\n", + " image_grid_thw=inputs.image_grid_thw.to(DEVICE),\n", + " output_hidden_states=True,\n", + " )\n", + " emb = outputs.hidden_states[-1]\n", + " _, v_end_id = get_vision_token_ids()\n", + " input_ids_0 = inputs.input_ids[0]\n", + " v_ends = (input_ids_0 == v_end_id).nonzero(as_tuple=True)[0]\n", + " if len(v_ends) < 2:\n", + " raise RuntimeError(f'Expected at least two vision_end tokens, got {len(v_ends)}')\n", + " text_start = int(v_ends[-1].item()) + 1\n", + " usable = max(0, emb.shape[1] - text_start)\n", + " length = min(max_length, usable)\n", + " embeds = torch.zeros((max_length, HIDDEN_SIZE), dtype=torch.float32, device=DEVICE)\n", + " masks = torch.zeros((max_length,), dtype=torch.long, device=DEVICE)\n", + " if length > 0:\n", + " embeds[:length] = emb[0, text_start:text_start + length].to(torch.float32)\n", + " masks[:length] = 1\n", + " if return_debug:\n", + " debug = {\n", + " 'full_seq_len': int(emb.shape[1]),\n", + " 'text_start_index': text_start,\n", + " 'usable_len_before_truncation': int(usable),\n", + " 'final_len': int(length),\n", + " }\n", + " return embeds.cpu().numpy(), masks.cpu().numpy(), debug\n", + " return embeds.cpu().numpy(), masks.cpu().numpy()\n", + "\n", + "@torch.inference_mode()\n", + "def run_single_preview(img1_path: Path = SINGLE_IMG1_PATH, img2_path: Path = SINGLE_IMG2_PATH, output_json: Path = SINGLE_PREVIEW_OUTPUT):\n", + " img1 = load_pil(img1_path)\n", + " img2 = load_pil(img2_path)\n", + " messages = build_dual_image_messages(img1, img2)\n", + " preview_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n", + " preview_image_inputs, _ = process_vision_info(messages)\n", + " preview_inputs = processor(text=[preview_text], images=preview_image_inputs, padding=True, return_tensors='pt').to(DEVICE)\n", + " generated_ids = model.generate(**preview_inputs, max_new_tokens=PREVIEW_MAX_NEW_TOKENS, do_sample=False)\n", + " generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(preview_inputs.input_ids, generated_ids)]\n", + " decoded_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)[0].strip()\n", + " embeds, masks, debug = build_dual_image_embedding(img1_path, img2_path, return_debug=True)\n", + " payload = {\n", + " 'img1_path': str(img1_path.resolve()),\n", + " 'img2_path': str(img2_path.resolve()),\n", + " 'fixed_prompt': FIXED_DUAL_IMAGE_PROMPT,\n", + " 'fixed_metadata_caption': FIXED_METADATA_CAPTION,\n", + " 'decoded_text': decoded_text,\n", + " 'embedding_shape': list(embeds.shape),\n", + " 'mask_sum': int(masks.sum()),\n", + " **debug,\n", + " }\n", + " output_json.parent.mkdir(parents=True, exist_ok=True)\n", + " output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + " return payload\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data 579\n", + "data_seed_2 600\n", + "data_seed_3 571\n", + "data_seed_4 583\n", + "data_seed_5 584\n", + "set1 unique persons = 600\n", + "set2 unique persons = 600\n" + ] + } + ], + "source": [ + "ROOTS = {\n", + " 'flux_neg': DATA_ROOT / 'source_img',\n", + " 'flux_pos': DATA_ROOT / 'target_img',\n", + " 'iy_pos': DATA_ROOT / 'reference_img' / 'left',\n", + " 'iy_neg': DATA_ROOT / 'reference_img' / 'right',\n", + "}\n", + "\n", + "def candidate_filename(prefix: str, person_id: str, pair_id: str, kind: str) -> str:\n", + " stem = f'{prefix}__p{person_id}_pair{pair_id}_s0'\n", + " if kind == 'flux_neg':\n", + " return stem + '__source.png'\n", + " if kind == 'flux_pos':\n", + " return stem + '__target.png'\n", + " if kind in {'iy_pos', 'iy_neg'}:\n", + " return stem + '.png'\n", + " raise KeyError(kind)\n", + "\n", + "def resolve_image_paths(prefix: str, person_id: str, pair_id: str) -> Dict[str, Path]:\n", + " return {kind: ROOTS[kind] / candidate_filename(prefix, person_id, pair_id, kind) for kind in ROOTS}\n", + "\n", + "def person_complete(prefix: str, person_id: str) -> bool:\n", + " for pair_id in PAIRS:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if not all(path.exists() for path in paths.values()):\n", + " return False\n", + " return True\n", + "\n", + "def scan_candidates() -> Dict[str, List[str]]:\n", + " seen = defaultdict(set)\n", + " for root in ROOTS.values():\n", + " for path in root.glob('*.png'):\n", + " m = IMAGE_RE.match(path.name)\n", + " if not m:\n", + " continue\n", + " prefix, person_id, pair_id, s, _ = m.groups()\n", + " if prefix not in PREFIXES or s != '0':\n", + " continue\n", + " seen[prefix].add(person_id)\n", + " candidates = {}\n", + " for prefix in PREFIXES:\n", + " valid = sorted(pid for pid in seen[prefix] if person_complete(prefix, pid))\n", + " candidates[prefix] = valid\n", + " return candidates\n", + "\n", + "def build_person_to_prefixes(candidates_by_prefix: Dict[str, List[str]]) -> Dict[str, List[str]]:\n", + " person_to_prefixes = defaultdict(list)\n", + " for prefix, persons in candidates_by_prefix.items():\n", + " for pid in persons:\n", + " person_to_prefixes[pid].append(prefix)\n", + " return {pid: sorted(prefixes) for pid, prefixes in person_to_prefixes.items()}\n", + "\n", + "def hopcroft_karp(graph: Dict[str, List[str]], left_nodes: List[str], right_nodes: List[str]):\n", + " INF = 10 ** 9\n", + " pair_u = {u: None for u in left_nodes}\n", + " pair_v = {v: None for v in right_nodes}\n", + " dist = {}\n", + "\n", + " from collections import deque\n", + "\n", + " def bfs():\n", + " queue = deque()\n", + " for u in left_nodes:\n", + " if pair_u[u] is None:\n", + " dist[u] = 0\n", + " queue.append(u)\n", + " else:\n", + " dist[u] = INF\n", + " found = False\n", + " while queue:\n", + " u = queue.popleft()\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None:\n", + " found = True\n", + " elif dist[pu] == INF:\n", + " dist[pu] = dist[u] + 1\n", + " queue.append(pu)\n", + " return found\n", + "\n", + " def dfs(u):\n", + " for v in graph[u]:\n", + " pu = pair_v[v]\n", + " if pu is None or (dist[pu] == dist[u] + 1 and dfs(pu)):\n", + " pair_u[u] = v\n", + " pair_v[v] = u\n", + " return True\n", + " dist[u] = INF\n", + " return False\n", + "\n", + " matching = 0\n", + " while bfs():\n", + " for u in left_nodes:\n", + " if pair_u[u] is None and dfs(u):\n", + " matching += 1\n", + " return matching, pair_u\n", + "\n", + "def assign_people_to_prefixes(candidates_by_prefix: Dict[str, List[str]], seed: int) -> Dict[str, List[str]]:\n", + " person_to_prefixes = build_person_to_prefixes(candidates_by_prefix)\n", + " left_nodes = sorted(person_to_prefixes)\n", + " if len(left_nodes) != 600:\n", + " raise RuntimeError(f'Expected 600 unique person ids, got {len(left_nodes)}')\n", + " rng = random.Random(seed)\n", + " slot_map = {}\n", + " right_nodes = []\n", + " for prefix in PREFIXES:\n", + " slot_names = [f'{prefix}#{idx:03d}' for idx in range(120)]\n", + " rng.shuffle(slot_names)\n", + " slot_map[prefix] = slot_names\n", + " right_nodes.extend(slot_names)\n", + " graph = {}\n", + " shuffled_left = left_nodes[:]\n", + " rng.shuffle(shuffled_left)\n", + " for pid in shuffled_left:\n", + " neighbors = []\n", + " prefixes = person_to_prefixes[pid][:]\n", + " rng.shuffle(prefixes)\n", + " for prefix in prefixes:\n", + " neighbors.extend(slot_map[prefix])\n", + " graph[pid] = neighbors\n", + " matching, pair_u = hopcroft_karp(graph, shuffled_left, right_nodes)\n", + " if matching != 600:\n", + " raise RuntimeError(f'Failed to assign 600 people to prefix slots, only matched {matching}')\n", + " result = defaultdict(list)\n", + " for pid, slot in pair_u.items():\n", + " prefix = slot.split('#', 1)[0]\n", + " result[prefix].append(pid)\n", + " for prefix in PREFIXES:\n", + " result[prefix] = sorted(result[prefix])\n", + " if len(result[prefix]) != 120:\n", + " raise RuntimeError(f'Prefix {prefix} expected 120 people, got {len(result[prefix])}')\n", + " return dict(result)\n", + "\n", + "def split_prefix_people(prefix_people: Dict[str, List[str]], seed: int, first_name: str, second_name: str) -> Dict[str, Dict[str, List[str]]]:\n", + " rng = random.Random(seed)\n", + " result = {}\n", + " for prefix, people in prefix_people.items():\n", + " shuffled = people[:]\n", + " rng.shuffle(shuffled)\n", + " result[prefix] = {\n", + " first_name: sorted(shuffled[:60]),\n", + " second_name: sorted(shuffled[60:120]),\n", + " }\n", + " return result\n", + "\n", + "candidates_by_prefix = scan_candidates()\n", + "for prefix in PREFIXES:\n", + " print(prefix, len(candidates_by_prefix[prefix]))\n", + "set1_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET1_SEED)\n", + "set2_prefix_people = assign_people_to_prefixes(candidates_by_prefix, SET2_SEED)\n", + "set1_split = split_prefix_people(set1_prefix_people, SET1_SEED + 100, 'neg_to_pos', 'pos_to_neg')\n", + "set2_split = split_prefix_people(set2_prefix_people, SET2_SEED + 100, 'pos_to_neg', 'neg_to_pos')\n", + "print('set1 unique persons =', len(set().union(*[set(v) for v in set1_prefix_people.values()])))\n", + "print('set2 unique persons =', len(set().union(*[set(v) for v in set2_prefix_people.values()])))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prepared split payload at /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_6000_split_v1.json\n" + ] + } + ], + "source": [ + "def save_step1x_npz(target_image_path: Path, embeds: np.ndarray, masks: np.ndarray, force_overwrite: bool = False) -> Path:\n", + " npz_path = target_image_path.with_suffix('')\n", + " npz_path = npz_path.parent / f'{npz_path.name}_step1x_te.npz'\n", + " if npz_path.exists() and not force_overwrite:\n", + " existing = np.load(npz_path)\n", + " if 'embeds' in existing and 'masks' in existing:\n", + " return npz_path\n", + " np.savez(npz_path, embeds=embeds, masks=masks)\n", + " return npz_path\n", + "\n", + "def build_record(set_id: str, direction: str, prefix: str, person_id: str, pair_id: str) -> Dict[str, str]:\n", + " paths = resolve_image_paths(prefix, person_id, pair_id)\n", + " if set_id == 'set1' and direction == 'neg_to_pos':\n", + " img1 = paths['flux_neg']\n", + " img2 = paths['iy_pos']\n", + " ref_image_path = paths['flux_neg']\n", + " target_image_path = paths['flux_pos']\n", + " elif set_id == 'set1' and direction == 'pos_to_neg':\n", + " img1 = paths['flux_pos']\n", + " img2 = paths['iy_neg']\n", + " ref_image_path = paths['flux_pos']\n", + " target_image_path = paths['flux_neg']\n", + " elif set_id == 'set2' and direction == 'pos_to_neg':\n", + " img1 = paths['iy_pos']\n", + " img2 = paths['flux_neg']\n", + " ref_image_path = paths['iy_pos']\n", + " target_image_path = paths['iy_neg']\n", + " elif set_id == 'set2' and direction == 'neg_to_pos':\n", + " img1 = paths['iy_neg']\n", + " img2 = paths['flux_pos']\n", + " ref_image_path = paths['iy_neg']\n", + " target_image_path = paths['iy_pos']\n", + " else:\n", + " raise ValueError((set_id, direction))\n", + " return {\n", + " 'set_id': set_id,\n", + " 'direction': direction,\n", + " 'prefix_family': prefix,\n", + " 'person_id': person_id,\n", + " 'pair_id': pair_id,\n", + " 'img1_path': str(img1.resolve()),\n", + " 'img2_path': str(img2.resolve()),\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'target_image_path': str(target_image_path.resolve()),\n", + " }\n", + "\n", + "def expand_split_to_records(split: Dict[str, Dict[str, List[str]]], set_id: str) -> List[Dict[str, str]]:\n", + " records = []\n", + " for prefix in PREFIXES:\n", + " for direction, people in split[prefix].items():\n", + " for person_id in people:\n", + " for pair_id in PAIRS:\n", + " records.append(build_record(set_id, direction, prefix, person_id, pair_id))\n", + " return records\n", + "\n", + "def build_smoke_records_set1(person_count: int = SMOKE_PERSON_COUNT) -> List[Dict[str, str]]:\n", + " selected_people = []\n", + " for prefix in PREFIXES:\n", + " directions = set1_split[prefix]\n", + " people = sorted(set(directions['neg_to_pos']) | set(directions['pos_to_neg']))\n", + " for person_id in people:\n", + " selected_people.append((prefix, person_id))\n", + " selected_people = selected_people[:person_count]\n", + " records = []\n", + " for prefix, person_id in selected_people:\n", + " for pair_id in PAIRS:\n", + " records.append(build_record('set1', 'neg_to_pos', prefix, person_id, pair_id))\n", + " records.append(build_record('set1', 'pos_to_neg', prefix, person_id, pair_id))\n", + " return records\n", + "\n", + "def write_json(path: Path, payload) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))\n", + "\n", + "def write_jsonl(path: Path, rows: Iterable[dict]) -> None:\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " with path.open('w', encoding='utf-8') as f:\n", + " for row in rows:\n", + " f.write(json.dumps(row, ensure_ascii=False) + '\\n')\n", + "\n", + "def prepare_cached_target_image(target_image_path: Path, cache_root: Path) -> Path:\n", + " target_dir = cache_root / 'target_img'\n", + " target_dir.mkdir(parents=True, exist_ok=True)\n", + " cached_target_path = target_dir / target_image_path.name\n", + " if cached_target_path.exists():\n", + " return cached_target_path\n", + " try:\n", + " cached_target_path.symlink_to(target_image_path)\n", + " except OSError:\n", + " shutil.copy2(target_image_path, cached_target_path)\n", + " return cached_target_path\n", + "\n", + "def generate_dataset(records: List[Dict[str, str]], force_overwrite_npz: bool = False, cache_root: Path | None = None) -> Tuple[dict, List[dict]]:\n", + " metadata = {}\n", + " audit_rows = []\n", + " for record in tqdm(records):\n", + " img1 = Path(record['img1_path'])\n", + " img2 = Path(record['img2_path'])\n", + " original_target_image_path = Path(record['target_image_path'])\n", + " ref_image_path = Path(record['ref_image_path'])\n", + " target_image_path = prepare_cached_target_image(original_target_image_path, cache_root) if cache_root is not None else original_target_image_path\n", + " embeds, masks, debug = build_dual_image_embedding(img1, img2, return_debug=True)\n", + " npz_path = save_step1x_npz(target_image_path, embeds, masks, force_overwrite=force_overwrite_npz)\n", + " key = str(target_image_path.resolve())\n", + " if key in metadata:\n", + " raise RuntimeError(f'Duplicate target image key: {key}')\n", + " metadata[key] = {\n", + " 'ref_image_path': str(ref_image_path.resolve()),\n", + " 'caption': FIXED_METADATA_CAPTION,\n", + " }\n", + " audit_row = dict(record)\n", + " audit_row['caption'] = FIXED_METADATA_CAPTION\n", + " audit_row['fixed_prompt'] = FIXED_DUAL_IMAGE_PROMPT\n", + " audit_row['original_target_image_path'] = str(original_target_image_path.resolve())\n", + " audit_row['embedding_npz_path'] = str(npz_path.resolve())\n", + " audit_row.update(debug)\n", + " audit_rows.append(audit_row)\n", + " return metadata, audit_rows\n", + "\n", + "set1_records = expand_split_to_records(set1_split, 'set1')\n", + "set2_records = expand_split_to_records(set2_split, 'set2')\n", + "assert len(set1_records) == 3000\n", + "assert len(set2_records) == 3000\n", + "all_records = set1_records + set2_records\n", + "assert len(all_records) == 6000\n", + "\n", + "SPLIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_split_v3.json'\n", + "SET1_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set1_flux2_3000_v3.json'\n", + "SET2_OUTPUT = OUTPUT_METADATA_DIR / 'connector_set2_infiniteyou_3000_v3.json'\n", + "COMBINED_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_train_v3.json'\n", + "AUDIT_OUTPUT = OUTPUT_METADATA_DIR / 'connector_6000_audit_v3.jsonl'\n", + "SMOKE_RECORDS_OUTPUT = CACHE_ROOT / 'connector_smoke_set1_100_records_v3.json'\n", + "SMOKE_TRAIN_OUTPUT = CACHE_ROOT / 'connector_smoke_set1_100_train_v3.json'\n", + "SMOKE_AUDIT_OUTPUT = CACHE_ROOT / 'connector_smoke_set1_100_audit_v3.jsonl'\n", + "\n", + "split_payload = {\n", + " 'master_seed': MASTER_SEED,\n", + " 'set1_seed': SET1_SEED,\n", + " 'set2_seed': SET2_SEED,\n", + " 'prefixes': PREFIXES,\n", + " 'pairs': PAIRS,\n", + " 'fixed_dual_image_prompt': FIXED_DUAL_IMAGE_PROMPT,\n", + " 'fixed_metadata_caption': FIXED_METADATA_CAPTION,\n", + " 'set1': set1_split,\n", + " 'set2': set2_split,\n", + "}\n", + "write_json(SPLIT_OUTPUT, split_payload)\n", + "print('Prepared split payload at', SPLIT_OUTPUT)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"img1_path\": \"/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/source_img/data__p0000_pair00_s0__source.png\",\n", + " \"img2_path\": \"/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/reference_img/left/data__p0000_pair00_s0.png\",\n", + " \"prompt\": \"Editing instruction:\\n\\n- Raise the left corner of the mouth upwards and outward.\\n- Raise the right corner of the mouth upwards and outward.\\n- Lower the left side of the eyebrows slightly.\\n- Lower the right side of the eyebrows slightly.\\n- Tighten the lips into a smile.\\n- Relax the forehead muscles to reduce any frown lines.\",\n", + " \"embedding_shape\": [\n", + " 640,\n", + " 3584\n", + " ],\n", + " \"mask_sum\": 381\n", + "}\n", + "Set RUN_FULL_GENERATION = True to generate all 6000 samples.\n" + ] + } + ], + "source": [ + "RUN_SINGLE_PREVIEW = False\n", + "if RUN_SINGLE_PREVIEW:\n", + " preview = run_single_preview()\n", + " print(json.dumps(preview, ensure_ascii=False, indent=2))\n", + "else:\n", + " print('Set RUN_SINGLE_PREVIEW = True to generate one dual-image embedding preview.')\n", + "\n", + "if RUN_SMOKE_GENERATION:\n", + " smoke_records = build_smoke_records_set1(person_count=SMOKE_PERSON_COUNT)\n", + " assert len(smoke_records) == SMOKE_PERSON_COUNT * len(PAIRS) * 2\n", + " write_json(SMOKE_RECORDS_OUTPUT, smoke_records)\n", + " smoke_metadata, smoke_audit = generate_dataset(smoke_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ, cache_root=CACHE_ROOT)\n", + " write_json(SMOKE_TRAIN_OUTPUT, smoke_metadata)\n", + " write_jsonl(SMOKE_AUDIT_OUTPUT, smoke_audit)\n", + " print('smoke samples =', len(smoke_metadata))\n", + " print('Saved:', SMOKE_RECORDS_OUTPUT)\n", + " print('Saved:', SMOKE_TRAIN_OUTPUT)\n", + " print('Saved:', SMOKE_AUDIT_OUTPUT)\n", + "else:\n", + " print('Set RUN_SMOKE_GENERATION = True to generate the set1 smoke subset.')\n", + "\n", + "if RUN_FULL_GENERATION:\n", + " set1_metadata, set1_audit = generate_dataset(set1_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " set2_metadata, set2_audit = generate_dataset(set2_records, force_overwrite_npz=FORCE_OVERWRITE_NPZ)\n", + " combined_metadata = dict(set1_metadata)\n", + " overlap = set(combined_metadata).intersection(set2_metadata)\n", + " if overlap:\n", + " raise RuntimeError(f'Unexpected overlap between set1 and set2 target keys: {len(overlap)}')\n", + " combined_metadata.update(set2_metadata)\n", + " write_json(SET1_OUTPUT, set1_metadata)\n", + " write_json(SET2_OUTPUT, set2_metadata)\n", + " write_json(COMBINED_OUTPUT, combined_metadata)\n", + " write_jsonl(AUDIT_OUTPUT, set1_audit + set2_audit)\n", + " print('set1 samples =', len(set1_metadata))\n", + " print('set2 samples =', len(set2_metadata))\n", + " print('combined samples =', len(combined_metadata))\n", + " print('Saved:', SET1_OUTPUT)\n", + " print('Saved:', SET2_OUTPUT)\n", + " print('Saved:', COMBINED_OUTPUT)\n", + " print('Saved:', AUDIT_OUTPUT)\n", + "else:\n", + " print('Set RUN_FULL_GENERATION = True to generate all 6000 samples.')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "89ce9ff1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_set1_flux2_3000_v1.json')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SET1_OUTPUT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6586cdaf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata/connector_6000_audit_v1.jsonl')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3a2f3f3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "step1x_v11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/train_connector_v1.log b/train_connector_v1.log new file mode 100644 index 0000000..ed3c115 --- /dev/null +++ b/train_connector_v1.log @@ -0,0 +1,617 @@ +GPU status: +0, NVIDIA RTX 6000 Ada Generation, 25124, 49140, 100 +1, NVIDIA RTX 6000 Ada Generation, 2, 49140, 0 +2, NVIDIA RTX 6000 Ada Generation, 2, 49140, 0 +3, NVIDIA RTX 6000 Ada Generation, 25036, 49140, 100 +4, NVIDIA RTX 6000 Ada Generation, 4, 49140, 0 +5, NVIDIA RTX 6000 Ada Generation, 25102, 49140, 100 +6, NVIDIA RTX 6000 Ada Generation, 25002, 49140, 100 +7, NVIDIA RTX 6000 Ada Generation, 25054, 49140, 100 +Auto-selected GPU_ID=2 +[W307 14:46:41.584914773 socket.cpp:752] [c10d] The client socket cannot be initialized to connect to [localhost]:29502 (errno: 97 - Address family not supported by protocol). +2026-03-07 14:46:45 INFO set VIDEO_TOTAL_PIXELS: 90316800 vision_process.py:41 +2026-03-07 14:46:46 INFO highvram is enabled / highvramが有効です train_util.py:4246 + WARNING cache_latents_to_disk is enabled, so cache_latents is also enabled / train_util.py:4263 + cache_latents_to_diskが有効なため、cache_latentsを有効にします +Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`. +You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0. +2026-03-07 14:46:47 INFO Loading dataset config from library/data_configs/step1x_edit.toml kohya_trainer.py:500 + INFO loading existing metadata: train_util.py:2296 + /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata.json + WARNING npz file does not exist. ignore npz files / train_util.py:2409 + npzファイルが見つからないためnpzファイルを無視します + INFO [Dataset 0] config_util.py:609 + batch_size: 1 + resolution: (768, 768) + resize_interpolation: None + enable_bucket: False + + [Subset 0 of Dataset 0] + image_dir: "/scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/target_img" + image_count: 301 + num_repeats: 1 + shuffle_caption: False + keep_tokens: 1 + caption_dropout_rate: 0.0 + caption_dropout_every_n_epochs: 0 + caption_tag_dropout_rate: 0.0 + caption_prefix: None + caption_suffix: None + color_aug: False + flip_aug: False + face_crop_aug_range: None + random_crop: False + token_warmup_min: 1, + token_warmup_step: 0, + alpha_mask: False + resize_interpolation: None + custom_attributes: {} + metadata_file: /scratch3/f007yzf/repos/Step1X-Edit-clean/training_data/metadata.json + + + INFO [Prepare dataset 0] config_util.py:621 + INFO loading image sizes. train_util.py:1073 + 0%| | 0/301 [00:00 step1x_utils.py:107 +import network module: library.qwen_connector_module_v1 + INFO [Dataset 0] train_util.py:2737 + INFO caching latents with caching strategy. train_util.py:1201 + INFO caching latents... train_util.py:1250 + 0%| | 0/301 [00:00