diff --git a/examples/hstu/configs/inference_config.py b/examples/hstu/configs/inference_config.py index 6f44362c6..a89850b19 100755 --- a/examples/hstu/configs/inference_config.py +++ b/examples/hstu/configs/inference_config.py @@ -262,16 +262,11 @@ def get_kvcache_metadata_buffer( hstu_config: InferenceHSTUConfig, kvcache_config: KVCacheConfig ): device = torch.cuda.current_device() - torch.bfloat16 if hstu_config.bf16 else torch.float16 if hstu_config.fp16 else torch.float32 max_new_history_seqlen = hstu_config.max_batch_size * hstu_config.max_seq_len max_num_pages_per_seq = ( hstu_config.max_seq_len + hstu_config.max_seq_len + kvcache_config.page_size - 1 ) // kvcache_config.page_size - max_host_kv_buffer_size = ( - hstu_config.max_batch_size * hstu_config.max_seq_len, - hstu_config.num_heads * hstu_config.head_dim, - ) default_num_pages_per_seq = 4 paged_indices_buffer = torch.randint(