diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index 3a35008daba..0b812100ed5 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -111,7 +111,10 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, else: model_config = get_model_config(model, model_path) - if isinstance(model_config, NemotronHybridConfig): + if isinstance( + model_config, + NemotronHybridConfig) and mamba_ssm_cache_dtype not in (None, + "auto"): model_config.set_mamba_ssm_cache_dtype(mamba_ssm_cache_dtype) from tensorrt_llm._torch.model_config import ModelConfig