diff --git a/nemoguardrails/eval/models.py b/nemoguardrails/eval/models.py index 8f1afbb3a2..7e6205661f 100644 --- a/nemoguardrails/eval/models.py +++ b/nemoguardrails/eval/models.py @@ -16,7 +16,7 @@ import os from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from nemoguardrails.eval.utils import load_dict_from_path from nemoguardrails.logging.explain import LLMCallInfo @@ -96,7 +96,8 @@ class InteractionSet(BaseModel): description="A list of tags that should be associated with the interactions. Useful for filtering when reporting.", ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def instantiate_expected_output(cls, values: Any): """Creates the right instance of the expected output.""" type_mapping = { @@ -132,18 +133,18 @@ class EvalConfig(BaseModel): description="The prompts that should be used for the various LLM tasks.", ) - @root_validator(pre=False, skip_on_failure=True) - def validate_policy_ids(cls, values: Any): + @model_validator(mode="after") + def validate_policy_ids(self): """Validates the policy ids used in the interactions.""" - policy_ids = {policy.id for policy in values.get("policies")} - for interaction_set in values.get("interactions"): + policy_ids = {policy.id for policy in self.policies} + for interaction_set in self.interactions: for expected_output in interaction_set.expected_output: if expected_output.policy not in policy_ids: raise ValueError(f"Invalid policy id {expected_output.policy} used in interaction set.") for policy_id in interaction_set.include_policies + interaction_set.exclude_policies: if policy_id not in policy_ids: raise ValueError(f"Invalid policy id {policy_id} used in interaction set.") - return values + return self @classmethod def from_path( diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 2615fe8380..5630030b8d 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -29,9 +29,8 @@ Field, PrivateAttr, SecretStr, + field_validator, model_validator, - root_validator, - validator, ) from nemoguardrails import utils @@ -436,7 +435,8 @@ class TaskPrompt(BaseModel): ge=1, ) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") + @classmethod def check_fields(cls, values): if not values.get("content") and not values.get("messages"): raise InvalidRailsConfigurationError("One of `content` or `messages` must be provided.") @@ -1592,7 +1592,8 @@ class RailsConfig(BaseModel): description="Configuration for tracing.", ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_model_exists_for_input_rails(cls, values): """Make sure we have a model for each input rail where one is provided using $model=""" rails = values.get("rails", {}) @@ -1618,7 +1619,8 @@ def check_model_exists_for_input_rails(cls, values): ) return values - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_model_exists_for_output_rails(cls, values): """Make sure we have a model for each output rail where one is provided using $model=""" rails = values.get("rails", {}) @@ -1644,7 +1646,8 @@ def check_model_exists_for_output_rails(cls, values): ) return values - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_prompt_exist_for_self_check_rails(cls, values): rails = values.get("rails", {}) prompts = values.get("prompts", []) or [] @@ -1701,7 +1704,8 @@ def check_prompt_exist_for_self_check_rails(cls, values): return values - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") + @classmethod def check_output_parser_exists(cls, values): tasks_requiring_output_parser = [ "self_check_input", @@ -1724,7 +1728,8 @@ def check_output_parser_exists(cls, values): ) return values - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") + @classmethod def check_jailbreak_detection_config(cls, values): """Validate jailbreak detection configuration against enabled flows.""" rails = values.get("rails") or {} @@ -1778,7 +1783,8 @@ def check_jailbreak_detection_config(cls, values): return values - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") + @classmethod def fill_in_default_values_for_v2_x(cls, values): instructions = values.get("instructions", {}) sample_conversation = values.get("sample_conversation") @@ -1793,7 +1799,8 @@ def fill_in_default_values_for_v2_x(cls, values): return values - @validator("models") + @field_validator("models") + @classmethod def validate_models_api_key_env_var(cls, models): """Model API Key Env var must be set to make LLM calls""" api_keys = [m.api_key_env_var for m in models] diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index a95a1347ea..3731260654 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -80,7 +80,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from nemoguardrails.logging.explain import LLMCallInfo @@ -196,7 +196,8 @@ class GenerationOptions(BaseModel): description="Options about what to include in the log. By default, nothing is included. ", ) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") + @classmethod def check_fields(cls, values): # Translate the `rails` generation option from List[str] to dict. if "rails" in values and isinstance(values["rails"], list):