Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion mei/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from nnfabrik.utility.dj_helpers import make_hash
from nnfabrik.utility.nnf_helper import dynamic_import, split_module_name

from .modules import ConstrainedOutputModel


def load_pickled_data(path):
with open(path, "rb") as datafile:
Expand All @@ -15,7 +17,7 @@ def load_pickled_data(path):

def get_mappings(dataset_config, key, load_func=load_pickled_data):
entities = []
for datafile_path in dataset_config["datafiles"]:
for datafile_path in dataset_config["paths"]:
data = load_func(datafile_path)
for neuron_pos, neuron_id in enumerate(data["unit_indices"]):
entities.append(
Expand All @@ -33,6 +35,28 @@ def import_module(path):
return dynamic_import(*split_module_name(path))


def get_output_selected_model(neuron_pos, session_id, model):
"""Creates a version of the model that has its output selected down to a single uniquely identified neuron.

Args:
neuron_pos: An integer, the position of the neuron in the model's output.
session_id: A string that uniquely identifies one of the model's readouts.
model: A PyTorch module that can be called with a keyword argument called "data_key". The output of the
module is expected to be a two dimensional Torch tensor where the first dimension corresponds to the
batch size and the second to the number of neurons.

Returns:
A function that takes the model input(s) as parameter(s) and returns the model output corresponding to the
selected neuron.
"""

# def output_selected_model(x, *args, **kwargs):
# output = model(x, *args, data_key=session_id, **kwargs)
# return output[:, neuron_pos]

return ConstrainedOutputModel(model, neuron_pos, forward_kwargs=dict(data_key=session_id))


class ModelLoader:
"""
A utility class for loading and caching models.
Expand Down
16 changes: 14 additions & 2 deletions mei/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import datajoint as dj
from nnfabrik.main import Dataset, schema

from . import mixins
from . import integration, mixins


class TrainedEnsembleModelTemplate(mixins.TrainedEnsembleModelTemplateMixin, dj.Manual):
Expand Down Expand Up @@ -48,7 +48,7 @@ class MEIMethod(mixins.MEIMethodMixin, dj.Lookup):
"""Table that contains MEI methods and their configurations."""


class MEITemplate(mixins.MEITemplateMixin, dj.Computed):
class MEITemplate(mixins.MEIMixin, dj.Computed):
"""MEI table template.

To create a functional "MEI" table, create a new class that inherits from this template and decorate it with your
Expand All @@ -60,3 +60,15 @@ class MEITemplate(mixins.MEITemplateMixin, dj.Computed):

method_table = MEIMethod
seed_table = MEISeed


class CSRFV1SelectorTemplate(mixins.CSRFV1SelectorTemplateMixin, dj.Computed):
"""CSRF V1 selector table template.

To create a functional "CSRFV1Selector" table, create a new class that inherits from this template and decorate it
with your preferred Datajoint schema. By default, the created table will point to the "Dataset" table in the
Datajoint schema called "nnfabrik.main". This behavior can be changed by overwriting the class attribute called
"dataset_table".
"""

dataset_table = Dataset
42 changes: 37 additions & 5 deletions mei/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,24 @@ def load_model(
key: Optional[Key] = None,
include_dataloader: Optional[bool] = True,
include_state_dict: Optional[bool] = True,
**kwargs,
) -> Tuple[Dataloaders, EnsembleModel]:
if key is None:
key = self.fetch1("KEY")
return self._load_ensemble_model(
key=key,
include_dataloader=include_dataloader,
include_state_dict=include_state_dict,
**kwargs,
)

def _load_ensemble_model(
self,
key: Optional[Key] = None,
include_dataloader: Optional[bool] = True,
include_state_dict: Optional[bool] = True,
**kwargs,
) -> Tuple[Dataloaders, EnsembleModel]:

ensemble_key = (self & key).fetch1()
model_keys = (self.Member() & ensemble_key).fetch(as_dict=True)

Expand All @@ -90,6 +92,7 @@ def _load_ensemble_model(
key=k,
include_dataloader=include_dataloader,
include_state_dict=include_state_dict,
**kwargs,
)
for k in model_keys
]
Expand All @@ -101,6 +104,7 @@ def _load_ensemble_model(
key=k,
include_dataloader=include_dataloader,
include_state_dict=include_state_dict,
**kwargs,
)
for k in model_keys
]
Expand Down Expand Up @@ -170,21 +174,26 @@ class MEIMethodMixin:
"postprocessing",
)

