diff --git a/src/cache_dit/_utils/backend_selector.py b/src/cache_dit/_utils/backend_selector.py new file mode 100644 index 00000000..f215461e --- /dev/null +++ b/src/cache_dit/_utils/backend_selector.py @@ -0,0 +1,55 @@ +import torch +from cache_dit.envs import ENV +from cache_dit.logger import init_logger + +logger = init_logger(__name__) + + +class BackendSelector: + _attn_backend: str | None = None + _selected: bool = False + + @classmethod + def auto_select(cls, pipe_or_adapter) -> str | None: + if cls._selected: + return cls._attn_backend + + enable = ENV.CACHE_DIT_ENABLE_MINDIESD_ATTN + if enable == "0": + cls._attn_backend = "_native_npu" + cls._selected = True + return cls._attn_backend + + device = cls._detect_device(pipe_or_adapter) + if device.type == "npu": + try: + import mindiesd # noqa F401 + + cls._attn_backend = "_mindiesd_laser" + logger.info(f"Auto-selected MindIE-SD attention backend: {cls._attn_backend}") + except Exception: + cls._attn_backend = "_native_npu" + logger.info(f"MindIE-SD not found, fallback attention backend: {cls._attn_backend}") + cls._selected = True + return cls._attn_backend + + @classmethod + def auto_select_kernel_backend(cls) -> str | None: + try: + import mindiesd # noqa F401 + + from cache_dit.kernels.backend import KernelBackend + + return KernelBackend.MINDIESD + except Exception: + return None + + @staticmethod + def _detect_device(pipe_or_adapter): + try: + if hasattr(pipe_or_adapter, "device"): + return pipe_or_adapter.device + param = next(pipe_or_adapter.parameters()) + return param.device + except (StopIteration, AttributeError): + return torch.device("cpu") diff --git a/src/cache_dit/_utils/utils.py b/src/cache_dit/_utils/utils.py index 8c1492ed..88af3e8d 100644 --- a/src/cache_dit/_utils/utils.py +++ b/src/cache_dit/_utils/utils.py @@ -698,6 +698,7 @@ def get_args(parse: bool = True, ) -> argparse.ArgumentParser | argparse.Namespa "sage", # Need install sageattention: https://github.com/thu-ml/SageAttention "_native_npu", # native npu attention "_npu_fia", # npu fused infer attention + "_mindiesd_laser", # MindIE-SD laser attention ], ) # Ulysses context parallelism settings diff --git a/src/cache_dit/attention/backends/npu.py b/src/cache_dit/attention/backends/npu.py index 7a024dbe..e92a66df 100644 --- a/src/cache_dit/attention/backends/npu.py +++ b/src/cache_dit/attention/backends/npu.py @@ -49,6 +49,8 @@ def _native_npu_attention( dropout_p: float = 0.0, scale: Optional[float] = None, return_lse: bool = False, + is_causal: bool = False, + enable_gqa: bool = False, _cp_config: Optional["_ContextParallelConfig"] = None, ) -> torch.Tensor: if return_lse: @@ -97,6 +99,8 @@ def _npu_fused_infer_attention( dropout_p: float = 0.0, scale: Optional[float] = None, return_lse: bool = False, + is_causal: bool = False, + enable_gqa: bool = False, _cp_config: Optional["_ContextParallelConfig"] = None, ) -> torch.Tensor: if _cp_config is None: @@ -128,3 +132,47 @@ def _npu_fused_infer_attention( _cp_config=_cp_config, ) return out + + +try: + from mindiesd.layers import attention_forward + + _mindiesd_available = True +except Exception: + _mindiesd_available = False + attention_forward = None + +if _mindiesd_available: + + @_AttnBackendRegistry.register( + _AttnBackend._MINDIESD_LASER, + constraints=[], + supports_context_parallel=True, + ) + def _mindiesd_laser_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + return_lse: bool = False, + is_causal: bool = False, + enable_gqa: bool = False, + _cp_config: Optional["_ContextParallelConfig"] = None, + ) -> torch.Tensor: + if return_lse: + raise ValueError( + "MindIE-SD laser attention backend does not support setting `return_lse=True`.") + scale_val = scale if scale is not None else 1.0 / math.sqrt(query.shape[-1]) + return attention_forward( + query, + key, + value, + attn_mask=attn_mask, + scale=scale_val, + fused=True, + head_first=False, + opt_mode="manual", + op_type="ascend_laser_attention", + ) diff --git a/src/cache_dit/attention/backends/register.py b/src/cache_dit/attention/backends/register.py index a030fa70..282fa9d3 100644 --- a/src/cache_dit/attention/backends/register.py +++ b/src/cache_dit/attention/backends/register.py @@ -38,6 +38,7 @@ class _AttnBackend(str, Enum): _SDPA_CUDNN = "_sdpa_cudnn" _NATIVE_NPU = "_native_npu" _NPU_FIA = "_npu_fia" + _MINDIESD_LASER = "_mindiesd_laser" def _default_active_backend() -> _AttnBackend: diff --git a/src/cache_dit/caching/cache_interface.py b/src/cache_dit/caching/cache_interface.py index 04a8bf16..e3831f77 100644 --- a/src/cache_dit/caching/cache_interface.py +++ b/src/cache_dit/caching/cache_interface.py @@ -17,11 +17,22 @@ from ..utils import check_controlnet from ..utils import parse_extra_modules from ..logger import init_logger +from ..envs import ENV from ..attention import set_attn_backend logger = init_logger(__name__) +def _auto_select_attention_backend(pipe_or_adapter) -> Optional[str]: + """Try to auto-select an optimal attention backend when none was specified.""" + try: + from cache_dit._utils.backend_selector import BackendSelector + + return BackendSelector.auto_select(pipe_or_adapter) + except Exception: + return None + + def enable_cache( # DiffusionPipeline or BlockAdapter pipe_or_adapter: Union[ @@ -177,6 +188,10 @@ def enable_cache( logger.warning("cache_config is None, skip cache acceleration for " f"{pipe_or_adapter.__class__.__name__}.") + # Auto-select attention backend when none specified + if attention_backend is None and parallelism_config is None: + attention_backend = _auto_select_attention_backend(pipe_or_adapter) + # Set custom attention backend for non-parallelism case if attention_backend is not None: if parallelism_config is not None: @@ -281,6 +296,30 @@ def enable_cache( # Enable quantization for the specified component inplace quantized_component = quantize(component, quantize_config=config) setattr(pipe, name, quantized_component) + + # Auto-enable MindieSDBackend when available on NPU + if not ENV.CACHE_DIT_FORCE_DISABLE_MINDIESD_COMPILE_CONFIG: + try: + import mindiesd # noqa F401 + + if hasattr(torch, 'npu') and torch.npu.is_available(): + from mindiesd.compilation import MindieSDBackend + + targets = [] + if isinstance(pipe_or_adapter, DiffusionPipeline): + t = pipe_or_adapter.transformer + targets = [t] if not isinstance(t, list) else t + else: + t = getattr(pipe_or_adapter, 'transformer', None) + if t is not None: + targets = [t] if not isinstance(t, list) else t + for i, target in enumerate(targets): + targets[i] = torch.compile(target, backend=MindieSDBackend(), dynamic=True) + if targets: + logger.info("Auto-enabled MindieSDBackend compile for transformer(s).") + except Exception: + pass + return pipe_or_adapter diff --git a/src/cache_dit/compile/utils.py b/src/cache_dit/compile/utils.py index c0f647e8..d26e5830 100644 --- a/src/cache_dit/compile/utils.py +++ b/src/cache_dit/compile/utils.py @@ -117,27 +117,33 @@ def set_compile_configs( # them to your needs and test the performance inductor_config.max_fusion_size = 64 inductor_config.max_pointwise_cat_inputs = 8 - inductor_config.triton.cudagraphs = cuda_graphs - inductor_config.triton.use_block_ptr = False - inductor_config.triton.codegen_upcast_to_fp32 = True - - # Copy from https://pytorch.org/blog/accelerating-generative-ai-3/ - inductor_config.conv_1x1_as_mm = True - inductor_config.coordinate_descent_tuning = True - inductor_config.coordinate_descent_check_all_directions = True - inductor_config.epilogue_fusion = False - - # Enable epilogue and prologue fusion - if ENV.CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or kwargs.get( - "epilogue_prologue_fusion", - False, - ): - inductor_config.epilogue_fusion = True - inductor_config.prologue_fusion = True - inductor_config.epilogue_fusion_first = True - # Dead code elimination - inductor_config.dce = True # default is False + if current_platform.device_type == "npu": + # NPU: skip CUDA-specific inductor configs (triton, coordinate_descent, etc) + inductor_config.dce = True # default is False + inductor_config.epilogue_fusion = False + else: + inductor_config.triton.cudagraphs = cuda_graphs + inductor_config.triton.use_block_ptr = False + inductor_config.triton.codegen_upcast_to_fp32 = True + + # Copy from https://pytorch.org/blog/accelerating-generative-ai-3/ + inductor_config.conv_1x1_as_mm = True + inductor_config.coordinate_descent_tuning = True + inductor_config.coordinate_descent_check_all_directions = True + inductor_config.epilogue_fusion = False + + # Enable epilogue and prologue fusion + if ENV.CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or kwargs.get( + "epilogue_prologue_fusion", + False, + ): + inductor_config.epilogue_fusion = True + inductor_config.prologue_fusion = True + inductor_config.epilogue_fusion_first = True + + # Dead code elimination + inductor_config.dce = True # default is False # May need to force disable all cache if force_disable_compile_caches: diff --git a/src/cache_dit/envs.py b/src/cache_dit/envs.py index cfc7800a..3a3cd614 100644 --- a/src/cache_dit/envs.py +++ b/src/cache_dit/envs.py @@ -73,3 +73,8 @@ class ENV(object): # recommended now since it may cause scaled_mm error for some models after TP. CACHE_DIT_DISABLE_EXCLUDE_FOR_QUANTIZE_AFTER_TP: bool = bool( int(os.environ.get("CACHE_DIT_DISABLE_EXCLUDE_FOR_QUANTIZE_AFTER_TP", "0"))) + + # MindIE-SD Backend ENVs + CACHE_DIT_ENABLE_MINDIESD_ATTN: str = os.environ.get("CACHE_DIT_ENABLE_MINDIESD_ATTN", "1") + CACHE_DIT_FORCE_DISABLE_MINDIESD_COMPILE_CONFIG: bool = bool( + int(os.environ.get("CACHE_DIT_FORCE_DISABLE_MINDIESD_COMPILE_CONFIG", "0"))) diff --git a/src/cache_dit/kernels/backend.py b/src/cache_dit/kernels/backend.py index fa74eee4..651cfe4f 100644 --- a/src/cache_dit/kernels/backend.py +++ b/src/cache_dit/kernels/backend.py @@ -7,6 +7,7 @@ class KernelBackend(Enum): TRITON = "Triton" CUDA = "CUDA" CUTEDSL = "CuteDSL" + MINDIESD = "MindIESD" NONE = "None" @classmethod @@ -41,4 +42,11 @@ def is_supported(cls, backend: "KernelBackend") -> bool: return True except ImportError: return False + if backend == cls.MINDIESD: + try: + import mindiesd # noqa F401 + + return True + except Exception: + return False return False diff --git a/src/cache_dit/kernels/ops.py b/src/cache_dit/kernels/ops.py index 6768b186..aad7e181 100644 --- a/src/cache_dit/kernels/ops.py +++ b/src/cache_dit/kernels/ops.py @@ -14,6 +14,10 @@ def _select_cuda_backend() -> KernelBackend: return KernelBackend.CUDA +def _select_mindiesd_backend() -> KernelBackend: + return KernelBackend.MINDIESD + + def _select_kernel_backend() -> KernelBackend: """Select the default backend for kernel ops. @@ -149,6 +153,20 @@ def _fused_merge_attn_states_impl( suff_out, suff_lse, ) + if backend == KernelBackend.MINDIESD: + max_lse = torch.max(prev_lse, suff_lse) + prev_scale = torch.exp(prev_lse - max_lse) + suff_scale = torch.exp(suff_lse - max_lse) + denom = prev_scale + suff_scale + denom = denom + 1e-12 + out = torch.where( + denom > 1e-12, + (prev_out * prev_scale.unsqueeze(-1) + suff_out * suff_scale.unsqueeze(-1)) / + denom.unsqueeze(-1), + prev_out, + ) + lse = max_lse + torch.log(denom) + return out, lse else: raise ValueError(_ERROR_TEMPLATE.format(backend))