Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ python examples/inference/inference.py \
--output_path output_images
```

Or run weights locally with

```bash
python examples/inference/inference.py \
--model_path path_to_local_model_directory \
--source_image path_to_your_image.jpg \
--output_path output_images
```

See the trained models on the HF Hub 🤗
- [Surface normals Checkpoint](https://huggingface.co/jasperai/LBM_normals)
- [Depth Checkpoint](https://huggingface.co/jasperai/LBM_depth)
Expand Down Expand Up @@ -163,6 +172,12 @@ To train the model, you can use the following command:
python examples/training/train_lbm_surface.py examples/training/config/surface.yaml
```

To prune trained output ckpt to just model weights for inference

```bash
python examples/training/convert_checkpoint_to_safetensors.py --checkpoint_path examples/training/output --output_dir out
```

*Note*: Make sure to update the relevant section of the `yaml` file to use your own data and log the results on your own [WandB](https://wandb.ai/site).

## Citation
Expand Down
14 changes: 12 additions & 2 deletions examples/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,24 @@
default="normals",
choices=["normals", "depth", "relighting"],
)
parser.add_argument(
"--model_path",
type=str,
help="Path to local model directory (overrides model_name if provided)",
)


args = parser.parse_args()


def main():
# download the weights from HF hub
if not os.path.exists(os.path.join(PATH, "ckpts", f"{args.model_name}")):
# Use custom model path if provided
if args.model_path:
logging.info(f"Loading LBM model from custom path: {args.model_path}")
model = get_model(args.model_path, torch_dtype=torch.bfloat16, device="cuda")

# Otherwise use model_name with HF hub or local cache
elif not os.path.exists(os.path.join(PATH, "ckpts", f"{args.model_name}")):
logging.info(f"Downloading {args.model_name} LBM model from HF hub...")
model = get_model(
f"jasperai/LBM_{args.model_name}",
Expand Down
135 changes: 135 additions & 0 deletions examples/training/convert_checkpoint_to_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
"""
Convert PyTorch Lightning checkpoint to safetensors format for inference.

This script converts the large training checkpoints (~14GB) that include optimizer state
and training metadata to lightweight safetensors files (~5GB) with just model weights.
"""

import argparse
import logging
import os
import shutil
from pathlib import Path

import torch
import yaml
from safetensors.torch import save_file

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def convert_checkpoint_to_safetensors(
checkpoint_path: str,
output_dir: str,
config_path: str = None,
):
"""
Convert a PyTorch Lightning checkpoint to safetensors format.

Args:
checkpoint_path: Path to the .ckpt file
output_dir: Directory to save the converted files
config_path: Path to config.yaml (if None, will look in checkpoint directory)
"""
checkpoint_path = Path(checkpoint_path)
output_dir = Path(output_dir)

if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)

# Load checkpoint
logger.info(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

# Extract model state dict
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
raise ValueError("No 'state_dict' found in checkpoint")

# Remove "model." prefix from keys (as done in inference/utils.py line 76)
logger.info("Cleaning state dict - removing 'model.' prefix")
cleaned_state_dict = {}
model_prefix = "model."

for key, value in state_dict.items():
if key.startswith(model_prefix):
new_key = key[len(model_prefix):]
# Clone tensors to break memory sharing (fixes safetensors shared memory error)
cleaned_state_dict[new_key] = value.clone()
else:
# Keep keys that don't have the model prefix
cleaned_state_dict[key] = value.clone()

# Save as safetensors
safetensors_path = output_dir / "model.safetensors"
logger.info(f"Saving safetensors to {safetensors_path}")
save_file(cleaned_state_dict, safetensors_path)

# Handle config.yaml
if config_path is None:
# Look for config.yaml in the same directory as checkpoint
config_path = checkpoint_path.parent / "config.yaml"
else:
config_path = Path(config_path)

if config_path.exists():
output_config_path = output_dir / "config.yaml"
logger.info(f"Copying config from {config_path} to {output_config_path}")
shutil.copy2(config_path, output_config_path)
else:
logger.warning(f"Config file not found at {config_path}")
logger.info("You may need to manually create config.yaml for inference")

# Log size comparison
original_size = checkpoint_path.stat().st_size / (1024**3) # GB
new_size = safetensors_path.stat().st_size / (1024**3) # GB

logger.info(f"Conversion complete!")
logger.info(f"Original checkpoint: {original_size:.2f} GB")
logger.info(f"Safetensors file: {new_size:.2f} GB")
logger.info(f"Size reduction: {((original_size - new_size) / original_size * 100):.1f}%")
logger.info(f"Output directory: {output_dir}")


def main():
parser = argparse.ArgumentParser(
description="Convert PyTorch Lightning checkpoint to safetensors format"
)
parser.add_argument(
"--checkpoint_path",
required=True,
help="Path to the .ckpt file to convert"
)
parser.add_argument(
"--output_dir",
required=True,
help="Directory to save the converted files"
)
parser.add_argument(
"--config_path",
help="Path to config.yaml (optional, will look in checkpoint directory if not provided)"
)

args = parser.parse_args()

try:
convert_checkpoint_to_safetensors(
checkpoint_path=args.checkpoint_path,
output_dir=args.output_dir,
config_path=args.config_path,
)
except Exception as e:
logger.error(f"Conversion failed: {e}")
return 1

return 0


if __name__ == "__main__":
exit(main())