diff --git a/docs/en/user_guides/how_to_deploy.md b/docs/en/user_guides/how_to_deploy.md index 0b8e31a395..458bc3eb3f 100644 --- a/docs/en/user_guides/how_to_deploy.md +++ b/docs/en/user_guides/how_to_deploy.md @@ -32,6 +32,12 @@ For example: python tools/misc/publish_model.py ./epoch_10.pth ./epoch_10_publish.pth ``` +To save model as float16 (half) add --float16, which is as follows: + +```shell +python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE} --float16 +``` + The script will automatically simplify the model, save the simplified model to the specified path, and add a timestamp to the filename, for example, `./epoch_10_publish-21815b2c_20230726.pth`. ## Deployment with MMDeploy diff --git a/docs/zh_cn/user_guides/how_to_deploy.md b/docs/zh_cn/user_guides/how_to_deploy.md index 2349fcca09..4cad3c70d3 100644 --- a/docs/zh_cn/user_guides/how_to_deploy.md +++ b/docs/zh_cn/user_guides/how_to_deploy.md @@ -32,6 +32,12 @@ python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE} python tools/misc/publish_model.py ./epoch_10.pth ./epoch_10_publish.pth ``` +要将模型保存为 float16 (half),请添加 --float16,如下所示: + +```shell +python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE} --float16 +``` + 脚本会自动对模型进行精简,并将精简后的模型保存到制定路径,并在文件名的最后加上时间戳,例如 `./epoch_10_publish-21815b2c_20230726.pth`。 ## 使用 MMDeploy 部署 diff --git a/mmpose/models/pose_estimators/base.py b/mmpose/models/pose_estimators/base.py index 216f592fda..db902ec6ae 100644 --- a/mmpose/models/pose_estimators/base.py +++ b/mmpose/models/pose_estimators/base.py @@ -158,6 +158,9 @@ def forward(self, if self.metainfo is not None: for data_sample in data_samples: data_sample.set_metainfo(self.metainfo) + param = next(self.backbone.parameters()) + if param.is_cuda and param.dtype == torch.float16: + inputs = inputs.half() return self.predict(inputs, data_samples) elif mode == 'tensor': return self._forward(inputs) diff --git a/tools/misc/publish_model.py b/tools/misc/publish_model.py index addf4cca64..d024f3151a 100644 --- a/tools/misc/publish_model.py +++ b/tools/misc/publish_model.py @@ -20,12 +20,21 @@ def parse_args(): type=str, default=['meta', 'state_dict'], help='keys to save in published checkpoint (default: meta state_dict)') + parser.add_argument( + '--float16', + action='store_true', + default=False, + help='Whether save model as float16') args = parser.parse_args() return args -def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): +def process_checkpoint(in_file, + out_file, + save_keys=['meta', 'state_dict'], + float16=False): checkpoint = torch.load(in_file, map_location='cpu') + checkpoint['meta']['float16'] = float16 # only keep `meta` and `state_dict` for smaller file size ckpt_keys = list(checkpoint.keys()) @@ -41,6 +50,17 @@ def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. + if float16: + print(save_keys) + if 'meta' not in save_keys: + raise ValueError( + 'Key `meta` must be in save_keys to save model as float16. ' + 'Change float16 to False or add `meta` in save_keys.') + print_log('Saving model as float16.', logger='current') + for key in checkpoint['state_dict'].keys(): + checkpoint['state_dict'][key] = checkpoint['state_dict'][key].half( + ) + if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) else: @@ -58,7 +78,8 @@ def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): def main(): args = parse_args() - process_checkpoint(args.in_file, args.out_file, args.save_keys) + process_checkpoint(args.in_file, args.out_file, args.save_keys, + args.float16) if __name__ == '__main__':