Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions nemoguardrails/eval/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_validator

from nemoguardrails.eval.utils import load_dict_from_path
from nemoguardrails.logging.explain import LLMCallInfo
Expand Down Expand Up @@ -96,7 +96,8 @@ class InteractionSet(BaseModel):
description="A list of tags that should be associated with the interactions. Useful for filtering when reporting.",
)

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def instantiate_expected_output(cls, values: Any):
"""Creates the right instance of the expected output."""
type_mapping = {
Expand Down Expand Up @@ -132,18 +133,18 @@ class EvalConfig(BaseModel):
description="The prompts that should be used for the various LLM tasks.",
)

@root_validator(pre=False, skip_on_failure=True)
def validate_policy_ids(cls, values: Any):
@model_validator(mode="after")
def validate_policy_ids(self):
"""Validates the policy ids used in the interactions."""
policy_ids = {policy.id for policy in values.get("policies")}
for interaction_set in values.get("interactions"):
policy_ids = {policy.id for policy in self.policies}
for interaction_set in self.interactions:
for expected_output in interaction_set.expected_output:
if expected_output.policy not in policy_ids:
raise ValueError(f"Invalid policy id {expected_output.policy} used in interaction set.")
for policy_id in interaction_set.include_policies + interaction_set.exclude_policies:
if policy_id not in policy_ids:
raise ValueError(f"Invalid policy id {policy_id} used in interaction set.")
return values
return self

@classmethod
def from_path(
Expand Down
27 changes: 17 additions & 10 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
Field,
PrivateAttr,
SecretStr,
field_validator,
model_validator,
root_validator,
validator,
)

from nemoguardrails import utils
Expand Down Expand Up @@ -436,7 +435,8 @@ class TaskPrompt(BaseModel):
ge=1,
)

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
@classmethod
def check_fields(cls, values):
if not values.get("content") and not values.get("messages"):
raise InvalidRailsConfigurationError("One of `content` or `messages` must be provided.")
Expand Down Expand Up @@ -1592,7 +1592,8 @@ class RailsConfig(BaseModel):
description="Configuration for tracing.",
)

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_model_exists_for_input_rails(cls, values):
"""Make sure we have a model for each input rail where one is provided using $model=<model_type>"""
rails = values.get("rails", {})
Expand All @@ -1618,7 +1619,8 @@ def check_model_exists_for_input_rails(cls, values):
)
return values

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_model_exists_for_output_rails(cls, values):
"""Make sure we have a model for each output rail where one is provided using $model=<model_type>"""
rails = values.get("rails", {})
Expand All @@ -1644,7 +1646,8 @@ def check_model_exists_for_output_rails(cls, values):
)
return values

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_prompt_exist_for_self_check_rails(cls, values):
rails = values.get("rails", {})
prompts = values.get("prompts", []) or []
Expand Down Expand Up @@ -1701,7 +1704,8 @@ def check_prompt_exist_for_self_check_rails(cls, values):

return values

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
@classmethod
def check_output_parser_exists(cls, values):
tasks_requiring_output_parser = [
"self_check_input",
Expand All @@ -1724,7 +1728,8 @@ def check_output_parser_exists(cls, values):
)
return values

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
@classmethod
def check_jailbreak_detection_config(cls, values):
"""Validate jailbreak detection configuration against enabled flows."""
rails = values.get("rails") or {}
Expand Down Expand Up @@ -1778,7 +1783,8 @@ def check_jailbreak_detection_config(cls, values):

return values

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
@classmethod
def fill_in_default_values_for_v2_x(cls, values):
instructions = values.get("instructions", {})
sample_conversation = values.get("sample_conversation")
Expand All @@ -1793,7 +1799,8 @@ def fill_in_default_values_for_v2_x(cls, values):

return values

@validator("models")
@field_validator("models")
@classmethod
def validate_models_api_key_env_var(cls, models):
"""Model API Key Env var must be set to make LLM calls"""
api_keys = [m.api_key_env_var for m in models]
Expand Down
5 changes: 3 additions & 2 deletions nemoguardrails/rails/llm/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_validator

from nemoguardrails.logging.explain import LLMCallInfo

Expand Down Expand Up @@ -196,7 +196,8 @@ class GenerationOptions(BaseModel):
description="Options about what to include in the log. By default, nothing is included. ",
)

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
@classmethod
def check_fields(cls, values):
# Translate the `rails` generation option from List[str] to dict.
if "rails" in values and isinstance(values["rails"], list):
Expand Down
Loading