Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/cache_dit/_utils/backend_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from cache_dit.envs import ENV
from cache_dit.logger import init_logger

logger = init_logger(__name__)


class BackendSelector:
_attn_backend: str | None = None
_selected: bool = False

@classmethod
def auto_select(cls, pipe_or_adapter) -> str | None:
if cls._selected:
return cls._attn_backend

enable = ENV.CACHE_DIT_ENABLE_MINDIESD_ATTN
if enable == "0":
cls._attn_backend = "_native_npu"
cls._selected = True
return cls._attn_backend

device = cls._detect_device(pipe_or_adapter)
if device.type == "npu":
try:
import mindiesd # noqa F401

cls._attn_backend = "_mindiesd_laser"
logger.info(f"Auto-selected MindIE-SD attention backend: {cls._attn_backend}")
except Exception:
cls._attn_backend = "_native_npu"
logger.info(f"MindIE-SD not found, fallback attention backend: {cls._attn_backend}")
cls._selected = True
return cls._attn_backend

@classmethod
def auto_select_kernel_backend(cls) -> str | None:
try:
import mindiesd # noqa F401

from cache_dit.kernels.backend import KernelBackend

return KernelBackend.MINDIESD
except Exception:
return None

@staticmethod
def _detect_device(pipe_or_adapter):
try:
if hasattr(pipe_or_adapter, "device"):
return pipe_or_adapter.device
param = next(pipe_or_adapter.parameters())
return param.device
except (StopIteration, AttributeError):
return torch.device("cpu")
1 change: 1 addition & 0 deletions src/cache_dit/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ def get_args(parse: bool = True, ) -> argparse.ArgumentParser | argparse.Namespa
"sage", # Need install sageattention: https://github.com/thu-ml/SageAttention
"_native_npu", # native npu attention
"_npu_fia", # npu fused infer attention
"_mindiesd_laser", # MindIE-SD laser attention
],
)
# Ulysses context parallelism settings
Expand Down
48 changes: 48 additions & 0 deletions src/cache_dit/attention/backends/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def _native_npu_attention(
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
is_causal: bool = False,
enable_gqa: bool = False,
_cp_config: Optional["_ContextParallelConfig"] = None,
) -> torch.Tensor:
if return_lse:
Expand Down Expand Up @@ -97,6 +99,8 @@ def _npu_fused_infer_attention(
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
is_causal: bool = False,
enable_gqa: bool = False,
_cp_config: Optional["_ContextParallelConfig"] = None,
) -> torch.Tensor:
if _cp_config is None:
Expand Down Expand Up @@ -128,3 +132,47 @@ def _npu_fused_infer_attention(
_cp_config=_cp_config,
)
return out


try:
from mindiesd.layers import attention_forward

_mindiesd_available = True
except Exception:
_mindiesd_available = False
attention_forward = None

if _mindiesd_available:

@_AttnBackendRegistry.register(
_AttnBackend._MINDIESD_LASER,
constraints=[],
supports_context_parallel=True,
)
def _mindiesd_laser_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
is_causal: bool = False,
enable_gqa: bool = False,
_cp_config: Optional["_ContextParallelConfig"] = None,
) -> torch.Tensor:
if return_lse:
raise ValueError(
"MindIE-SD laser attention backend does not support setting `return_lse=True`.")
scale_val = scale if scale is not None else 1.0 / math.sqrt(query.shape[-1])
return attention_forward(
query,
key,
value,
attn_mask=attn_mask,
scale=scale_val,
fused=True,
head_first=False,
opt_mode="manual",
op_type="ascend_laser_attention",
)
1 change: 1 addition & 0 deletions src/cache_dit/attention/backends/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class _AttnBackend(str, Enum):
_SDPA_CUDNN = "_sdpa_cudnn"
_NATIVE_NPU = "_native_npu"
_NPU_FIA = "_npu_fia"
_MINDIESD_LASER = "_mindiesd_laser"


def _default_active_backend() -> _AttnBackend:
Expand Down
39 changes: 39 additions & 0 deletions src/cache_dit/caching/cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,22 @@
from ..utils import check_controlnet
from ..utils import parse_extra_modules
from ..logger import init_logger
from ..envs import ENV
from ..attention import set_attn_backend

logger = init_logger(__name__)


def _auto_select_attention_backend(pipe_or_adapter) -> Optional[str]:
"""Try to auto-select an optimal attention backend when none was specified."""
try:
from cache_dit._utils.backend_selector import BackendSelector

return BackendSelector.auto_select(pipe_or_adapter)
except Exception:
return None


