diff --git a/src/monkey_patch/monkey.py b/src/monkey_patch/monkey.py index 62cf613..6b9fdb8 100644 --- a/src/monkey_patch/monkey.py +++ b/src/monkey_patch/monkey.py @@ -181,19 +181,6 @@ def mock_func(*args, **kwargs): else: return patched_func(*args, **kwargs) - def _get_args(func_args, kwarg_names, num_args): - num_pos_args = num_args - len(kwarg_names) # Calculate number of positional arguments - args_for_call = func_args[:num_pos_args] - # Pop keyword arguments off the stack - kwargs_for_call = {} # New dictionary to hold keyword arguments for the call - for name in reversed(kwarg_names): # Reverse to match the order on the stack - try: - kwargs_for_call[name] = func_args.pop() # Pop the value off the stack - except IndexError: - print(f"Debug: func_args is empty, can't pop for {name}") - func_args = func_args[:-num_pos_args] # Remove the positional arguments from func_args - return args_for_call, func_args, kwargs_for_call - return wrapper @staticmethod @@ -227,7 +214,6 @@ def wrapper(*args, **kwargs): raise TypeError(f"Output type was not valid. Expected an object of type {function_description.output_type_hint}, got '{output.generated_response}'") output.generated_response = choice output.distilled_model = False - datapoint = FunctionExample(args, kwargs, output.generated_response) if output.suitable_for_finetuning and not output.distilled_model: diff --git a/src/monkey_patch/register.py b/src/monkey_patch/register.py index eff6bfc..c1e7521 100644 --- a/src/monkey_patch/register.py +++ b/src/monkey_patch/register.py @@ -2,6 +2,7 @@ from typing import get_type_hints, Literal, Optional from monkey_patch.models.function_description import FunctionDescription +from monkey_patch.utils import get_source alignable_functions = {} @@ -74,7 +75,7 @@ def get_class_definition(class_type): elif hasattr(class_type, "__args__"): # Access inner types return [get_class_definition(arg) for arg in class_type.__args__ if arg is not None] elif inspect.isclass(class_type) and class_type.__module__ != "builtins": - return inspect.getsource(class_type) + return get_source(class_type) return class_type.__name__ # Extract class definitions for input and output types diff --git a/src/monkey_patch/utils.py b/src/monkey_patch/utils.py index 7c2a382..fa71985 100644 --- a/src/monkey_patch/utils.py +++ b/src/monkey_patch/utils.py @@ -3,7 +3,7 @@ import json import typing from typing import get_args, Literal - +import inspect def json_default(thing): try: @@ -121,3 +121,40 @@ def get_key(args, kwargs) -> tuple: args_tuple = _deep_tuple(args) kwargs_tuple = _deep_tuple(kwargs) return args_tuple, kwargs_tuple + +def _get_source_ipython(func) -> str: + """ + Get the source code of a function from IPython (to support Colab and Jupyter notebooks) + :param func: The function to get the source code from + :return: The source code of the function + """ + # Get the IPython instance + from IPython import get_ipython + ipython = get_ipython() + + # Get the input history + input_cells = ipython.history_manager.input_hist_parsed + + class_name = func.__name__ + source_code = None + + for cell in input_cells: + if f"class {class_name}" in cell: + source_code = cell + break + + # If found, print the source code + return source_code + +def get_source(func) -> str: + """ + Get the source code of a function + Args: + func (function): the function to get the source code from + Returns: + source (str): the source code of the function + """ + try: + return inspect.getsource(func) + except Exception: + return _get_source_ipython(func) \ No newline at end of file