Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion diffsynth/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@
},
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors")
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/model-*.safetensors")
"model_hash": "2d11bf14bba8b4e87477c8199a895403",
"model_name": "joyai_image_text_encoder",
"model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder",
Expand Down Expand Up @@ -1071,6 +1072,23 @@
"model_class": "diffsynth.models.fid.FIDInceptionModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsFIDStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="UnifiedReward-2.0-qwen35-9b/model-*.safetensors")
"model_hash": "f9786d06eca5c0f1ece89843b2c4cc66",
"model_name": "image_metrics_unified_reward_2",
"model_class": "diffsynth.models.unified_reward_2.UnifiedReward2Qwen35ForConditionalGeneration",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsUnifiedRewardStateDictConverter",
"extra_kwargs": {"variant": "qwen35_9b"},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Qwen-Image-Bench/model-*.safetensors")
"model_hash": "ff4ad0463675e96738483611f6dd551b",
"model_name": "image_metrics_qwen_image_bench",
"model_class": "diffsynth.models.qwen_image_bench.QwenImageBenchQwen35ForConditionalGeneration",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsUnifiedRewardStateDictConverter",
"extra_kwargs": {"variant": "qwen35"},
},

]

hidream_o1_image_series = [
Expand All @@ -1086,4 +1104,4 @@
stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + hidream_o1_image_series
+ image_metrics_series
)
)
70 changes: 70 additions & 0 deletions diffsynth/metrics/qwen_image_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from transformers import AutoProcessor
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.qwen_image_bench import QwenImageBenchModel
from .base import Metric
from transformers.utils import logging
logging.set_verbosity_error()


class QwenImageBenchMetric(Metric):
def __init__(self, model: QwenImageBenchModel):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(
model_id="Qwen/Qwen-Image-Bench",
origin_file_pattern="model-*.safetensors",
),
processor_config: ModelConfig = ModelConfig(
model_id="Qwen/Qwen-Image-Bench",
origin_file_pattern="",
),
torch_dtype: torch.dtype = None,
device: torch.device = get_device_type(),
max_new_tokens: int = 4096,
resize_long_edge: int = 1024,
processor_kwargs: dict = None,
vram_limit: float = None,
):
processor_kwargs = processor_kwargs or {}
model_pool = cls.download_and_load_models(
[model_config],
torch_dtype=torch_dtype or torch.bfloat16,
device=device,
vram_limit=vram_limit,
)
model = model_pool.fetch_model("image_metrics_qwen_image_bench")
if model is None:
raise ValueError("Cannot find model: image_metrics_qwen_image_bench")
if hasattr(model, "model"):
model = model.model

processor_config.download_if_necessary()
processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
model = QwenImageBenchModel(
model=model,
processor=processor,
max_new_tokens=max_new_tokens,
resize_long_edge=resize_long_edge,
).eval()
return cls(model)

@torch.no_grad()
def evaluate(self, prompt: str | list[str] | None, images, dimensions=None):
return self.model(prompt, images, dimensions=dimensions)

@torch.no_grad()
def score(self, prompt: str | list[str] | None, images, dimensions=None):
outputs = self.evaluate(prompt, images, dimensions=dimensions)
return [self.model._primary_score(output) for output in outputs]

def compute(self, prompt: str | list[str] | None, images, dimensions=None):
return self.score(prompt, images, dimensions=dimensions)

def forward(self, prompt: str | list[str] | None, images, dimensions=None):
return self.score(prompt, images, dimensions=dimensions)
69 changes: 69 additions & 0 deletions diffsynth/metrics/unified_reward_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from transformers import AutoProcessor

from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.unified_reward_2 import UnifiedReward2Model
from .base import Metric
from transformers.utils import logging
logging.set_verbosity_error()


