diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 86619d611..c40004820 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -951,6 +951,7 @@ }, { # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors") + # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/model-*.safetensors") "model_hash": "2d11bf14bba8b4e87477c8199a895403", "model_name": "joyai_image_text_encoder", "model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder", @@ -1071,6 +1072,23 @@ "model_class": "diffsynth.models.fid.FIDInceptionModel", "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsFIDStateDictConverter", }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="UnifiedReward-2.0-qwen35-9b/model-*.safetensors") + "model_hash": "f9786d06eca5c0f1ece89843b2c4cc66", + "model_name": "image_metrics_unified_reward_2", + "model_class": "diffsynth.models.unified_reward_2.UnifiedReward2Qwen35ForConditionalGeneration", + "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsUnifiedRewardStateDictConverter", + "extra_kwargs": {"variant": "qwen35_9b"}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Qwen-Image-Bench/model-*.safetensors") + "model_hash": "ff4ad0463675e96738483611f6dd551b", + "model_name": "image_metrics_qwen_image_bench", + "model_class": "diffsynth.models.qwen_image_bench.QwenImageBenchQwen35ForConditionalGeneration", + "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsUnifiedRewardStateDictConverter", + "extra_kwargs": {"variant": "qwen35"}, + }, + ] hidream_o1_image_series = [ @@ -1086,4 +1104,4 @@ stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + hidream_o1_image_series + image_metrics_series -) +) \ No newline at end of file diff --git a/diffsynth/metrics/qwen_image_bench.py b/diffsynth/metrics/qwen_image_bench.py new file mode 100644 index 000000000..a7a3f6558 --- /dev/null +++ b/diffsynth/metrics/qwen_image_bench.py @@ -0,0 +1,70 @@ +import torch +from transformers import AutoProcessor +from ..core import ModelConfig +from ..core.device.npu_compatible_device import get_device_type +from ..models.qwen_image_bench import QwenImageBenchModel +from .base import Metric +from transformers.utils import logging +logging.set_verbosity_error() + + +class QwenImageBenchMetric(Metric): + def __init__(self, model: QwenImageBenchModel): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + model_config: ModelConfig = ModelConfig( + model_id="Qwen/Qwen-Image-Bench", + origin_file_pattern="model-*.safetensors", + ), + processor_config: ModelConfig = ModelConfig( + model_id="Qwen/Qwen-Image-Bench", + origin_file_pattern="", + ), + torch_dtype: torch.dtype = None, + device: torch.device = get_device_type(), + max_new_tokens: int = 4096, + resize_long_edge: int = 1024, + processor_kwargs: dict = None, + vram_limit: float = None, + ): + processor_kwargs = processor_kwargs or {} + model_pool = cls.download_and_load_models( + [model_config], + torch_dtype=torch_dtype or torch.bfloat16, + device=device, + vram_limit=vram_limit, + ) + model = model_pool.fetch_model("image_metrics_qwen_image_bench") + if model is None: + raise ValueError("Cannot find model: image_metrics_qwen_image_bench") + if hasattr(model, "model"): + model = model.model + + processor_config.download_if_necessary() + processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs) + model = QwenImageBenchModel( + model=model, + processor=processor, + max_new_tokens=max_new_tokens, + resize_long_edge=resize_long_edge, + ).eval() + return cls(model) + + @torch.no_grad() + def evaluate(self, prompt: str | list[str] | None, images, dimensions=None): + return self.model(prompt, images, dimensions=dimensions) + + @torch.no_grad() + def score(self, prompt: str | list[str] | None, images, dimensions=None): + outputs = self.evaluate(prompt, images, dimensions=dimensions) + return [self.model._primary_score(output) for output in outputs] + + def compute(self, prompt: str | list[str] | None, images, dimensions=None): + return self.score(prompt, images, dimensions=dimensions) + + def forward(self, prompt: str | list[str] | None, images, dimensions=None): + return self.score(prompt, images, dimensions=dimensions) diff --git a/diffsynth/metrics/unified_reward_2.py b/diffsynth/metrics/unified_reward_2.py new file mode 100644 index 000000000..abad17a38 --- /dev/null +++ b/diffsynth/metrics/unified_reward_2.py @@ -0,0 +1,69 @@ +import torch +from transformers import AutoProcessor + +from ..core import ModelConfig +from ..core.device.npu_compatible_device import get_device_type +from ..models.unified_reward_2 import UnifiedReward2Model +from .base import Metric +from transformers.utils import logging +logging.set_verbosity_error() + + +class UnifiedReward2Metric(Metric): + def __init__(self, model: UnifiedReward2Model): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + model_config: ModelConfig = ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-2.0-qwen35-9b/model-*.safetensors", + ), + processor_config: ModelConfig = ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-2.0-qwen35-9b/", + ), + torch_dtype: torch.dtype = None, + device: torch.device = get_device_type(), + max_new_tokens: int = 1024, + processor_kwargs: dict = None, + vram_limit: float = None, + ): + processor_kwargs = processor_kwargs or {} + model_pool = cls.download_and_load_models( + [model_config], + torch_dtype=torch_dtype or torch.bfloat16, + device=device, + vram_limit=vram_limit, + ) + model = model_pool.fetch_model("image_metrics_unified_reward_2") + if model is None: + raise ValueError("Cannot find model: image_metrics_unified_reward_2") + if hasattr(model, "model"): + model = model.model + + processor_config.download_if_necessary() + processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs) + model = UnifiedReward2Model( + model=model, + processor=processor, + max_new_tokens=max_new_tokens, + ).eval() + return cls(model) + + @torch.no_grad() + def evaluate(self, prompt: str | list[str] | None, images): + return self.model(prompt, images) + + @torch.no_grad() + def score(self, prompt: str | list[str] | None, images): + outputs = self.evaluate(prompt, images) + return [self.model._primary_score(output) for output in outputs] + + def compute(self, prompt: str | list[str] | None, images): + return self.score(prompt, images) + + def forward(self, prompt: str | list[str] | None, images): + return self.score(prompt, images) diff --git a/diffsynth/metrics/unified_reward_edit.py b/diffsynth/metrics/unified_reward_edit.py new file mode 100644 index 000000000..dc522f97f --- /dev/null +++ b/diffsynth/metrics/unified_reward_edit.py @@ -0,0 +1,95 @@ +import torch +from transformers import AutoProcessor +from ..core import ModelConfig +from ..core.device.npu_compatible_device import get_device_type +from ..models.unified_reward_edit import UnifiedRewardEditModel +from .base import Metric +from transformers.utils import logging +logging.set_verbosity_error() + + +DEFAULT_UNIFIED_REWARD_TASK = "edit_pointwise_score" + + +class UnifiedRewardEditMetric(Metric): + def __init__(self, model: UnifiedRewardEditModel): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + model_config: ModelConfig = ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/model-*.safetensors", + ), + processor_config: ModelConfig = ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/", + ), + torch_dtype: torch.dtype = None, + device: torch.device = get_device_type(), + task: str = DEFAULT_UNIFIED_REWARD_TASK, + max_new_tokens: int = 256, + processor_kwargs: dict = None, + vram_limit: float = None, + ): + processor_kwargs = processor_kwargs or {} + model_pool = cls.download_and_load_models( + [model_config], + torch_dtype=torch_dtype or torch.bfloat16, + device=device, + vram_limit=vram_limit, + ) + model = model_pool.fetch_model("image_metrics_unified_reward_edit") + if model is None: + raise ValueError("Cannot find model: image_metrics_unified_reward_edit") + if hasattr(model, "model"): + model = model.model + + processor_config.download_if_necessary() + processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs) + model = UnifiedRewardEditModel( + model=model, + processor=processor, + task=task, + max_new_tokens=max_new_tokens, + ).eval() + return cls(model) + + + @staticmethod + def _primary_score(parsed: dict, task: str): + if task == "edit_pairwise_rank": + winner = parsed.get("winner") + if isinstance(winner, (int, float)): + return int(winner) + if isinstance(winner, str): + winner = winner.lower() + if "equally" in winner or "tie" in winner or ("image 1" in winner and "image 2" in winner): + return 0 + if "image 1" in winner or "first image" in winner: + return 1 + if "image 2" in winner or "second image" in winner: + return 2 + return 0 + if task == "edit_pairwise_score": + return [parsed.get("image_1_score"), parsed.get("image_2_score")] + return parsed.get("score") + + + @torch.no_grad() + def evaluate(self, prompt: str | list[str] | None, images, task: str = DEFAULT_UNIFIED_REWARD_TASK): + outputs = self.model(prompt, images, task=task) + return [{**output, "score": self._primary_score(output, task)} for output in outputs] + + @torch.no_grad() + def score(self, prompt: str | list[str] | None, images, task: str = DEFAULT_UNIFIED_REWARD_TASK): + outputs = self.evaluate(prompt, images, task=task) + return [output["score"] for output in outputs] + + def compute(self, prompt: str | list[str] | None, images, task: str = DEFAULT_UNIFIED_REWARD_TASK): + return self.score(prompt, images, task=task) + + def forward(self, prompt: str | list[str] | None, images, task: str = DEFAULT_UNIFIED_REWARD_TASK): + return self.score(prompt, images, task=task) diff --git a/diffsynth/models/qwen_image_bench.py b/diffsynth/models/qwen_image_bench.py new file mode 100644 index 000000000..2113dfa48 --- /dev/null +++ b/diffsynth/models/qwen_image_bench.py @@ -0,0 +1,593 @@ +import json +import re +from collections import defaultdict +from typing import Union + +import torch +from PIL import Image + +ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]] + + +class QwenImageBenchQwen35ForConditionalGeneration(torch.nn.Module): + def __init__(self, variant: str = "qwen35"): + super().__init__() + from transformers import Qwen3_5Config, Qwen3_5ForConditionalGeneration + + config = Qwen3_5Config( + bos_token_id=None, + eos_token_id=248046, + hidden_size=5120, + image_token_id=248056, + model_type="qwen3_5", + pad_token_id=248044, + text_config={ + "attention_bias": False, + "attention_dropout": 0.0, + "attn_output_gate": True, + "bos_token_id": None, + "eos_token_id": 248044, + "full_attention_interval": 4, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 17408, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 48, + "linear_value_head_dim": 128, + "mamba_ssm_dtype": "float32", + "max_position_embeddings": 262144, + "mlp_only_layers": [], + "model_type": "qwen3_5_text", + "mtp_num_hidden_layers": 1, + "mtp_use_dedicated_embeddings": False, + "num_attention_heads": 24, + "num_hidden_layers": 64, + "num_key_value_heads": 4, + "pad_token_id": None, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-6, + "rope_parameters": { + "mrope_interleaved": True, + "mrope_section": [11, 11, 10], + "partial_rotary_factor": 0.25, + "rope_theta": 10000000, + "rope_type": "default", + }, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 248320, + }, + tie_word_embeddings=False, + use_cache=True, + video_token_id=248057, + vision_config={ + "deepstack_visual_indexes": [], + "depth": 27, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4304, + "model_type": "qwen3_5", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": 5120, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + vision_end_token_id=248054, + vision_start_token_id=248053, + ) + layer_types = ["linear_attention", "linear_attention", "linear_attention", "full_attention"] * 16 + config.text_config.layer_types = layer_types + self.model = Qwen3_5ForConditionalGeneration(config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, *args, **kwargs): + return self.model.generate(*args, **kwargs) + + +QUALITY_CHECKLIST = """## Realism +- Physical Logic: Does the image adhere to real-world physical laws (e.g., gravity, reflection, shadow direction, object stability)? +- Material Texture: Do the surface materials of objects (such as skin, fabric, metal, wood) exhibit realistic texture and material properties? +## Detail +- Noise: Is the image rich in detail without excessive noise or unnatural smoothing? +- Edge Clarity: Are the outlines and edges of objects sharp, well-defined, and free from blurring or aliasing? +- Naturalness: Does the image appear natural and free from the artificial "plastic" or "greasy" look commonly associated with AI-generated images? +## Resolution +- Resolution: Is the overall image resolution high-definition, free from visible pixelation or compression artifacts?""" + + +AESTHETICS_CHECKLIST = """## Composition +- Composition: Is the composition of the image balanced, visually guided, and aesthetically pleasing? +## Color Harmony +- Color Harmony: Is the overall color palette harmonious, cohesive, and appropriate for the mood of the image? +## Lighting +- Lighting & Atmosphere: Does the lighting and shadow atmosphere of the image (such as contrast between light and dark, and the overall lighting atmosphere) match the scene setting of the prompt? +## Anatomical Portraiture +- Anatomical Fidelity: Are the facial feature proportions, skeletal structure, and limb articulation anatomically correct and consistent with human biology? Does the facial skin exhibit realistic micro-level textures such as pores and fine lines? +## Emotional Expression +- Emotional Expression: Does the image's overall aesthetic tone effectively convey the intended emotion and mood described in the prompt? +## Style Control +- Style Control: Does the image accurately capture and represent the specific artistic style requested in the prompt (e.g., Van Gogh's brushwork, Cyberpunk aesthetic)?""" + +ALIGNMENT_CHECKLIST = """## Attributes +- Quantity: Does the number of objects in the image match the quantity specified in the prompt? +- Facial Expression: Does the facial expression of the person or animal accurately reflect the emotional state specified in the prompt? +- Material Properties: Do the materials of objects in the image match the material descriptions in the prompt? +- Color: Do the colors of objects in the image match the color specifications in the prompt? +- Shape: Do the shapes of objects in the image match the shape descriptions in the prompt? +- Size: Do the sizes of objects in the image match the size specifications in the prompt? +## Actions +- Contact Interaction: If the prompt involves physical contact between subjects, is the contact interaction depicted naturally and realistically? +- Non-contact Interaction: If the prompt involves non-contact relationships between subjects, is the spatial and social relationship depicted naturally and logically? +- Full-body Action: Does the overall posture and body action of the subject (person or animal) accurately perform the activity described in the prompt? +## Layout +- 2D Space: Are the relative positions of objects on the 2D plane (e.g., left/right, top/bottom, foreground/background) consistent with the prompt's spatial instructions? +- 3D Space: Does the layout, occlusion, and relative position of objects in 3D space conform to the prompt requirements or spatial logic? +## Relations +- Composition Relationship: Does the image successfully integrate multiple elements into a visually coherent and logically consistent whole? +- Difference/Similarity: Are the specified differences or similarities in shape, color, or material between objects accurately represented? +- Containment: Are the containment or enclosure relationships between objects correctly depicted? +## Scene +- Real-world Scene: Does the scene type and environmental setting (e.g., office, forest, street) match the location described in the prompt? +- Virtual Scene: Are the elements within a fictional or fantasy scene internally consistent and logically coherent?""" + +REAL_WORLD_FIDELITY_CHECKLIST = """## Fairness +- Social Bias: Does the image avoid reinforcing social biases by automatically associating specific genders with particular professions or settings? +- Cultural Fairness: Is the image free from stereotypical portrayals based on region, race, or cultural background? +## Safety & Compliance +- Safety & Compliance: Is the image safe and compliant, effectively avoiding prohibited content such as pornography, violence, or hate symbols? +## World Knowledge +- Animals: Are real-world animals depicted with anatomically accurate features and realistic biological details? +- Objects: Are the typical appearance, structure, brand logo, or iconic characteristics of real-world items accurately reproduced? +- Information Visualization: Does the image accurately and clearly translate abstract or scientific concepts from the prompt into an effective and understandable visual form? +- Temporal Characteristics: Does the image accurately reflect the iconic elements of a specific historical period (e.g., technology, clothing, architecture, lifestyle of that era)? +- Cultural Elements: Are the cultural elements (such as symbols, traditional clothing, rituals, and customs) accurately depicted and consistent with real-world cultural practices?""" + + +CREATIVE_GENERATION_CHECKLIST = """## Imagination +- Imagination: Does the image demonstrate creative originality and imaginative thinking when combining novel or surreal elements? +## Feature Matching +- Feature Matching: Are the multi-element fusion regions in the image visually seamless, without abrupt breaks, harsh edges, or logical contradictions? +## Logical Resolution +- Logical Resolution: Does the image accurately depict causal relationships between events (e.g., breaking glass → shards flying, rain → wet surfaces)? +## Text Rendering +- Text Accuracy: If the image contains text, is the text clear, legible, and free from garbled characters, misspellings, or typographical errors? +- Text Layout: Is the text layout (e.g., centering, alignment, line spacing, margins) in the image visually appealing and professionally structured? +- Font: Does the font style used in the image match the font type specified in the prompt (e.g., SimSun, Heiti, handwritten, serif)? +- Cross-lingual Generation: Does the image correctly follow the translation instructions in the prompt, producing accurate text in the target language? +## Design Applications +- Graphic Design: Does the graphic design (e.g., advertisement, poster) exhibit a clear information hierarchy, effective visual guidance, and professional layout? +- Product Design: Does the product design in the image demonstrate reasonable industrial design logic (e.g., ergonomic grip, logical interface placement, structural integrity)? +- Spatial Design: Does the interior or architectural space conform to the principles of perspective, proportion, and building design standards? +- Fashion Styling: Does the clothing cut and silhouette match the style described in the prompt (e.g., Hanfu, cyberpunk, haute couture)? Does the makeup style (e.g., smoky eyes, nude makeup, theatrical look) suit the occasion and character setting? +- Game Design: Do the game props and UI elements have practical in-game usability (e.g., icon recognizability, interactive affordances, clear feedback cues)? +- Art Design: Does the image successfully demonstrate the specific artistic design style required by the prompt (e.g., unique brushstrokes, distinctive color scheme, coherent artistic language)? +## Visual Storytelling +- Cinematic Style: Does the image reproduce the signature visual language of the specific director referenced in the prompt (e.g., Wes Anderson's symmetrical composition, Wong Kar-wai's warm color palette)? +- Camera / Lens Style: Does the image reflect the characteristic imaging effects of the specific photographic equipment or lens referenced in the prompt (e.g., film grain, bokeh, digital sharpening)? +- Storyboard Creation: Does the image's scene composition follow the panel layout requirements outlined in the prompt (e.g., three-panel, four-panel, split-screen)? +- Shot Sizes: Does the image meet the framing and shot size requirements specified in the prompt (e.g., close-up, medium shot, wide shot)? +- Composition: Does the image follow the specific composition rules required by the prompt (e.g., rule of thirds, golden ratio, leading lines)? +- Angles: Does the camera angle comply with the prompt's specification (e.g., bird's-eye view, low angle, Dutch angle)? +- Comic Creation: Does the image conform to the comic style required by the prompt (e.g., American comics, Japanese manga, European BD)?""" + +DIM_TO_CHECKLIST = { + "Quality": QUALITY_CHECKLIST, + "Aesthetics": AESTHETICS_CHECKLIST, + "Alignment": ALIGNMENT_CHECKLIST, + "Real-world Fidelity": REAL_WORLD_FIDELITY_CHECKLIST, + "Creative Generation": CREATIVE_GENERATION_CHECKLIST, +} +LEVEL1_DIMS = list(DIM_TO_CHECKLIST) + +SYSTEM_PROMPT = ( + "You are an expert evaluator for text-to-image (T2I) generation quality. " + "Given an image and the text prompt used to generate it, you evaluate the image " + "on specific quality criteria using a structured checklist." +) + +USER_PROMPT_HEADER = """# Text Prompt Used to Generate the Image +{prompt} + +# Generated Image +""" + +USER_PROMPT_BODY = """ + +# Evaluation Dimension +{level1_dim} + +# Scoring Rules +- **0 (Fail)**: Clear defect present. Would noticeably reduce image quality. +- **1 (Pass)**: No defect. Meets baseline expectations. +- **2 (Excel)**: Exceptionally executed. Only when concrete excellence is observable. +- **N/A**: This criterion does not apply to this image/prompt. + +# Evaluation Checklist +{format_checklist} + +# Output Format +Respond with a valid JSON object only (no markdown code blocks): +{{ + "{{level2_dim}}": {{ + "{{level3_dim}}": {{"score": 0|1|2}}, + "{{level3_dim}}": {{"score": "N/A"}} + }} +}}""" + +SCORE_MAP = {0: 0.0, 1: 60.0, 2: 100.0} + +CHECKLIST_L3_TO_L2 = { + "Quality": { + "Physical Logic": "Realism", "Material Texture": "Realism", + "Noise": "Detail", "Edge Clarity": "Detail", "Naturalness": "Detail", + "Resolution": "Resolution", + }, + "Aesthetics": { + "Composition": "Composition", "Color Harmony": "Color Harmony", + "Lighting & Atmosphere": "Lighting", + "Anatomical Fidelity": "Anatomical Portraiture", + "Emotional Expression": "Emotional Expression", + "Style Control": "Style Control", + }, + "Alignment": { + "Quantity": "Attributes", "Facial Expression": "Attributes", + "Material Properties": "Attributes", "Color": "Attributes", + "Shape": "Attributes", "Size": "Attributes", + "Contact Interaction": "Actions", "Non-contact Interaction": "Actions", + "Full-body Action": "Actions", + "2D Space": "Layout", "3D Space": "Layout", + "Composition Relationship": "Relations", "Difference/Similarity": "Relations", + "Containment": "Relations", + "Real-world Scene": "Scene", "Virtual Scene": "Scene", + }, + "Real-world Fidelity": { + "Social Bias": "Fairness", "Cultural Fairness": "Fairness", + "Safety & Compliance": "Safety & Compliance", + "Animals": "World Knowledge", "Objects": "World Knowledge", + "Information Visualization": "World Knowledge", + "Temporal Characteristics": "World Knowledge", + "Cultural Elements": "World Knowledge", + }, + "Creative Generation": { + "Imagination": "Imagination", + "Feature Matching": "Feature Matching", + "Logical Resolution": "Logical Resolution", + "Text Accuracy": "Text Rendering", "Text Layout": "Text Rendering", + "Font": "Text Rendering", "Cross-lingual Generation": "Text Rendering", + "Graphic Design": "Design Applications", "Product Design": "Design Applications", + "Spatial Design": "Design Applications", "Fashion Styling": "Design Applications", + "Game Design": "Design Applications", "Art Design": "Design Applications", + "Cinematic Style": "Visual Storytelling", "Camera / Lens Style": "Visual Storytelling", + "Storyboard Creation": "Visual Storytelling", "Shot Sizes": "Visual Storytelling", + "Composition": "Visual Storytelling", "Angles": "Visual Storytelling", + "Comic Creation": "Visual Storytelling", + }, +} + +L3_RENAME = { + "Creative Generation": {"Feature Mapping": "Feature Matching"}, +} + + +def _as_list(value): + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + +def _mean_non_none(values): + valid = [value for value in values if value is not None] + return sum(valid) / len(valid) if valid else None + + +def parse_dims_by_level1(dims_en_str): + """ + Parse dims_en string, group by level-1 dimension. + Input: "Quality / Realism / Physical Logic; Aesthetics / Color Harmony / Color Harmony" + Output: {"Quality": [("Realism", "Physical Logic")], "Aesthetics": [("Color Harmony", "Color Harmony")]} + """ + result = defaultdict(list) + parts = [p.strip() for p in dims_en_str.split(';')] + for p in parts: + levels = [l.strip() for l in p.split('/')] + if len(levels) >= 3: + result[levels[0]].append((levels[1], levels[2])) + elif len(levels) == 2: + result[levels[0]].append((levels[1], levels[1])) + return dict(result) + + +def extract_json_from_response(response_text: str): + text = response_text or "" + think_end = text.rfind("") + if think_end != -1: + text = text[think_end + len("") :] + text = text.strip() + try: + return json.loads(text) + except json.JSONDecodeError: + pass + json_match = re.search(r"\{[\s\S]*\}", text) + if json_match: + try: + return json.loads(json_match.group()) + except json.JSONDecodeError: + pass + return None + + +def map_score(raw_score): + """Map raw score to final score: 0→0, 1→60, 2→100, 'N/A'→None.""" + if isinstance(raw_score, str) and raw_score.upper() == "N/A": + return None + try: + return SCORE_MAP[int(raw_score)] + except (KeyError, ValueError, TypeError): + return None + + +def fix_score_json(score_json, l1_dim): + """Fix flat structure, L3 misplacement, and L3 typos based on checklists.py hierarchy.""" + if not score_json: + return score_json + + mapping = CHECKLIST_L3_TO_L2.get(l1_dim, {}) + rename = L3_RENAME.get(l1_dim, {}) + + first_val = next(iter(score_json.values()), None) + if isinstance(first_val, dict) and "score" in first_val: + result = {} + for l3_name, score_obj in score_json.items(): + l3_name = rename.get(l3_name, l3_name) + l2_name = mapping.get(l3_name, l3_name) + result.setdefault(l2_name, {})[l3_name] = score_obj + return result + + result = {} + for l2_key, l3_dict in score_json.items(): + if not isinstance(l3_dict, dict): + continue + for l3_name, score_obj in l3_dict.items(): + l3_name = rename.get(l3_name, l3_name) + correct_l2 = mapping.get(l3_name, l2_key) + result.setdefault(correct_l2, {})[l3_name] = score_obj + return result + + + +def compute_dimension_score(score_json): + """ + Compute aggregated score for a single level-1 dimension. + + Input: {"Realism": {"Physical Logic": {"score": 0}, "Material Properties": {"score": 1}}, ...} + Output: { + "level1_score": float | None, + "level2_scores": {"Realism": float | None, ...}, + "level3_scores": {"Realism": {"Physical Logic": 0.0, ...}, ...} + } + """ + level2_scores = {} + level3_scores = {} + + for level2_name, level3_dict in score_json.items(): + level3_scores[level2_name] = {} + level3_mapped = [] + + for level3_name, score_obj in level3_dict.items(): + raw = score_obj.get("score") if isinstance(score_obj, dict) else score_obj + mapped = map_score(raw) + level3_scores[level2_name][level3_name] = mapped + if mapped is not None: + level3_mapped.append(mapped) + + level2_scores[level2_name] = _mean_non_none(level3_mapped) + + level1_score = _mean_non_none(list(level2_scores.values())) + + return { + "level1_score": level1_score, + "level2_scores": level2_scores, + "level3_scores": level3_scores, + } + + +def aggregate_total_score(dim_results): + """ + Aggregate across all level-1 dimensions to total score. + + Input: {"Quality": {"level1_score": 60.0, ...}, "Aesthetics": {"level1_score": 80.0, ...}, ...} + Output: float | None + """ + level1_scores = [ + r["level1_score"] for r in dim_results.values() + if r is not None and r.get("level1_score") is not None + ] + return _mean_non_none(level1_scores) + + +class QwenImageBenchModel(torch.nn.Module): + def __init__(self, model: torch.nn.Module, processor, max_new_tokens: int = 4096, resize_long_edge: int = 1024): + super().__init__() + self.model = model + self.processor = processor + self.max_new_tokens = max_new_tokens + self.resize_long_edge = resize_long_edge + + @property + def device(self): + return next(self.parameters(), torch.tensor([])).device + + @property + def dtype(self): + return next(self.parameters(), torch.tensor(0.0)).dtype + + def _prepare_image(self, image: Image.Image): + image = image.convert("RGB") + if self.resize_long_edge and max(image.size) > self.resize_long_edge: + image = image.resize((self.resize_long_edge, self.resize_long_edge), Image.LANCZOS) + return image + + @staticmethod + def _validate_dimension(level1_dim: str): + if level1_dim not in DIM_TO_CHECKLIST: + supported = ", ".join(LEVEL1_DIMS) + raise ValueError(f"Unsupported Qwen-Image-Bench dimension: {level1_dim}. Supported dimensions: {supported}.") + + @classmethod + def _dims_to_level1(cls, dimensions=None): + if dimensions is None: + return LEVEL1_DIMS + if isinstance(dimensions, dict): + dimensions = list(dimensions.keys()) + elif isinstance(dimensions, str): + if "/" in dimensions or ";" in dimensions: + dimensions = list(parse_dims_by_level1(dimensions).keys()) + elif "," in dimensions: + dimensions = [dim.strip() for dim in dimensions.split(",") if dim.strip()] + else: + dimensions = [dimensions] + else: + dimensions = list(dimensions) + if not dimensions: + dimensions = LEVEL1_DIMS + for level1_dim in dimensions: + cls._validate_dimension(level1_dim) + return dimensions + + @classmethod + def _normalize_dimensions(cls, dimensions, batch_size: int): + if isinstance(dimensions, (list, tuple)) and dimensions and not all(isinstance(dim, str) for dim in dimensions): + if len(dimensions) != batch_size: + raise ValueError(f"Expected {batch_size} dimension sets, got {len(dimensions)}.") + return [cls._dims_to_level1(dim_set) for dim_set in dimensions] + dims = cls._dims_to_level1(dimensions) + return [dims for _ in range(batch_size)] + + def _build_messages(self, prompt: str, image: Image.Image, level1_dim: str): + self._validate_dimension(level1_dim) + checklist = DIM_TO_CHECKLIST[level1_dim] + user_content = [ + {"type": "text", "text": USER_PROMPT_HEADER.format(prompt=prompt or "")}, + {"type": "image", "image": image}, + {"type": "text", "text": USER_PROMPT_BODY.format(level1_dim=level1_dim, format_checklist=checklist)}, + ] + return [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + def _apply_chat_template(self, messages): + try: + return self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + except TypeError: + return self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + def _processor_inputs(self, text, image): + inputs = self.processor(text=[text], images=[image], padding=True, return_tensors="pt") + input_ids = inputs["input_ids"] + inputs = inputs.to(self.device) + if self.dtype != torch.float32: + inputs = { + name: value.to(dtype=self.dtype) if torch.is_tensor(value) and torch.is_floating_point(value) else value + for name, value in inputs.items() + } + return inputs, input_ids + + def _decode_dimension(self, prompt: str, image: Image.Image, level1_dim: str): + messages = self._build_messages(prompt, image, level1_dim) + text = self._apply_chat_template(messages) + inputs, input_ids = self._processor_inputs(text, image) + with torch.no_grad(): + generated_ids = self.model.generate( + **inputs, + do_sample=False, + repetition_penalty=1.05, + max_new_tokens=self.max_new_tokens, + pad_token_id=getattr(self.processor.tokenizer, "pad_token_id", None), + eos_token_id=getattr(self.processor.tokenizer, "eos_token_id", None), + ) + generated_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)] + return self.processor.batch_decode(generated_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + @staticmethod + def _parse_dimension_output(raw_text: str, level1_dim: str): + score_json = extract_json_from_response(raw_text) + if score_json is None: + return None + score_json = fix_score_json(score_json, level1_dim) + return compute_dimension_score(score_json), score_json + + def _evaluate_sample(self, prompt: str, image: Image.Image, dimensions): + image = self._prepare_image(image) + raw_outputs = {} + raw_scores = {} + dimension_scores = {} + parse_failures = [] + for level1_dim in dimensions: + raw_text = self._decode_dimension(prompt, image, level1_dim) + raw_outputs[level1_dim] = raw_text + parsed = self._parse_dimension_output(raw_text, level1_dim) + if parsed is None: + raw_scores[level1_dim] = None + dimension_scores[level1_dim] = None + parse_failures.append(level1_dim) + continue + dimension_score, fixed_score_json = parsed + raw_scores[level1_dim] = fixed_score_json + dimension_scores[level1_dim] = dimension_score + total_score = aggregate_total_score(dimension_scores) + return { + "total_score": total_score, + "level1_scores": { + dim: data.get("level1_score") if data is not None else None for dim, data in dimension_scores.items() + }, + "level2_scores": { + dim: data.get("level2_scores", {}) if data is not None else {} for dim, data in dimension_scores.items() + }, + "level3_scores": { + dim: data.get("level3_scores", {}) if data is not None else {} for dim, data in dimension_scores.items() + }, + "raw_scores": raw_scores, + "raw_outputs": raw_outputs, + "parse_failures": parse_failures, + } + + def _normalize_inputs(self, prompts, images): + prompts = _as_list(prompts if prompts is not None else "") + images = _as_list(images) + if len(prompts) == 1 and len(images) > 1: + prompts = prompts * len(images) + if len(images) == 1 and len(prompts) > 1: + images = images * len(prompts) + if len(prompts) != len(images): + raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.") + return prompts, images + + @staticmethod + def _primary_score(parsed: dict): + score = parsed.get("total_score") + return float(score) if score is not None else 0.0 + + @torch.no_grad() + def forward(self, prompt: str | list[str] | None, images: ImageInput, dimensions=None): + prompts, images = self._normalize_inputs(prompt, images) + dimension_sets = self._normalize_dimensions(dimensions, len(prompts)) + outputs = [] + for single_prompt, single_image, single_dimensions in zip(prompts, images, dimension_sets): + outputs.append(self._evaluate_sample(single_prompt, single_image, single_dimensions)) + return outputs diff --git a/diffsynth/models/unified_reward_2.py b/diffsynth/models/unified_reward_2.py new file mode 100644 index 000000000..8d886df9e --- /dev/null +++ b/diffsynth/models/unified_reward_2.py @@ -0,0 +1,230 @@ +import re +from typing import Union +import torch +from PIL import Image + +ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]] + + +class UnifiedReward2Qwen35ForConditionalGeneration(torch.nn.Module): + def __init__(self, variant: str = "qwen35_9b"): + super().__init__() + from transformers import Qwen3_5Config, Qwen3_5ForConditionalGeneration + + if variant != "qwen35_9b": + raise ValueError(f"Unsupported UnifiedReward-2 variant: {variant}") + hidden_size = 4096 + intermediate_size = 12288 + num_hidden_layers = 32 + num_attention_heads = 16 + linear_num_value_heads = 32 + + config = Qwen3_5Config( + bos_token_id=None, + eos_token_id=248046, + hidden_size=hidden_size, + image_token_id=248056, + model_type="qwen3_5", + pad_token_id=248044, + text_config={ + "attention_bias": False, + "attention_dropout": 0.0, + "attn_output_gate": True, + "bos_token_id": None, + "eos_token_id": 248044, + "full_attention_interval": 4, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": hidden_size, + "initializer_range": 0.02, + "intermediate_size": intermediate_size, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": linear_num_value_heads, + "linear_value_head_dim": 128, + "mamba_ssm_dtype": "float32", + "max_position_embeddings": 262144, + "mlp_only_layers": [], + "model_type": "qwen3_5_text", + "mtp_num_hidden_layers": 1, + "mtp_use_dedicated_embeddings": False, + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_hidden_layers, + "num_key_value_heads": 4, + "pad_token_id": None, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-6, + "rope_parameters": { + "mrope_interleaved": True, + "mrope_section": [11, 11, 10], + "partial_rotary_factor": 0.25, + "rope_theta": 10000000, + "rope_type": "default", + }, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 248320, + }, + tie_word_embeddings=False, + use_cache=True, + video_token_id=248057, + vision_config={ + "deepstack_visual_indexes": [], + "depth": 27, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4304, + "model_type": "qwen3_5", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": hidden_size, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + vision_end_token_id=248054, + vision_start_token_id=248053, + ) + config.text_config.layer_types = ["linear_attention", "linear_attention", "linear_attention", "full_attention"] * ( + num_hidden_layers // 4 + ) + self.model = Qwen3_5ForConditionalGeneration(config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, *args, **kwargs): + return self.model.generate(*args, **kwargs) + + +def _as_list(value): + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + +def _coerce_float(value): + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + match = re.search(r"[-+]?\d*\.?\d+", value) + if match: + return float(match.group()) + return None + + +def _mean(values): + values = [value for value in values if value is not None] + return sum(values) / len(values) if values else None + + +class UnifiedReward2Model(torch.nn.Module): + def __init__(self, model: torch.nn.Module, processor, max_new_tokens: int = 1024): + super().__init__() + self.model = model + self.processor = processor + self.max_new_tokens = max_new_tokens + + @property + def device(self): + return next(self.parameters(), torch.tensor([])).device + + @property + def dtype(self): + return next(self.parameters(), torch.tensor(0.0)).dtype + + def _build_prompt(self, prompt: str | None): + prompt = prompt or "" + return ( + "You are presented with a generated image and its associated text caption. " + "Your task is to analyze the image across multiple dimensions in relation to the caption. Specifically:\n" + "Provide overall assessments for the image along the following axes (each rated from 1 to 5):\n" + "- Alignment Score: How well the image matches the caption in terms of content.\n" + "- Coherence Score: How logically consistent the image is (absence of visual glitches, object distortions, etc.).\n" + "- Style Score: How aesthetically appealing the image looks, regardless of caption accuracy.\n\n" + "Output your evaluation using the format below:\n\n" + "Alignment Score (1-5): X\n" + "Coherence Score (1-5): Y\n" + "Style Score (1-5): Z\n\n" + "Do not include explanations, analysis, bullet points, or any text outside the requested output format.\n\n" + "Your task is provided as follows:\n" + f"Text Caption: [{prompt}]" + ) + + def _build_messages(self, prompt: str | None, image: Image.Image): + content = [ + {"type": "image", "image": image.convert("RGB")}, + {"type": "text", "text": self._build_prompt(prompt)}, + ] + return [{"role": "user", "content": content}] + + def _processor_inputs(self, text, images): + inputs = self.processor(text=[text], images=images, padding=True, return_tensors="pt") + input_ids = inputs["input_ids"] + inputs = inputs.to(self.device) + if self.dtype != torch.float32: + inputs = { + name: value.to(dtype=self.dtype) if torch.is_tensor(value) and torch.is_floating_point(value) else value + for name, value in inputs.items() + } + return inputs, input_ids + + def _decode_sample(self, prompt: str | None, image: Image.Image): + image = image.convert("RGB") + messages = self._build_messages(prompt, image) + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs, input_ids = self._processor_inputs(text, [image]) + with torch.no_grad(): + generated_ids = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=self.max_new_tokens, + pad_token_id=getattr(self.processor.tokenizer, "pad_token_id", None), + eos_token_id=getattr(self.processor.tokenizer, "eos_token_id", None), + ) + generated_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)] + return self.processor.batch_decode(generated_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + @staticmethod + def _extract_labeled_number(text: str, label: str): + text = text.replace("*", "") + pattern = rf"{re.escape(label)}\s*(?:\([^)]+\))?\s*[::]\s*([-+]?\d*\.?\d+)" + match = re.search(pattern, text, flags=re.I) + return _coerce_float(match.group(1)) if match else None + + def _parse_output(self, text: str): + alignment = self._extract_labeled_number(text, "Alignment Score") + coherence = self._extract_labeled_number(text, "Coherence Score") + style = self._extract_labeled_number(text, "Style Score") + raw_score = [score for score in (alignment, coherence, style) if score is not None] + return { + "alignment": alignment, + "coherence": coherence, + "style": style, + "score": _mean(raw_score), + } + + @staticmethod + def _primary_score(parsed: dict): + return parsed.get("score") + + @torch.no_grad() + def forward(self, prompt: str | list[str] | None, images: ImageInput): + prompts = _as_list(prompt if prompt is not None else "") + images = _as_list(images) + if len(prompts) == 1 and len(images) > 1: + prompts = prompts * len(images) + if len(images) == 1 and len(prompts) > 1: + images = images * len(prompts) + if len(prompts) != len(images): + raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.") + outputs = [] + for single_prompt, single_image in zip(prompts, images): + raw_text = self._decode_sample(single_prompt, single_image) + outputs.append(self._parse_output(raw_text)) + return outputs diff --git a/diffsynth/models/unified_reward_edit.py b/diffsynth/models/unified_reward_edit.py new file mode 100644 index 000000000..71964b31a --- /dev/null +++ b/diffsynth/models/unified_reward_edit.py @@ -0,0 +1,377 @@ +import json +import re +from typing import Union + +import torch +from PIL import Image + +ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]] + + +class UnifiedRewardQwen3VLForConditionalGeneration(torch.nn.Module): + def __init__(self, variant: str = "qwen3vl"): + super().__init__() + from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration + + config = Qwen3VLConfig( + text_config={ + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-6, + "rope_scaling": { + "mrope_interleaved": True, + "mrope_section": [24, 20, 20], + "rope_type": "default", + }, + "rope_theta": 5000000, + "use_cache": True, + "vocab_size": 151936, + }, + vision_config={ + "deepstack_visual_indexes": [8, 16, 24], + "depth": 27, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4304, + "model_type": "qwen3_vl", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": 4096, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + ) + self.model = Qwen3VLForConditionalGeneration(config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, *args, **kwargs): + return self.model.generate(*args, **kwargs) + + +def _as_list(value): + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + +def _coerce_float(value): + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + match = re.search(r"[-+]?\d*\.?\d+", value) + if match: + return float(match.group()) + return None + + +def _mean(values): + values = [value for value in values if value is not None] + return sum(values) / len(values) if values else None + + +class UnifiedRewardEditModel(torch.nn.Module): + SUPPORTED_TASKS = {"edit_pointwise_score", "edit_pairwise_rank", "edit_pairwise_score"} + DEFAULT_TASK = "edit_pointwise_score" + + def __init__(self, model: torch.nn.Module, processor, task: str = DEFAULT_TASK, max_new_tokens: int = 256): + super().__init__() + self._validate_task(task) + self.model = model + self.processor = processor + self.task = task + self.max_new_tokens = max_new_tokens + + @classmethod + def _validate_task(cls, task: str): + if task not in cls.SUPPORTED_TASKS: + supported = ", ".join(sorted(cls.SUPPORTED_TASKS)) + raise ValueError(f"Unsupported UnifiedReward task: {task}. Supported tasks: {supported}.") + + @property + def device(self): + return next(self.parameters(), torch.tensor([])).device + + @property + def dtype(self): + return next(self.parameters(), torch.tensor(0.0)).dtype + + def _build_edit_pointwise_score_prompt(self, instruction: str): + return ( + "You are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules.\n" + "All the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials.\n\n" + "IMPORTANT: You will have to give your output in this way (Keep your reasoning concise and short.):\n" + "{\n\n\"reasoning\" : \"...\",\n\"score\" : [...],\n}\n\n" + "RULES:\n\n" + "Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first.\n" + "The objective is to evaluate how successfully the editing instruction has been executed in the second image.\n\n" + "Note that sometimes the two images might look identical due to the failure of image edit.\n\n\n" + "From scale 0 to 25: \n" + "A score from 0 to 25 will be given based on the success of the editing. " + "(0 indicates that the scene in the edited image does not follow the editing instruction at all. " + "25 indicates that the scene in the edited image follow the editing instruction text perfectly.)\n" + "A second score from 0 to 25 will rate the degree of overediting in the second image. " + "(0 indicates that the scene in the edited image is completely different from the original. " + "25 indicates that the edited image can be recognized as a minimal edited yet effective version of original.)\n" + "Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting.\n\n" + f"Editing instruction:{instruction}\n" + ) + + def _build_edit_pairwise_rank_prompt(self, instruction: str): + return ( + "You are tasked with comparing two edited images and determining which one is better based on the given criteria.\n\n" + "The evaluation will consider how well each model executed the instructions and the overall quality of the edit, including its visual appeal.\n\n" + "**Inputs Provided:**\n" + "- Source Image (before editing)\n" + "- Edited Image 1 (after applying the instruction)\n" + "- Edited Image 2 (after applying the instruction)\n" + "- Text Instruction\n\n" + "### Evaluation Criteria for Each Image:\n\n" + "1. **Instruction Fidelity** \n" + "Assess how accurately the edits align with the given instruction. The following aspects should be considered:\n" + "- **Semantic Accuracy:** Does the edited image reflect the correct objects and changes as described in the instruction? For example, if instructed to replace \"apples with oranges,\" ensure that oranges appear instead of other fruits.\n" + "- **Completeness of Changes:** Ensure all parts of the instruction are fully addressed. For multi-step instructions, verify that every change is made as specified.\n" + "- **Exclusivity of Changes:** Confirm that only the specified changes were made. Other elements of the image should remain consistent with the original.\n\n" + "2. **Visual Integrity & Realism** \n" + "Evaluate the visual quality of the edited image, taking into account technical accuracy and aesthetic appeal:\n" + "- **Realism & Physical Consistency:** Does the edit respect the laws of physics and scene consistency, including lighting, shadows, and perspective?\n" + "- **Artifact-Free Quality:** Look for any technical flaws such as blurring, pixel misalignment, unnatural textures, or visible seams. The image should be clean and free from distractions.\n" + "- **Aesthetic Harmony:** The image should maintain a pleasing visual balance, with careful attention to composition, color harmony, and overall appeal. The changes should enhance the image rather than detract from it.\n\n" + "### Final Output:\n" + "Based on the above evaluation, determine which edited image is better.\n\n" + f"Text instruction - {instruction}\n" + ) + + def _build_edit_pairwise_score_prompt(self, instruction: str): + return ( + "You are tasked with assigning scores to two edited images, comparing each with the original source image. \n\n" + "The score should reflect both how well the model executed the instructions and the overall quality of the edit, including its visual appeal for both images.\n\n" + "**Inputs Provided:**\n" + "- Source Image (before editing)\n" + "- Edited Image 1 (after applying the instruction)\n" + "- Edited Image 2 (after applying the instruction)\n" + "- Text Instruction\n\n" + "### Evaluation Criteria for Each Image:\n\n" + "1. **Instruction Fidelity** \n" + "Assess how accurately the edits align with the given instruction. The following aspects should be considered:\n" + "- **Semantic Accuracy:** Does the edited image reflect the correct objects and changes as described in the instruction? For example, if instructed to replace \"apples with oranges,\" ensure that oranges appear instead of other fruits.\n" + "- **Completeness of Changes:** Ensure all parts of the instruction are fully addressed. For multi-step instructions, verify that every change is made as specified.\n" + "- **Exclusivity of Changes:** Confirm that only the specified changes were made. Other elements of the image should remain consistent with the original.\n\n" + "2. **Visual Integrity & Realism** \n" + "Evaluate the visual quality of the edited image, taking into account technical accuracy and aesthetic appeal:\n" + "- **Realism & Physical Consistency:** Does the edit respect the laws of physics and scene consistency, including lighting, shadows, and perspective?\n" + "- **Artifact-Free Quality:** Look for any technical flaws such as blurring, pixel misalignment, unnatural textures, or visible seams. The image should be clean and free from distractions.\n" + "- **Aesthetic Harmony:** The image should maintain a pleasing visual balance, with careful attention to composition, color harmony, and overall appeal. The changes should enhance the image rather than detract from it.\n\n" + "### Scoring Guidelines:\n" + "- The score can range from **positive to negative** based on how well the edit follows the instruction and maintains visual quality.\n" + "- A **higher score** indicates a strong adherence to the instruction, clean edits, and a high-quality final result.\n" + "- A **negative score** reflects significant issues, such as errors in the edits, missing parts, over-editing, or visual artifacts that compromise the result.\n\n" + "Please provide the scores for each image based on the evaluation of the above aspects.\n\n" + f"Text instruction - {instruction}\n" + ) + + def _build_prompt(self, prompt: str | None, task: str): + self._validate_task(task) + instruction = prompt or "" + if task == "edit_pointwise_score": + return self._build_edit_pointwise_score_prompt(instruction) + if task == "edit_pairwise_rank": + return self._build_edit_pairwise_rank_prompt(instruction) + return self._build_edit_pairwise_score_prompt(instruction) + + def _build_messages(self, prompt: str | None, images, task: str): + images = [image.convert("RGB") for image in _as_list(images)] + expected_images = 2 if task == "edit_pointwise_score" else 3 + if len(images) != expected_images: + raise ValueError( + f"UnifiedReward {task} expects exactly {expected_images} images. " + "Use [source_image, edited_image] for edit_pointwise_score and " + "[source_image, edited_image_1, edited_image_2] for pairwise tasks." + ) + content = [{"type": "image", "image": image} for image in images] + content.append({"type": "text", "text": self._build_prompt(prompt, task)}) + return [{"role": "user", "content": content}] + + def _processor_inputs(self, text, images): + inputs = self.processor( + text=[text], + images=images, + padding=True, + return_tensors="pt", + ) + + input_ids = inputs["input_ids"] + inputs = inputs.to(self.device) + + if self.dtype != torch.float32: + inputs = { + name: ( + value.to(dtype=self.dtype) + if torch.is_tensor(value) and torch.is_floating_point(value) + else value + ) + for name, value in inputs.items() + } + return inputs, input_ids + + def _decode_sample(self, prompt: str | None, images, task: str): + messages = self._build_messages(prompt, images, task) + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs, input_ids = self._processor_inputs(text, [image.convert("RGB") for image in _as_list(images)]) + + with torch.no_grad(): + generated_ids = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=self.max_new_tokens, + pad_token_id=getattr(self.processor.tokenizer, "pad_token_id", None), + eos_token_id=getattr(self.processor.tokenizer, "eos_token_id", None), + ) + generated_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids) + ] + return self.processor.batch_decode( + generated_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + @staticmethod + def _extract_first_json(text: str): + match = re.search(r"\{.*\}", text, flags=re.S) + if not match: + return None + payload = match.group(0) + for loader in (json.loads, lambda s: json.loads(s.replace("'", '"'))): + try: + return loader(payload) + except Exception: + pass + return None + + @staticmethod + def _extract_score_list(text: str): + match = re.search(r'"?score"?\s*:\s*\[([^\]]+)\]', text, flags=re.I) + if not match: + return [] + scores = [_coerce_float(value) for value in re.findall(r"[-+]?\d*\.?\d+", match.group(1))] + return [score for score in scores if score is not None] + + @staticmethod + def _extract_pair_scores(text: str): + patterns = [ + r"Image\s*1[^-+\d]*([-+]?\d*\.?\d+).*?Image\s*2[^-+\d]*([-+]?\d*\.?\d+)", + r"Edited\s*Image\s*1[^-+\d]*([-+]?\d*\.?\d+).*?Edited\s*Image\s*2[^-+\d]*([-+]?\d*\.?\d+)", + r"score(?:s)?[^-+\d]*([-+]?\d*\.?\d+)[^\n\r-+\d]+([-+]?\d*\.?\d+)", + ] + for pattern in patterns: + match = re.search(pattern, text, flags=re.I | re.S) + if match: + return [_coerce_float(match.group(1)), _coerce_float(match.group(2))] + return [] + + @staticmethod + def _parse_rank(text: str): + if re.search(r"both images? (are )?(equally good|equal|tie)", text, flags=re.I): + return "Edited image 1 and 2 are equally good" + if re.search(r"((edited )?image\s*1|first image).{0,24}\b(better|best|wins?|preferred|superior)\b", text, flags=re.I | re.S): + return "Edited image 1" + if re.search(r"((edited )?image\s*2|second image).{0,24}\b(better|best|wins?|preferred|superior)\b", text, flags=re.I | re.S): + return "Edited image 2" + return None + + def _parse_pointwise_score(self, text: str): + parsed = self._extract_first_json(text) + scores = [] + reasoning = None + if isinstance(parsed, dict): + if isinstance(parsed.get("score"), list): + scores = [_coerce_float(value) for value in parsed["score"]] + scores = [value for value in scores if value is not None] + reasoning = parsed.get("reasoning") + if not scores: + scores = self._extract_score_list(text) + editing_success = scores[0] if len(scores) > 0 else None + overediting = scores[1] if len(scores) > 1 else None + return { + "editing_success": editing_success, + "overediting": overediting, + "score": _mean([editing_success, overediting]), + "reasoning": reasoning, + } + + def _parse_pairwise_score(self, text: str): + scores = self._extract_score_list(text) + if len(scores) < 2: + scores = self._extract_pair_scores(text) + scores = [score for score in scores if score is not None] + return { + "image_1_score": scores[0] if len(scores) > 0 else None, + "image_2_score": scores[1] if len(scores) > 1 else None, + } + + def _parse_pairwise_rank(self, text: str): + return { + "winner": self._parse_rank(text), + } + + def _parse_output(self, text: str, task: str): + self._validate_task(task) + if task == "edit_pointwise_score": + return self._parse_pointwise_score(text) + if task == "edit_pairwise_score": + return self._parse_pairwise_score(text) + return self._parse_pairwise_rank(text) + + @staticmethod + def _primary_score(parsed: dict, task: str): + if task == "edit_pairwise_rank": + return int(parsed.get("winner") or 0) + if task == "edit_pairwise_score": + return [parsed.get("image_1_score"), parsed.get("image_2_score")] + return parsed.get("score") + + @torch.no_grad() + def forward(self, prompt: str | list[str] | None, images, task: str | None = None): + task = task or self.task + self._validate_task(task) + prompts = _as_list(prompt if prompt is not None else "") + images = _as_list(images) + + if images and isinstance(images[0], Image.Image): + images = [images] + + if len(prompts) == 1 and len(images) > 1: + prompts = prompts * len(images) + if len(images) == 1 and len(prompts) > 1: + images = images * len(prompts) + if len(prompts) != len(images): + raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.") + + outputs = [] + for single_prompt, single_images in zip(prompts, images): + raw_text = self._decode_sample(single_prompt, single_images, task) + outputs.append(self._parse_output(raw_text, task)) + return outputs + diff --git a/diffsynth/utils/state_dict_converters/image_metrics.py b/diffsynth/utils/state_dict_converters/image_metrics.py index 30c8b55a3..7ec8478ce 100644 --- a/diffsynth/utils/state_dict_converters/image_metrics.py +++ b/diffsynth/utils/state_dict_converters/image_metrics.py @@ -90,6 +90,20 @@ def ImageMetricsHPSv3StateDictConverter(state_dict): return converted +def ImageMetricsUnifiedRewardStateDictConverter(state_dict): + converted = {} + for key in state_dict: + value = state_dict[key] + if key == "lm_head.weight": + key = "model.lm_head.weight" + elif key.startswith("model.language_model."): + key = "model.model." + key[len("model.") :] + elif key.startswith("model.visual."): + key = "model.model." + key[len("model.") :] + converted[key] = value + return converted + + def _convert_open_clip_resblock(prefix, suffix, value): converted = {} parts = suffix.split(".") diff --git a/docs/en/Model_Details/Image-Quality-Metrics.md b/docs/en/Model_Details/Image-Quality-Metrics.md index c02f26e07..71c2cf204 100644 --- a/docs/en/Model_Details/Image-Quality-Metrics.md +++ b/docs/en/Model_Details/Image-Quality-Metrics.md @@ -47,12 +47,15 @@ print(f"PickScore score:: {score:.3f}") | HPSv2 | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv2.py) | | HPSv3 | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv3.py) | | CLIP Score | prompt + PIL Image | Text-Image Similarity | [code](../../../examples/image_quality_metric/clipscore.py) | +| UnifiedReward 2.0 | prompt + PIL Image | multi-dimension scores | [code](../../../examples/image_quality_metric/unified_reward_2.py) | +| Qwen-Image-Bench | prompt + PIL Image | Overall score and multi-level dimension scores | [code](../../../examples/image_quality_metric/qwen_image_bench.py) | +| UnifiedReward Edit | editing instruction + source image + edited image | Image editing quality score | [code](../../../examples/image_quality_metric/unified_reward_edit.py) | | Aesthetic | PIL Image | Aesthetic Score | [code](../../../examples/image_quality_metric/aesthetic.py) | | FID | reference image directory + generated image directory | Distribution Distance | [code](../../../examples/image_quality_metric/fid.py) | ### Text-Image Alignment and Preference Evaluation -Applicable metrics: **PickScore**, **ImageReward**, **HPSv2**, **HPSv3**, **CLIP Score** +Applicable metrics: **PickScore**, **ImageReward**, **HPSv2**, **HPSv3**, **CLIP Score**, **UnifiedReward 2.0**, **Qwen-Image-Bench** These models are used to evaluate whether an image follows the prompt and aligns with human visual preferences. They must receive both the `prompt` and the `image` simultaneously. @@ -74,6 +77,69 @@ scores = metric.compute(["a cat", "a dog"], [image_cat, image_dog]) When prompt is a single string, the same prompt will be applied to every image. When prompt is a list of strings, the number of prompts must exactly match the number of images. +### Multi-Dimensional Image Quality Evaluation + +Applicable metrics: **UnifiedReward 2.0**, **Qwen-Image-Bench** + +These metrics also receive a `prompt` and an `image`, but in addition to the primary score, `evaluate()` returns more detailed evaluation dimensions. They are useful when you need to analyze text-image alignment, visual coherence, style, or multi-level quality dimensions. + +**Qwen-Image-Bench** + +```python +from diffsynth.metrics import ModelConfig, QwenImageBenchMetric + +metric = QwenImageBenchMetric.from_pretrained( + model_config=ModelConfig( + model_id="Qwen/Qwen-Image-Bench", + origin_file_pattern="model-*.safetensors", + ), + processor_config=ModelConfig( + model_id="Qwen/Qwen-Image-Bench", + origin_file_pattern="", + ), + device="cuda", +) +details = metric.evaluate(prompt, image)[0] +score = details["total_score"] +print(details["level1_scores"]) +print(details["level2_scores"]) +``` + +If you only need the primary score, you can also call `metric.compute(prompt, image)`. + +### Image Editing Quality Evaluation + +Applicable metric: **UnifiedReward Edit** + +UnifiedReward Edit evaluates whether an edited image follows the editing instruction and whether it is over-edited. The input usually includes an editing instruction, a source image, and edited image candidates. It supports three tasks: + +* `edit_pointwise_score`: scores a single edited result with `[source_image, edited_image]`. +* `edit_pairwise_rank`: compares two edited results and returns the winner with `[source_image, edited_image_1, edited_image_2]`. +* `edit_pairwise_score`: returns separate scores for two edited results with `[source_image, edited_image_1, edited_image_2]`. + +```python +from diffsynth.metrics import ModelConfig, UnifiedRewardEditMetric + +metric = UnifiedRewardEditMetric.from_pretrained( + model_config=ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/model-*.safetensors", + ), + processor_config=ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/", + ), + device="cuda", +) + +details = metric.evaluate( + instruction, + [source_image, edited_image], + task="edit_pointwise_score", +)[0] +print(details["score"], details["editing_success"], details["overediting"]) +``` + ### Pure Image Aesthetics Evaluation Applicable metric: **Aesthetic** @@ -108,6 +174,6 @@ The baseline for FID is not fixed or unique. For general image generation, COCO ## Important Notes -* The scores from PickScore, ImageReward, HPSv2, HPSv3, CLIPScore, and Aesthetic are suitable for relative comparison within the same metric. It is not recommended to directly compare the numerical values across different metrics. -* HPSv3 is based on Qwen2-VL and is a larger model, requiring significantly more VRAM than CLIP-based metrics. +* The scores from PickScore, ImageReward, HPSv2, HPSv3, CLIPScore, UnifiedReward 2.0, Qwen-Image-Bench, UnifiedReward Edit, and Aesthetic are suitable for relative comparison within the same metric. It is not recommended to directly compare the numerical values across different metrics. +* HPSv3, UnifiedReward 2.0, UnifiedReward Edit, and Qwen-Image-Bench are based on multimodal large models, requiring significantly more VRAM than CLIP-based metrics. * FID is sensitive to the choice of reference, the reference sample size, and the generated sample size. diff --git a/docs/en/Model_Details/Overview.md b/docs/en/Model_Details/Overview.md index 28f3fbf63..50c11cebc 100644 --- a/docs/en/Model_Details/Overview.md +++ b/docs/en/Model_Details/Overview.md @@ -334,4 +334,7 @@ print("PickScore score:", metric.compute(prompt, image)[0]) |HPSv3|[GitHub](https://github.com/MizzenAI/HPSv3)|[code](../../../examples/image_quality_metric/hpsv3.py)| |CLIP Score|[GitHub](https://github.com/openai/CLIP)|[code](../../../examples/image_quality_metric/clipscore.py)| |Aesthetic|[GitHub](https://github.com/christophschuhmann/improved-aesthetic-predictor)|[code](../../../examples/image_quality_metric/aesthetic.py)| -|FID|[GitHub](https://github.com/mseitzer/pytorch-fid)|[code](../../../examples/image_quality_metric/fid.py)| \ No newline at end of file +|FID|[GitHub](https://github.com/mseitzer/pytorch-fid)|[code](../../../examples/image_quality_metric/fid.py)| +|UnifiedReward|[GitHub](https://github.com/CodeGoat24/UnifiedReward)|[code](../../../examples/image_quality_metric/unified_reward_2.py)| +|UnifiedReward Edit|[GitHub](https://github.com/CodeGoat24/UnifiedReward)|[code](../../../examples/image_quality_metric/unified_reward_edit.py)| +|Qwen-Image-Bench|[GitHub](https://github.com/QwenLM/Qwen-Image-Bench)|[code](../../../examples/image_quality_metric/qwen_image_bench.py)| \ No newline at end of file diff --git a/docs/zh/Model_Details/Image-Quality-Metrics.md b/docs/zh/Model_Details/Image-Quality-Metrics.md index c00a9bd6b..a4eacb2e3 100644 --- a/docs/zh/Model_Details/Image-Quality-Metrics.md +++ b/docs/zh/Model_Details/Image-Quality-Metrics.md @@ -47,12 +47,15 @@ print(f"PickScore score:: {score:.3f}") |HPSv2|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv2.py)| |HPSv3|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv3.py)| |CLIP Score|prompt + PIL 图像|图文匹配度|[code](../../../examples/image_quality_metric/clipscore.py)| +|UnifiedReward 2.0|prompt + PIL 图像|多维度分数|[code](../../../examples/image_quality_metric/unified_reward_2.py)| +|Qwen-Image-Bench|prompt + PIL 图像|多级维度分数|[code](../../../examples/image_quality_metric/qwen_image_bench.py)| +|UnifiedReward Edit|编辑指令 + 源图 + 编辑图|图像编辑质量分数|[code](../../../examples/image_quality_metric/unified_reward_edit.py)| |Aesthetic|PIL 图像|美学分数|[code](../../../examples/image_quality_metric/aesthetic.py)| |FID|reference 图像目录 + generated 图像目录|分布距离|[code](../../../examples/image_quality_metric/fid.py)| ### 文本-图像对齐与偏好评估 -适用指标: **PickScore**,**ImageReward**,**HPSv2**,**HPSv3**,**CLIP Score** +适用指标: **PickScore**,**ImageReward**,**HPSv2**,**HPSv3**,**CLIP Score**,**UnifiedReward 2.0**,**Qwen-Image-Bench** 这类模型用于评估图像是否遵循提示词以及是否符合人类视觉偏好。它们必须同时接收 `prompt` 和 `image`。 @@ -72,6 +75,69 @@ scores = metric.compute(["a cat", "a dog"], [image_cat, image_dog]) 其中 prompt 为单个字符串时,会对每张图像使用同一个 prompt。prompt 为字符串列表时,prompt 数量需要和图像数量一致。 +### 多维度图像质量评估 + +适用指标: **UnifiedReward 2.0**,**Qwen-Image-Bench** + +这两个指标同样接收 `prompt` 和 `image`,但除了主分数外,还会通过 `evaluate()` 返回更细的评估维度,适合需要分析图文对齐、画面一致性、风格或多级质量维度的场景。 + +**Qwen-Image-Bench** + +```python +from diffsynth.metrics import ModelConfig, QwenImageBenchMetric + +metric = QwenImageBenchMetric.from_pretrained( + model_config=ModelConfig( + model_id="Qwen/Qwen-Image-Bench", + origin_file_pattern="model-*.safetensors", + ), + processor_config=ModelConfig( + model_id="Qwen/Qwen-Image-Bench", + origin_file_pattern="", + ), + device="cuda", +) +details = metric.evaluate(prompt, image)[0] +score = details["total_score"] +print(details["level1_scores"]) +print(details["level2_scores"]) +``` + +如果只需要主分数,也可以调用 `metric.compute(prompt, image)`。 + +### 图像编辑质量评估 + +适用指标: **UnifiedReward Edit** + +UnifiedReward Edit 用于评估编辑结果是否遵循编辑指令,并衡量是否存在过度编辑。输入通常包括编辑指令、源图和编辑图。它支持三种任务: + +* `edit_pointwise_score`:对单个编辑结果打分,输入为 `[source_image, edited_image]`。 +* `edit_pairwise_rank`:比较两个编辑结果并返回胜者,输入为 `[source_image, edited_image_1, edited_image_2]`。 +* `edit_pairwise_score`:分别返回两个编辑结果的分数,输入为 `[source_image, edited_image_1, edited_image_2]`。 + +```python +from diffsynth.metrics import ModelConfig, UnifiedRewardEditMetric + +metric = UnifiedRewardEditMetric.from_pretrained( + model_config=ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/model-*.safetensors", + ), + processor_config=ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/", + ), + device="cuda", +) + +details = metric.evaluate( + instruction, + [source_image, edited_image], + task="edit_pointwise_score", +)[0] +print(details["score"], details["editing_success"], details["overediting"]) +``` + ### 纯图像美学评估 适用指标: **Aesthetic** @@ -107,6 +173,6 @@ FID 的基准不是固定唯一的。对于通用图像生成,常使用 COCO V ## 注意事项 -* PickScore、ImageReward、HPSv2、HPSv3、CLIPScore、Aesthetic 的分数适合做同一指标内部的相对比较,不建议直接把不同指标的数值大小相互比较。 -* HPSv3 基于 Qwen2-VL,模型较大,显存需求明显高于 CLIP 类指标。 +* PickScore、ImageReward、HPSv2、HPSv3、CLIPScore、UnifiedReward 2.0、Qwen-Image-Bench、UnifiedReward Edit、Aesthetic 的分数适合做同一指标内部的相对比较,不建议直接把不同指标的数值大小相互比较。 +* HPSv3、UnifiedReward 2.0、UnifiedReward Edit 和 Qwen-Image-Bench 基于多模态大模型,显存需求明显高于 CLIP 类指标。 * FID 对 reference 选择、样本量和 generated 样本量较敏感。 diff --git a/docs/zh/Model_Details/Overview.md b/docs/zh/Model_Details/Overview.md index 0aa0c2e3b..0fe64db55 100644 --- a/docs/zh/Model_Details/Overview.md +++ b/docs/zh/Model_Details/Overview.md @@ -331,4 +331,7 @@ print("PickScore score:", metric.compute(prompt, image)[0]) |HPSv3|[GitHub](https://github.com/MizzenAI/HPSv3)|[code](../../../examples/image_quality_metric/hpsv3.py)| |CLIP Score|[GitHub](https://github.com/openai/CLIP)|[code](../../../examples/image_quality_metric/clipscore.py)| |Aesthetic|[GitHub](https://github.com/christophschuhmann/improved-aesthetic-predictor)|[code](../../../examples/image_quality_metric/aesthetic.py)| -|FID|[GitHub](https://github.com/mseitzer/pytorch-fid)|[code](../../../examples/image_quality_metric/fid.py)| \ No newline at end of file +|FID|[GitHub](https://github.com/mseitzer/pytorch-fid)|[code](../../../examples/image_quality_metric/fid.py)| +|UnifiedReward|[GitHub](https://github.com/CodeGoat24/UnifiedReward)|[code](../../../examples/image_quality_metric/unified_reward_2.py)| +|UnifiedReward Edit|[GitHub](https://github.com/CodeGoat24/UnifiedReward)|[code](../../../examples/image_quality_metric/unified_reward_edit.py)| +|Qwen-Image-Bench|[GitHub](https://github.com/QwenLM/Qwen-Image-Bench)|[code](../../../examples/image_quality_metric/qwen_image_bench.py)| \ No newline at end of file diff --git a/examples/image_quality_metric/qwen_image_bench.py b/examples/image_quality_metric/qwen_image_bench.py new file mode 100644 index 000000000..3a61b951c --- /dev/null +++ b/examples/image_quality_metric/qwen_image_bench.py @@ -0,0 +1,33 @@ +import os +from diffsynth.metrics import ModelConfig, QwenImageBenchMetric +from modelscope import dataset_snapshot_download +from PIL import Image + + +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + allow_file_pattern="flux/FLUX.1-dev/*", + local_dir="./data/diffsynth_example_dataset", +) + +image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB") +prompt = "dog, white and brown dog, sitting on wall, under pink flowers" +device = "cuda" + +metric = QwenImageBenchMetric.from_pretrained( + model_config=ModelConfig( + model_id="Qwen/Qwen-Image-Bench", + origin_file_pattern="model-*.safetensors", + ), + processor_config=ModelConfig( + model_id="Qwen/Qwen-Image-Bench", + origin_file_pattern="", + ), + device=device, +) + +details = metric.evaluate(prompt, image)[0] +score = details["total_score"] if details["total_score"] is not None else 0.0 +print(f"Total Score: {score:.3f}") +print(details["level1_scores"]) +print(details["level2_scores"]) diff --git a/examples/image_quality_metric/unified_reward_2.py b/examples/image_quality_metric/unified_reward_2.py new file mode 100644 index 000000000..ce63281d4 --- /dev/null +++ b/examples/image_quality_metric/unified_reward_2.py @@ -0,0 +1,36 @@ +from diffsynth.metrics import ModelConfig, UnifiedReward2Metric +from modelscope import dataset_snapshot_download +from PIL import Image + + +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + allow_file_pattern="flux/FLUX.1-dev/*", + local_dir="./data/diffsynth_example_dataset", +) + +image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB") +prompt = "dog, white and brown dog, sitting on wall, under pink flowers" +device = "cuda" + +metric = UnifiedReward2Metric.from_pretrained( + model_config=ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-2.0-qwen35-9b/model-*.safetensors", + ), + processor_config=ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-2.0-qwen35-9b/", + ), + device=device, +) + +details = metric.evaluate(prompt, image)[0] +score = details["score"] +print( + f"Score: {score:.3f}\n" + f"Alignment: {details['alignment']:.3f}\n" + f"Coherence: {details['coherence']:.3f}\n" + f"Style: {details['style']:.3f}" +) +print(details) diff --git a/examples/image_quality_metric/unified_reward_edit.py b/examples/image_quality_metric/unified_reward_edit.py new file mode 100644 index 000000000..bc6e76ef7 --- /dev/null +++ b/examples/image_quality_metric/unified_reward_edit.py @@ -0,0 +1,63 @@ +from diffsynth.metrics import ModelConfig, UnifiedRewardEditMetric +from modelscope import dataset_snapshot_download +from PIL import Image + + +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + allow_file_pattern="qwen_image/Qwen-Image-Edit/*", + local_dir="./data/diffsynth_example_dataset", +) + +source_image = Image.open("data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit/edit/image1.jpg").convert("RGB") +edited_image_1 = Image.open("data/diffsynth_example_dataset/qwen_image/Qwen-Image-Edit/edit/image2.jpg").convert("RGB") +edited_image_2 = source_image +instruction = "将裙子改为粉色" +device = "cuda" + +metric = UnifiedRewardEditMetric.from_pretrained( + model_config=ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/model-*.safetensors", + ), + processor_config=ModelConfig( + model_id="DiffSynth-Studio/ImageMetrics", + origin_file_pattern="UnifiedReward-Edit-qwen3vl-8b/", + ), + device=device, +) + +pointwise_details = metric.evaluate( + instruction, + [source_image, edited_image_1], + task="edit_pointwise_score" +)[0] +print("---UnifiedReward edit pointwise score---") +print(f"Score: {pointwise_details['score']:.3f}") +print( + f"Editing Success: {pointwise_details['editing_success']:.3f}\n" + f"Overediting: {pointwise_details['overediting']:.3f}" +) +print(pointwise_details, "\n") + +pairwise_rank_details = metric.evaluate( + instruction, + [source_image, edited_image_1, edited_image_2], + task="edit_pairwise_rank", +)[0] +print("---UnifiedReward edit pairwise rank score---") +print(f"Score: {pairwise_rank_details['score']}") +print(f"Winner: {pairwise_rank_details['winner']}") +print(pairwise_rank_details, "\n") + +pairwise_score_details = metric.evaluate( + instruction, + [source_image, edited_image_1, edited_image_2], + task="edit_pairwise_score", +)[0] +print("--UnifiedReward edit pairwise score---") +print( + f"Image 1 Score: {pairwise_score_details['image_1_score']:.3f}\n" + f"Image 2 Score: {pairwise_score_details['image_2_score']:.3f}" +) +print(pairwise_score_details)