From 4ea198efb79c6b192ccb9cdba17643f483b20a3e Mon Sep 17 00:00:00 2001 From: samueljwu <56311527+samueljwu@users.noreply.github.com> Date: Mon, 25 May 2026 08:48:33 +0800 Subject: [PATCH 1/2] Enhance db.OnConflictOpts #346 --- src/openlifu/db/database.py | 31 ++++++++++++++++++++++++++++- tests/test_database.py | 39 +++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/openlifu/db/database.py b/src/openlifu/db/database.py index 952d4731..b1d89058 100644 --- a/src/openlifu/db/database.py +++ b/src/openlifu/db/database.py @@ -28,7 +28,22 @@ from .subject import Subject from .user import User -OnConflictOpts = Enum('OnConflictOpts', ['ERROR', 'OVERWRITE', 'SKIP']) + +class OnConflictOpts(str, Enum): + ERROR = "error" + OVERWRITE = "overwrite" + SKIP = "skip" + + +def _normalize_on_conflict(on_conflict: OnConflictOpts | str) -> OnConflictOpts: + if isinstance(on_conflict, OnConflictOpts): + return on_conflict + if isinstance(on_conflict, str): + try: + return OnConflictOpts(on_conflict.lower()) + except ValueError as exc: + raise ValueError("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'.") from exc + raise ValueError("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'.") class Database: @@ -39,6 +54,7 @@ def __init__(self, path: str | None = None): self.logger = logging.getLogger(__name__) def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on_conflict: OnConflictOpts = OnConflictOpts.ERROR): + on_conflict = _normalize_on_conflict(on_conflict) grid_hashes = self.get_gridweight_hashes(transducer_id) if grid_hash in grid_hashes: if on_conflict == OnConflictOpts.ERROR: @@ -55,6 +71,7 @@ def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on self.logger.info(f"Added grid weights with hash {grid_hash} for transducer {transducer_id} to the database.") def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None: + on_conflict = _normalize_on_conflict(on_conflict) # Check if the sonication user ID already exists in the database user_id = user.id user_ids = self.get_user_ids() @@ -80,6 +97,7 @@ def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ER self.logger.info(f"Added User with ID {user_id} to the database.") def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None: + on_conflict = _normalize_on_conflict(on_conflict) # Check if the user ID already exists in the database user_ids = self.get_user_ids() @@ -104,6 +122,7 @@ def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts self.logger.info(f"Removed Sonication User with ID {user_id} from the database.") def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnConflictOpts.ERROR): + on_conflict = _normalize_on_conflict(on_conflict) # Check if the sonication protocol ID already exists in the database protocol_id = protocol.id protocol_ids = self.get_protocol_ids() @@ -131,6 +150,7 @@ def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnCon self.logger.info(f"Added Sonication Protocol with ID {protocol_id} to the database.") def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR): + on_conflict = _normalize_on_conflict(on_conflict) # Check if the sonication protocol ID already exists in the database protocol_ids = self.get_protocol_ids() @@ -155,6 +175,7 @@ def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConf self.logger.info(f"Removed Sonication Protocol with ID {protocol_id} from the database.") def write_session(self, subject:Subject, session:Session, on_conflict=OnConflictOpts.ERROR): + on_conflict = _normalize_on_conflict(on_conflict) # Generate session ID session_id = session.id @@ -227,6 +248,7 @@ def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConfli session_id: ID of the session to delete on_conflict: Behavior when session doesn't exist ('error' or 'skip') """ + on_conflict = _normalize_on_conflict(on_conflict) # Check if the session ID exists in the database for this subject session_ids = self.get_session_ids(subject_id) @@ -260,6 +282,7 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o Returns: None: This method does not return a value """ + on_conflict = _normalize_on_conflict(on_conflict) # Check whether the run already exist in the session run_ids = self.get_run_ids(session.subject_id, session.id) @@ -291,6 +314,7 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o protocol.to_file(run_metadata_filepath.parent / f'{run.id}_protocol_snapshot.json') def write_subject(self, subject, on_conflict=OnConflictOpts.ERROR): + on_conflict = _normalize_on_conflict(on_conflict) subject_id = subject.id subject_ids = self.get_subject_ids() @@ -336,6 +360,7 @@ def write_transducer( Returns: None: This method does not return a value """ + on_conflict = _normalize_on_conflict(on_conflict) transducer_id = transducer.id transducer_ids = self.get_transducer_ids() @@ -386,6 +411,7 @@ def write_transducer( self.logger.info(f"Added transducer with ID {transducer_id} to the database.") def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, on_conflict=OnConflictOpts.ERROR): + on_conflict = _normalize_on_conflict(on_conflict) if not Path(volume_data_filepath).exists(): raise ValueError(f'Volume data filepath does not exist: {volume_data_filepath}') @@ -445,6 +471,7 @@ def write_photocollection(self, subject_id, session_id, reference_number: str, p """ Writes a photocollection to database and copies the associated photos into the database, specified by the subject, session, and reference_number of the photocollection.""" + on_conflict = _normalize_on_conflict(on_conflict) photocollection_dir = Path(self.get_session_dir(subject_id, session_id)) / 'photocollections' / reference_number @@ -481,6 +508,7 @@ def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_da """ Writes a photoscan object to database and copies the associated data filepaths into the database. While the model data file is required, the associated texture and .mtl files are optional and can be provided if present. When a photoscan that is already present in the database is being re-written,the associated data files do not need to be specified """ + on_conflict = _normalize_on_conflict(on_conflict) photoscan_ids = self.get_photoscan_ids(subject_id, session_id) if photoscan.id in photoscan_ids: @@ -539,6 +567,7 @@ def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_da self.logger.info(f"Added photoscan with ID {photoscan.id} for session {session_id} to the database.") def write_solution(self, session:Session, solution:Solution, on_conflict: OnConflictOpts=OnConflictOpts.ERROR): + on_conflict = _normalize_on_conflict(on_conflict) solution_ids = self.get_solution_ids(session.subject_id, session.id) if solution.id in solution_ids: diff --git a/tests/test_database.py b/tests/test_database.py index eb478c3e..2b89abca 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -93,6 +93,45 @@ def test_write_protocol(example_database: Database): reloaded_protocol = example_database.load_protocol(protocol.id) assert reloaded_protocol.name == "new_name" + +def test_on_conflict_accepts_enum_and_strings(example_database: Database): + assert OnConflictOpts.OVERWRITE.value == "overwrite" + + protocol = Protocol(name="bleh", id="a_protocol_with_string_conflict_option") + example_database.write_protocol(protocol) + + with pytest.raises(ValueError, match="already exists"): + example_database.write_protocol(protocol, on_conflict="ERROR") + + protocol.name = "skipped_name" + example_database.write_protocol(protocol, on_conflict="SkIp") + reloaded_protocol = example_database.load_protocol(protocol.id) + assert reloaded_protocol.name == "bleh" + + protocol.name = "overwritten_name" + example_database.write_protocol(protocol, on_conflict="OVERWRITE") + reloaded_protocol = example_database.load_protocol(protocol.id) + assert reloaded_protocol.name == "overwritten_name" + + example_database.delete_protocol("non_existent_protocol", on_conflict="skip") + + user = User(name="initial_name", id="a_user_with_string_conflict_option") + example_database.write_user(user) + + user.name = "skipped_name" + example_database.write_user(user, on_conflict="skip") + reloaded_user = example_database.load_user(user.id) + assert reloaded_user.name == "initial_name" + + user.name = "overwritten_name" + example_database.write_user(user, on_conflict="overwrite") + reloaded_user = example_database.load_user(user.id) + assert reloaded_user.name == "overwritten_name" + + with pytest.raises(ValueError, match="Invalid 'on_conflict' option"): + example_database.write_protocol(protocol, on_conflict="replace") + + def test_delete_protocol(example_database: Database): # Write a protocol protocol = Protocol(name="bleh", id="a_protocol_to_be_deleted") From 00c54af8660c94137641a3a06eb15f72f76be815 Mon Sep 17 00:00:00 2001 From: samueljwu <56311527+samueljwu@users.noreply.github.com> Date: Mon, 25 May 2026 22:48:25 +0800 Subject: [PATCH 2/2] Update on_conflict type annotations #346 --- src/openlifu/db/database.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/openlifu/db/database.py b/src/openlifu/db/database.py index b1d89058..ec647fb3 100644 --- a/src/openlifu/db/database.py +++ b/src/openlifu/db/database.py @@ -53,7 +53,7 @@ def __init__(self, path: str | None = None): self.path = os.path.normpath(path) self.logger = logging.getLogger(__name__) - def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on_conflict: OnConflictOpts = OnConflictOpts.ERROR): + def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): on_conflict = _normalize_on_conflict(on_conflict) grid_hashes = self.get_gridweight_hashes(transducer_id) if grid_hash in grid_hashes: @@ -70,7 +70,7 @@ def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on f.create_dataset("grid_weights", data=grid_weights) self.logger.info(f"Added grid weights with hash {grid_hash} for transducer {transducer_id} to the database.") - def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None: + def write_user(self, user: User, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR) -> None: on_conflict = _normalize_on_conflict(on_conflict) # Check if the sonication user ID already exists in the database user_id = user.id @@ -96,7 +96,7 @@ def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ER self.logger.info(f"Added User with ID {user_id} to the database.") - def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None: + def delete_user(self, user_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR) -> None: on_conflict = _normalize_on_conflict(on_conflict) # Check if the user ID already exists in the database user_ids = self.get_user_ids() @@ -121,7 +121,7 @@ def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts self.logger.info(f"Removed Sonication User with ID {user_id} from the database.") - def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnConflictOpts.ERROR): + def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): on_conflict = _normalize_on_conflict(on_conflict) # Check if the sonication protocol ID already exists in the database protocol_id = protocol.id @@ -149,7 +149,7 @@ def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnCon self.logger.info(f"Added Sonication Protocol with ID {protocol_id} to the database.") - def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR): + def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): on_conflict = _normalize_on_conflict(on_conflict) # Check if the sonication protocol ID already exists in the database protocol_ids = self.get_protocol_ids() @@ -174,7 +174,7 @@ def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConf self.logger.info(f"Removed Sonication Protocol with ID {protocol_id} from the database.") - def write_session(self, subject:Subject, session:Session, on_conflict=OnConflictOpts.ERROR): + def write_session(self, subject:Subject, session:Session, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): on_conflict = _normalize_on_conflict(on_conflict) # Generate session ID session_id = session.id @@ -240,7 +240,7 @@ def write_session(self, subject:Subject, session:Session, on_conflict=OnConflict self.logger.info(f"Added session with ID {session_id} for subject {subject.id} to the database.") - def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR): + def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): """Delete a session and its associated data from the database. Args: @@ -271,7 +271,7 @@ def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConfli self.logger.info(f"Removed session with ID {session_id} from the database.") - def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, on_conflict=OnConflictOpts.ERROR): + def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): """Write a run with a snapshot of session and a snapshot of protocol if provided Args: @@ -313,7 +313,7 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o # Write snapshot of the protocol protocol.to_file(run_metadata_filepath.parent / f'{run.id}_protocol_snapshot.json') - def write_subject(self, subject, on_conflict=OnConflictOpts.ERROR): + def write_subject(self, subject, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): on_conflict = _normalize_on_conflict(on_conflict) subject_id = subject.id subject_ids = self.get_subject_ids() @@ -349,7 +349,7 @@ def write_transducer( transducer, registration_surface_model_filepath: PathLike | None = None, transducer_body_model_filepath: PathLike | None = None, - on_conflict: OnConflictOpts=OnConflictOpts.ERROR, + on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR, ) -> None: """ Writes a transducer object to database and copies the affiliated transducer data files to the database if provided. When a transducer that is already present in the database is being re-written, the associated model data files do not need to be provided if they have previously been added to the database. @@ -410,7 +410,7 @@ def write_transducer( self.logger.info(f"Added transducer with ID {transducer_id} to the database.") - def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, on_conflict=OnConflictOpts.ERROR): + def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): on_conflict = _normalize_on_conflict(on_conflict) if not Path(volume_data_filepath).exists(): raise ValueError(f'Volume data filepath does not exist: {volume_data_filepath}') @@ -467,7 +467,7 @@ def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, if temp_nifti_path is not None and temp_nifti_path.exists(): temp_nifti_path.unlink() - def write_photocollection(self, subject_id, session_id, reference_number: str, photo_paths: List[PathLike], on_conflict=OnConflictOpts.ERROR): + def write_photocollection(self, subject_id, session_id, reference_number: str, photo_paths: List[PathLike], on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): """ Writes a photocollection to database and copies the associated photos into the database, specified by the subject, session, and reference_number of the photocollection.""" @@ -504,7 +504,7 @@ def write_photocollection(self, subject_id, session_id, reference_number: str, p self.logger.info(f"Added photocollection with reference number {reference_number} for session {session_id} to the database.") - def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_data_filepath: str | None = None, texture_data_filepath: str | None = None, mtl_data_filepath: str | None = None, on_conflict=OnConflictOpts.ERROR): + def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_data_filepath: str | None = None, texture_data_filepath: str | None = None, mtl_data_filepath: str | None = None, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): """ Writes a photoscan object to database and copies the associated data filepaths into the database. While the model data file is required, the associated texture and .mtl files are optional and can be provided if present. When a photoscan that is already present in the database is being re-written,the associated data files do not need to be specified """ @@ -566,7 +566,7 @@ def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_da self.logger.info(f"Added photoscan with ID {photoscan.id} for session {session_id} to the database.") - def write_solution(self, session:Session, solution:Solution, on_conflict: OnConflictOpts=OnConflictOpts.ERROR): + def write_solution(self, session:Session, solution:Solution, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR): on_conflict = _normalize_on_conflict(on_conflict) solution_ids = self.get_solution_ids(session.subject_id, session.id)