diff --git a/.gitignore b/.gitignore index 04c27b4..643d990 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,6 @@ runs/ exps/ wandb wandb/ + +# generated calib data +datasets/ diff --git a/deepcompressor/app/diffusion/dataset/base.py b/deepcompressor/app/diffusion/dataset/base.py index 535b191..f0c177d 100644 --- a/deepcompressor/app/diffusion/dataset/base.py +++ b/deepcompressor/app/diffusion/dataset/base.py @@ -39,6 +39,9 @@ def __len__(self) -> int: return len(self.filepaths) def __getitem__(self, idx) -> dict[str, tp.Any]: + + # TODO verfify ZImage data loading. + data = np.load(self.filepaths[idx], allow_pickle=True).item() if isinstance(data["input_args"][0], str): name = data["input_args"][0] diff --git a/deepcompressor/app/diffusion/dataset/calib.py b/deepcompressor/app/diffusion/dataset/calib.py index f794d30..4b5b637 100644 --- a/deepcompressor/app/diffusion/dataset/calib.py +++ b/deepcompressor/app/diffusion/dataset/calib.py @@ -15,6 +15,7 @@ FluxSingleTransformerBlock, FluxTransformerBlock, ) +from diffusers.models.transformers.transformer_z_image import ZImageTransformerBlock from omniconfig import configclass from deepcompressor.data.cache import ( @@ -172,6 +173,18 @@ def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache: ), outputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()), ) + elif isinstance(module, ZImageTransformerBlock): + return IOTensorsCache( + inputs=TensorsCache( + OrderedDict( + x=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()), + attn_mask=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()), + freqs_cis=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()), + # TODO verify + ) + ), + outputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()), + ) elif isinstance(module, Attention): return IOTensorsCache( inputs=TensorsCache( diff --git a/deepcompressor/app/diffusion/dataset/collect/utils.py b/deepcompressor/app/diffusion/dataset/collect/utils.py index dc4169d..be9b1ce 100644 --- a/deepcompressor/app/diffusion/dataset/collect/utils.py +++ b/deepcompressor/app/diffusion/dataset/collect/utils.py @@ -10,6 +10,7 @@ FluxTransformer2DModel, PixArtTransformer2DModel, SanaTransformer2DModel, + ZImageTransformer2DModel, ) from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel @@ -58,10 +59,19 @@ def __call__( new_args.append(input_kwargs.pop("hidden_states")) elif isinstance(module, FluxTransformer2DModel): new_args.append(input_kwargs.pop("hidden_states")) + elif isinstance(module, ZImageTransformer2DModel): + new_args.append(input_kwargs.pop("x")) + new_args.append(input_kwargs.pop("t")) + new_args.append(input_kwargs.pop("cap_feats")) else: raise ValueError(f"Unknown model: {module}") cache = tree_map(lambda x: x.cpu(), {"input_args": new_args, "input_kwargs": input_kwargs, "outputs": output}) - split_cache = tree_split(cache) + + if isinstance(module, ZImageTransformer2DModel): + # assume that batch size is 1. + split_cache = [cache] + else: + split_cache = tree_split(cache) if isinstance(module, PixArtTransformer2DModel) and self.zero_redundancy: for cache in split_cache: diff --git a/deepcompressor/app/diffusion/nn/patch.py b/deepcompressor/app/diffusion/nn/patch.py index a39ff40..5bf4c1c 100644 --- a/deepcompressor/app/diffusion/nn/patch.py +++ b/deepcompressor/app/diffusion/nn/patch.py @@ -1,8 +1,10 @@ import torch.nn as nn from diffusers.models.attention_processor import Attention from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock +from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d +from deepcompressor.nn.patch.ff import convert_z_image_ff from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear from deepcompressor.utils import patch, tools @@ -116,3 +118,11 @@ def replace_attn_processor(model: nn.Module) -> None: logger.info(f"+ Replacing {name} processor with DiffusionAttentionProcessor.") module.set_processor(DiffusionAttentionProcessor(module.processor)) tools.logging.Formatter.indent_dec() + + +def replace_zimage_feedforward(z_image_model: ZImageTransformer2DModel) -> None: + """Replace custom FeedForward module in `ZImageTransformerBlock`s with standard FeedForward in diffusers lib.""" + for _, module in z_image_model.named_modules(): + if isinstance(module, ZImageTransformerBlock): + orig_ff = module.feed_forward + module.feed_forward = convert_z_image_ff(orig_ff) diff --git a/deepcompressor/app/diffusion/nn/struct.py b/deepcompressor/app/diffusion/nn/struct.py index a25c4f4..4d7197a 100644 --- a/deepcompressor/app/diffusion/nn/struct.py +++ b/deepcompressor/app/diffusion/nn/struct.py @@ -35,6 +35,12 @@ FluxTransformer2DModel, FluxTransformerBlock, ) +from diffusers.models.transformers.transformer_z_image import ( + ZImageTransformerBlock, + ZImageTransformer2DModel, + TimestepEmbedder, + RopeEmbedder +) from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel from diffusers.models.unets.unet_2d import UNet2DModel from diffusers.models.unets.unet_2d_blocks import ( @@ -46,6 +52,7 @@ UpBlock2D, ) from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.pipelines.z_image.pipeline_z_image import ZImagePipeline from diffusers.pipelines import ( FluxControlPipeline, FluxFillPipeline, @@ -71,7 +78,7 @@ from deepcompressor.nn.struct.base import BaseModuleStruct from deepcompressor.utils.common import join_name -from .attention import DiffusionAttentionProcessor +from deepcompressor.app.diffusion.nn.attention import DiffusionAttentionProcessor # endregion @@ -85,6 +92,7 @@ FluxSingleTransformerBlock, FluxTransformerBlock, SanaTransformerBlock, + ZImageTransformerBlock, ] UNET_BLOCK_CLS = tp.Union[ DownBlock2D, @@ -100,6 +108,7 @@ SD3Transformer2DModel, FluxTransformer2DModel, SanaTransformer2DModel, + ZImageTransformer2DModel, ] UNET_CLS = tp.Union[UNet2DModel, UNet2DConditionModel] MODEL_CLS = tp.Union[DIT_CLS, UNET_CLS] @@ -112,6 +121,7 @@ FluxControlPipeline, FluxFillPipeline, SanaPipeline, + ZImagePipeline, ] PIPELINE_CLS = tp.Union[UNET_PIPELINE_CLS, DIT_PIPELINE_CLS] @@ -268,15 +278,18 @@ def _default_construct( @classmethod def _get_default_key_map(cls) -> dict[str, set[str]]: - unet_key_map = UNetStruct._get_default_key_map() - dit_key_map = DiTStruct._get_default_key_map() - flux_key_map = FluxStruct._get_default_key_map() key_map: dict[str, set[str]] = defaultdict(set) - for rkey, keys in unet_key_map.items(): - key_map[rkey].update(keys) - for rkey, keys in dit_key_map.items(): - key_map[rkey].update(keys) - for rkey, keys in flux_key_map.items(): + # unet_key_map = UNetStruct._get_default_key_map() + # dit_key_map = DiTStruct._get_default_key_map() + # flux_key_map = FluxStruct._get_default_key_map() + # for rkey, keys in unet_key_map.items(): + # key_map[rkey].update(keys) + # for rkey, keys in dit_key_map.items(): + # key_map[rkey].update(keys) + # for rkey, keys in flux_key_map.items(): + # key_map[rkey].update(keys) + zimage_key_map = ZImageStruct._get_default_key_map() + for rkey, keys in zimage_key_map.items(): key_map[rkey].update(keys) return {k: v for k, v in key_map.items() if v} @@ -362,7 +375,7 @@ def _default_construct( o_proj_rname = "to_out.0" assert isinstance(o_proj, nn.Linear) elif parent is not None: - assert isinstance(parent.module, FluxSingleTransformerBlock) + assert isinstance(parent.module, (FluxSingleTransformerBlock, ZImageTransformerBlock)) assert isinstance(parent.module.proj_out, ConcatLinear) assert len(parent.module.proj_out.linears) == 2 o_proj = parent.module.proj_out.linears[0] @@ -580,11 +593,11 @@ def __post_init__(self) -> None: super().__post_init__() self.norm_key = join_name(self.key, self.norm_rkey, sep="_") self.add_norm_key = join_name(self.key, self.add_norm_rkey, sep="_") - self.attn_norm_structs = [ + self.pre_attn_norm_structs = [ DiffusionModuleStruct(norm, parent=self, fname="pre_attn_norm", rname=rname, rkey=self.norm_rkey, idx=idx) for idx, (norm, rname) in enumerate(zip(self.pre_attn_norms, self.pre_attn_norm_rnames, strict=True)) ] - self.add_attn_norm_structs = [ + self.pre_attn_add_norm_structs = [ DiffusionModuleStruct( norm, parent=self, fname="pre_attn_add_norm", rname=rname, rkey=self.add_norm_rkey, idx=idx ) @@ -606,10 +619,10 @@ def __post_init__(self) -> None: ) def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]: - for attn_norm in self.attn_norm_structs: + for attn_norm in self.pre_attn_norm_structs: if attn_norm.module is not None: yield from attn_norm.named_key_modules() - for add_attn_norm in self.add_attn_norm_structs: + for add_attn_norm in self.pre_attn_add_norm_structs: if add_attn_norm.module is not None: yield from add_attn_norm.named_key_modules() for attn_struct in self.attn_structs: @@ -705,6 +718,15 @@ def _default_construct( ffn, ffn_rname = module.ff, "ff" pre_add_ffn_norm, pre_add_ffn_norm_rname = module.norm2_context, "norm2_context" add_ffn, add_ffn_rname = module.ff_context, "ff_context" + elif isinstance(module, ZImageTransformerBlock): + parallel = False + norm_type, add_norm_type = "rms_norm", None + pre_attn_norms, pre_attn_norm_rnames = [module.attention_norm1], ["attention_norm1"] + attns, attn_rnames = [module.attention], ["attention"] + pre_attn_add_norms, pre_attn_add_norm_rnames = [], [] + pre_ffn_norm, pre_ffn_norm_rname = module.ffn_norm1, "ffn_norm1" + ffn, ffn_rname = module.feed_forward, "feed_forward" + pre_add_ffn_norm, pre_add_ffn_norm_rname, add_ffn, add_ffn_rname = None, "", None, "" else: raise NotImplementedError(f"Unsupported module type: {type(module)}") return DiffusionTransformerBlockStruct( @@ -1696,6 +1718,8 @@ def _default_construct( module = module.transformer if isinstance(module, FluxTransformer2DModel): return FluxStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs) + elif isinstance(module, ZImageTransformer2DModel): + return ZImageStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs) else: if isinstance(module, PixArtTransformer2DModel): input_embed, input_embed_rname = module.pos_embed, "pos_embed" @@ -1946,6 +1970,260 @@ def _get_default_key_map(cls) -> dict[str, set[str]]: return {k: v for k, v in key_map.items() if v} +@dataclass(kw_only=True) +class ZImageStruct(DiffusionModelStruct, DiffusionTransformerStruct): + # region relative keys + # TODO + noise_refiner_rkey: tp.ClassVar[str] = "nrk" + context_refiner_rkey: tp.ClassVar[str] = "crk" + layers_rkey: tp.ClassVar[str] = "lk" + # endregion + + module: ZImageTransformer2DModel = field(repr=False, kw_only=False) + """the module of ZImageTransformer2DModel""" + + # region child modules + all_x_embedder: nn.ModuleDict + all_final_layer: nn.ModuleDict + noise_refiner: nn.ModuleList + context_refiner: nn.ModuleList + t_embedder: TimestepEmbedder + cap_embedder: nn.Sequential + x_pad_token: nn.Parameter + cap_pad_token: nn.Parameter + layers: nn.ModuleList + rope_embedder: RopeEmbedder + # endregion + + # region relative names + all_x_embedder_rname: str + all_final_layer_rname: str + noise_refiner_rname: str + context_refiner_rname: str + t_embedder_rname: str + cap_embedder_rname: str + x_pad_token_rname: str + cap_pad_token_rname: str + layers_rname: str + rope_embedder_rname: str + # endregion + + # region absolute names + noise_refiner_names: list[str] = field(init=False, repr=False) + context_refiner_names: list[str] = field(init=False, repr=False) + layers_names: list[str] = field(init=False, repr=False) + # endregion + + # region absolute keys + # TODO + # endregion + + # region child structs + noise_refiner_structs: list[DiffusionTransformerBlockStruct] = field(init=False) + context_refiner_structs: list[DiffusionTransformerBlockStruct] = field(init=False) + layers_structs: list[DiffusionTransformerBlockStruct] = field(init=False) + # endregion + + + @property + def num_blocks(self) -> int: + return len(self.noise_refiner_structs) + len(self.context_refiner_structs) + len(self.layers_structs) + + @property + def block_structs(self) -> list[DiffusionTransformerBlockStruct]: + return [*self.noise_refiner_structs, *self.context_refiner_structs, *self.layers_structs] + + @property + def block_names(self) -> list[str]: + return [*self.noise_refiner_names, *self.context_refiner_names, *self.layers_names] + + def __post_init__(self) -> None: + BaseModuleStruct.__post_init__(self) + noise_refiner_indexed_rnames = [ + f"{self.noise_refiner_rname}.{idx}" for idx in range(len(self.noise_refiner)) + ] + self.noise_refiner_names = [join_name(self.name, indexed_rname) for indexed_rname in noise_refiner_indexed_rnames] + + context_refiner_indexed_rnames = [ + f"{self.context_refiner_rname}.{idx}" for idx in range(len(self.context_refiner)) + ] + self.context_refiner_names = [join_name(self.name, indexed_rname) for indexed_rname in context_refiner_indexed_rnames] + + layers_indexed_rnames = [ + f"{self.layers_rname}.{idx}" for idx in range(len(self.layers)) + ] + self.layers_names = [join_name(self.name, indexed_rname) for indexed_rname in layers_indexed_rnames] + + self.pre_module_structs = OrderedDict() + self.post_module_structs = OrderedDict() + + self.noise_refiner_structs = [ + self.transformer_block_struct_cls.construct( + block, + parent=self, + fname="noise_refiner", + rname=rname, + rkey=self.noise_refiner_rkey, # TODO + idx=idx, + ) + for idx, (block, rname) in enumerate( + zip(self.noise_refiner, noise_refiner_indexed_rnames, strict=True) + ) + ] + + self.context_refiner_structs = [ + self.transformer_block_struct_cls.construct( + block, + parent=self, + fname="context_refiner", + rname=rname, + rkey=self.context_refiner_rkey, # TODO + idx=idx, + ) + for idx, (block, rname) in enumerate( + zip(self.context_refiner, context_refiner_indexed_rnames, strict=True) + ) + ] + + self.layers_structs = [ + self.transformer_block_struct_cls.construct( + block, + parent=self, + fname="layers", + rname=rname, + rkey=self.layers_rkey, # TODO + idx=idx, + ) + for idx, (block, rname) in enumerate( + zip(self.layers, layers_indexed_rnames, strict=True) + ) + ] + + + def _get_iter_block_activations_args( + self, **input_kwargs + ) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]: + layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], [] + + layers.extend(self.noise_refiner) + layer_structs.extend(self.noise_refiner_structs) + use_prev_layer_outputs.append(False) + use_prev_layer_outputs.extend([True] * (len(self.noise_refiner) - 1)) + recomputes.extend([False] * len(self.noise_refiner)) + + layers.extend(self.context_refiner) + layer_structs.extend(self.context_refiner_structs) + use_prev_layer_outputs.append(False) + use_prev_layer_outputs.extend([True] * (len(self.context_refiner) - 1)) + recomputes.extend([False] * len(self.context_refiner)) + + layers.extend(self.layers) + layer_structs.extend(self.layers_structs) + use_prev_layer_outputs.append(False) + use_prev_layer_outputs.extend([True] * (len(self.layers) - 1)) + recomputes.extend([False] * len(self.layers)) + + return layers, layer_structs, recomputes, use_prev_layer_outputs + + + def get_prev_module_keys(self) -> tuple[str, ...]: + return tuple() + + + def get_post_module_keys(self) -> tuple[str, ...]: + return tuple() + + + @staticmethod + def _default_construct( + module: tp.Union[ZImagePipeline, ZImageTransformer2DModel], + /, + parent: tp.Optional[BaseModuleStruct] = None, + fname: str = "", + rname: str = "", + rkey: str = "", + idx: int = 0, + **kwargs, + ) -> "ZImageStruct": + if isinstance(module, ZImagePipeline): + module = module.transformer + if isinstance(module, ZImageTransformer2DModel): + all_x_embedder, all_x_embedder_rname = module.all_x_embedder, "all_x_embedder" + all_final_layer, all_final_layer_rname = module.all_final_layer, "all_final_layer" + noise_refiner, noise_refiner_rname = module.noise_refiner, "noise_refiner" + context_refiner, context_refiner_rname = module.context_refiner, "context_refiner" + t_embedder, t_embedder_rname = module.t_embedder, "t_embedder" + cap_embedder, cap_embedder_rname = module.cap_embedder, "cap_embedder" + x_pad_token, x_pad_token_rname = module.x_pad_token, "x_pad_token" + cap_pad_token, cap_pad_token_rname = module.cap_pad_token, "cap_pad_token" + layers, layers_rname = module.layers, "layers" + rope_embedder, rope_embedder_rname = module.rope_embedder, "rope_embedder" + return ZImageStruct( + module=module, + parent=parent, + fname=fname, + idx=idx, + rname=rname, + rkey=rkey, + + all_x_embedder=all_x_embedder, + all_final_layer=all_final_layer, + noise_refiner=noise_refiner, + context_refiner=context_refiner, + t_embedder=t_embedder, + cap_embedder=cap_embedder, + x_pad_token=x_pad_token, + cap_pad_token=cap_pad_token, + layers=layers, + rope_embedder=rope_embedder, + + all_x_embedder_rname=all_x_embedder_rname, + all_final_layer_rname=all_final_layer_rname, + noise_refiner_rname=noise_refiner_rname, + context_refiner_rname=context_refiner_rname, + t_embedder_rname=t_embedder_rname, + cap_embedder_rname=cap_embedder_rname, + x_pad_token_rname=x_pad_token_rname, + cap_pad_token_rname=cap_pad_token_rname, + layers_rname=layers_rname, + rope_embedder_rname=rope_embedder_rname, + + # these fields are not valid in Z-Image model, just hard code to None + norm_in=None, + proj_in=None, + norm_out=None, + proj_out=None, + norm_in_rname=None, + proj_in_rname=None, + norm_out_rname=None, + proj_out_rname=None, + transformer_blocks=None, + transformer_blocks_rname=None + ) + raise NotImplementedError(f"Unsupported module type: {type(module)}") + + @classmethod + def _get_default_key_map(cls) -> dict[str, set[str]]: + key_map: dict[str, set[str]] = defaultdict(set) + for block_rkey, block_cls in ( + (cls.noise_refiner_rkey, cls.transformer_block_struct_cls), + (cls.context_refiner_rkey, cls.transformer_block_struct_cls), + (cls.layers_rkey, cls.transformer_block_struct_cls), + ): + block_key = block_rkey + block_key_map = block_cls._get_default_key_map() + for rkey, keys in block_key_map.items(): + brkey = join_name(block_rkey, rkey, sep="_") + for key in keys: + key = join_name(block_key, key, sep="_") + key_map[rkey].add(key) + key_map[brkey].add(key) + if block_rkey: + key_map[block_rkey].add(key) + return {k: v for k, v in key_map.items() if v} + + + DiffusionAttentionStruct.register_factory(Attention, DiffusionAttentionStruct._default_construct) DiffusionFeedForwardStruct.register_factory( @@ -1962,8 +2240,16 @@ def _get_default_key_map(cls) -> dict[str, set[str]]: tp.Union[FluxPipeline, FluxControlPipeline, FluxTransformer2DModel], FluxStruct._default_construct ) +ZImageStruct.register_factory( + tp.Union[ZImagePipeline, ZImageTransformer2DModel], ZImageStruct._default_construct +) + DiTStruct.register_factory(tp.Union[DIT_PIPELINE_CLS, DIT_CLS], DiTStruct._default_construct) DiffusionTransformerStruct.register_factory(Transformer2DModel, DiffusionTransformerStruct._default_construct) DiffusionModelStruct.register_factory(tp.Union[PIPELINE_CLS, MODEL_CLS], DiffusionModelStruct._default_construct) + + +if __name__ == "__main__": + print(ZImageStruct._get_default_key_map()) diff --git a/deepcompressor/app/diffusion/pipeline/config.py b/deepcompressor/app/diffusion/pipeline/config.py index bc7cbe2..5d2488c 100644 --- a/deepcompressor/app/diffusion/pipeline/config.py +++ b/deepcompressor/app/diffusion/pipeline/config.py @@ -6,6 +6,8 @@ from dataclasses import dataclass, field import torch +from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel +from diffusers.pipelines.z_image.pipeline_z_image import ZImagePipeline from diffusers.pipelines import ( AutoPipelineForText2Image, DiffusionPipeline, @@ -22,12 +24,14 @@ from deepcompressor.utils import tools from deepcompressor.utils.hooks import AccumBranchHook, ProcessHook +from .customized.zimage import check_z_image_customized_path, build_customized_z_image_pipeline from ....nn.patch.linear import ConcatLinear, ShiftedLinear from ....nn.patch.lowrank import LowRankBranch from ..nn.patch import ( - replace_fused_linear_with_concat_linear, - replace_up_block_conv_with_concat_conv, - shift_input_activations, + # replace_fused_linear_with_concat_linear, + # replace_up_block_conv_with_concat_conv, + replace_zimage_feedforward, + # shift_input_activations, ) __all__ = ["DiffusionPipelineConfig"] @@ -344,6 +348,10 @@ def _default_build( path = "black-forest-labs/FLUX.1-Fill-dev" elif name == "flux.1-schnell": path = "black-forest-labs/FLUX.1-schnell" + elif name == "z-image-turbo": + path = "Tongyi-MAI/Z-Image-Turbo" + elif name == "z-image-customized": + assert check_z_image_customized_path(path), f"Invalid Path for z-image-customized {path}." else: raise ValueError(f"Path for {name} is not specified.") if name in ["flux.1-canny-dev", "flux.1-depth-dev"]: @@ -357,14 +365,20 @@ def _default_build( pipeline.text_encoder.to(dtype) else: pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype) + elif name == "z-image-turbo": + pipeline = ZImagePipeline.from_pretrained(path, torch_dtype=dtype, low_cpu_mem_usage=False) + elif name == "z-image-customized": + pipeline = build_customized_z_image_pipeline(name, path, dtype, device) else: pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype) pipeline = pipeline.to(device) model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer - replace_fused_linear_with_concat_linear(model) - replace_up_block_conv_with_concat_conv(model) - if shift_activations: - shift_input_activations(model) + # replace_fused_linear_with_concat_linear(model) + # replace_up_block_conv_with_concat_conv(model) + if isinstance(model, ZImageTransformer2DModel): + replace_zimage_feedforward(model) + # if shift_activations: + # shift_input_activations(model) return pipeline @staticmethod diff --git a/deepcompressor/app/diffusion/pipeline/customized/__init__.py b/deepcompressor/app/diffusion/pipeline/customized/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepcompressor/app/diffusion/pipeline/customized/state_dict.txt b/deepcompressor/app/diffusion/pipeline/customized/state_dict.txt new file mode 100644 index 0000000..c1dc50c --- /dev/null +++ b/deepcompressor/app/diffusion/pipeline/customized/state_dict.txt @@ -0,0 +1,454 @@ +model.diffusion_model.cap_embedder.0.weight: tensor shape: torch.Size([2560]), dtype: torch.bfloat16 +model.diffusion_model.cap_embedder.1.bias: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.cap_embedder.1.weight: tensor shape: torch.Size([3840, 2560]), dtype: torch.bfloat16 +model.diffusion_model.cap_pad_token: tensor shape: torch.Size([1, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.0.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.context_refiner.1.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.final_layer.adaLN_modulation.1.bias: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.final_layer.adaLN_modulation.1.weight: tensor shape: torch.Size([3840, 256]), dtype: torch.bfloat16 +model.diffusion_model.final_layer.linear.bias: tensor shape: torch.Size([64]), dtype: torch.bfloat16 +model.diffusion_model.final_layer.linear.weight: tensor shape: torch.Size([64, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.0.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.1.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.10.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.11.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.12.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.13.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.14.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.15.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.16.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.17.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.18.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.19.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.2.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.20.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.21.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.22.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.23.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.24.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.25.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.26.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.27.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.28.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.29.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.3.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.4.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.5.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.6.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.7.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.8.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.layers.9.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.0.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.adaLN_modulation.0.bias: tensor shape: torch.Size([15360]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.adaLN_modulation.0.weight: tensor shape: torch.Size([15360, 256]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.attention.k_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.attention.out.weight: tensor shape: torch.Size([3840, 3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.attention.q_norm.weight: tensor shape: torch.Size([128]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.attention.qkv.weight: tensor shape: torch.Size([11520, 3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.attention_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.attention_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.feed_forward.w1.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.feed_forward.w2.weight: tensor shape: torch.Size([3840, 10240]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.feed_forward.w3.weight: tensor shape: torch.Size([10240, 3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.ffn_norm1.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.noise_refiner.1.ffn_norm2.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.norm_final.weight: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.t_embedder.mlp.0.bias: tensor shape: torch.Size([1024]), dtype: torch.bfloat16 +model.diffusion_model.t_embedder.mlp.0.weight: tensor shape: torch.Size([1024, 256]), dtype: torch.bfloat16 +model.diffusion_model.t_embedder.mlp.2.bias: tensor shape: torch.Size([256]), dtype: torch.bfloat16 +model.diffusion_model.t_embedder.mlp.2.weight: tensor shape: torch.Size([256, 1024]), dtype: torch.bfloat16 +model.diffusion_model.x_embedder.bias: tensor shape: torch.Size([3840]), dtype: torch.bfloat16 +model.diffusion_model.x_embedder.weight: tensor shape: torch.Size([3840, 64]), dtype: torch.bfloat16 +model.diffusion_model.x_pad_token: tensor shape: torch.Size([1, 3840]), dtype: torch.bfloat16 \ No newline at end of file diff --git a/deepcompressor/app/diffusion/pipeline/customized/zimage.py b/deepcompressor/app/diffusion/pipeline/customized/zimage.py new file mode 100644 index 0000000..8c365a6 --- /dev/null +++ b/deepcompressor/app/diffusion/pipeline/customized/zimage.py @@ -0,0 +1,156 @@ +import json +import torch +from pathlib import Path + +from diffusers.pipelines import DiffusionPipeline +from diffusers.pipelines.z_image.pipeline_z_image import ZImagePipeline +from diffusers.models.transformers import ZImageTransformer2DModel +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL + +from safetensors.torch import load_file + + +def patch_transformer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + replaced = {} + for key, value in state_dict.items(): + patched_key = key.replace("model.diffusion_model.", "") + if "attention.qkv" in patched_key: + to_q, to_k, to_v = torch.chunk(value, 3, dim=0) + replaced[patched_key.replace('qkv', 'to_q')] = to_q + replaced[patched_key.replace('qkv', 'to_k')] = to_k + replaced[patched_key.replace('qkv', 'to_v')] = to_v + elif "attention.out" in patched_key: + replaced[patched_key.replace("out", "to_out.0")] = value + elif "attention.q_norm" in patched_key: + replaced[patched_key.replace("q_norm", "norm_q")] = value + elif "attention.k_norm" in patched_key: + replaced[patched_key.replace("k_norm", "norm_k")] = value + elif "final_layer" in patched_key: + replaced[patched_key.replace("final_layer", "all_final_layer.2-1")] = value + elif "x_embedder" in patched_key: + replaced[patched_key.replace("x_embedder", "all_x_embedder.2-1")] = value + elif "norm_final" in patched_key: + # `norm_final` is not used in Z-Image Turbo + continue + else: + replaced[patched_key] = value + return replaced + + +def load_customized_z_image_transformer(path: str, dtype: str | torch.dtype): + transformer_config = json.load(open(f"{path}/config.json", "r"))["transformer_config"] + with torch.device("meta"): + transformer = ZImageTransformer2DModel.from_config(transformer_config).to(dtype) + state_dict = load_file(f"{path}/transformer.safetensors", device="cpu") + state_dict = patch_transformer_state_dict(state_dict) + transformer.load_state_dict(state_dict, assign=True) + return transformer + + +def patch_vae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + replaced = {} + for key, value in state_dict.items(): + if "decoder.norm_out" in key or "encoder.norm_out" in key: + replaced[key.replace(".norm_out.", ".conv_norm_out.")] = value + elif "decoder.mid" in key or "encoder.mid" in key: + patched_key = key.replace(".mid.", ".mid_block.") + if "attn_1" in patched_key: + patched_key = patched_key.replace("attn_1", "attentions.0") + if "attentions.0.k" in patched_key: + patched_key = patched_key.replace("attentions.0.k", "attentions.0.to_k") + if "to_k.weight" in patched_key: + replaced[patched_key] = value.squeeze() + else: + replaced[patched_key] = value + elif "attentions.0.norm" in patched_key: + replaced[patched_key.replace("attentions.0.norm", "attentions.0.group_norm")] = value + elif "attentions.0.proj_out" in patched_key: + replaced[patched_key.replace("attentions.0.proj_out", "attentions.0.to_out.0")] = value.squeeze() + elif "attentions.0.q" in patched_key: + patched_key = patched_key.replace("attentions.0.q", "attentions.0.to_q") + if "to_q.weight" in patched_key: + replaced[patched_key] = value.squeeze() + else: + replaced[patched_key] = value + elif "attentions.0.v" in patched_key: + patched_key = patched_key.replace("attentions.0.v", "attentions.0.to_v") + if "to_v.weight" in patched_key: + replaced[patched_key] = value.squeeze() + else: + replaced[patched_key] = value + else: + raise ValueError(f"Unexpected key in VAE state dict: {key}") + elif "block_1" in patched_key: + replaced[patched_key.replace("block_1", "resnets.0")] = value + elif "block_2" in patched_key: + replaced[patched_key.replace("block_2", "resnets.1")] = value + else: + raise ValueError(f"Unexpected key in VAE state dict: {key}") + elif "decoder.up" in key: + if "decoder.up.0" in key: + patched_key = key.replace("decoder.up.0", "decoder.up_blocks.3") + elif "decoder.up.1" in key: + patched_key = key.replace("decoder.up.1", "decoder.up_blocks.2") + elif "decoder.up.2" in key: + patched_key = key.replace("decoder.up.2", "decoder.up_blocks.1") + elif "decoder.up.3" in key: + patched_key = key.replace("decoder.up.3", "decoder.up_blocks.0") + else: + raise ValueError(f"Unexpected key in VAE state dict: {key}") + if ".block." in patched_key: + patched_key = patched_key.replace(".block.", ".resnets.") + if "nin_shortcut" in patched_key: + patched_key = patched_key.replace("nin_shortcut", "conv_shortcut") + elif ".upsample." in patched_key: + patched_key = patched_key.replace(".upsample.", ".upsamplers.0.") + else: + raise ValueError(f"Unexpected key in VAE state dict: {key}") + replaced[patched_key] = value + elif "encoder.down" in key: + patched_key = key.replace("encoder.down.", "encoder.down_blocks.") + if ".block." in patched_key: + patched_key = patched_key.replace(".block.", ".resnets.") + if "nin_shortcut" in patched_key: + patched_key = patched_key.replace("nin_shortcut", "conv_shortcut") + elif ".downsample." in patched_key: + patched_key = patched_key.replace(".downsample.", ".downsamplers.0.") + else: + raise ValueError(f"Unexpected key in VAE state dict: {key}") + replaced[patched_key] = value + else: + replaced[key] = value + return replaced + +def load_customized_z_image_vae(path: str, dtype: str | torch.dtype): + vae_config = json.load(open(f"{path}/config.json", "r"))["vae_config"] + with torch.device("meta"): + vae = AutoencoderKL.from_config(vae_config).to(dtype) + state_dict = load_file(f"{path}/ae.safetensors", device="cpu") + state_dict = patch_vae_state_dict(state_dict) + vae.load_state_dict(state_dict, assign=True) + return vae + + +def build_customized_z_image_pipeline(name: str, path: str, dtype: str | torch.dtype, device: str | torch.device) -> DiffusionPipeline: + assert name == "z-image-customized", f"Unsupported pipeline name: {name}" + DIFFUSERS_REPO_ID = "Tongyi-MAI/Z-Image-Turbo" + transformer = load_customized_z_image_transformer(path, dtype) + vae = load_customized_z_image_vae(path, dtype) + pipe = ZImagePipeline.from_pretrained( + DIFFUSERS_REPO_ID, + transformer=transformer, + vae=vae, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=False + ).to(device) + return pipe + + +def check_z_image_customized_path(path: str) -> bool: + dir_path = Path(path) + if not dir_path.is_dir(): + return False + transformer_file = dir_path / "transformer.safetensors" + ae_file = dir_path / "ae.safetensors" + config_file = dir_path / "config.json" + return transformer_file.is_file() and ae_file.is_file() and config_file.is_file() diff --git a/deepcompressor/backend/nunchaku/common.py b/deepcompressor/backend/nunchaku/common.py index f0886ef..57a6b51 100644 --- a/deepcompressor/backend/nunchaku/common.py +++ b/deepcompressor/backend/nunchaku/common.py @@ -194,6 +194,7 @@ def convert_to_nunchaku_transformer_block_state_dict( print( f" - Converting {block_name} weights of {candidate_local_names} to {converted_local_name}." f" (smooth_fused={smooth_fused}, shifted={shift is not None}, float_point={float_point})" + f" smooth={type(smooth)}, branch={type(branch)}, bias={type(bias)}, shift={type(shift)}" ) update_state_dict( converted, diff --git a/deepcompressor/backend/nunchaku/convert.py b/deepcompressor/backend/nunchaku/convert.py index d37d0ba..a50bdd8 100644 --- a/deepcompressor/backend/nunchaku/convert.py +++ b/deepcompressor/backend/nunchaku/convert.py @@ -7,6 +7,7 @@ import torch import tqdm +from .z_image import convert_to_nunchaku_z_image_state_dicts from .flux import convert_to_nunchaku_flux_state_dicts @@ -16,6 +17,7 @@ parser.add_argument("--output-root", type=str, default="", help="root to the output checkpoint directory.") parser.add_argument("--model-name", type=str, default=None, help="name of the model.") parser.add_argument("--float-point", action="store_true", help="use float-point 4-bit quantization.") + parser.add_argument("--dry-run", type=bool, default=False, help="if True, state dicts will be NOT be saved") args = parser.parse_args() if not args.output_root: args.output_root = args.quant_path @@ -26,25 +28,46 @@ else: model_name = args.model_name assert model_name, "Model name must be provided." - assert "flux" in model_name.lower(), "Only Flux models are supported." + assert "flux" in model_name.lower() or "z-image" in model_name.lower(), f"{model_name} model is NOT supported so far." state_dict_path = os.path.join(args.quant_path, "model.pt") scale_dict_path = os.path.join(args.quant_path, "scale.pt") smooth_dict_path = os.path.join(args.quant_path, "smooth.pt") branch_dict_path = os.path.join(args.quant_path, "branch.pt") - map_location = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu" + map_location = "cuda:1" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu" state_dict = torch.load(state_dict_path, map_location=map_location) scale_dict = torch.load(scale_dict_path, map_location="cpu") smooth_dict = torch.load(smooth_dict_path, map_location=map_location) if os.path.exists(smooth_dict_path) else {} branch_dict = torch.load(branch_dict_path, map_location=map_location) if os.path.exists(branch_dict_path) else {} - converted_state_dict, other_state_dict = convert_to_nunchaku_flux_state_dicts( - state_dict=state_dict, - scale_dict=scale_dict, - smooth_dict=smooth_dict, - branch_dict=branch_dict, - float_point=args.float_point, - ) - output_dirpath = os.path.join(args.output_root, model_name) - os.makedirs(output_dirpath, exist_ok=True) - safetensors.torch.save_file(converted_state_dict, os.path.join(output_dirpath, "transformer_blocks.safetensors")) - safetensors.torch.save_file(other_state_dict, os.path.join(output_dirpath, "unquantized_layers.safetensors")) - print(f"Quantized model saved to {output_dirpath}.") + if "flux" in model_name.lower(): + converted_state_dict, other_state_dict = convert_to_nunchaku_flux_state_dicts( + state_dict=state_dict, + scale_dict=scale_dict, + smooth_dict=smooth_dict, + branch_dict=branch_dict, + float_point=args.float_point, + ) + elif "z-image" in model_name.lower(): + if any("refiner" in k for k in smooth_dict): + skip_refiners = False + else: + skip_refiners = True + print(f" - skip_refiners = {skip_refiners}, float_point = {type(args.float_point)} {args.float_point}") + converted_state_dict, other_state_dict = convert_to_nunchaku_z_image_state_dicts( + model_dict=state_dict, + scale_dict=scale_dict, + smooth_dict=smooth_dict, + branch_dict=branch_dict, + float_point=args.float_point, + skip_refiners=skip_refiners, + ) + else: + raise ValueError(f"{model_name} model is NOT supported so far.") + + if args.dry_run: + print(f"Program in DRY RUN mode. Quantized checkpoints NOT saved.") + else: + output_dirpath = os.path.join(args.output_root, model_name) + os.makedirs(output_dirpath, exist_ok=True) + safetensors.torch.save_file(converted_state_dict, os.path.join(output_dirpath, "transformer_blocks.safetensors")) + safetensors.torch.save_file(other_state_dict, os.path.join(output_dirpath, "unquantized_layers.safetensors")) + print(f"Quantized model saved to {output_dirpath}.") diff --git a/deepcompressor/backend/nunchaku/z_image.py b/deepcompressor/backend/nunchaku/z_image.py new file mode 100644 index 0000000..2815700 --- /dev/null +++ b/deepcompressor/backend/nunchaku/z_image.py @@ -0,0 +1,171 @@ +from collections import OrderedDict +import torch + +from .common import convert_to_nunchaku_transformer_block_state_dict, update_state_dict + + +def _print_dict(d: dict, dict_name: str): + print(f"################# print {dict_name} ################") + for k, v in d.items(): + if isinstance(v, torch.Tensor): + print( + f"{dict_name}_key: {k} -> value tensor shape: {v.shape}, dtype: {v.dtype}") + elif isinstance(v, OrderedDict): + for sub_k, sub_v in v.items(): + if isinstance(sub_v, torch.Tensor): + print( + f"{dict_name}_key: {k}/{sub_k} -> tensor shape: {sub_v.shape}, dtype: {sub_v.dtype}") + else: + print( + f"{dict_name}_key: {k}/{sub_k} -> value type: {type(sub_v)}, {sub_v}") + else: + print(f"{dict_name}_key: {k} -> value type: {type(v)}, {v}") + print("\n") + + +def _replace_lora_and_smooth_key(transformer_block_state_dict: dict): + replaced = {} + for k, v in transformer_block_state_dict.items(): + if ".lora_down" in k: + new_k = k.replace(".lora_down", ".proj_down") + elif ".lora_up" in k: + new_k = k.replace(".lora_up", ".proj_up") + elif ".smooth_orig" in k: + new_k = k.replace(".smooth_orig", ".smooth_factor_orig") + elif ".smooth" in k: + new_k = k.replace(".smooth", ".smooth_factor") + else: + new_k = k + replaced[new_k] = v + return replaced + + + +def z_image_transformer_block_convert( + model_dict: dict[str, torch.Tensor], + scale_dict: dict[str, torch.Tensor], + smooth_dict: dict[str, torch.Tensor], + branch_dict: dict[str, torch.Tensor], + block_name: str, + float_point: bool = False, +) -> dict[str, torch.Tensor]: + converted_quantized_part_names = [ + "attention.to_qkv", # attention q,k,v + "attention.to_out.0", # attention to_out + "feed_forward.net.0.proj", # feed forward up proj + "feed_forward.net.2", # feed forward down proj + ] + + def _original_name(converted_name: str): + if "to_qkv" in converted_name: + return [ + converted_name.replace("to_qkv", "to_q"), + converted_name.replace("to_qkv", "to_k"), + converted_name.replace("to_qkv", "to_v") + ] + else: + return converted_name + + def _smooth_name(converted_name: str): + if "to_qkv" in converted_name: + return converted_name.replace("to_qkv", "to_q") + else: + return converted_name + + _branch_name = _smooth_name + + converted_transformer_block_state_dict = convert_to_nunchaku_transformer_block_state_dict( + state_dict=model_dict, + scale_dict=scale_dict, + smooth_dict=smooth_dict, + branch_dict=branch_dict, + block_name=block_name, + local_name_map={ + name: _original_name(name) for name in converted_quantized_part_names + }, + smooth_name_map={ + name: _smooth_name(name) for name in converted_quantized_part_names + }, + branch_name_map={ + name: _branch_name(name) for name in converted_quantized_part_names + }, + convert_map={ + name: "linear" for name in converted_quantized_part_names + }, + float_point=float_point, + ) + + not_quantized_parts = [ + # all norm layers are not quantized. + "attention.norm_q.weight", + "attention.norm_k.weight", + "attention_norm1.weight", + "attention_norm2.weight", + "ffn_norm1.weight", + "ffn_norm2.weight", + "adaLN_modulation.0.weight", + "adaLN_modulation.0.bias", + ] + + for part_name in not_quantized_parts: + absolute_name = f"{block_name}.{part_name}" + if absolute_name in model_dict: + print(f" - Copying {block_name} weights: {part_name}") + converted_transformer_block_state_dict[part_name] = model_dict[absolute_name].clone().cpu() + + converted_transformer_block_state_dict = _replace_lora_and_smooth_key(converted_transformer_block_state_dict) + return converted_transformer_block_state_dict + + +def convert_to_nunchaku_z_image_state_dicts( + model_dict: dict[str, torch.Tensor], + scale_dict: dict[str, torch.Tensor], + smooth_dict: dict[str, torch.Tensor], + branch_dict: dict[str, torch.Tensor], + float_point: bool = False, + skip_refiners: bool = False, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + + _print_dict(model_dict, "model_dict") + _print_dict(scale_dict, "scale_dict") + _print_dict(smooth_dict, "smooth_dict") + _print_dict(branch_dict, "branch_dict") + + transformer_block_names: set[str] = set() + others: dict[str, torch.Tensor] = {} + + if skip_refiners: + transfomer_block_name_prefix = ("layers.",) + else: + transfomer_block_name_prefix = ( + "noise_refiner.", "context_refiner.", "layers.") + for param_name in model_dict.keys(): + if param_name.startswith(transfomer_block_name_prefix): + block_name = ".".join(param_name.split(".")[:2]) + transformer_block_names.add(block_name) + else: + others[param_name] = model_dict[param_name] + + transformer_block_names = sorted(transformer_block_names, key=lambda x: ( + x.split(".")[0], int(x.split(".")[-1]))) + print(f"Converting {len(transformer_block_names)} transformer blocks...") + converted_state_dict: dict[str, torch.Tensor] = {} + for b_name in transformer_block_names: + converted_tranzformer_block = z_image_transformer_block_convert( + model_dict=model_dict, + scale_dict=scale_dict, + smooth_dict=smooth_dict, + branch_dict=branch_dict, + block_name=b_name, + float_point=float_point, + ) + update_state_dict( + converted_state_dict, + converted_tranzformer_block, + prefix=b_name, + ) + + _print_dict(converted_state_dict, "converted_state_dict") + _print_dict(others, "others") + + return converted_state_dict, others diff --git a/deepcompressor/csrc/load.py b/deepcompressor/csrc/load.py index 6aae3f1..2a1ed42 100644 --- a/deepcompressor/csrc/load.py +++ b/deepcompressor/csrc/load.py @@ -16,6 +16,7 @@ extra_cuda_cflags=[ "-O3", "-std=c++20", + "-I/usr/local/cuda/include", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", diff --git a/deepcompressor/dataset/cache.py b/deepcompressor/dataset/cache.py index 7d63d48..07bdefb 100644 --- a/deepcompressor/dataset/cache.py +++ b/deepcompressor/dataset/cache.py @@ -311,7 +311,7 @@ def _iter_layer_activations( # noqa: C901 if early_stop_module is not None: forward_hooks.append(early_stop_module.register_forward_hook(EarlyStopHook())) with torch.inference_mode(): - device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cuda:1" if torch.cuda.is_available() else "cpu" tbar = tqdm( desc="collecting acts info", leave=False, @@ -380,6 +380,7 @@ def _iter_layer_activations( # noqa: C901 if psutil.virtual_memory().percent > 90: raise RuntimeError("memory usage > 90%%, aborting") gc.collect() + tbar.close() else: # region we then forward the layer to collect activations device = next(layer.parameters()).device diff --git a/deepcompressor/nn/patch/ff.py b/deepcompressor/nn/patch/ff.py new file mode 100644 index 0000000..bde5699 --- /dev/null +++ b/deepcompressor/nn/patch/ff.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +from diffusers.models.attention import FeedForward +from diffusers.models.transformers.transformer_z_image import FeedForward as ZImageFeedForward + + +def convert_z_image_ff(zff: ZImageFeedForward) -> FeedForward: + assert isinstance(zff, ZImageFeedForward) + assert zff.w1.in_features == zff.w3.in_features + assert zff.w1.out_features == zff.w3.out_features + assert zff.w1.out_features == zff.w2.in_features + converted_ff = FeedForward( + dim=zff.w1.in_features, + dim_out=zff.w2.out_features, + dropout=0.0, + activation_fn="swiglu", + inner_dim=zff.w2.in_features, + bias=False, + ).to(dtype=zff.w1.weight.dtype, device=zff.w1.weight.device) + + up_proj: nn.Linear = converted_ff.net[0].proj + down_proj: nn.Linear = converted_ff.net[2] + with torch.no_grad(): + up_proj.weight.copy_(torch.cat([zff.w3.weight, zff.w1.weight], dim=0)) + down_proj.weight.copy_(zff.w2.weight) + + return converted_ff + + +if __name__ == "__main__": + _dim = 50 + _hidden_dim = 100 + z_image_ff = ZImageFeedForward(dim=_dim, hidden_dim=_hidden_dim) + with torch.no_grad(): + z_image_ff.w1.weight = Parameter(torch.randn(_hidden_dim, _dim, dtype=torch.float32)) + z_image_ff.w2.weight = Parameter(torch.randn(_dim, _hidden_dim, dtype=torch.float32)) + z_image_ff.w3.weight = Parameter(torch.randn(_hidden_dim, _dim, dtype=torch.float32)) + + converted_ff = convert_z_image_ff(z_image_ff) + + x = torch.randn(_dim, dtype=torch.float32) + + y1 = z_image_ff(x) + + y2 = converted_ff(x) + + print(f"y1: {y1}") + print(f"y2: {y2}") + + print(f"allclose: {torch.allclose(y1, y2)}") + diff --git a/deepcompressor/utils/common.py b/deepcompressor/utils/common.py index 8040a9a..12233d9 100644 --- a/deepcompressor/utils/common.py +++ b/deepcompressor/utils/common.py @@ -192,10 +192,10 @@ def tree_collate(batch: list[tp.Any] | tuple[tp.Any, ...]) -> tp.Any: return [tree_collate(samples) for samples in zip(*batch, strict=True)] elif isinstance(batch[0], torch.Tensor): # if all tensors in batch are exactly the same, return the tensor itself - if all(torch.equal(batch[0], b) for b in batch): - return batch[0] - else: - return torch.cat(batch) + # if all(torch.equal(batch[0], b) for b in batch): + # return batch[0] + # else: + return torch.cat(batch) else: return batch[0] diff --git a/examples/diffusion/configs/__default__.yaml b/examples/diffusion/configs/__default__.yaml index c0ad56c..cab9ef1 100644 --- a/examples/diffusion/configs/__default__.yaml +++ b/examples/diffusion/configs/__default__.yaml @@ -20,7 +20,7 @@ eval: gen_root: "{output}/{job}" ref_root: baselines/{dtype}/{model}/{protocol} gt_stats_root: benchmarks/stats - num_gpus: 8 + num_gpus: 1 batch_size_per_gpu: 1 chunk_start: 0 chunk_step: 1 diff --git a/examples/diffusion/configs/collect/qdiff.yaml b/examples/diffusion/configs/collect/qdiff.yaml index 3142610..b7cb714 100644 --- a/examples/diffusion/configs/collect/qdiff.yaml +++ b/examples/diffusion/configs/collect/qdiff.yaml @@ -1,5 +1,5 @@ collect: root: datasets dataset_name: qdiff - data_path: prompts/qdiff.yaml + data_path: examples/diffusion/prompts/qdiff.yaml num_samples: 128 \ No newline at end of file diff --git a/examples/diffusion/configs/model/z-image-customized-rank256.yaml b/examples/diffusion/configs/model/z-image-customized-rank256.yaml new file mode 100644 index 0000000..e25db71 --- /dev/null +++ b/examples/diffusion/configs/model/z-image-customized-rank256.yaml @@ -0,0 +1,66 @@ +pipeline: + name: z-image-customized + path: /data/dongd/dc_downloaded_model + dtype: torch.bfloat16 + device: cuda:1 +collect: + root: /data/dongd/dc_calib/datasets/{dtype}/{model}/{protocol}/{data}/s128 +eval: + num_steps: 9 + guidance_scale: 0 + protocol: fmeuler{num_steps}-g{guidance_scale} + benchmarks: + - "deepcompressor/dataset/MJHQ30KMETA.json" + num_samples: 300 +quant: + calib: + batch_size: 1 + path: datasets/{dtype}/{model}/{protocol}/{data}/s128 + num_samples: 128 + num_workers: 1 + wgts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + low_rank: + rank: 256 + sample_batch_size: 16 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + ipts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + opts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + smooth: + proj: + element_batch_size: -1 + sample_batch_size: 16 + element_size: -1 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + attn: + sample_batch_size: 16 + sample_size: -1 diff --git a/examples/diffusion/configs/model/z-image-turbo-rank128-skip-refiners.yaml b/examples/diffusion/configs/model/z-image-turbo-rank128-skip-refiners.yaml new file mode 100644 index 0000000..5d4af08 --- /dev/null +++ b/examples/diffusion/configs/model/z-image-turbo-rank128-skip-refiners.yaml @@ -0,0 +1,70 @@ +pipeline: + name: z-image-turbo + dtype: torch.bfloat16 +collect: + root: /data/dongd/dc_calib/datasets/{dtype}/{model}/{protocol}/{data}/s128 +eval: + num_steps: 9 + guidance_scale: 0 + protocol: fmeuler{num_steps}-g{guidance_scale} + benchmarks: + - "deepcompressor/dataset/MJHQ30KMETA.json" + num_samples: 300 +quant: + calib: + batch_size: 1 + path: datasets/{dtype}/{model}/{protocol}/{data}/s128 + num_samples: 128 + num_workers: 1 + wgts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + low_rank: + rank: 128 + sample_batch_size: 16 + sample_size: -1 + skips: + - nrk + - crk + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + ipts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + skips: + - nrk + - crk + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + opts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + smooth: + proj: + element_batch_size: -1 + sample_batch_size: 16 + element_size: -1 + sample_size: -1 + skips: + - nrk + - crk + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + attn: + sample_batch_size: 16 + sample_size: -1 diff --git a/examples/diffusion/configs/model/z-image-turbo-rank128.yaml b/examples/diffusion/configs/model/z-image-turbo-rank128.yaml new file mode 100644 index 0000000..eaa602d --- /dev/null +++ b/examples/diffusion/configs/model/z-image-turbo-rank128.yaml @@ -0,0 +1,65 @@ +pipeline: + name: z-image-turbo + dtype: torch.bfloat16 + device: cuda:4 +collect: + root: /data/dongd/dc_calib/datasets/{dtype}/{model}/{protocol}/{data}/s128 +eval: + num_steps: 9 + guidance_scale: 0 + protocol: fmeuler{num_steps}-g{guidance_scale} + benchmarks: + - "deepcompressor/dataset/MJHQ30KMETA.json" + num_samples: 300 +quant: + calib: + batch_size: 1 + path: datasets/{dtype}/{model}/{protocol}/{data}/s128 + num_samples: 128 + num_workers: 1 + wgts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + low_rank: + rank: 128 + sample_batch_size: 16 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + ipts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + opts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + smooth: + proj: + element_batch_size: -1 + sample_batch_size: 16 + element_size: -1 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + attn: + sample_batch_size: 16 + sample_size: -1 diff --git a/examples/diffusion/configs/model/z-image-turbo-rank256.yaml b/examples/diffusion/configs/model/z-image-turbo-rank256.yaml new file mode 100644 index 0000000..ecb865b --- /dev/null +++ b/examples/diffusion/configs/model/z-image-turbo-rank256.yaml @@ -0,0 +1,65 @@ +pipeline: + name: z-image-turbo + dtype: torch.bfloat16 + device: cuda:4 +collect: + root: /data/dongd/dc_calib/datasets/{dtype}/{model}/{protocol}/{data}/s128 +eval: + num_steps: 9 + guidance_scale: 0 + protocol: fmeuler{num_steps}-g{guidance_scale} + benchmarks: + - "deepcompressor/dataset/MJHQ30KMETA.json" + num_samples: 300 +quant: + calib: + batch_size: 1 + path: datasets/{dtype}/{model}/{protocol}/{data}/s128 + num_samples: 128 + num_workers: 1 + wgts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + low_rank: + rank: 256 + sample_batch_size: 16 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + ipts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + opts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + smooth: + proj: + element_batch_size: -1 + sample_batch_size: 16 + element_size: -1 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + attn: + sample_batch_size: 16 + sample_size: -1 diff --git a/examples/diffusion/configs/model/z-image-turbo-rank64.yaml b/examples/diffusion/configs/model/z-image-turbo-rank64.yaml new file mode 100644 index 0000000..33f7a57 --- /dev/null +++ b/examples/diffusion/configs/model/z-image-turbo-rank64.yaml @@ -0,0 +1,64 @@ +pipeline: + name: z-image-turbo + dtype: torch.bfloat16 +collect: + root: /data/dongd/dc_calib/datasets/{dtype}/{model}/{protocol}/{data}/s128 +eval: + num_steps: 9 + guidance_scale: 0 + protocol: fmeuler{num_steps}-g{guidance_scale} + benchmarks: + - "deepcompressor/dataset/MJHQ30KMETA.json" + num_samples: 300 +quant: + calib: + batch_size: 1 + path: datasets/{dtype}/{model}/{protocol}/{data}/s128 + num_samples: 128 + num_workers: 1 + wgts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + low_rank: + rank: 64 + sample_batch_size: 16 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + ipts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + opts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + smooth: + proj: + element_batch_size: -1 + sample_batch_size: 16 + element_size: -1 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + attn: + sample_batch_size: 16 + sample_size: -1 diff --git a/examples/diffusion/configs/model/z-image-turbo-skip-refiners.yaml b/examples/diffusion/configs/model/z-image-turbo-skip-refiners.yaml new file mode 100644 index 0000000..ce32e4b --- /dev/null +++ b/examples/diffusion/configs/model/z-image-turbo-skip-refiners.yaml @@ -0,0 +1,69 @@ +pipeline: + name: z-image-turbo + dtype: torch.bfloat16 +collect: + root: /data/dongd/dc_calib/datasets/{dtype}/{model}/{protocol}/{data}/s128 +eval: + num_steps: 9 + guidance_scale: 0 + protocol: fmeuler{num_steps}-g{guidance_scale} + benchmarks: + - "deepcompressor/dataset/MJHQ30KMETA.json" + num_samples: 300 +quant: + calib: + batch_size: 1 + path: datasets/{dtype}/{model}/{protocol}/{data}/s128 + num_samples: 128 + num_workers: 1 + wgts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + low_rank: + sample_batch_size: 16 + sample_size: -1 + skips: + - nrk + - crk + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + ipts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + skips: + - nrk + - crk + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + opts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + smooth: + proj: + element_batch_size: -1 + sample_batch_size: 16 + element_size: -1 + sample_size: -1 + skips: + - nrk + - crk + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + attn: + sample_batch_size: 16 + sample_size: -1 diff --git a/examples/diffusion/configs/model/z-image-turbo.yaml b/examples/diffusion/configs/model/z-image-turbo.yaml new file mode 100644 index 0000000..3630015 --- /dev/null +++ b/examples/diffusion/configs/model/z-image-turbo.yaml @@ -0,0 +1,63 @@ +pipeline: + name: z-image-turbo + dtype: torch.bfloat16 +collect: + root: /data/dongd/dc_calib/datasets/{dtype}/{model}/{protocol}/{data}/s128 +eval: + num_steps: 9 + guidance_scale: 0 + protocol: fmeuler{num_steps}-g{guidance_scale} + benchmarks: + - "deepcompressor/dataset/MJHQ30KMETA.json" + num_samples: 300 +quant: + calib: + batch_size: 1 + path: datasets/{dtype}/{model}/{protocol}/{data}/s128 + num_samples: 128 + num_workers: 1 + wgts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + low_rank: + sample_batch_size: 16 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + ipts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + opts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + smooth: + proj: + element_batch_size: -1 + sample_batch_size: 16 + element_size: -1 + sample_size: -1 + skips: + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + attn: + sample_batch_size: 16 + sample_size: -1 diff --git a/examples/diffusion/configs/svdquant/int4.yaml b/examples/diffusion/configs/svdquant/int4.yaml index 626635b..4190555 100644 --- a/examples/diffusion/configs/svdquant/int4.yaml +++ b/examples/diffusion/configs/svdquant/int4.yaml @@ -22,4 +22,4 @@ quant: - null allow_unsigned: true pipeline: - shift_activations: true \ No newline at end of file + shift_activations: false \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 55eb2ef..f9f1612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ torchvision = ">= 0.18.1" torchmetrics = ">= 1.4.0" ninja = ">= 1.11.1" bitsandbytes = ">= 0.42.0" -transformers = ">= 4.46.0" +transformers = ">= 4.46.0, <=4.57.3" lm_eval = ">= 0.4.2" accelerate = ">= 0.26.0" datasets = ">= 2.16.0" @@ -44,7 +44,7 @@ bs4 = ">= 0.0.2" ftfy = ">= 6.2.0" cd-fvd = ">= 0.1.1" xformers = ">= 0.0.26" -pyav = ">= 13.0.0" +av = ">= 13.0.0" clip = ">= 0.2.0" image_reward = { git = "https://github.com/THUDM/ImageReward.git", branch = "main" } diff --git a/z_image_scripts/z_image_key_map.py b/z_image_scripts/z_image_key_map.py new file mode 100644 index 0000000..40f0c86 --- /dev/null +++ b/z_image_scripts/z_image_key_map.py @@ -0,0 +1,259 @@ +'''printed from ZImageStruct._get_default_key_map()''' +_map = { + 'transformer_norm': { + 'lk_transformer_norm', + 'nrk_transformer_norm', + 'crk_transformer_norm' + }, + 'nrk_transformer_norm': { + 'nrk_transformer_norm' + }, + 'nrk': { + 'nrk_attn_add_qkv_proj', + 'nrk_ffn_add_up_proj', + 'nrk_ffn_up_proj', + 'nrk_ffn_add_down_proj', + 'nrk_attn_out_proj', + 'nrk_transformer_norm', + 'nrk_ffn_down_proj', + 'nrk_transformer_add_norm', + 'nrk_attn_add_out_proj', + 'nrk_attn_qkv_proj' + }, + 'transformer_add_norm': { + 'nrk_transformer_add_norm', + 'lk_transformer_add_norm', + 'crk_transformer_add_norm' + }, + 'nrk_transformer_add_norm': { + 'nrk_transformer_add_norm' + }, + 'attn': { + 'crk_attn_out_proj', + 'nrk_attn_out_proj', + 'lk_attn_out_proj', + 'lk_attn_qkv_proj', + 'crk_attn_qkv_proj', + 'nrk_attn_qkv_proj' + }, + 'nrk_attn': { + 'nrk_attn_qkv_proj', + 'nrk_attn_out_proj' + }, + 'attn_add': { + 'nrk_attn_add_qkv_proj', + 'lk_attn_add_qkv_proj', + 'crk_attn_add_out_proj', + 'crk_attn_add_qkv_proj', + 'nrk_attn_add_out_proj', + 'lk_attn_add_out_proj' + }, + 'nrk_attn_add': { + 'nrk_attn_add_qkv_proj', + 'nrk_attn_add_out_proj' + }, + 'attn_qkv_proj': { + 'lk_attn_qkv_proj', + 'crk_attn_qkv_proj', + 'nrk_attn_qkv_proj' + }, + 'nrk_attn_qkv_proj': { + 'nrk_attn_qkv_proj' + }, + 'attn_out_proj': { + 'lk_attn_out_proj', + 'crk_attn_out_proj', + 'nrk_attn_out_proj' + }, + 'nrk_attn_out_proj': { + 'nrk_attn_out_proj' + }, + 'attn_add_qkv_proj': { + 'nrk_attn_add_qkv_proj', + 'lk_attn_add_qkv_proj', + 'crk_attn_add_qkv_proj' + }, + 'nrk_attn_add_qkv_proj': { + 'nrk_attn_add_qkv_proj' + }, + 'attn_add_out_proj': { + 'lk_attn_add_out_proj', + 'crk_attn_add_out_proj', + 'nrk_attn_add_out_proj' + }, + 'nrk_attn_add_out_proj': { + 'nrk_attn_add_out_proj' + }, + 'ffn': { + 'crk_ffn_up_proj', + 'nrk_ffn_up_proj', + 'lk_ffn_up_proj', + 'nrk_ffn_down_proj', + 'crk_ffn_down_proj', + 'lk_ffn_down_proj' + }, + 'nrk_ffn': { + 'nrk_ffn_down_proj', + 'nrk_ffn_up_proj' + }, + 'ffn_add': { + 'nrk_ffn_add_up_proj', + 'crk_ffn_add_down_proj', + 'lk_ffn_add_up_proj', + 'nrk_ffn_add_down_proj', + 'crk_ffn_add_up_proj', + 'lk_ffn_add_down_proj' + }, + 'nrk_ffn_add': { + 'nrk_ffn_add_up_proj', + 'nrk_ffn_add_down_proj' + }, + 'ffn_up_proj': { + 'nrk_ffn_up_proj', + 'lk_ffn_up_proj', + 'crk_ffn_up_proj' + }, + 'nrk_ffn_up_proj': { + 'nrk_ffn_up_proj' + }, + 'ffn_down_proj': { + 'nrk_ffn_down_proj', + 'crk_ffn_down_proj', + 'lk_ffn_down_proj' + }, + 'nrk_ffn_down_proj': { + 'nrk_ffn_down_proj' + }, + 'ffn_add_up_proj': { + 'nrk_ffn_add_up_proj', + 'crk_ffn_add_up_proj', + 'lk_ffn_add_up_proj' + }, + 'nrk_ffn_add_up_proj': { + 'nrk_ffn_add_up_proj' + }, + 'ffn_add_down_proj': { + 'lk_ffn_add_down_proj', + 'nrk_ffn_add_down_proj', + 'crk_ffn_add_down_proj' + }, + 'nrk_ffn_add_down_proj': { + 'nrk_ffn_add_down_proj' + }, + 'crk_transformer_norm': { + 'crk_transformer_norm' + }, + 'crk': { + 'crk_ffn_up_proj', + 'crk_attn_out_proj', + 'crk_ffn_add_down_proj', + 'crk_attn_add_out_proj', + 'crk_transformer_add_norm', + 'crk_attn_add_qkv_proj', + 'crk_ffn_add_up_proj', + 'crk_ffn_down_proj', + 'crk_transformer_norm', + 'crk_attn_qkv_proj' + }, + 'crk_transformer_add_norm': { + 'crk_transformer_add_norm' + }, + 'crk_attn': { + 'crk_attn_qkv_proj', + 'crk_attn_out_proj' + }, + 'crk_attn_add': { + 'crk_attn_add_out_proj', + 'crk_attn_add_qkv_proj' + }, + 'crk_attn_qkv_proj': { + 'crk_attn_qkv_proj' + }, + 'crk_attn_out_proj': { + 'crk_attn_out_proj' + }, + 'crk_attn_add_qkv_proj': { + 'crk_attn_add_qkv_proj' + }, + 'crk_attn_add_out_proj': { + 'crk_attn_add_out_proj' + }, + 'crk_ffn': { + 'crk_ffn_up_proj', + 'crk_ffn_down_proj' + }, + 'crk_ffn_add': { + 'crk_ffn_add_up_proj', + 'crk_ffn_add_down_proj' + }, + 'crk_ffn_up_proj': { + 'crk_ffn_up_proj' + }, + 'crk_ffn_down_proj': { + 'crk_ffn_down_proj' + }, + 'crk_ffn_add_up_proj': { + 'crk_ffn_add_up_proj' + }, + 'crk_ffn_add_down_proj': { + 'crk_ffn_add_down_proj' + }, + 'lk_transformer_norm': { + 'lk_transformer_norm' + }, + 'lk': { + 'lk_ffn_add_up_proj', + 'lk_transformer_add_norm', + 'lk_attn_add_qkv_proj', + 'lk_transformer_norm', + 'lk_ffn_up_proj', + 'lk_attn_out_proj', + 'lk_attn_qkv_proj', + 'lk_ffn_add_down_proj', + 'lk_ffn_down_proj', + 'lk_attn_add_out_proj' + }, + 'lk_transformer_add_norm': { + 'lk_transformer_add_norm' + }, + 'lk_attn': { + 'lk_attn_out_proj', + 'lk_attn_qkv_proj' + }, + 'lk_attn_add': { + 'lk_attn_add_qkv_proj', + 'lk_attn_add_out_proj' + }, + 'lk_attn_qkv_proj': { + 'lk_attn_qkv_proj' + }, + 'lk_attn_out_proj': { + 'lk_attn_out_proj' + }, + 'lk_attn_add_qkv_proj': { + 'lk_attn_add_qkv_proj' + }, + 'lk_attn_add_out_proj': { + 'lk_attn_add_out_proj' + }, + 'lk_ffn': { + 'lk_ffn_down_proj', + 'lk_ffn_up_proj' + }, + 'lk_ffn_add': { + 'lk_ffn_add_down_proj', + 'lk_ffn_add_up_proj' + }, + 'lk_ffn_up_proj': { + 'lk_ffn_up_proj' + }, + 'lk_ffn_down_proj': { + 'lk_ffn_down_proj' + }, + 'lk_ffn_add_up_proj': { + 'lk_ffn_add_up_proj' + }, + 'lk_ffn_add_down_proj': { + 'lk_ffn_add_down_proj' + } +} \ No newline at end of file diff --git a/z_image_scripts/z_image_key_map_test.sh b/z_image_scripts/z_image_key_map_test.sh new file mode 100755 index 0000000..246e084 --- /dev/null +++ b/z_image_scripts/z_image_key_map_test.sh @@ -0,0 +1 @@ +TORCH_CUDA_ARCH_LIST="8.9" python3 -c "import runpy; runpy.run_path('deepcompressor/app/diffusion/nn/struct.py', run_name='__main__')" \ No newline at end of file diff --git a/z_image_scripts/z_image_turbo_calib.sh b/z_image_scripts/z_image_turbo_calib.sh new file mode 100755 index 0000000..f9e2b81 --- /dev/null +++ b/z_image_scripts/z_image_turbo_calib.sh @@ -0,0 +1,6 @@ +TORCH_CUDA_ARCH_LIST="9.0" python3 -m deepcompressor.app.diffusion.dataset.collect.calib \ + examples/diffusion/configs/model/z-image-customized-rank256.yaml examples/diffusion/configs/collect/qdiff.yaml + +echo $? + +# nohup z_image_scripts/z_image_turbo_calib.sh > z_image_scripts/z_image_turbo_calib_20251217_1931.log 2>&1 & \ No newline at end of file diff --git a/z_image_scripts/z_image_turbo_convert.sh b/z_image_scripts/z_image_turbo_convert.sh new file mode 100755 index 0000000..3d55b6d --- /dev/null +++ b/z_image_scripts/z_image_turbo_convert.sh @@ -0,0 +1,12 @@ +echo "-quant-path /data/dongd/dc_saved_model/Z_IMAGE_TURBO_20251217_2034" + +TORCH_CUDA_ARCH_LIST="9.0" python -m deepcompressor.backend.nunchaku.convert \ + --quant-path /data/dongd/dc_saved_model/Z_IMAGE_TURBO_20251217_2034 \ + --output-root /data/dongd/dc_converted_model/Z_IMAGE_CUSTOM_20251217_2034_r256_int4 \ + --model-name z-image-customized \ + # --float-point \ + + +echo $? + +# nohup z_image_scripts/z_image_turbo_convert.sh > z_image_scripts/z_image_turbo_convert_20251218_0750.log 2>&1 & \ No newline at end of file diff --git a/z_image_scripts/z_image_turbo_gen_image_after_quant.sh b/z_image_scripts/z_image_turbo_gen_image_after_quant.sh new file mode 100755 index 0000000..a160ad1 --- /dev/null +++ b/z_image_scripts/z_image_turbo_gen_image_after_quant.sh @@ -0,0 +1,9 @@ +echo "--load-from /data/dongd/dc_saved_model/Z_IMAGE_TURBO_20251203_0031" + +TORCH_CUDA_ARCH_LIST="9.0" python3 -m deepcompressor.app.diffusion.ptq \ + examples/diffusion/configs/model/z-image-turbo.yaml examples/diffusion/configs/svdquant/int4.yaml \ + --load-from /data/dongd/dc_saved_model/Z_IMAGE_TURBO_20251203_0031 --skip-eval true + +echo $? + +# nohup z_image_scripts/z_image_turbo_gen_image_after_quant.sh > z_image_scripts/z_image_turbo_gen_image_after_quant_20251203_1136.log 2>&1 & \ No newline at end of file diff --git a/z_image_scripts/z_image_turbo_quantize.sh b/z_image_scripts/z_image_turbo_quantize.sh new file mode 100755 index 0000000..b83fbcb --- /dev/null +++ b/z_image_scripts/z_image_turbo_quantize.sh @@ -0,0 +1,7 @@ +TORCH_CUDA_ARCH_LIST="9.0" python3 -m deepcompressor.app.diffusion.ptq \ + examples/diffusion/configs/model/z-image-customized-rank256.yaml examples/diffusion/configs/svdquant/int4.yaml \ + --save-model /data/dongd/dc_saved_model/Z_IMAGE_TURBO_20251217_2034 --copy-on-save true --skip-eval true + +echo $? + +# nohup z_image_scripts/z_image_turbo_quantize.sh > z_image_scripts/z_image_turbo_quantize_20251217_2034.log 2>&1 & \ No newline at end of file