Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions src/openlifu/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,22 @@
from .subject import Subject
from .user import User

OnConflictOpts = Enum('OnConflictOpts', ['ERROR', 'OVERWRITE', 'SKIP'])

class OnConflictOpts(str, Enum):
ERROR = "error"
OVERWRITE = "overwrite"
SKIP = "skip"


def _normalize_on_conflict(on_conflict: OnConflictOpts | str) -> OnConflictOpts:
if isinstance(on_conflict, OnConflictOpts):
return on_conflict
if isinstance(on_conflict, str):
try:
return OnConflictOpts(on_conflict.lower())
except ValueError as exc:
raise ValueError("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'.") from exc
raise ValueError("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'.")


class Database:
Expand All @@ -38,7 +53,8 @@ def __init__(self, path: str | None = None):
self.path = os.path.normpath(path)
self.logger = logging.getLogger(__name__)

def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
grid_hashes = self.get_gridweight_hashes(transducer_id)
if grid_hash in grid_hashes:
if on_conflict == OnConflictOpts.ERROR:
Expand All @@ -54,7 +70,8 @@ def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on
f.create_dataset("grid_weights", data=grid_weights)
self.logger.info(f"Added grid weights with hash {grid_hash} for transducer {transducer_id} to the database.")

def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None:
def write_user(self, user: User, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR) -> None:
on_conflict = _normalize_on_conflict(on_conflict)
# Check if the sonication user ID already exists in the database
user_id = user.id
user_ids = self.get_user_ids()
Expand All @@ -79,7 +96,8 @@ def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ER

self.logger.info(f"Added User with ID {user_id} to the database.")

def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None:
def delete_user(self, user_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR) -> None:
on_conflict = _normalize_on_conflict(on_conflict)
# Check if the user ID already exists in the database
Comment on lines +99 to 101
user_ids = self.get_user_ids()

Expand All @@ -103,7 +121,8 @@ def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts

self.logger.info(f"Removed Sonication User with ID {user_id} from the database.")

def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
# Check if the sonication protocol ID already exists in the database
protocol_id = protocol.id
protocol_ids = self.get_protocol_ids()
Expand All @@ -130,7 +149,8 @@ def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnCon

self.logger.info(f"Added Sonication Protocol with ID {protocol_id} to the database.")

def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
# Check if the sonication protocol ID already exists in the database
Comment on lines +152 to 154
protocol_ids = self.get_protocol_ids()

Expand All @@ -154,7 +174,8 @@ def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConf

self.logger.info(f"Removed Sonication Protocol with ID {protocol_id} from the database.")

