diff --git a/demo/dj_integration.ipynb b/demo/dj_integration.ipynb index 686a6ce..a1513c7 100644 --- a/demo/dj_integration.ipynb +++ b/demo/dj_integration.ipynb @@ -33,7 +33,7 @@ "from torch import load\n", "\n", "from mei.main import TrainedEnsembleModelTemplate, CSRFV1SelectorTemplate, MEISeed, MEIMethod, MEITemplate\n", - "from nnfabrik.template import TrainedModelBase\n", + "from nnfabrik.templates import TrainedModelBase\n", "from nnfabrik.main import Dataset" ] }, @@ -246,7 +246,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.0" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/mei/integration.py b/mei/integration.py index 8f2f0d7..4077547 100755 --- a/mei/integration.py +++ b/mei/integration.py @@ -6,6 +6,8 @@ from nnfabrik.utility.dj_helpers import make_hash from nnfabrik.utility.nnf_helper import dynamic_import, split_module_name +from .modules import ConstrainedOutputModel + def load_pickled_data(path): with open(path, "rb") as datafile: @@ -15,7 +17,7 @@ def load_pickled_data(path): def get_mappings(dataset_config, key, load_func=load_pickled_data): entities = [] - for datafile_path in dataset_config["datafiles"]: + for datafile_path in dataset_config["paths"]: data = load_func(datafile_path) for neuron_pos, neuron_id in enumerate(data["unit_indices"]): entities.append( @@ -33,6 +35,28 @@ def import_module(path): return dynamic_import(*split_module_name(path)) +def get_output_selected_model(neuron_pos, session_id, model): + """Creates a version of the model that has its output selected down to a single uniquely identified neuron. + + Args: + neuron_pos: An integer, the position of the neuron in the model's output. + session_id: A string that uniquely identifies one of the model's readouts. + model: A PyTorch module that can be called with a keyword argument called "data_key". The output of the + module is expected to be a two dimensional Torch tensor where the first dimension corresponds to the + batch size and the second to the number of neurons. + + Returns: + A function that takes the model input(s) as parameter(s) and returns the model output corresponding to the + selected neuron. + """ + + # def output_selected_model(x, *args, **kwargs): + # output = model(x, *args, data_key=session_id, **kwargs) + # return output[:, neuron_pos] + + return ConstrainedOutputModel(model, neuron_pos, forward_kwargs=dict(data_key=session_id)) + + class ModelLoader: """ A utility class for loading and caching models. diff --git a/mei/main.py b/mei/main.py index c0f9002..c135d8d 100755 --- a/mei/main.py +++ b/mei/main.py @@ -7,7 +7,7 @@ import datajoint as dj from nnfabrik.main import Dataset, schema -from . import mixins +from . import integration, mixins class TrainedEnsembleModelTemplate(mixins.TrainedEnsembleModelTemplateMixin, dj.Manual): @@ -48,7 +48,7 @@ class MEIMethod(mixins.MEIMethodMixin, dj.Lookup): """Table that contains MEI methods and their configurations.""" -class MEITemplate(mixins.MEITemplateMixin, dj.Computed): +class MEITemplate(mixins.MEIMixin, dj.Computed): """MEI table template. To create a functional "MEI" table, create a new class that inherits from this template and decorate it with your @@ -60,3 +60,15 @@ class MEITemplate(mixins.MEITemplateMixin, dj.Computed): method_table = MEIMethod seed_table = MEISeed + + +class CSRFV1SelectorTemplate(mixins.CSRFV1SelectorTemplateMixin, dj.Computed): + """CSRF V1 selector table template. + + To create a functional "CSRFV1Selector" table, create a new class that inherits from this template and decorate it + with your preferred Datajoint schema. By default, the created table will point to the "Dataset" table in the + Datajoint schema called "nnfabrik.main". This behavior can be changed by overwriting the class attribute called + "dataset_table". + """ + + dataset_table = Dataset diff --git a/mei/mixins.py b/mei/mixins.py index 7a97a62..7b6ac58 100755 --- a/mei/mixins.py +++ b/mei/mixins.py @@ -62,6 +62,7 @@ def load_model( key: Optional[Key] = None, include_dataloader: Optional[bool] = True, include_state_dict: Optional[bool] = True, + **kwargs, ) -> Tuple[Dataloaders, EnsembleModel]: if key is None: key = self.fetch1("KEY") @@ -69,6 +70,7 @@ def load_model( key=key, include_dataloader=include_dataloader, include_state_dict=include_state_dict, + **kwargs, ) def _load_ensemble_model( @@ -76,8 +78,8 @@ def _load_ensemble_model( key: Optional[Key] = None, include_dataloader: Optional[bool] = True, include_state_dict: Optional[bool] = True, + **kwargs, ) -> Tuple[Dataloaders, EnsembleModel]: - ensemble_key = (self & key).fetch1() model_keys = (self.Member() & ensemble_key).fetch(as_dict=True) @@ -90,6 +92,7 @@ def _load_ensemble_model( key=k, include_dataloader=include_dataloader, include_state_dict=include_state_dict, + **kwargs, ) for k in model_keys ] @@ -101,6 +104,7 @@ def _load_ensemble_model( key=k, include_dataloader=include_dataloader, include_state_dict=include_state_dict, + **kwargs, ) for k in model_keys ] @@ -170,21 +174,26 @@ class MEIMethodMixin: "postprocessing", ) - def add_method(self, method_fn: str, method_config: Mapping, comment: str = "") -> None: + def add_method(self, method_fn: str, method_config: Mapping, comment: str = "", **kwargs) -> None: self.insert1( dict( method_fn=method_fn, method_hash=make_hash(method_config), method_config=method_config, method_comment=comment, - ) + ), + **kwargs, ) def generate_mei(self, dataloaders: Dataloaders, model: Module, key: Key, seed: int) -> Dict[str, Any]: method_fn, method_config = (self & key).fetch1("method_fn", "method_config") method_fn = self.import_func(method_fn) self.insert_key_in_ops(method_config=method_config, key=key) - mei, score, output = method_fn(dataloaders, model, method_config, seed) + + train_input_shape = next(iter(next(iter(dataloaders["train"].values())))).images.shape + input_shape = (1,) + train_input_shape[1:] + mei, score, output = method_fn(model=model, config=method_config, seed=seed, shape=input_shape) + return dict(key, mei=mei, score=score, output=output) def generate_ringmei( @@ -209,7 +218,7 @@ class MEISeedMixin: """ -class MEITemplateMixin: +class MEIMixin: definition = """ # contains maximally exciting images (MEIs) -> self.method_table @@ -260,3 +269,26 @@ def _save_to_disk(self, mei_entity: Dict[str, Any], temp_dir: str, name: str) -> @staticmethod def _create_random_filename(length: Optional[int] = 32) -> str: return "".join(choice(ascii_letters) for _ in range(length)) + + +class CSRFV1SelectorTemplateMixin: + dataset_table = None + + definition = """ + # contains information that can be used to map a neuron's id to its corresponding integer position in the output of + # the model. + -> self.dataset_table + neuron_id : smallint unsigned # unique neuron identifier + --- + neuron_position : smallint unsigned # integer position of the neuron in the model's output + session_id : varchar(13) # unique session identifier + """ + + def make(self, key): + dataset_config = (self.dataset_table & key).fetch1("dataset_config") + mappings = integration.get_mappings(dataset_config, key) + self.insert(mappings) + + def get_output_selected_model(self, model, key): + neuron_pos, session_id = (self & key).fetch1("neuron_position", "session_id") + return integration.get_output_selected_model(neuron_pos, session_id, model) diff --git a/mei/modules.py b/mei/modules.py index ebf4371..55d6f03 100644 --- a/mei/modules.py +++ b/mei/modules.py @@ -75,6 +75,9 @@ def __init__( self.constraint = constraint if (isinstance(constraint, Iterable) or constraint is None) else [constraint] self.forward_kwargs = forward_kwargs if forward_kwargs else dict() self.target_fn = target_fn + if hasattr(model, "core") and hasattr(model, "readout"): + self.core = model.core + self.readout = model.readout def __call__(self, x: Tensor, *args, **kwargs) -> Tensor: """Computes the constrained output of the model. @@ -87,7 +90,11 @@ def __call__(self, x: Tensor, *args, **kwargs) -> Tensor: Returns: A tensor representing the constrained output of the model. """ - output = self.model(x, *args, **self.forward_kwargs, **kwargs) + duplicate_keys = self.forward_kwargs.keys() & kwargs.keys() + reduced_forward_kwargs = self.forward_kwargs + for key in duplicate_keys: + reduced_forward_kwargs.pop(key) + output = self.model(x, *args, **reduced_forward_kwargs, **kwargs) return ( self.target_fn(output) if self.constraint is None or len(self.constraint) == 0 diff --git a/requirements.txt b/requirements.txt index c08d2bd..acf47e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,6 @@ datajoint<=0.12.9 nnfabrik torch scipy>=1.7.0,<=1.12.0 -numpy>=1.22.0,<=1.26.4 +numpy>=1.21.0,<=1.26.4 torchvision pytest