Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions diffsynth/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,27 @@
"model_class": "diffsynth.models.fid.FIDInceptionModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsFIDStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="LPIPS/alexnet.safetensors")
"model_hash": "08a75c660c9b2e775c530a0955857f1f",
"model_name": "image_metrics_lpips_alex",
"model_class": "diffsynth.models.lpips.LPIPSModel",
"extra_kwargs": {"net": "alex"},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="LPIPS/vgg.safetensors")
"model_hash": "5740953aaa8aba2ecd9b9c23da813591",
"model_name": "image_metrics_lpips_vgg",
"model_class": "diffsynth.models.lpips.LPIPSModel",
"extra_kwargs": {"net": "vgg"},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="LPIPS/squeezenet.safetensors")
"model_hash": "ff994b70a30599287a332105396d5004",
"model_name": "image_metrics_lpips_squeeze",
"model_class": "diffsynth.models.lpips.LPIPSModel",
"extra_kwargs": {"net": "squeeze"},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="UnifiedReward-2.0-qwen35-9b/model-*.safetensors")
"model_hash": "f9786d06eca5c0f1ece89843b2c4cc66",
Expand All @@ -1088,7 +1109,6 @@
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsUnifiedRewardStateDictConverter",
"extra_kwargs": {"variant": "qwen35"},
},

]

hidream_o1_image_series = [
Expand All @@ -1104,4 +1124,4 @@
stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + hidream_o1_image_series
+ image_metrics_series
)
)
2 changes: 2 additions & 0 deletions diffsynth/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .hpsv2 import HPSv2Metric
from .hpsv3 import HPSv3Metric
from .image_reward import ImageRewardMetric
from .lpips import LPIPSMetric
from .pickscore import PickScoreMetric
from .qwen_image_bench import QwenImageBenchMetric
from .unified_reward_2 import UnifiedReward2Metric
Expand All @@ -22,6 +23,7 @@
"CLIPMetric",
"AestheticMetric",
"FIDMetric",
"LPIPSMetric",
"QwenImageBenchMetric",
"UnifiedReward2Metric",
"UnifiedRewardEditMetric",
Expand Down
63 changes: 63 additions & 0 deletions diffsynth/metrics/lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch

from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.lpips import LPIPSModel, LPIPS_NET_CHOICES, LPIPSCompute
from .base import Metric


_LPIPS_DEFAULT_FILES = {
"alex": "LPIPS/alexnet.safetensors",
"vgg": "LPIPS/vgg.safetensors",
"squeeze": "LPIPS/squeezenet.safetensors",
}

_LPIPS_MODEL_NAMES = {
"alex": "image_metrics_lpips_alex",
"vgg": "image_metrics_lpips_vgg",
"squeeze": "image_metrics_lpips_squeeze",
}


class LPIPSMetric(Metric):
def __init__(self, model: LPIPSCompute):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
net: str = "alex",
model_config: ModelConfig = None,
device: torch.device = get_device_type(),
batch_size: int = 16,
target_size: int = 512,
vram_limit: float = None,
):
Comment on lines +28 to +36
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Expose the force_resize parameter in from_pretrained to allow skipping the expensive size-checking loop.

    def from_pretrained(
        cls,
        net: str = "alex",
        model_config: ModelConfig = None,
        device: torch.device = get_device_type(),
        batch_size: int = 16,
        num_workers: int = 0,
        target_size: int = 512,
        vram_limit: float = None,
        force_resize: bool = False,
    ):

if net not in LPIPS_NET_CHOICES:
raise ValueError(f"net must be one of {LPIPS_NET_CHOICES}, got {net!r}")
if model_config is None:
model_config = ModelConfig(
model_id="DiffSynth-Studio/ImageMetrics",
origin_file_pattern=_LPIPS_DEFAULT_FILES[net],
)
model_pool = cls.download_and_load_models([model_config], torch_dtype=torch.float32, device=device, vram_limit=vram_limit)
backbone = model_pool.fetch_model(_LPIPS_MODEL_NAMES[net])
if backbone is None:
raise RuntimeError(
f"Failed to load LPIPS model for net={net!r}. The provided weights do not match the registered hash for {_LPIPS_MODEL_NAMES[net]}."
)
compute_model = LPIPSCompute(
model=backbone,
device=device,
batch_size=batch_size,
target_size=target_size,
)
Comment on lines +50 to +55
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pass force_resize to LPIPSCompute.

Suggested change
compute_model = LPIPSCompute(
model=backbone,
device=device,
batch_size=batch_size,
num_workers=num_workers,
target_size=target_size,
)
compute_model = LPIPSCompute(
model=backbone,
device=device,
batch_size=batch_size,
num_workers=num_workers,
target_size=target_size,
force_resize=force_resize,
)

return cls(compute_model)

@torch.no_grad()
def compute(self, image_a, image_b) -> float:
return self.model.compute(image_a, image_b)

def forward(self, image_a, image_b):
return self.compute(image_a, image_b)
Loading