diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index 5eb78c8a47e..fd814e5910e 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -1359,6 +1359,27 @@ def doc_to_prefix(self, doc): return utils.apply_template(gen_prefix, doc) return None + def _apply_gen_kwargs_templates(self, gen_kwargs: dict, doc: dict) -> dict: + """Apply Jinja2 templating to string values inside gen_kwargs. + + After rendering, any string that is a valid Python literal (e.g. a + JSON-encoded list of tool definitions) is parsed with ast.literal_eval + so that structured values arrive at the model as native Python objects. + """ + if gen_kwargs is None: + return gen_kwargs + result = {} + for k, v in gen_kwargs.items(): + if isinstance(v, str): + rendered = utils.apply_template(v, doc) + try: + result[k] = ast.literal_eval(rendered) + except (ValueError, SyntaxError): + result[k] = rendered + else: + result[k] = v + return result + def construct_requests( self, doc: dict, ctx: str | list[str], **kwargs ) -> list[Instance] | Instance: @@ -1405,7 +1426,12 @@ def construct_requests( arguments.extend(aux_arguments) elif self.OUTPUT_TYPE == "generate_until": - arguments = (ctx, deepcopy(self.config.generation_kwargs)) + arguments = ( + ctx, + self._apply_gen_kwargs_templates( + deepcopy(self.config.generation_kwargs), doc + ), + ) multimodal_arg = {} if (