Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,27 @@ def doc_to_prefix(self, doc):
return utils.apply_template(gen_prefix, doc)
return None

def _apply_gen_kwargs_templates(self, gen_kwargs: dict, doc: dict) -> dict:
"""Apply Jinja2 templating to string values inside gen_kwargs.

After rendering, any string that is a valid Python literal (e.g. a
JSON-encoded list of tool definitions) is parsed with ast.literal_eval
so that structured values arrive at the model as native Python objects.
"""
if gen_kwargs is None:
return gen_kwargs
result = {}
for k, v in gen_kwargs.items():
if isinstance(v, str):
rendered = utils.apply_template(v, doc)
try:
result[k] = ast.literal_eval(rendered)
except (ValueError, SyntaxError):
result[k] = rendered
else:
result[k] = v
return result

def construct_requests(
self, doc: dict, ctx: str | list[str], **kwargs
) -> list[Instance] | Instance:
Expand Down Expand Up @@ -1405,7 +1426,12 @@ def construct_requests(
arguments.extend(aux_arguments)

elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
arguments = (
ctx,
self._apply_gen_kwargs_templates(
deepcopy(self.config.generation_kwargs), doc
),
)

multimodal_arg = {}
if (
Expand Down
Loading