class UnifiedReward2Metric(Metric):
def __init__(self, model: UnifiedReward2Model):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(
model_id="DiffSynth-Studio/ImageMetrics",
origin_file_pattern="UnifiedReward-2.0-qwen35-9b/model-*.safetensors",
),
processor_config: ModelConfig = ModelConfig(
model_id="DiffSynth-Studio/ImageMetrics",
origin_file_pattern="UnifiedReward-2.0-qwen35-9b/",
),
torch_dtype: torch.dtype = None,
device: torch.device = get_device_type(),
max_new_tokens: int = 1024,
processor_kwargs: dict = None,
vram_limit: float = None,
):
processor_kwargs = processor_kwargs or {}
model_pool = cls.download_and_load_models(
[model_config],
torch_dtype=torch_dtype or torch.bfloat16,
device=device,
vram_limit=vram_limit,
)
model = model_pool.fetch_model("image_metrics_unified_reward_2")
if model is None:
raise ValueError("Cannot find model: image_metrics_unified_reward_2")
if hasattr(model, "model"):
model = model.model

processor_config.download_if_necessary()
processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
model = UnifiedReward2Model(
model=model,
processor=processor,
max_new_tokens=max_new_tokens,
).eval()
return cls(model)

@torch.no_grad()
def evaluate(self, prompt: str | list[str] | None, images):
return self.model(prompt, images)

@torch.no_grad()
def score(self, prompt: str | list[str] | None, images):
outputs = self.evaluate(prompt, images)
return [self.model._primary_score(output) for output in outputs]

def compute(self, prompt: str | list[str] | None, images):
return self.score(prompt, images)

def forward(self, prompt: str | list[str] | None, images):
return self.score(prompt, images)
95 changes: 95 additions & 0 deletions diffsynth/metrics/unified_reward_edit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from transformers import AutoProcessor
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.unified_reward_edit import UnifiedRewardEditModel
from .base import Metric
from transformers.utils import logging
logging.set_verbosity_error()


DEFAULT_UNIFIED_REWARD_TASK = "edit_pointwise_score"


class UnifiedRewardEditMetric(Metric):
def __init__(self, model: UnifiedRewardEditModel):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(
model_id="DiffSynth-Studio/ImageMetrics",
origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/model-*.safetensors",
),
processor_config: ModelConfig = ModelConfig(
model_id="DiffSynth-Studio/ImageMetrics",
origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/",
),
torch_dtype: torch.dtype = None,
device: torch.device = get_device_type(),
task: str = DEFAULT_UNIFIED_REWARD_TASK,
max_new_tokens: int = 256,
processor_kwargs: dict = None,
vram_limit: float = None,
):
processor_kwargs = processor_kwargs or {}
model_pool = cls.download_and_load_models(
[model_config],
torch_dtype=torch_dtype or torch.bfloat16,
device=device,
vram_limit=vram_limit,
)
model = model_pool.fetch_model("image_metrics_unified_reward_edit")
if model is None:
raise ValueError("Cannot find model: image_metrics_unified_reward_edit")
if hasattr(model, "model"):
model = model.model

processor_config.download_if_necessary()
processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
model = UnifiedRewardEditModel(
model=model,
processor=processor,
task=task,
max_new_tokens=max_new_tokens,
).eval()
return cls(model)


@staticmethod
def _primary_score(parsed: dict, task: str):
if task == "edit_pairwise_rank":
winner = parsed.get("winner")
if isinstance(winner, (int, float)):
return int(winner)
if isinstance(winner, str):
winner = winner.lower()
if "equally" in winner or "tie" in winner or ("image 1" in winner and "image 2" in winner):
return 0
if "image 1" in winner or "first image" in winner:
return 1
if "image 2" in winner or "second image" in winner:
return 2
return 0
if task == "edit_pairwise_score":
return [parsed.get("image_1_score"), parsed.get("image_2_score")]
return parsed.get("score")


@torch.no_grad()
def evaluate(self, prompt: str | list[str] | None, images, task: str = DEFAULT_UNIFIED_REWARD_TASK):
outputs = self.model(prompt, images, task=task)
return [{**output, "score": self._primary_score(output, task)} for output in outputs]

@torch.no_grad()
def score(self, prompt: str | list[str] | None, images, task: str = DEFAULT_UNIFIED_REWARD_TASK):
outputs = self.evaluate(prompt, images, task=task)
return [output["score"] for output in outputs]

def compute(self, prompt: str | list[str] | None, images, task: str = DEFAULT_UNIFIED_REWARD_TASK):
return self.score(prompt, images, task=task)

def forward(self, prompt: str | list[str] | None, images, task: str = DEFAULT_UNIFIED_REWARD_TASK):
return self.score(prompt, images, task=task)
Loading