Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/openlifu/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,22 @@
from .subject import Subject
from .user import User

OnConflictOpts = Enum('OnConflictOpts', ['ERROR', 'OVERWRITE', 'SKIP'])

class OnConflictOpts(str, Enum):
ERROR = "error"
OVERWRITE = "overwrite"
SKIP = "skip"


def _normalize_on_conflict(on_conflict: OnConflictOpts | str) -> OnConflictOpts:
if isinstance(on_conflict, OnConflictOpts):
return on_conflict
if isinstance(on_conflict, str):
try:
return OnConflictOpts(on_conflict.lower())
except ValueError as exc:
raise ValueError("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'.") from exc
raise ValueError("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'.")


class Database:
Expand All @@ -39,6 +54,7 @@ def __init__(self, path: str | None = None):
self.logger = logging.getLogger(__name__)

def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samueljwu agreed with this: updating the signature needed for valid type annotations in the case that someone wants to pass a str. Even though OnConflictOpts now inherits from str, that doesn't address the type annotation: it allows you to pass an OnConflictOpts where a str is expected, but not vice versa.

grid_hashes = self.get_gridweight_hashes(transducer_id)
if grid_hash in grid_hashes:
if on_conflict == OnConflictOpts.ERROR:
Expand All @@ -55,6 +71,7 @@ def write_gridweights(self, transducer_id: str, grid_hash: str, grid_weights, on
self.logger.info(f"Added grid weights with hash {grid_hash} for transducer {transducer_id} to the database.")

def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None:
on_conflict = _normalize_on_conflict(on_conflict)
# Check if the sonication user ID already exists in the database
user_id = user.id
user_ids = self.get_user_ids()
Expand All @@ -80,6 +97,7 @@ def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ER
self.logger.info(f"Added User with ID {user_id} to the database.")

def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None:
on_conflict = _normalize_on_conflict(on_conflict)
# Check if the user ID already exists in the database
Comment on lines +99 to 101
user_ids = self.get_user_ids()

Expand All @@ -104,6 +122,7 @@ def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts
self.logger.info(f"Removed Sonication User with ID {user_id} from the database.")

def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
# Check if the sonication protocol ID already exists in the database
protocol_id = protocol.id
protocol_ids = self.get_protocol_ids()
Expand Down Expand Up @@ -131,6 +150,7 @@ def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnCon
self.logger.info(f"Added Sonication Protocol with ID {protocol_id} to the database.")

def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
# Check if the sonication protocol ID already exists in the database
Comment on lines +152 to 154
protocol_ids = self.get_protocol_ids()

Expand All @@ -155,6 +175,7 @@ def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConf
self.logger.info(f"Removed Sonication Protocol with ID {protocol_id} from the database.")

def write_session(self, subject:Subject, session:Session, on_conflict=OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
# Generate session ID
session_id = session.id

Expand Down Expand Up @@ -227,6 +248,7 @@ def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConfli
session_id: ID of the session to delete
on_conflict: Behavior when session doesn't exist ('error' or 'skip')
"""
on_conflict = _normalize_on_conflict(on_conflict)
Comment on lines 249 to +251
# Check if the session ID exists in the database for this subject
session_ids = self.get_session_ids(subject_id)

Expand Down Expand Up @@ -260,6 +282,7 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o
Returns:
None: This method does not return a value
"""
on_conflict = _normalize_on_conflict(on_conflict)
# Check whether the run already exist in the session
run_ids = self.get_run_ids(session.subject_id, session.id)

Expand Down Expand Up @@ -291,6 +314,7 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o
protocol.to_file(run_metadata_filepath.parent / f'{run.id}_protocol_snapshot.json')

def write_subject(self, subject, on_conflict=OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
subject_id = subject.id
subject_ids = self.get_subject_ids()

Expand Down Expand Up @@ -336,6 +360,7 @@ def write_transducer(
Returns:
None: This method does not return a value
"""
on_conflict = _normalize_on_conflict(on_conflict)
transducer_id = transducer.id
transducer_ids = self.get_transducer_ids()

Expand Down Expand Up @@ -386,6 +411,7 @@ def write_transducer(
self.logger.info(f"Added transducer with ID {transducer_id} to the database.")

def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, on_conflict=OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
if not Path(volume_data_filepath).exists():
raise ValueError(f'Volume data filepath does not exist: {volume_data_filepath}')

Expand Down Expand Up @@ -445,6 +471,7 @@ def write_photocollection(self, subject_id, session_id, reference_number: str, p
""" Writes a photocollection to database and copies the associated
photos into the database, specified by the subject, session, and
reference_number of the photocollection."""
on_conflict = _normalize_on_conflict(on_conflict)

photocollection_dir = Path(self.get_session_dir(subject_id, session_id)) / 'photocollections' / reference_number

Expand Down Expand Up @@ -481,6 +508,7 @@ def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_da
""" Writes a photoscan object to database and copies the associated data filepaths into the database.
While the model data file is required, the associated texture and .mtl files are optional and can be provided if present.
When a photoscan that is already present in the database is being re-written,the associated data files do not need to be specified """
on_conflict = _normalize_on_conflict(on_conflict)

photoscan_ids = self.get_photoscan_ids(subject_id, session_id)
if photoscan.id in photoscan_ids:
Expand Down Expand Up @@ -539,6 +567,7 @@ def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_da
self.logger.info(f"Added photoscan with ID {photoscan.id} for session {session_id} to the database.")

def write_solution(self, session:Session, solution:Solution, on_conflict: OnConflictOpts=OnConflictOpts.ERROR):
on_conflict = _normalize_on_conflict(on_conflict)
solution_ids = self.get_solution_ids(session.subject_id, session.id)

if solution.id in solution_ids:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,45 @@ def test_write_protocol(example_database: Database):
reloaded_protocol = example_database.load_protocol(protocol.id)
assert reloaded_protocol.name == "new_name"


def test_on_conflict_accepts_enum_and_strings(example_database: Database):
assert OnConflictOpts.OVERWRITE.value == "overwrite"

protocol = Protocol(name="bleh", id="a_protocol_with_string_conflict_option")
example_database.write_protocol(protocol)

with pytest.raises(ValueError, match="already exists"):
example_database.write_protocol(protocol, on_conflict="ERROR")
Comment on lines +97 to +104

protocol.name = "skipped_name"
example_database.write_protocol(protocol, on_conflict="SkIp")
reloaded_protocol = example_database.load_protocol(protocol.id)
assert reloaded_protocol.name == "bleh"

protocol.name = "overwritten_name"
example_database.write_protocol(protocol, on_conflict="OVERWRITE")
reloaded_protocol = example_database.load_protocol(protocol.id)
assert reloaded_protocol.name == "overwritten_name"

example_database.delete_protocol("non_existent_protocol", on_conflict="skip")

user = User(name="initial_name", id="a_user_with_string_conflict_option")
example_database.write_user(user)

user.name = "skipped_name"
example_database.write_user(user, on_conflict="skip")
reloaded_user = example_database.load_user(user.id)
assert reloaded_user.name == "initial_name"

user.name = "overwritten_name"
example_database.write_user(user, on_conflict="overwrite")
reloaded_user = example_database.load_user(user.id)
assert reloaded_user.name == "overwritten_name"

with pytest.raises(ValueError, match="Invalid 'on_conflict' option"):
example_database.write_protocol(protocol, on_conflict="replace")


def test_delete_protocol(example_database: Database):
# Write a protocol
protocol = Protocol(name="bleh", id="a_protocol_to_be_deleted")
Expand Down
Loading