Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,20 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
this->compile();
}

std::vector<std::vector<infinicore::Tensor>> InferEngine::get_kv_cache() {
std::vector<std::vector<infinicore::Tensor>> kv_cache_list;
if (workers_.empty()) {
throw std::runtime_error("InferEngine::get_cache_vec: no workers");
}
kv_cache_list.reserve(workers_.size());
for (auto &worker : workers_) {
kv_cache_list.push_back(std::move(worker->get_kv_cache()));
}

for (auto &worker : workers_) {
worker->wait();
}
return kv_cache_list;
}

} // namespace infinilm::engine
2 changes: 2 additions & 0 deletions csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class InferEngine {

void reset_cache(const cache::CacheConfig *new_config);

std::vector<std::vector<infinicore::Tensor>> get_kv_cache();

~InferEngine();

const distributed::DistConfig &get_dist_config() const;
Expand Down
16 changes: 16 additions & 0 deletions csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,22 @@ void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
cv_.notify_all();
}

//------------------------------------------------------
// get_forward_context
//------------------------------------------------------
std::vector<infinicore::Tensor> RankWorker::get_kv_cache() {
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_ || should_exit_; });

if (should_exit_) {
throw std::runtime_error("RankWorker stopped; cannot get_cache_vec");
}

ASSERT(forward_context_.kv_cache_vec.size() > 0 && "RankWorker::get_kv_cache(): kv_cache_vec is empty");

return forward_context_.kv_cache_vec;
}

//------------------------------------------------------
// close -- request shutdown and join thread
//------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class RankWorker {
// Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig *new_config);

std::vector<infinicore::Tensor> get_kv_cache();

// Compile the model graph if enabled.
void compile();

Expand Down
13 changes: 5 additions & 8 deletions csrc/pybind11/engine/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ inline void bind_infer_engine(py::module &m) {
}
return state_dict_tp_all;
})
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
auto cfg = self.get_cache_config();
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })
Expand Down Expand Up @@ -114,10 +112,9 @@ inline void bind_infer_engine(py::module &m) {
}
return state_dict_tp_all;
})
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_kv_cache", &InferEngine::get_kv_cache, "Get per-rank kv cache list")
.def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
Expand Down
6 changes: 6 additions & 0 deletions python/infinilm/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ def reset_cache(self, cache_config):
def state_dict_keyname(self):
return super().state_dict()[0].keys()

def get_kv_cache(self) -> list[list[infinicore.Tensor]]:
"""
get per-rank kv cache.
"""
return super().get_kv_cache()

def load_state_dict(self, state_dict, strict=None):
for name, param in state_dict.items():
super().load_param(name, param._underlying)
56 changes: 56 additions & 0 deletions python/infinilm/llm/engine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Engine configuration — shared by LLMEngine, Worker, ModelRunner.
"""

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class EngineConfig:
"""Configuration for LLM Engine.

Attributes:
model_path: Path to the model directory.
device: Device type string ('cpu', 'cuda', 'mlu', etc.).
dtype: Data type string ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
cache_type: Cache type ('paged' or 'static').
max_batch_size: Maximum batch size for inference (only for paged cache).
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
kv_connector_type: KV connector type for PD separation ('null', etc.).
kv_connector_role: KV connector role ('none', 'sender', 'receiver', 'both').
kv_connector_kwargs: Extra keyword arguments for the KV connector.
"""

model_path: str
device: str = "cuda"
dtype: str = "float16"
tensor_parallel_size: int = 1
cache_type: str = "paged" # "paged" or "static"
max_batch_size: int = 16
max_tokens: int = 4096
num_blocks: int = 512
block_size: int = 256
max_cache_len: int = 4096
temperature: float = 1.0
top_p: float = 0.8
top_k: int = 1
enable_graph: bool = False
attn_backend: str = "default"
# ---- PD separation ----
kv_connector_type: str = "null"
kv_connector_role: str = "none"
kv_connector_kwargs: Optional[dict] = field(default=None)

def __post_init__(self):
if self.kv_connector_kwargs is None:
self.kv_connector_kwargs = {}
51 changes: 51 additions & 0 deletions python/infinilm/llm/kv_connector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
KV Connector package for Prefill-Decode disaggregated inference.
"""

from infinilm.llm.kv_connector.base import (
KVConnectorBase,
KVConnectorMetadata,
KVConnectorRole,
NullKVConnector,
)

__all__ = [
"KVConnectorBase",
"KVConnectorMetadata",
"KVConnectorRole",
"NullKVConnector",
"create_kv_connector",
]


def create_kv_connector(
connector_type: str = "null",
role: str = KVConnectorRole.NONE,
**kwargs,
) -> KVConnectorBase:
"""Factory function to create KV connectors.

Args:
connector_type: Type of connector.
- "null": No-op connector (standalone mode, default).
- Future: "mooncake", "rdma", "tcp", etc.
role: Role of the connector (none/sender/receiver/both).
**kwargs: Additional connector-specific arguments.

Returns:
A KVConnectorBase instance.

Raises:
ValueError: If connector_type is not recognized.
"""
if connector_type is None or connector_type == "null":
return NullKVConnector(**kwargs)
# ---- Future connector types can be registered here ----
# elif connector_type == "mooncake":
# from infinilm.llm.kv_connector.mooncake_connector import MooncakeKVConnector
# return MooncakeKVConnector(role=role, **kwargs)
else:
raise ValueError(
f"Unknown KV connector type: '{connector_type}'. "
f"Supported types: ['null']"
)
Loading