diff --git a/llava/model/configuration_llava.py b/llava/model/configuration_llava.py index a7872515..c855664b 100755 --- a/llava/model/configuration_llava.py +++ b/llava/model/configuration_llava.py @@ -66,6 +66,7 @@ def __init__( time_token_format=None, image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}', video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}', + model_dtype: str = "torch.bfloat16", **kwargs, ): super().__init__() @@ -111,6 +112,8 @@ def __init__( self.image_encoder = image_encoder self.video_encoder = video_encoder + self.model_dtype = model_dtype + class JsonSchemaResponseFormat(BaseModel): schema_: str = Field(alias="schema")