diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index b841d52a6b..3cffdc22cf 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1174,7 +1174,12 @@ def get_model_from_task( ) if library_name == "timm": - model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True) + import os + + if os.path.isdir(model_name_or_path): + model = model_class(model_name_or_path, pretrained=True, exportable=True) + else: + model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True) model = model.to(torch_dtype).to(device) elif library_name == "sentence_transformers": token = model_kwargs.pop("token", None)