Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/strategies/ar_spec/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def step(self,
stop_words_rsp = stop_words.reshape(1, 1, -1)
assert stop_words_rsp.ndim == token_ids_rsp.ndim == 3
stop_mask = (token_ids_rsp == stop_words_rsp).any(-1)
mask = mask ^ stop_mask
mask = torch.logical_or(mask, stop_mask)
# find the index of first `1`, if not found, would be 0
index = torch.argmax(mask.int(), dim=-1, keepdim=True)
# update index of 0 to -1 if not found
Expand Down
13 changes: 13 additions & 0 deletions lmdeploy/pytorch/strategies/ar_spec/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def num_valid_ids(self):
def num_spec_ids(self):
return self._num_spec_ids

@property
def routed_experts(self) -> np.ndarray:
if (not self.return_routed_experts) or self.all_routed_experts is None:
return None

end = max(0, self.num_valid_ids - 1)
if 0 < end <= len(self.all_routed_experts):
return self.all_routed_experts.get_real()[:end]
else:
return None

@property
def generated_ids(self) -> np.ndarray:
end = self.num_valid_ids
Expand All @@ -59,6 +70,8 @@ def set_stop_pos(self, pos: int):
self._num_spec_ids = 0
self._num_new_valid = 0
self.history_cache.resize(self.num_valid_ids)
if self.all_routed_experts is not None:
self.all_routed_experts.resize(self.num_valid_ids-1)
Comment thread
RunningLeon marked this conversation as resolved.
Outdated

def _update_token_ids_inputs(self, token_ids: np.ndarray):
"""Append tokens."""
Expand Down
Loading