diff --git a/optimum/gptq/data.py b/optimum/gptq/data.py index 7e5fc0b43d..7cb57d31bc 100644 --- a/optimum/gptq/data.py +++ b/optimum/gptq/data.py @@ -156,7 +156,7 @@ def get_c4(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): while True: i = random.randint(0, len(data) - 1) enc = tokenizer(data[i]["text"], return_tensors="pt") - if enc.input_ids.shape[1] >= seqlen: + if enc.input_ids.shape[1] > seqlen: break i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) j = i + seqlen @@ -184,7 +184,7 @@ def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train") while True: i = random.randint(0, len(data) - 1) enc = tokenizer(data[i]["text"], return_tensors="pt") - if enc.input_ids.shape[1] >= seqlen: + if enc.input_ids.shape[1] > seqlen: break i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) j = i + seqlen