Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 73 additions & 65 deletions python/infinilm/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from typing import List, Optional, Union, AsyncIterator
from dataclasses import dataclass

from transformers import AutoTokenizer
from tokenizers import decoders as _dec

import infinicore

from infinilm.llm.request import (
Expand All @@ -28,11 +25,12 @@
from infinilm.llm.sampling_params import SamplingParams
from infinilm.llm.scheduler import Scheduler
from infinilm.llm.static_scheduler import StaticScheduler

from infinilm.processors import AutoInfinilmProcessor
from infinilm.distributed import DistConfig
from infinilm.infer_engine import InferEngine
from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.multimodal.multimodal import resolve_multimodal_inputs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,11 +97,9 @@ def __init__(self, config: EngineConfig):
self.model_engine, config.model_path, dtype=self.model_engine.dtype
)

# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
config.model_path, trust_remote_code=True
)
self._fix_tokenizer_decoder()
# Initialize processor/tokenizer
self.processor = AutoInfinilmProcessor.from_pretrained(config.model_path)
self.tokenizer = self.processor.get_tokenizer()

# Initialize KV cache based on cache type
if config.cache_type == "static":
Expand Down Expand Up @@ -166,26 +162,6 @@ def _init_device(self):

self.dtype = dtype_map[self.config.dtype]

def _fix_tokenizer_decoder(self):
"""Fix tokenizer decoder for llama models."""
if "llama" in self.model_engine.model_type.lower():
backend = getattr(self.tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend)
norm = getattr(target, "normalizer", None)
dec = getattr(target, "decoder", None)
sn = repr(norm)[:800] if norm is not None else ""
sd = repr(dec)[:800] if dec is not None else ""
has_prepend = "Prepend" in sn
has_strip = "Strip" in sd
if has_prepend and has_strip:
target.decoder = _dec.Sequence(
[
_dec.Replace("▁", " "),
_dec.ByteFallback(),
_dec.Fuse(),
]
)

def add_request(self, request: InferenceRequest):
"""Add a request to the scheduler."""
self.scheduler.add_request(request)
Expand All @@ -204,10 +180,12 @@ def step(self) -> tuple[list[InferenceRequest], list[tuple]]:
return [], []

