diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index f7c4105..064e4f3 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -2,6 +2,7 @@ from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph from .gemma import Gemma2 from .mistral import Mistral, MistralAwqOffload, MistralOffload, MistralCudagraph, MistralAwq +from .glm4 import Glm4 class AutoModelLM: """ 自动模型加载器,根据模型类型动态加载对应的类。 @@ -117,7 +118,8 @@ class AutoModelLM: "mistralai/Mistral-Small-24B-Instruct-2501": Mistral, "stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwq, "PyrTools/Ministral-8B-Instruct-2410-AWQ": MistralAwq, - "mistralai/Ministral-8B-Instruct-2410": Mistral + "mistralai/Ministral-8B-Instruct-2410": Mistral, + "THUDM/glm-4-9b-chat-hf": Glm4, } _CUDAGRAPH_MODEL_MAPPING = { diff --git a/umbrella/models/glm4.py b/umbrella/models/glm4.py new file mode 100644 index 0000000..2f05f75 --- /dev/null +++ b/umbrella/models/glm4.py @@ -0,0 +1,186 @@ +from transformers import GlmForCausalLM, GlmConfig +import torch +import torch.nn.functional as F +import gc +import flashinfer +from ..attn.cache import KV_Cache, StaticKV_Cache +from .glm4_layer import Glm4Layer +from .base import LLMBase +from .model_utils import layer_norm, capture_graph + +from tqdm import tqdm + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=2): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + + + +class Glm4(LLMBase): + def __init__(self, + model_name: str, + batch_size :int = 1, + max_length :int = 256, + device :str = 'cuda:0', + dtype = torch.float16) -> None: + super().__init__() + self.batch_size = batch_size + self.device = device + self.dtype = dtype + self.config = GlmConfig.from_pretrained(model_name) + self.model_name = model_name + self.max_length = max_length + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] + + + def alloc(self, **kwargs): + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = GlmForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers :list[Glm4Layer] = [] + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = Glm4Layer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + + + + @torch.inference_mode() + def layer_compute(self, + buffer: Glm4Layer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + + #attention + query_states = F.linear(hidden_states, buffer.wq, bias=buffer.qbias) + key_states = F.linear(hidden_states, buffer.wk, bias=buffer.kbias) + value_states = F.linear(hidden_states, buffer.wv, bias=buffer.vbias) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + + + + + # MLP + up_states = F.linear(hidden_states, buffer.gate_up_proj) + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * F.silu(gate) + hidden_states = F.linear(up_states, buffer.down_proj) + hidden_states = residual + hidden_states + return hidden_states + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + hidden_states = F.embedding(input_ids, self.embed_tokens) + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + \ No newline at end of file diff --git a/umbrella/models/glm4_layer.py b/umbrella/models/glm4_layer.py new file mode 100644 index 0000000..9ae3636 --- /dev/null +++ b/umbrella/models/glm4_layer.py @@ -0,0 +1,95 @@ +from __future__ import annotations +import torch +from transformers.models.glm.modeling_glm import GlmDecoderLayer + +class Glm4Layer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wq :torch.Tensor = None + self.wk :torch.Tensor = None + self.wv :torch.Tensor = None + self.wo :torch.Tensor = None + self.qbias :torch.Tensor = None + self.kbias :torch.Tensor = None + self.vbias :torch.Tensor = None + self.obias :torch.Tensor = None + self.gate_up_proj: torch.Tensor = None + self.down_proj: torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: GlmDecoderLayer): + self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() + self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() + self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + self.qbias :torch.Tensor = hf_layer.self_attn.q_proj.bias.detach() + self.kbias :torch.Tensor = hf_layer.self_attn.k_proj.bias.detach() + self.vbias :torch.Tensor = hf_layer.self_attn.v_proj.bias.detach() + self.gate_up_proj: torch.Tensor = hf_layer.mlp.gate_up_proj.weight.detach() + self.down_proj: torch.Tensor = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wq = self.wq.to(device, non_blocking=non_blocking) + self.wk = self.wk.to(device, non_blocking=non_blocking) + self.wv = self.wv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.qbias = self.qbias.to(device, non_blocking=non_blocking) + self.kbias = self.kbias.to(device, non_blocking=non_blocking) + self.vbias = self.vbias.to(device, non_blocking=non_blocking) + self.gate_up_proj = self.gate_up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: Glm4Layer): + + self.wq.copy_(layer.wq, non_blocking=True) + self.wk.copy_(layer.wk, non_blocking=True) + self.wv.copy_(layer.wv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.qbias.copy_(layer.qbias, non_blocking=True) + self.kbias.copy_(layer.kbias, non_blocking=True) + self.vbias.copy_(layer.vbias, non_blocking=True) + self.gate_up_proj.copy_(layer.gate_up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: Glm4Layer, device): + + self.device = device + self.wq = torch.zeros_like(layer.wq).to(device) + self.wk = torch.zeros_like(layer.wk).to(device) + self.wv = torch.zeros_like(layer.wv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + self.qbias = torch.zeros_like(layer.qbias).to(device) + self.kbias = torch.zeros_like(layer.kbias).to(device) + self.vbias = torch.zeros_like(layer.vbias).to(device) + self.gate_up_proj = torch.zeros_like(layer.gate_up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + + + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) diff --git a/umbrella/templates.py b/umbrella/templates.py index ac30a80..5c0beb1 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -22,8 +22,10 @@ """, 'gemma2': "{}", -'mistral': "[INST] {} [/INST]" - +'mistral': "[INST] {} [/INST]", +'glm4': """<|user|> +{}<|assistant|> +""" } SysPrompts = { @@ -39,6 +41,9 @@ 'gemma2': "", 'gemma2-it': "", 'mistral': "", + 'glm4': """[gMASK]<|system|> + You are a helpful assistant + """ }