Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions nemoguardrails/actions/action_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

from langchain_core.runnables import Runnable

from nemoguardrails import utils
from nemoguardrails.exceptions import LLMCallException

Expand Down Expand Up @@ -218,11 +216,10 @@ async def execute_action(
else:
log.warning(f"Synchronous action `{action_name}` has been called.")

elif isinstance(fn, Runnable):
# If it's a Runnable, we invoke it as well
runnable = fn

result = await runnable.ainvoke(input=params)
elif hasattr(fn, "ainvoke") and callable(fn.ainvoke): # type: ignore[union-attr]
# Duck-type check for LangChain Runnables (or any object
# with ainvoke) to avoid importing langchain in core.
result = await fn.ainvoke(input=params) # type: ignore[union-attr]
else:
# TODO: there should be a common base class here
fn_run_func = getattr(fn, "run", None)
Expand Down
12 changes: 8 additions & 4 deletions nemoguardrails/eval/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from typing import List

import typer
from langchain_community.cache import SQLiteCache
from langchain_core.globals import set_llm_cache

from nemoguardrails.eval.check import LLMJudgeComplianceChecker
from nemoguardrails.eval.eval import run_eval
Expand Down Expand Up @@ -172,8 +170,14 @@ def check_compliance(
if disable_llm_cache:
console.print("[orange]Caching is disabled.[/]")
else:
console.print("[green]Caching is enabled.[/]")
set_llm_cache(SQLiteCache(database_path=".langchain.db"))
try:
from langchain_community.cache import SQLiteCache
from langchain_core.globals import set_llm_cache

set_llm_cache(SQLiteCache(database_path=".langchain.db"))
console.print("[green]Caching is enabled.[/]")
except ImportError:
console.print("[yellow]langchain not installed, LLM caching unavailable.[/]")

console.print(f"Using eval configuration from {eval_config_path}.")
console.print(f"Using output paths: {output_path}.")
Expand Down
22 changes: 7 additions & 15 deletions nemoguardrails/evaluate/evaluate_factcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@

import tqdm
import typer
from langchain_core.prompts import PromptTemplate

from nemoguardrails import LLMRails
from nemoguardrails.actions.llm.utils import llm_call
from nemoguardrails.evaluate.utils import load_dataset
from nemoguardrails.llm.prompts import Task
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.rails.llm.config import RailsConfig
from nemoguardrails.utils import get_or_create_event_loop


class FactCheckEvaluation:
Expand Down Expand Up @@ -89,27 +89,19 @@ def create_negative_samples(self, dataset):
that it will not be grounded in the evidence passage. change details in the answer to make the answer
wrong but yet believable.\nevidence: {evidence}\nanswer: {answer}\nincorrect answer:"""

create_negatives_prompt = PromptTemplate(
template=create_negatives_template,
input_variables=["evidence", "answer"],
)

# Bind config parameters to the LLM for generating negative samples
llm_with_config = self.llm.bind(temperature=0.8, max_tokens=300)
loop = get_or_create_event_loop()

print("Creating negative samples...")
for data in tqdm.tqdm(dataset):
assert "evidence" in data and "question" in data and "answer" in data
evidence = data["evidence"]
answer = data["answer"]

# Format the prompt and invoke the LLM directly
formatted_prompt = create_negatives_prompt.format(evidence=evidence, answer=answer)
negative_answer = llm_with_config.invoke(formatted_prompt)
if isinstance(negative_answer, str):
data["incorrect_answer"] = negative_answer.strip()
else:
data["incorrect_answer"] = negative_answer.content.strip()
formatted_prompt = create_negatives_template.format(evidence=evidence, answer=answer)
response = loop.run_until_complete(
self.llm.generate_async(formatted_prompt, temperature=0.8, max_tokens=300)
)
data["incorrect_answer"] = response.content.strip()

return dataset

Expand Down
5 changes: 1 addition & 4 deletions nemoguardrails/library/hallucination/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import logging
from typing import Optional

from langchain_core.prompts import PromptTemplate

from nemoguardrails import RailsConfig
from nemoguardrails.actions import action
from nemoguardrails.actions.llm.utils import (
Expand Down Expand Up @@ -56,8 +54,7 @@ async def self_check_hallucination(
if bot_response and last_bot_prompt_string:
num_responses = HALLUCINATION_NUM_EXTRA_RESPONSES

last_bot_prompt = PromptTemplate(template="{text}", input_variables=["text"])
formatted_prompt = last_bot_prompt.format(text=last_bot_prompt_string)
formatted_prompt = last_bot_prompt_string

async def _generate_extra_response(index: int) -> Optional[str]:
llm_call_info_var.set(LLMCallInfo(task=Task.SELF_CHECK_HALLUCINATION.value))
Expand Down
Loading