# Build model inputs
model_input_dict = scheduler_output.build_model_inputs(
self.config.temperature, self.config.top_p, self.config.top_k
model_input = self.processor.build_model_inputs(
scheduler_output,
self.config.temperature,
self.config.top_p,
self.config.top_k,
)
model_input = self._prepare_model_input(model_input_dict)

# Run inference
sampled_tokens = self.model_engine.forward(**model_input)
Expand All @@ -222,28 +200,6 @@ def step(self) -> tuple[list[InferenceRequest], list[tuple]]:

return scheduler_output.scheduled_requests, pending

def _prepare_model_input(self, model_input_dict: dict) -> dict:
"""Convert model input dict to infinicore tensors."""
model_input = {}
for key, value in model_input_dict.items():
if value is None:
# Skip None values (block_tables/slot_mapping for static cache)
model_input[key] = None
elif key in ["input_ids", "position_ids", "slot_mapping"]:
model_input[key] = infinicore.from_list(value, dtype=infinicore.int64)
elif key in [
"past_kv_lengths",
"total_kv_lengths",
"input_offsets",
"cu_seqlens",
"block_tables",
]:
model_input[key] = infinicore.from_list(value, dtype=infinicore.int32)
else:
# temperature, top_k, top_p, etc.
model_input[key] = value
return model_input

def _update_requests(
self,
is_prefill: bool,
Expand Down Expand Up @@ -361,6 +317,12 @@ def detokenize(self, token_ids: List[int]) -> str:
"""Detokenize token IDs to text."""
return self.tokenizer.decode(token_ids)

def process(self, prompt, images, videos, audios, **kwargs) -> dict:
"""Process the input prompt and media into final model inputs."""
return self.processor(
prompt, images=images, videos=videos, audios=audios, **kwargs
)

def apply_chat_template(
self,
messages: List[dict],
Expand All @@ -369,7 +331,7 @@ def apply_chat_template(
) -> str:
"""Apply chat template to messages."""
chat_template_kwargs = chat_template_kwargs or {}
return self.tokenizer.apply_chat_template(
return self.processor.apply_chat_template(
conversation=messages,
add_generation_prompt=add_generation_prompt,
tokenize=False,
Expand Down Expand Up @@ -654,6 +616,9 @@ def _batch_put(pending):

def add_request(
self,
messages: Optional[List[dict]],
apply_chat_template: bool = True,
add_generation_prompt: bool = True,
prompt: Optional[str] = None,
prompt_token_ids: Optional[List[int]] = None,
sampling_params: Optional[SamplingParams] = None,
Expand All @@ -665,8 +630,28 @@ def add_request(
"""Add a request to the engine.

Args:
prompt: Text prompt for generation.
prompt_token_ids: Pre-tokenized prompt.
messages: List of message dicts (chat conversation). Following this format:
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "xxxxxxxxx"
},
{
"type": "image_url",
"image_url": {
"url": "xxx.jpg"
}
},
]
},
]
apply_chat_template: Whether to apply the chat template.
add_generation_prompt: Whether to add a generation prompt.
prompt: Text prompt for generation. If provided, it will be used directly after encoded by tokenizer, ignoring messages.
prompt_token_ids: Pre-tokenized prompt. If provided, it will be used directly as input.
sampling_params: Sampling parameters.
request_id: Optional request ID.
request_data: Optional request data dict (for server use).
Expand All @@ -678,8 +663,32 @@ def add_request(
if request_id is None:
request_id = f"cmpl-{uuid.uuid4().hex}"

if prompt_token_ids is None and prompt is not None:
images, videos, audios = None, None, None
processed_inputs = None

if prompt_token_ids is not None:
prompt = self.engine.detokenize(prompt_token_ids)
elif prompt is not None:
prompt_token_ids = self.engine.tokenize(prompt)
else:
assert messages is not None, (
"Either messages or prompt/prompt_token_ids must be provided"
)

assert apply_chat_template, (
"apply_chat_template needs to be true for multi-role conversation"
)

prompt = self.engine.apply_chat_template(
messages, add_generation_prompt=add_generation_prompt
)

images, videos, audios = resolve_multimodal_inputs(messages)
processed_inputs = self.engine.process(
prompt, images, videos, audios, return_tensors="pt"
)

prompt_token_ids = processed_inputs.get("input_ids").flatten().tolist()

if sampling_params is None:
sampling_params = SamplingParams(max_tokens=self.config.max_tokens)
Expand All @@ -691,6 +700,7 @@ def add_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
processed_inputs=processed_inputs,
sampling_params=sampling_params,
eos_token_ids=self.engine.eos_token_ids,
request_data=request_data,
Expand All @@ -711,7 +721,7 @@ def add_chat_request(
request_data: Optional[dict] = None,
http_request: Optional[any] = None,
add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
**kwargs,
) -> InferenceRequest:
"""Add a chat request to the engine.

Expand All @@ -725,13 +735,11 @@ def add_chat_request(
Returns:
The created InferenceRequest object.
"""
prompt = self.engine.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=chat_template_kwargs,
)

return self.add_request(
prompt=prompt,
messages=messages,
apply_chat_template=True,
add_generation_prompt=add_generation_prompt,
sampling_params=sampling_params,
request_id=request_id,
request_data=request_data,
Expand Down
7 changes: 5 additions & 2 deletions python/infinilm/llm/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,23 @@ def __init__(
request_id: str,
prompt: Optional[str] = None,
prompt_token_ids: Optional[List[int]] = None,
processed_inputs: Optional[dict] = None,
sampling_params: Optional[SamplingParams] = None,
eos_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
# For server use
request_data: Optional[dict] = None,
http_request: Optional[Any] = None,
):
self.arrival_time: float = arrival_time or time.time()
self.finished_time: Optional[float] = None

# Request metadata
self.request_id: str = request_id
self.prompt: Optional[str] = prompt
self.prompt_token_ids: List[int] = prompt_token_ids or []
self.prompt_length: int = len(self.prompt_token_ids)
self.arrival_time: float = arrival_time or time.time()
self.finished_time: Optional[float] = None
self.processed_inputs: Optional[dict] = processed_inputs

# Sampling parameters
self.sampling_params: SamplingParams = sampling_params or SamplingParams()
Expand Down
27 changes: 27 additions & 0 deletions python/infinilm/multimodal/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import List, Union
from PIL import Image


def resolve_multimodal_inputs(messages: Union[List[dict], dict]):
"""Get images, videos, audios from the messages."""
if isinstance(messages, dict):
messages = [messages]

images = []
videos = []
audios = []

for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
continue

for item in content:
if item.get("type") == "image":
# TODO support other image url formats
images.append(Image.open(item["image_url"]))

else: # TODO support video/audio
raise NotImplementedError("Only image input is supported for now")

return images, videos, audios
18 changes: 18 additions & 0 deletions python/infinilm/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .processor import InfinilmProcessor
from .basic_llm_processor import BasicLLMProcessor
from .llama_processor import LlamaProcessor

from transformers import AutoConfig


class AutoInfinilmProcessor:
@classmethod
def from_pretrained(cls, model_dir_path: str, **kwargs) -> InfinilmProcessor:
"""Factory method to get the appropriate processor based on model config."""
config = AutoConfig.from_pretrained(model_dir_path, trust_remote_code=True)
model_type = config.model_type.lower()

if model_type in ["llama"]:
return LlamaProcessor(model_dir_path)
else:
return BasicLLMProcessor(model_dir_path)
Loading