def add_method(self, method_fn: str, method_config: Mapping, comment: str = "") -> None:
def add_method(self, method_fn: str, method_config: Mapping, comment: str = "", **kwargs) -> None:
self.insert1(
dict(
method_fn=method_fn,
method_hash=make_hash(method_config),
method_config=method_config,
method_comment=comment,
)
),
**kwargs,
)

def generate_mei(self, dataloaders: Dataloaders, model: Module, key: Key, seed: int) -> Dict[str, Any]:
method_fn, method_config = (self & key).fetch1("method_fn", "method_config")
method_fn = self.import_func(method_fn)
self.insert_key_in_ops(method_config=method_config, key=key)
mei, score, output = method_fn(dataloaders, model, method_config, seed)

train_input_shape = next(iter(next(iter(dataloaders["train"].values())))).images.shape
input_shape = (1, 4) + train_input_shape[2:]
Comment thread
Mvystrcilova marked this conversation as resolved.
Outdated
mei, score, output = method_fn(model=model, config=method_config, seed=seed, shape=input_shape)

return dict(key, mei=mei, score=score, output=output)

def generate_ringmei(
Expand All @@ -209,7 +218,7 @@ class MEISeedMixin:
"""


class MEITemplateMixin:
class MEIMixin:
definition = """
# contains maximally exciting images (MEIs)
-> self.method_table
Expand Down Expand Up @@ -260,3 +269,26 @@ def _save_to_disk(self, mei_entity: Dict[str, Any], temp_dir: str, name: str) ->
@staticmethod
def _create_random_filename(length: Optional[int] = 32) -> str:
return "".join(choice(ascii_letters) for _ in range(length))


class CSRFV1SelectorTemplateMixin:
dataset_table = None

definition = """
# contains information that can be used to map a neuron's id to its corresponding integer position in the output of
# the model.
-> self.dataset_table
neuron_id : smallint unsigned # unique neuron identifier
---
neuron_position : smallint unsigned # integer position of the neuron in the model's output
session_id : varchar(13) # unique session identifier
"""

def make(self, key):
dataset_config = (self.dataset_table & key).fetch1("dataset_config")
mappings = integration.get_mappings(dataset_config, key)
self.insert(mappings)

def get_output_selected_model(self, model, key):
neuron_pos, session_id = (self & key).fetch1("neuron_position", "session_id")
return integration.get_output_selected_model(neuron_pos, session_id, model)
9 changes: 8 additions & 1 deletion mei/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def __init__(
self.constraint = constraint if (isinstance(constraint, Iterable) or constraint is None) else [constraint]
self.forward_kwargs = forward_kwargs if forward_kwargs else dict()
self.target_fn = target_fn
if hasattr(model, "core") and hasattr(model, "readout"):
self.core = model.core
self.readout = model.readout

def __call__(self, x: Tensor, *args, **kwargs) -> Tensor:
"""Computes the constrained output of the model.
Expand All @@ -87,7 +90,11 @@ def __call__(self, x: Tensor, *args, **kwargs) -> Tensor:
Returns:
A tensor representing the constrained output of the model.
"""
output = self.model(x, *args, **self.forward_kwargs, **kwargs)
duplicate_keys = self.forward_kwargs.keys() & kwargs.keys()
reduced_forward_kwargs = self.forward_kwargs
for key in duplicate_keys:
reduced_forward_kwargs.pop(key)
output = self.model(x, *args, **reduced_forward_kwargs, **kwargs)
return (
self.target_fn(output)
if self.constraint is None or len(self.constraint) == 0
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ datajoint<=0.12.9
nnfabrik
torch
scipy>=1.7.0,<=1.12.0
numpy>=1.22.0,<=1.26.4
numpy>=1.21.0,<=1.26.4
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was it previously less loose, do we know?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not aware that we would know. @PPierzc do you have a hunch?

torchvision
pytest