def write_session(self, subject:Subject, session:Session, on_conflict=OnConflictOpts.ERROR):
def write_session(self, subject:Subject, session:Session, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
# Generate session ID
session_id = session.id

Expand Down Expand Up @@ -219,14 +240,15 @@ def write_session(self, subject:Subject, session:Session, on_conflict=OnConflict

self.logger.info(f"Added session with ID {session_id} for subject {subject.id} to the database.")

def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
"""Delete a session and its associated data from the database.

Args:
subject_id: ID of the subject the session belongs to
session_id: ID of the session to delete
on_conflict: Behavior when session doesn't exist ('error' or 'skip')
"""
on_conflict = _normalize_on_conflict(on_conflict)
Comment on lines 249 to +251
# Check if the session ID exists in the database for this subject
session_ids = self.get_session_ids(subject_id)

Expand All @@ -249,7 +271,7 @@ def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConfli

self.logger.info(f"Removed session with ID {session_id} from the database.")

def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, on_conflict=OnConflictOpts.ERROR):
def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
"""Write a run with a snapshot of session and a snapshot of protocol if provided

Args:
Expand All @@ -260,6 +282,7 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o
Returns:
None: This method does not return a value
"""
on_conflict = _normalize_on_conflict(on_conflict)
# Check whether the run already exist in the session
run_ids = self.get_run_ids(session.subject_id, session.id)

Expand Down Expand Up @@ -290,7 +313,8 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o
# Write snapshot of the protocol
protocol.to_file(run_metadata_filepath.parent / f'{run.id}_protocol_snapshot.json')

def write_subject(self, subject, on_conflict=OnConflictOpts.ERROR):
def write_subject(self, subject, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
subject_id = subject.id
subject_ids = self.get_subject_ids()

Expand Down Expand Up @@ -325,7 +349,7 @@ def write_transducer(
transducer,
registration_surface_model_filepath: PathLike | None = None,
transducer_body_model_filepath: PathLike | None = None,
on_conflict: OnConflictOpts=OnConflictOpts.ERROR,
on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR,
) -> None:
""" Writes a transducer object to database and copies the affiliated transducer data files to the database if provided. When a transducer that is already present in the database is being re-written,
the associated model data files do not need to be provided if they have previously been added to the database.
Expand All @@ -336,6 +360,7 @@ def write_transducer(
Returns:
None: This method does not return a value
"""
on_conflict = _normalize_on_conflict(on_conflict)
transducer_id = transducer.id
transducer_ids = self.get_transducer_ids()

Expand Down Expand Up @@ -385,7 +410,8 @@ def write_transducer(

self.logger.info(f"Added transducer with ID {transducer_id} to the database.")

def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, on_conflict=OnConflictOpts.ERROR):
def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
if not Path(volume_data_filepath).exists():
raise ValueError(f'Volume data filepath does not exist: {volume_data_filepath}')

Expand Down Expand Up @@ -441,10 +467,11 @@ def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath,
if temp_nifti_path is not None and temp_nifti_path.exists():
temp_nifti_path.unlink()

def write_photocollection(self, subject_id, session_id, reference_number: str, photo_paths: List[PathLike], on_conflict=OnConflictOpts.ERROR):
def write_photocollection(self, subject_id, session_id, reference_number: str, photo_paths: List[PathLike], on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
""" Writes a photocollection to database and copies the associated
photos into the database, specified by the subject, session, and
reference_number of the photocollection."""
on_conflict = _normalize_on_conflict(on_conflict)

photocollection_dir = Path(self.get_session_dir(subject_id, session_id)) / 'photocollections' / reference_number

Expand Down Expand Up @@ -477,10 +504,11 @@ def write_photocollection(self, subject_id, session_id, reference_number: str, p

self.logger.info(f"Added photocollection with reference number {reference_number} for session {session_id} to the database.")

def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_data_filepath: str | None = None, texture_data_filepath: str | None = None, mtl_data_filepath: str | None = None, on_conflict=OnConflictOpts.ERROR):
def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_data_filepath: str | None = None, texture_data_filepath: str | None = None, mtl_data_filepath: str | None = None, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
""" Writes a photoscan object to database and copies the associated data filepaths into the database.
While the model data file is required, the associated texture and .mtl files are optional and can be provided if present.
When a photoscan that is already present in the database is being re-written,the associated data files do not need to be specified """
on_conflict = _normalize_on_conflict(on_conflict)

photoscan_ids = self.get_photoscan_ids(subject_id, session_id)
if photoscan.id in photoscan_ids:
Expand Down Expand Up @@ -538,7 +566,8 @@ def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_da

self.logger.info(f"Added photoscan with ID {photoscan.id} for session {session_id} to the database.")

def write_solution(self, session:Session, solution:Solution, on_conflict: OnConflictOpts=OnConflictOpts.ERROR):
def write_solution(self, session:Session, solution:Solution, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
solution_ids = self.get_solution_ids(session.subject_id, session.id)

if solution.id in solution_ids:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,45 @@ def test_write_protocol(example_database: Database):
reloaded_protocol = example_database.load_protocol(protocol.id)
assert reloaded_protocol.name == "new_name"


def test_on_conflict_accepts_enum_and_strings(example_database: Database):
assert OnConflictOpts.OVERWRITE.value == "overwrite"

protocol = Protocol(name="bleh", id="a_protocol_with_string_conflict_option")
example_database.write_protocol(protocol)

with pytest.raises(ValueError, match="already exists"):
example_database.write_protocol(protocol, on_conflict="ERROR")
Comment on lines +97 to +104

protocol.name = "skipped_name"
example_database.write_protocol(protocol, on_conflict="SkIp")
reloaded_protocol = example_database.load_protocol(protocol.id)
assert reloaded_protocol.name == "bleh"

protocol.name = "overwritten_name"
example_database.write_protocol(protocol, on_conflict="OVERWRITE")
reloaded_protocol = example_database.load_protocol(protocol.id)
assert reloaded_protocol.name == "overwritten_name"

example_database.delete_protocol("non_existent_protocol", on_conflict="skip")

user = User(name="initial_name", id="a_user_with_string_conflict_option")
example_database.write_user(user)

user.name = "skipped_name"
example_database.write_user(user, on_conflict="skip")
reloaded_user = example_database.load_user(user.id)
assert reloaded_user.name == "initial_name"

user.name = "overwritten_name"
example_database.write_user(user, on_conflict="overwrite")
reloaded_user = example_database.load_user(user.id)
assert reloaded_user.name == "overwritten_name"

with pytest.raises(ValueError, match="Invalid 'on_conflict' option"):
example_database.write_protocol(protocol, on_conflict="replace")


def test_delete_protocol(example_database: Database):
# Write a protocol
protocol = Protocol(name="bleh", id="a_protocol_to_be_deleted")
Expand Down
Loading