def enable_cache(
# DiffusionPipeline or BlockAdapter
pipe_or_adapter: Union[
Expand Down Expand Up @@ -177,6 +188,10 @@ def enable_cache(
logger.warning("cache_config is None, skip cache acceleration for "
f"{pipe_or_adapter.__class__.__name__}.")

# Auto-select attention backend when none specified
if attention_backend is None and parallelism_config is None:
attention_backend = _auto_select_attention_backend(pipe_or_adapter)

# Set custom attention backend for non-parallelism case
if attention_backend is not None:
if parallelism_config is not None:
Expand Down Expand Up @@ -281,6 +296,30 @@ def enable_cache(
# Enable quantization for the specified component inplace
quantized_component = quantize(component, quantize_config=config)
setattr(pipe, name, quantized_component)

# Auto-enable MindieSDBackend when available on NPU
if not ENV.CACHE_DIT_FORCE_DISABLE_MINDIESD_COMPILE_CONFIG:
try:
import mindiesd # noqa F401

if hasattr(torch, 'npu') and torch.npu.is_available():
from mindiesd.compilation import MindieSDBackend

targets = []
if isinstance(pipe_or_adapter, DiffusionPipeline):
t = pipe_or_adapter.transformer
targets = [t] if not isinstance(t, list) else t
else:
t = getattr(pipe_or_adapter, 'transformer', None)
if t is not None:
targets = [t] if not isinstance(t, list) else t
for i, target in enumerate(targets):
targets[i] = torch.compile(target, backend=MindieSDBackend(), dynamic=True)
if targets:
logger.info("Auto-enabled MindieSDBackend compile for transformer(s).")
except Exception:
pass

return pipe_or_adapter


Expand Down
46 changes: 26 additions & 20 deletions src/cache_dit/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,27 +117,33 @@ def set_compile_configs(
# them to your needs and test the performance
inductor_config.max_fusion_size = 64
inductor_config.max_pointwise_cat_inputs = 8
inductor_config.triton.cudagraphs = cuda_graphs
inductor_config.triton.use_block_ptr = False
inductor_config.triton.codegen_upcast_to_fp32 = True

# Copy from https://pytorch.org/blog/accelerating-generative-ai-3/
inductor_config.conv_1x1_as_mm = True
inductor_config.coordinate_descent_tuning = True
inductor_config.coordinate_descent_check_all_directions = True
inductor_config.epilogue_fusion = False

# Enable epilogue and prologue fusion
if ENV.CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or kwargs.get(
"epilogue_prologue_fusion",
False,
):
inductor_config.epilogue_fusion = True
inductor_config.prologue_fusion = True
inductor_config.epilogue_fusion_first = True

# Dead code elimination
inductor_config.dce = True # default is False
if current_platform.device_type == "npu":
# NPU: skip CUDA-specific inductor configs (triton, coordinate_descent, etc)
inductor_config.dce = True # default is False
inductor_config.epilogue_fusion = False
else:
inductor_config.triton.cudagraphs = cuda_graphs
inductor_config.triton.use_block_ptr = False
inductor_config.triton.codegen_upcast_to_fp32 = True

# Copy from https://pytorch.org/blog/accelerating-generative-ai-3/
inductor_config.conv_1x1_as_mm = True
inductor_config.coordinate_descent_tuning = True
inductor_config.coordinate_descent_check_all_directions = True
inductor_config.epilogue_fusion = False

# Enable epilogue and prologue fusion
if ENV.CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or kwargs.get(
"epilogue_prologue_fusion",
False,
):
inductor_config.epilogue_fusion = True
inductor_config.prologue_fusion = True
inductor_config.epilogue_fusion_first = True

# Dead code elimination
inductor_config.dce = True # default is False

# May need to force disable all cache
if force_disable_compile_caches:
Expand Down
5 changes: 5 additions & 0 deletions src/cache_dit/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ class ENV(object):
# recommended now since it may cause scaled_mm error for some models after TP.
CACHE_DIT_DISABLE_EXCLUDE_FOR_QUANTIZE_AFTER_TP: bool = bool(
int(os.environ.get("CACHE_DIT_DISABLE_EXCLUDE_FOR_QUANTIZE_AFTER_TP", "0")))

# MindIE-SD Backend ENVs
CACHE_DIT_ENABLE_MINDIESD_ATTN: str = os.environ.get("CACHE_DIT_ENABLE_MINDIESD_ATTN", "1")
CACHE_DIT_FORCE_DISABLE_MINDIESD_COMPILE_CONFIG: bool = bool(
int(os.environ.get("CACHE_DIT_FORCE_DISABLE_MINDIESD_COMPILE_CONFIG", "0")))
8 changes: 8 additions & 0 deletions src/cache_dit/kernels/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class KernelBackend(Enum):
TRITON = "Triton"
CUDA = "CUDA"
CUTEDSL = "CuteDSL"
MINDIESD = "MindIESD"
NONE = "None"

@classmethod
Expand Down Expand Up @@ -41,4 +42,11 @@ def is_supported(cls, backend: "KernelBackend") -> bool:
return True
except ImportError:
return False
if backend == cls.MINDIESD:
try:
import mindiesd # noqa F401

return True
except Exception:
return False
return False
18 changes: 18 additions & 0 deletions src/cache_dit/kernels/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def _select_cuda_backend() -> KernelBackend:
return KernelBackend.CUDA


def _select_mindiesd_backend() -> KernelBackend:
return KernelBackend.MINDIESD


def _select_kernel_backend() -> KernelBackend:
"""Select the default backend for kernel ops.

Expand Down Expand Up @@ -149,6 +153,20 @@ def _fused_merge_attn_states_impl(
suff_out,
suff_lse,
)
if backend == KernelBackend.MINDIESD:
max_lse = torch.max(prev_lse, suff_lse)
prev_scale = torch.exp(prev_lse - max_lse)
suff_scale = torch.exp(suff_lse - max_lse)
denom = prev_scale + suff_scale
denom = denom + 1e-12
out = torch.where(
denom > 1e-12,
(prev_out * prev_scale.unsqueeze(-1) + suff_out * suff_scale.unsqueeze(-1)) /
denom.unsqueeze(-1),
prev_out,
)
lse = max_lse + torch.log(denom)
return out, lse
else:
raise ValueError(_ERROR_TEMPLATE.format(backend))

Expand Down
Loading