diff --git a/medcat-trainer/webapp/.dockerignore b/medcat-trainer/webapp/.dockerignore
index 62c88576a..8ae4a023e 100644
--- a/medcat-trainer/webapp/.dockerignore
+++ b/medcat-trainer/webapp/.dockerignore
@@ -9,4 +9,15 @@ __pycache__/
*.md
.pytest_cache
.mypy_cache
-node_modules/
+
+# Frontend — rebuilt in the frontend-builder stage
+frontend/node_modules/
+frontend/dist/
+frontend/coverage/
+
+# Backend tests — not needed in the production image
+api/**/tests/
+
+# User-uploaded models and datasets — mounted at /home/api/media in compose
+api/media/*
+!api/media/.keep
diff --git a/medcat-trainer/webapp/Dockerfile b/medcat-trainer/webapp/Dockerfile
index 47c3c22af..831012915 100644
--- a/medcat-trainer/webapp/Dockerfile
+++ b/medcat-trainer/webapp/Dockerfile
@@ -1,57 +1,69 @@
-FROM python:3.12
+# -----------------------------------------------------------------------------
+# Stage 1: Build frontend assets (Node toolchain discarded after this stage)
+# -----------------------------------------------------------------------------
+FROM node:20-bookworm-slim AS frontend-builder
-# Update and upgrade everything
-RUN apt-get update -y && \
- apt-get upgrade -y
-
-# install vim as its annoying not to have an editor
-RUN apt-get install -y vim
-
-# install supervisor
-RUN apt-get install -y supervisor
+WORKDIR /build
+COPY frontend/package.json frontend/package-lock.json ./
+RUN --mount=type=cache,target=/root/.npm \
+ npm ci --prefer-offline --no-audit --no-fund
-# install gettext for envsubst (used to generate runtime config)
-RUN apt-get install -y gettext
+COPY frontend/ ./
+# CI test-frontend already runs type-check; build-only avoids a second vue-tsc pass.
+# No sourcemaps in the image — saves ~25MB+ and build I/O.
+RUN NODE_OPTIONS=--max-old-space-size=4096 \
+ npm run build-only -- --sourcemap false
-# install cron - and remove any default tabs
-RUN apt-get install -y cron && which cron && rm -rf /etc/cron.*/*
+# -----------------------------------------------------------------------------
+# Stage 2: Install Python deps (Rust/build tools discarded after this stage)
+# -----------------------------------------------------------------------------
+FROM python:3.12-bookworm AS python-builder
-# Get node and npm
-RUN apt install -y nodejs && apt install -y npm
+RUN apt-get update -y && \
+ apt-get install -y --no-install-recommends build-essential curl && \
+ rm -rf /var/lib/apt/lists/*
-# Install Rust - for tokenziers dep in medcat.
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"
-# Copy dependency files first for better layer caching
-WORKDIR /home/frontend
-COPY frontend/package.json frontend/package-lock.json ./
-RUN npm install
-
-# Install uv and Python dependencies
WORKDIR /home
-COPY pyproject.toml uv.lock* ./
-# Install dependencies using a buildkit cache mount for speed on repeat
+COPY pyproject.toml uv.lock ./
RUN --mount=type=cache,target=/root/.cache/uv \
- pip install uv && \
+ pip install --no-cache-dir uv && \
uv sync --frozen --cache-dir=/root/.cache/uv --no-install-project --extra observability
-# Ensure venv has pip (uv venvs don't include it; spacy download needs it)
RUN uv run python -m ensurepip --upgrade
-# Download spaCy models (only requires spaCy, not application code)
ARG SPACY_MODELS="en_core_web_md"
RUN for SPACY_MODEL in ${SPACY_MODELS}; do uv run python -m spacy download ${SPACY_MODEL}; done
-# Copy rest of project
+# -----------------------------------------------------------------------------
+# Stage 3: Runtime image — no Node, npm, Rust, or frontend devDependencies
+# -----------------------------------------------------------------------------
+FROM python:3.12-bookworm
+
+RUN apt-get update -y && \
+ apt-get install -y --no-install-recommends \
+ vim \
+ supervisor \
+ gettext \
+ cron \
+ && rm -rf /var/lib/apt/lists/* \
+ && rm -rf /etc/cron.*/*
+
+RUN pip install --no-cache-dir uv
+
WORKDIR /home
-COPY ./ .
+COPY pyproject.toml uv.lock ./
+COPY --from=python-builder /home/.venv /home/.venv
+COPY api ./api
+COPY scripts ./scripts
+COPY templates ./templates
+COPY --from=frontend-builder /build/dist ./frontend/dist
-# Build frontend
-WORKDIR /home/frontend
-RUN npm run build
+# MEDIA_ROOT is a runtime volume; ensure the directory exists without baking in local uploads
+RUN mkdir -p /home/api/media
-# copy backup crontab and chmod scripts
RUN chmod u+x /home/scripts/entry.sh && \
chmod u+x /home/scripts/crontab && cp /home/scripts/crontab /etc/crontab && \
chmod a+x /home/scripts/run.sh && \
@@ -59,4 +71,3 @@ RUN chmod u+x /home/scripts/entry.sh && \
chmod a+x /home/scripts/nginx-entrypoint.sh
WORKDIR /home/api/
-
diff --git a/medcat-trainer/webapp/api/api/tests/_helpers.py b/medcat-trainer/webapp/api/api/tests/_helpers.py
new file mode 100644
index 000000000..44103e5ac
--- /dev/null
+++ b/medcat-trainer/webapp/api/api/tests/_helpers.py
@@ -0,0 +1,91 @@
+"""Shared helpers for backend tests.
+
+These utilities make it easier to construct lightweight model fixtures
+without triggering MedCAT model loading or expensive dataset parsing.
+"""
+
+import os
+import tempfile
+from contextlib import contextmanager
+
+import pandas as pd
+
+from django.contrib.auth.models import User
+
+from .. import signals as api_signals
+from ..models import (
+ ConceptDB,
+ Dataset,
+ Document,
+ Entity,
+ ProjectAnnotateEntities,
+ Vocabulary,
+)
+
+
+@contextmanager
+def dataset_signals_disconnected():
+ """Temporarily disconnect Dataset post_save / pre_save signals.
+
+ Useful in unit tests that want to insert a Dataset row without triggering
+ `dataset_from_file` which expects a CSV/XLSX on disk with the right schema.
+ """
+ from django.db.models.signals import post_save, pre_save
+
+ post_save.disconnect(api_signals.save_dataset, sender=Dataset)
+ pre_save.disconnect(api_signals.pre_save_dataset, sender=Dataset)
+ try:
+ yield
+ finally:
+ post_save.connect(api_signals.save_dataset, sender=Dataset)
+ pre_save.connect(api_signals.pre_save_dataset, sender=Dataset)
+
+
+def create_dataset(name='test-dataset', file_name='test-dataset.csv'):
+ """Create a Dataset row without firing the file-parsing signals."""
+ with dataset_signals_disconnected():
+ ds = Dataset.objects.create(name=name, original_file=file_name)
+ return ds
+
+
+def make_csv_file(tmp_dir, rows=None, file_name='dataset.csv'):
+ """Write a small CSV with 'name' and 'text' columns and return its path."""
+ if rows is None:
+ rows = [
+ {'name': 'doc-a', 'text': 'Patient reports chest pain.'},
+ {'name': 'doc-b', 'text': 'No fever or cough.'},
+ ]
+ path = os.path.join(tmp_dir, file_name)
+ pd.DataFrame(rows).to_csv(path, index=False)
+ return path
+
+
+def create_basic_project(name='test-project'):
+ """Create a ProjectAnnotateEntities along with a CDB / Vocab / Dataset."""
+ cdb = ConceptDB(name=f'{name}-cdb', cdb_file=f'{name}-cdb.dat')
+ cdb.save(skip_load=True)
+ vocab = Vocabulary(name=f'{name}-vocab', vocab_file=f'{name}-vocab.dat')
+ vocab.save(skip_load=True)
+
+ ds = create_dataset(name=f'{name}-ds', file_name=f'{name}-ds.csv')
+
+ project = ProjectAnnotateEntities()
+ project.name = name
+ project.dataset = ds
+ project.concept_db = cdb
+ project.vocab = vocab
+ project.cuis = ''
+ project.save()
+ return project
+
+
+def create_document(project, name='doc1', text='hello world'):
+ return Document.objects.create(name=name, text=text, dataset=project.dataset)
+
+
+def create_user(username='testuser', password='pw', **extra):
+ return User.objects.create_user(username=username, password=password, **extra)
+
+
+def create_entity(label='C001'):
+ return Entity.objects.create(label=label)
diff --git a/medcat-trainer/webapp/api/api/tests/test_admin_actions.py b/medcat-trainer/webapp/api/api/tests/test_admin_actions.py
new file mode 100644
index 000000000..54e635646
--- /dev/null
+++ b/medcat-trainer/webapp/api/api/tests/test_admin_actions.py
@@ -0,0 +1,150 @@
+"""Unit tests for api.admin.actions.
+
+These tests focus on retrieve_project_data and the download_* helpers since
+they back the JSON export feature that the upload tests already validate.
+"""
+
+import json
+
+from django.test import TestCase, override_settings
+
+from ..admin.actions import (
+ download_projects_with_text,
+ download_projects_without_text,
+ retrieve_project_data,
+)
+from ..models import (
+ AnnotatedEntity,
+ EntityRelation,
+ MetaAnnotation,
+ MetaTask,
+ MetaTaskValue,
+ ProjectAnnotateEntities,
+ Relation,
+)
+from ._helpers import (
+ create_basic_project,
+ create_document,
+ create_entity,
+ create_user,
+)
+
+
+@override_settings(MEDIA_ROOT='/tmp/mct-tests-admin')
+class RetrieveProjectDataTests(TestCase):
+ def setUp(self):
+ self.user = create_user(username='admin-actions-user')
+ self.project = create_basic_project(name='admin-actions-proj')
+ self.doc = create_document(self.project, name='doc-1', text='hello world')
+ self.entity = create_entity(label='C100')
+ self.entity_b = create_entity(label='C200')
+
+ self.ann_a = AnnotatedEntity.objects.create(
+ user=self.user, project=self.project, document=self.doc, entity=self.entity,
+ value='hello', start_ind=0, end_ind=5, acc=0.9, validated=True, correct=True,
+ )
+ self.ann_b = AnnotatedEntity.objects.create(
+ user=self.user, project=self.project, document=self.doc, entity=self.entity_b,
+ value='world', start_ind=6, end_ind=11, acc=0.95, validated=True, correct=True,
+ )
+
+ self.task = MetaTask.objects.create(name='Presence')
+ self.value = MetaTaskValue.objects.create(name='True')
+ MetaAnnotation.objects.create(
+ annotated_entity=self.ann_a,
+ meta_task=self.task,
+ meta_task_value=self.value,
+ validated=True,
+ )
+
+ self.project.validated_documents.add(self.doc)
+
+ def test_returns_basic_project_metadata(self):
+ out = retrieve_project_data(ProjectAnnotateEntities.objects.filter(id=self.project.id))
+ self.assertEqual(len(out['projects']), 1)
+ proj = out['projects'][0]
+ self.assertEqual(proj['name'], 'admin-actions-proj')
+ self.assertEqual(proj['cuis'], self.project.cuis)
+ self.assertEqual(proj['project_status'], 'A')
+ self.assertEqual(len(proj['documents']), 1)
+
+ def test_includes_annotation_text_and_indices(self):
+ out = retrieve_project_data(ProjectAnnotateEntities.objects.filter(id=self.project.id))
+ doc = out['projects'][0]['documents'][0]
+ cuis = sorted(a['cui'] for a in doc['annotations'])
+ self.assertEqual(cuis, ['C100', 'C200'])
+ # check start/end indices match
+ ann_a = next(a for a in doc['annotations'] if a['cui'] == 'C100')
+ self.assertEqual(ann_a['start'], 0)
+ self.assertEqual(ann_a['end'], 5)
+ self.assertEqual(ann_a['value'], 'hello')
+ self.assertTrue(ann_a['validated'])
+ self.assertTrue(ann_a['correct'])
+
+ def test_includes_meta_annotations(self):
+ out = retrieve_project_data(ProjectAnnotateEntities.objects.filter(id=self.project.id))
+ doc = out['projects'][0]['documents'][0]
+ ann_a = next(a for a in doc['annotations'] if a['cui'] == 'C100')
+ self.assertIn('Presence', ann_a['meta_anns'])
+ self.assertEqual(ann_a['meta_anns']['Presence']['value'], 'True')
+
+ def test_relations_included(self):
+ rel = Relation.objects.create(label='hasFinding')
+ EntityRelation.objects.create(
+ user=self.user,
+ project=self.project,
+ document=self.doc,
+ relation=rel,
+ start_entity=self.ann_a,
+ end_entity=self.ann_b,
+ validated=True,
+ )
+
+ out = retrieve_project_data(ProjectAnnotateEntities.objects.filter(id=self.project.id))
+ rels = out['projects'][0]['documents'][0]['relations']
+ self.assertEqual(len(rels), 1)
+ self.assertEqual(rels[0]['relation'], 'hasFinding')
+ self.assertEqual(rels[0]['start_entity_cui'], 'C100')
+ self.assertEqual(rels[0]['end_entity_cui'], 'C200')
+
+
+@override_settings(MEDIA_ROOT='/tmp/mct-tests-admin')
+class DownloadProjectsTests(TestCase):
+ def setUp(self):
+ self.user = create_user(username='dl-action-user')
+ self.project = create_basic_project(name='dl-action-proj')
+ self.doc = create_document(self.project, name='doc-only', text='annotated text')
+ ent = create_entity(label='C-DL')
+ AnnotatedEntity.objects.create(
+ user=self.user, project=self.project, document=self.doc, entity=ent,
+ value='annotated', start_ind=0, end_ind=9, acc=1.0, validated=True, correct=True,
+ )
+ self.project.validated_documents.add(self.doc)
+
+ def test_download_with_text_includes_document_text(self):
+ resp = download_projects_with_text(
+ ProjectAnnotateEntities.objects.filter(id=self.project.id)
+ )
+ self.assertEqual(resp.status_code, 200)
+ body = json.loads(resp.content)
+ self.assertEqual(body['projects'][0]['documents'][0]['text'], 'annotated text')
+
+ def test_download_without_text_omits_document_text(self):
+ resp = download_projects_without_text(
+ ProjectAnnotateEntities.objects.filter(id=self.project.id),
+ with_doc_name=False,
+ )
+ self.assertEqual(resp.status_code, 200)
+ body = json.loads(resp.content)
+ doc = body['projects'][0]['documents'][0]
+ self.assertNotIn('text', doc)
+
+ def test_download_without_text_with_doc_name_includes_name(self):
+ resp = download_projects_without_text(
+ ProjectAnnotateEntities.objects.filter(id=self.project.id),
+ with_doc_name=True,
+ )
+ body = json.loads(resp.content)
+ doc = body['projects'][0]['documents'][0]
+ self.assertEqual(doc['name'], 'doc-only')
+ self.assertNotIn('text', doc)
diff --git a/medcat-trainer/webapp/api/api/tests/test_data_utils_extras.py b/medcat-trainer/webapp/api/api/tests/test_data_utils_extras.py
new file mode 100644
index 000000000..ae312fcd2
--- /dev/null
+++ b/medcat-trainer/webapp/api/api/tests/test_data_utils_extras.py
@@ -0,0 +1,157 @@
+"""Additional unit tests for api.data_utils functions not already covered by
+test_data_utils.py.
+"""
+
+import os
+import tempfile
+
+import pandas as pd
+from django.test import TestCase, override_settings
+
+from ..data_utils import dataset_from_file, delete_orphan_docs, sanitise_input
+from ..models import Document
+from ._helpers import create_dataset, dataset_signals_disconnected, make_csv_file
+
+
+class SanitiseInputTests(TestCase):
+ def test_replaces_br_with_newline(self):
+ self.assertEqual(sanitise_input('a
b'), 'a\nb')
+
+ def test_replaces_paragraph_with_newline(self):
+ self.assertEqual(sanitise_input('
hi
'), '\nhi\n') + + def test_strips_span_tags_keeping_content(self): + self.assertEqual(sanitise_input('word'), 'word') + + def test_replaces_div_tags_with_newlines(self): + # Opening tag requires attributes in the regex; closing becomes a newline. + self.assertEqual(sanitise_input('Line1
hello
'}], + file_name='data.csv', + ) + ds = self._make_dataset_with_file('dff-2', csv_path) + + dataset_from_file(ds) + + doc = Document.objects.get(dataset=ds, name='d1') + self.assertEqual(doc.text, '\nhello\n') + + @override_settings(MEDIA_ROOT='/') + def test_raises_on_non_unique_names(self): + csv_path = make_csv_file( + self.tmp_dir, + rows=[ + {'name': 'dup', 'text': 'a'}, + {'name': 'dup', 'text': 'b'}, + ], + file_name='data.csv', + ) + ds = self._make_dataset_with_file('dff-3', csv_path) + + with self.assertRaises(Exception) as ctx: + dataset_from_file(ds) + self.assertIn('name column', str(ctx.exception)) + + @override_settings(MEDIA_ROOT='/') + def test_raises_when_exceeding_max_size(self): + old = os.environ.get('MAX_DATASET_SIZE') + os.environ['MAX_DATASET_SIZE'] = '1' + try: + csv_path = make_csv_file( + self.tmp_dir, + rows=[ + {'name': 'a', 'text': 't1'}, + {'name': 'b', 'text': 't2'}, + ], + file_name='data.csv', + ) + ds = self._make_dataset_with_file('dff-4', csv_path) + + with self.assertRaises(Exception) as ctx: + dataset_from_file(ds) + self.assertIn('Max dataset size', str(ctx.exception)) + finally: + if old is None: + os.environ.pop('MAX_DATASET_SIZE', None) + else: + os.environ['MAX_DATASET_SIZE'] = old + + @override_settings(MEDIA_ROOT='/') + def test_rejects_unsupported_extensions(self): + # The original_file path must end with neither .csv nor .xlsx + path = os.path.join(self.tmp_dir, 'bad_ext.tsv') + with open(path, 'w') as f: + f.write('name\ttext\n1\t2\n') + ds = self._make_dataset_with_file('dff-5', path) + + with self.assertRaises(Exception) as ctx: + dataset_from_file(ds) + self.assertIn('.csv or .xlsx', str(ctx.exception)) diff --git a/medcat-trainer/webapp/api/api/tests/test_metrics.py b/medcat-trainer/webapp/api/api/tests/test_metrics.py new file mode 100644 index 000000000..f938d529e --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_metrics.py @@ -0,0 +1,166 @@ +"""Unit tests for api.metrics.ProjectMetrics. + +These tests exercise the pure-Python data-shaping logic in ProjectMetrics +using a synthetic MedCAT export. Code paths requiring an actual CAT model +are exercised by passing cat=None. +""" + +import pandas as pd +from django.test import TestCase + +from ..metrics import ProjectMetrics + + +def _build_export(num_projects=1, num_docs=2, num_anns=2): + """Build a synthetic MedCAT trainer export structure used for metrics tests.""" + projects = [] + next_id = 1 + for p_idx in range(num_projects): + proj = { + 'id': p_idx + 1, + 'name': f'proj-{p_idx + 1}', + 'meta_anno_defs': [ + {'name': 'Presence', 'values': ['True', 'False']}, + ], + 'documents': [], + } + for d_idx in range(num_docs): + doc = { + 'id': next_id, + 'name': f'doc-{next_id}', + 'text': 'some clinical text', + 'annotations': [], + } + next_id += 1 + for a_idx in range(num_anns): + doc['annotations'].append({ + 'id': 1000 + a_idx + d_idx * 10 + p_idx * 100, + 'cui': 'C001' if a_idx % 2 == 0 else 'C002', + 'value': 'token', + 'start': a_idx * 10, + 'end': a_idx * 10 + 5, + 'validated': True, + 'correct': True, + 'deleted': False, + 'alternative': False, + 'killed': False, + 'irrelevant': False, + 'manually_created': False, + 'acc': 1.0, + 'user': f'user{a_idx % 2}', + 'last_modified': '2024-01-01 10:00:00.000000', + 'meta_anns': { + 'Presence': { + 'name': 'Presence', + 'value': 'True' if a_idx % 2 == 0 else 'False', + 'acc': 1.0, + 'validated': True, + } + }, + }) + proj['documents'].append(doc) + projects.append(proj) + return {'projects': projects} + + +class ProjectMetricsInitTests(TestCase): + def test_annotations_extracted_with_project_and_doc_metadata(self): + export = _build_export(num_projects=1, num_docs=1, num_anns=2) + pm = ProjectMetrics(export, cat=None) + self.assertEqual(len(pm.annotations), 2) + ann = pm.annotations[0] + self.assertEqual(ann['project'], 'proj-1') + self.assertEqual(ann['project_id'], 1) + self.assertEqual(ann['document_name'], 'doc-1') + self.assertEqual(ann['document_id'], 1) + self.assertIn('Presence', ann) # meta annotations flattened + + def test_projects2names_and_doc_maps_populated(self): + export = _build_export(num_projects=2, num_docs=2, num_anns=1) + pm = ProjectMetrics(export, cat=None) + self.assertEqual(pm.projects2names[1], 'proj-1') + self.assertEqual(pm.projects2names[2], 'proj-2') + self.assertEqual(len(pm.projects2doc_ids[1]), 2) + self.assertEqual(len(pm.projects2doc_ids[2]), 2) + # docs2names contains all docs + self.assertEqual(len(pm.docs2names), 4) + + def test_meta_annotation_values_flattened_per_annotation(self): + export = _build_export(num_projects=1, num_docs=1, num_anns=2) + pm = ProjectMetrics(export, cat=None) + # Two annotations: one True, one False + presence_values = sorted(a['Presence'] for a in pm.annotations) + self.assertEqual(presence_values, ['False', 'True']) + + +class AnnotationDataFrameTests(TestCase): + def test_annotation_df_without_cat_does_not_add_concept_name(self): + export = _build_export(num_anns=2) + pm = ProjectMetrics(export, cat=None) + df = pm.annotation_df() + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2 * 2) # docs * anns + self.assertNotIn('concept_name', df.columns) + + def test_concept_summary_without_cat_returns_basic_records(self): + export = _build_export(num_anns=2) + pm = ProjectMetrics(export, cat=None) + summary = pm.concept_summary() + self.assertIsInstance(summary, list) + # All annotations are validated+correct, so all should appear + cuis = {row['cui'] for row in summary} + self.assertEqual(cuis, {'C001', 'C002'}) + + def test_user_stats_groups_by_user(self): + export = _build_export(num_docs=2, num_anns=2) + pm = ProjectMetrics(export, cat=None) + stats = pm.user_stats(by_user=True) + self.assertIsInstance(stats, pd.DataFrame) + users = set(stats['user'].tolist()) + self.assertEqual(users, {'user0', 'user1'}) + + def test_user_stats_by_date_includes_date_column(self): + export = _build_export(num_docs=1, num_anns=2) + pm = ProjectMetrics(export, cat=None) + stats = pm.user_stats(by_user=False) + self.assertIn('date', stats.columns) + self.assertIn('user', stats.columns) + self.assertIn('count', stats.columns) + + +class RenameMetaAnnsTests(TestCase): + def test_rename_meta_task_name(self): + export = _build_export(num_docs=1, num_anns=1) + pm = ProjectMetrics(export, cat=None) + # The annotation initially has 'Presence' key. + self.assertIn('Presence', pm.annotations[0]) + + pm.rename_meta_anns(meta_anns2rename={'Presence': 'Existence'}) + + # Original 'Presence' should be renamed to 'Existence' + # Note: the rename happens on the underlying mct_export then _annotations() + # is rebuilt. So check on the rebuilt annotations list. + self.assertNotIn('Presence', pm.annotations[0]) + self.assertIn('Existence', pm.annotations[0]) + + def test_rename_meta_value_when_specified(self): + export = _build_export(num_docs=1, num_anns=1) + pm = ProjectMetrics(export, cat=None) + # Rename 'True' to 'Yes' inside the renamed task 'Existence' + pm.rename_meta_anns( + meta_anns2rename={'Presence': 'Existence'}, + meta_ann_values2rename={'Existence': {'True': 'Yes'}}, + ) + self.assertEqual(pm.annotations[0]['Existence'], 'Yes') + + +class EmptyExportTests(TestCase): + def test_handles_empty_documents_without_error(self): + export = {'projects': [{'id': 1, 'name': 'empty', 'documents': [], + 'meta_anno_defs': []}]} + pm = ProjectMetrics(export, cat=None) + self.assertEqual(pm.annotations, []) + # annotation_df on empty annotations raises; just check the helpers + # that should still work. + self.assertEqual(pm.projects2names[1], 'empty') + self.assertEqual(pm.projects2doc_ids[1], []) diff --git a/medcat-trainer/webapp/api/api/tests/test_model_cache.py b/medcat-trainer/webapp/api/api/tests/test_model_cache.py new file mode 100644 index 000000000..983cf3370 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_model_cache.py @@ -0,0 +1,206 @@ +"""Unit tests for api.model_cache. + +We avoid loading actual MedCAT artifacts by mocking CDB.load / Vocab.load / +CAT.load_model_pack. Cache state is reset in setUp and tearDown. +""" + +from unittest.mock import MagicMock, patch + +from django.test import TestCase, override_settings + +from .. import model_cache +from ..models import ConceptDB, ModelPack, Vocabulary +from ._helpers import create_basic_project + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-model-cache') +class ModelCacheTests(TestCase): + def setUp(self): + self.cdb_map = {} + self.vocab_map = {} + self.cat_map = {} + self.project = create_basic_project(name='mc-proj') + + def test_is_model_loaded_returns_false_when_cache_empty(self): + self.assertFalse( + model_cache.is_model_loaded(self.project, cdb_map=self.cdb_map, cat_map=self.cat_map) + ) + + def test_is_model_loaded_returns_true_when_cdb_cached(self): + self.cdb_map[self.project.concept_db.id] = MagicMock() + self.assertTrue( + model_cache.is_model_loaded(self.project, cdb_map=self.cdb_map, cat_map=self.cat_map) + ) + + def test_get_cached_medcat_returns_none_when_missing(self): + self.assertIsNone( + model_cache.get_cached_medcat(self.project, cat_map=self.cat_map) + ) + + def test_get_cached_medcat_returns_value_when_present(self): + cat_id = f'{self.project.concept_db.id}-{self.project.vocab.id}' + sentinel = MagicMock() + self.cat_map[cat_id] = sentinel + self.assertIs( + model_cache.get_cached_medcat(self.project, cat_map=self.cat_map), + sentinel, + ) + + def test_get_cached_medcat_raises_when_no_cdb_and_not_remote(self): + # Project without CDB but not using a remote service - should raise + self.project.concept_db = None + self.project.use_model_service = False + with self.assertRaises(Exception) as ctx: + model_cache.get_cached_medcat(self.project, cat_map=self.cat_map) + self.assertIn('misconfigured', str(ctx.exception)) + + def test_get_cached_medcat_raises_for_remote_service_project(self): + self.project.use_model_service = True + self.project.model_service_url = 'http://x' + self.project.concept_db = None + self.project.vocab = None + self.project.save() + with self.assertRaises(ValueError): + model_cache.get_cached_medcat(self.project, cat_map=self.cat_map) + + def test_clear_cached_medcat_removes_cat_from_cat_map(self): + cdb_id = self.project.concept_db.id + vocab_id = self.project.vocab.id + cat_id = f'{cdb_id}-{vocab_id}' + self.cat_map[cat_id] = MagicMock() + + model_cache.clear_cached_medcat(self.project, cat_map=self.cat_map) + + self.assertNotIn(cat_id, self.cat_map) + + def test_clear_cached_cdb_no_op_when_missing(self): + # Should not raise even if cdb not in map + model_cache.clear_cached_cdb(99999, cdb_map=self.cdb_map) + + def test_clear_cached_vocab_no_op_when_missing(self): + model_cache.clear_cached_vocab(99999, vocab_map=self.vocab_map) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-model-cache') +class GetCachedCdbTests(TestCase): + def setUp(self): + self.cdb = ConceptDB(name='cached-cdb', cdb_file='cached-cdb.dat') + self.cdb.save(skip_load=True) + self.cdb_map = {} + + @patch('api.utils.clear_cdb_cnf_addons') + @patch('api.model_cache.CDB.load') + def test_loads_when_not_cached(self, mock_load, mock_clear): + loaded = MagicMock() + mock_load.return_value = loaded + + cached = model_cache.get_cached_cdb(self.cdb.id, cdb_map=self.cdb_map) + self.assertIs(cached, loaded) + self.assertIn(self.cdb.id, self.cdb_map) + mock_clear.assert_called_once() + + @patch('api.model_cache.CDB.load') + def test_returns_existing_when_cached(self, mock_load): + sentinel = MagicMock() + self.cdb_map[self.cdb.id] = sentinel + result = model_cache.get_cached_cdb(self.cdb.id, cdb_map=self.cdb_map) + self.assertIs(result, sentinel) + mock_load.assert_not_called() + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-model-cache') +class IsModelPackLoadedTests(TestCase): + def test_returns_true_when_present(self): + cat_map = {'mp42': MagicMock()} + self.assertTrue(model_cache.is_model_pack_loaded(42, cat_map=cat_map)) + + def test_returns_false_when_absent(self): + self.assertFalse(model_cache.is_model_pack_loaded(42, cat_map={})) + + def test_clear_by_modelpack_id_removes_entry(self): + cat_map = {'mp42': MagicMock(), 'mp7': MagicMock()} + model_cache.clear_cached_medcat_by_model_pack_id(42, cat_map=cat_map) + self.assertNotIn('mp42', cat_map) + self.assertIn('mp7', cat_map) + + def test_clear_by_modelpack_id_no_op_when_missing(self): + cat_map = {} + # Should not raise + model_cache.clear_cached_medcat_by_model_pack_id(42, cat_map=cat_map) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-model-cache') +class GetMedcatFromModelPackIdTests(TestCase): + def setUp(self): + from django.core.files.uploadedfile import SimpleUploadedFile + + self.modelpack = ModelPack( + name='mp-cached', + model_pack=SimpleUploadedFile('mp.zip', b'fake'), + ) + self.modelpack.save(skip_load=True) + + @patch('api.model_cache.CAT.load_model_pack') + def test_loads_when_not_cached(self, mock_load): + cat_map = {} + loaded = MagicMock() + mock_load.return_value = loaded + result = model_cache.get_medcat_from_model_pack_id(self.modelpack.id, cat_map=cat_map) + self.assertIs(result, loaded) + self.assertIn(f'mp{self.modelpack.id}', cat_map) + + @patch('api.model_cache.CAT.load_model_pack') + def test_returns_cached_when_present(self, mock_load): + sentinel = MagicMock() + cat_map = {f'mp{self.modelpack.id}': sentinel} + result = model_cache.get_medcat_from_model_pack_id(self.modelpack.id, cat_map=cat_map) + self.assertIs(result, sentinel) + mock_load.assert_not_called() + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-model-cache') +class GetMedcatFromCdbVocabTests(TestCase): + def setUp(self): + self.project = create_basic_project(name='cv-proj') + + @patch('api.model_cache.CAT') + @patch('api.model_cache.Vocab.load') + @patch('api.model_cache.CDB.load') + @patch('api.utils.clear_cdb_cnf_addons') + def test_loads_and_caches(self, mock_clear, mock_cdb_load, mock_vocab_load, mock_cat_cls): + mock_cdb = MagicMock() + mock_vocab = MagicMock() + mock_cdb_load.return_value = mock_cdb + mock_vocab_load.return_value = mock_vocab + mock_cat_instance = MagicMock() + mock_cat_cls.return_value = mock_cat_instance + + cdb_map = {} + vocab_map = {} + cat_map = {} + + result = model_cache.get_medcat_from_cdb_vocab( + self.project, cdb_map=cdb_map, vocab_map=vocab_map, cat_map=cat_map + ) + + cat_id = f'{self.project.concept_db.id}-{self.project.vocab.id}' + self.assertIn(cat_id, cat_map) + self.assertIn(self.project.concept_db.id, cdb_map) + self.assertIn(self.project.vocab.id, vocab_map) + self.assertIs(result, mock_cat_instance) + + @patch('api.model_cache.CAT') + @patch('api.model_cache.Vocab.load') + @patch('api.model_cache.CDB.load') + def test_returns_cached_cat_when_present(self, mock_cdb_load, mock_vocab_load, mock_cat_cls): + cat_id = f'{self.project.concept_db.id}-{self.project.vocab.id}' + sentinel = MagicMock() + cat_map = {cat_id: sentinel} + + result = model_cache.get_medcat_from_cdb_vocab( + self.project, cdb_map={}, vocab_map={}, cat_map=cat_map + ) + self.assertIs(result, sentinel) + mock_cdb_load.assert_not_called() + mock_vocab_load.assert_not_called() + mock_cat_cls.assert_not_called() diff --git a/medcat-trainer/webapp/api/api/tests/test_models.py b/medcat-trainer/webapp/api/api/tests/test_models.py new file mode 100644 index 000000000..d48a5f786 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_models.py @@ -0,0 +1,203 @@ +"""Unit tests for api.models validation and string representations.""" + +from django.core.exceptions import ValidationError +from django.test import TestCase, override_settings + +from ..models import ( + AnnotatedEntity, + ConceptDB, + Document, + Entity, + EntityRelation, + MetaAnnotation, + MetaTask, + MetaTaskValue, + ProjectAnnotateEntities, + Relation, + Vocabulary, + cdb_name_validator, +) +from ._helpers import ( + create_basic_project, + create_dataset, + create_document, + create_entity, + create_user, +) + + +class StringRepresentationTests(TestCase): + def test_entity_str(self): + ent = Entity.objects.create(label='C001') + self.assertEqual(str(ent), 'C001') + + def test_relation_str(self): + rel = Relation.objects.create(label='hasFinding') + self.assertEqual(str(rel), 'hasFinding') + + def test_meta_task_value_str(self): + v = MetaTaskValue.objects.create(name='True') + self.assertEqual(str(v), 'True') + + def test_meta_task_str(self): + mt = MetaTask.objects.create(name='Presence') + self.assertEqual(str(mt), 'Presence') + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-models') +class CdbNameValidatorTests(TestCase): + def test_validator_accepts_alphanumeric_with_underscore(self): + cdb_name_validator('abc_123') # should not raise + + def test_validator_rejects_leading_digit(self): + with self.assertRaises(ValidationError): + cdb_name_validator('1abc') + + def test_validator_rejects_special_chars(self): + with self.assertRaises(ValidationError): + cdb_name_validator('a-b') + + def test_validator_rejects_empty(self): + with self.assertRaises(ValidationError): + cdb_name_validator('') + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-models') +class ProjectAnnotateEntitiesValidationTests(TestCase): + @classmethod + def setUpTestData(cls): + cdb = ConceptDB(name='val_cdb', cdb_file='val_cdb.dat') + cdb.save(skip_load=True) + vocab = Vocabulary(name='val_vocab', vocab_file='val_vocab.dat') + vocab.save(skip_load=True) + cls.cdb = cdb + cls.vocab = vocab + cls.dataset = create_dataset(name='val_ds', file_name='val_ds.csv') + + def _new_project(self, **kwargs): + proj = ProjectAnnotateEntities() + proj.name = kwargs.pop('name', 'p1') + proj.dataset = kwargs.pop('dataset', self.dataset) + proj.cuis = '' + for k, v in kwargs.items(): + setattr(proj, k, v) + return proj + + def test_save_requires_cdb_vocab_or_model_pack(self): + proj = self._new_project() + with self.assertRaises(ValidationError): + proj.save() + + def test_save_with_cdb_and_vocab_succeeds(self): + proj = self._new_project(concept_db=self.cdb, vocab=self.vocab) + proj.save() # should not raise + self.assertIsNotNone(proj.id) + + def test_use_model_service_requires_url(self): + proj = self._new_project(use_model_service=True) + with self.assertRaises(ValidationError): + proj.save() + + def test_use_model_service_with_url_skips_model_validation(self): + proj = self._new_project(use_model_service=True, model_service_url='http://x') + proj.save() + self.assertIsNotNone(proj.id) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-models') +class AnnotatedEntitySaveUpdatesProjectTests(TestCase): + def test_saving_annotation_updates_project_last_modified(self): + user = create_user(username='auser') + project = create_basic_project(name='ae-proj') + doc = create_document(project, name='doc', text='hello') + ent = create_entity(label='C100') + + before = project.last_modified + AnnotatedEntity.objects.create( + user=user, + project=project, + document=doc, + entity=ent, + value='hello', + start_ind=0, + end_ind=5, + acc=0.5, + ) + project.refresh_from_db() + self.assertGreaterEqual(project.last_modified, before) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-models') +class MetaAnnotationSaveUpdatesParentTests(TestCase): + def test_saving_meta_annotation_updates_annotated_entity_last_modified(self): + user = create_user(username='muser') + project = create_basic_project(name='ma-proj') + doc = create_document(project, name='doc', text='hello') + ent = create_entity(label='C200') + ann = AnnotatedEntity.objects.create( + user=user, + project=project, + document=doc, + entity=ent, + value='hello', + start_ind=0, + end_ind=5, + acc=0.5, + ) + task = MetaTask.objects.create(name='Presence') + val = MetaTaskValue.objects.create(name='True') + + before = ann.last_modified + MetaAnnotation.objects.create( + annotated_entity=ann, + meta_task=task, + meta_task_value=val, + validated=True, + ) + ann.refresh_from_db() + self.assertGreaterEqual(ann.last_modified, before) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-models') +class EntityRelationSaveUpdatesProjectTests(TestCase): + def test_saving_relation_updates_project(self): + user = create_user(username='ruser') + project = create_basic_project(name='rel-proj') + doc = create_document(project, name='doc', text='hello world') + ent_a = create_entity(label='A') + ent_b = create_entity(label='B') + + start = AnnotatedEntity.objects.create( + user=user, project=project, document=doc, entity=ent_a, + value='hello', start_ind=0, end_ind=5, acc=1.0, + ) + end = AnnotatedEntity.objects.create( + user=user, project=project, document=doc, entity=ent_b, + value='world', start_ind=6, end_ind=11, acc=1.0, + ) + + rel = Relation.objects.create(label='has') + before = project.last_modified + EntityRelation.objects.create( + user=user, + project=project, + document=doc, + relation=rel, + start_entity=start, + end_entity=end, + ) + project.refresh_from_db() + self.assertGreaterEqual(project.last_modified, before) + + +class ConceptDbCannotChangeFilePathTests(TestCase): + @override_settings(MEDIA_ROOT='/tmp/mct-tests-models') + def test_change_of_cdb_file_after_first_save_raises(self): + cdb = ConceptDB(name='cant_change', cdb_file='orig.dat') + cdb.save(skip_load=True) + + # Simulate Django reload semantics + reloaded = ConceptDB.objects.get(id=cdb.id) + reloaded.cdb_file.name = 'other.dat' + with self.assertRaises(ValidationError): + reloaded.save(skip_load=True) diff --git a/medcat-trainer/webapp/api/api/tests/test_oidc_utils.py b/medcat-trainer/webapp/api/api/tests/test_oidc_utils.py new file mode 100644 index 000000000..275dbeb05 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_oidc_utils.py @@ -0,0 +1,93 @@ +"""Unit tests for api.oidc_utils.""" + +from django.contrib.auth import get_user_model +from django.test import TestCase + +from ..oidc_utils import get_user_by_email + + +class GetUserByEmailTests(TestCase): + def setUp(self): + self.User = get_user_model() + + def test_creates_new_user_from_full_claims(self): + claims = { + 'preferred_username': 'jdoe', + 'email': 'jdoe@example.com', + 'given_name': 'John', + 'family_name': 'Doe', + } + user = get_user_by_email(request=None, id_token=claims) + self.assertEqual(user.username, 'jdoe') + self.assertEqual(user.email, 'jdoe@example.com') + self.assertEqual(user.first_name, 'John') + self.assertEqual(user.last_name, 'Doe') + self.assertFalse(user.is_superuser) + self.assertFalse(user.is_staff) + + def test_assigns_superuser_when_role_present(self): + claims = { + 'preferred_username': 'admin', + 'email': 'admin@example.com', + 'roles': ['medcattrainer_superuser'], + } + user = get_user_by_email(request=None, id_token=claims) + self.assertTrue(user.is_superuser) + self.assertFalse(user.is_staff) + + def test_assigns_staff_when_role_present(self): + claims = { + 'preferred_username': 'staffuser', + 'email': 'staff@example.com', + 'roles': ['medcattrainer_staff'], + } + user = get_user_by_email(request=None, id_token=claims) + self.assertTrue(user.is_staff) + self.assertFalse(user.is_superuser) + + def test_falls_back_to_sub_when_username_missing(self): + claims = {'sub': 'unique-sub-id', 'email': 'a@b.com'} + user = get_user_by_email(request=None, id_token=claims) + self.assertEqual(user.username, 'unique-sub-id') + + def test_falls_back_to_client_id_when_no_username_or_sub(self): + claims = {'client_id': 'svc-client', 'email': 'svc@example.com'} + user = get_user_by_email(request=None, id_token=claims) + self.assertEqual(user.username, 'svc-client') + + def test_falls_back_to_random_username_when_nothing_provided(self): + # No username, no sub, no client_id - should still create a user with a random username + user = get_user_by_email(request=None, id_token={'email': 'nouser@example.com'}) + self.assertTrue(user.username.startswith('oidc-')) + self.assertEqual(user.email, 'nouser@example.com') + + def test_email_falls_back_to_username_when_missing(self): + claims = {'preferred_username': 'just-username'} + user = get_user_by_email(request=None, id_token=claims) + self.assertEqual(user.email, 'just-username') + + def test_returns_existing_user_when_email_matches(self): + existing = self.User.objects.create_user( + username='oldname', email='dup@example.com', password='x') + + claims = {'preferred_username': 'newname', 'email': 'dup@example.com'} + returned = get_user_by_email(request=None, id_token=claims) + + self.assertEqual(returned.id, existing.id) + # Existing user should have updated profile fields + returned.refresh_from_db() + self.assertEqual(returned.username, 'newname') + + def test_role_updates_existing_user(self): + existing = self.User.objects.create_user( + username='roleuser', email='role@example.com', password='x') + self.assertFalse(existing.is_superuser) + + get_user_by_email(request=None, id_token={ + 'preferred_username': 'roleuser', + 'email': 'role@example.com', + 'roles': ['medcattrainer_superuser', 'medcattrainer_staff'], + }) + existing.refresh_from_db() + self.assertTrue(existing.is_superuser) + self.assertTrue(existing.is_staff) diff --git a/medcat-trainer/webapp/api/api/tests/test_permissions.py b/medcat-trainer/webapp/api/api/tests/test_permissions.py new file mode 100644 index 000000000..8c0ee7412 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_permissions.py @@ -0,0 +1,101 @@ +"""Unit tests for api.permissions.""" + +from unittest.mock import MagicMock + +from django.contrib.auth.models import User +from django.test import TestCase, RequestFactory, override_settings + +from ..models import ( + ConceptDB, + ProjectAnnotateEntities, + ProjectGroup, + Vocabulary, +) +from ..permissions import IsReadOnly, is_project_admin +from ._helpers import create_dataset + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-perms') +class IsReadOnlyTests(TestCase): + def setUp(self): + self.permission = IsReadOnly() + self.factory = RequestFactory() + + def test_allows_safe_methods(self): + for method in ('GET', 'HEAD', 'OPTIONS'): + request = self.factory.generic(method, '/api/x/') + self.assertTrue( + self.permission.has_permission(request, view=MagicMock()), + f'Expected {method} to be allowed', + ) + + def test_denies_unsafe_methods(self): + for method in ('POST', 'PUT', 'PATCH', 'DELETE'): + request = self.factory.generic(method, '/api/x/') + self.assertFalse( + self.permission.has_permission(request, view=MagicMock()), + f'Expected {method} to be denied', + ) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-perms') +class IsProjectAdminTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.superuser = User.objects.create_superuser(username='su', password='pw', email='su@x') + cls.staff = User.objects.create_user(username='st', password='pw', is_staff=True) + cls.member = User.objects.create_user(username='m1', password='pw') + cls.group_admin = User.objects.create_user(username='ga', password='pw') + cls.outsider = User.objects.create_user(username='out', password='pw') + + cdb = ConceptDB(name='perm_cdb', cdb_file='perm_cdb.dat') + cdb.save(skip_load=True) + vocab = Vocabulary(name='perm_vocab', vocab_file='perm_vocab.dat') + vocab.save(skip_load=True) + dataset = create_dataset(name='perm_ds', file_name='perm_ds.csv') + + cls.group = ProjectGroup.objects.create( + name='grp1', + dataset=dataset, + concept_db=cdb, + vocab=vocab, + cuis='', + ) + cls.group.administrators.add(cls.group_admin) + + cls.project_no_group = ProjectAnnotateEntities() + cls.project_no_group.name = 'p-no-group' + cls.project_no_group.dataset = dataset + cls.project_no_group.concept_db = cdb + cls.project_no_group.vocab = vocab + cls.project_no_group.cuis = '' + cls.project_no_group.save() + cls.project_no_group.members.add(cls.member) + + cls.project_with_group = ProjectAnnotateEntities() + cls.project_with_group.name = 'p-grouped' + cls.project_with_group.dataset = dataset + cls.project_with_group.concept_db = cdb + cls.project_with_group.vocab = vocab + cls.project_with_group.group = cls.group + cls.project_with_group.cuis = '' + cls.project_with_group.save() + + def test_superuser_is_always_admin(self): + self.assertTrue(is_project_admin(self.superuser, self.project_no_group)) + + def test_staff_user_is_always_admin(self): + self.assertTrue(is_project_admin(self.staff, self.project_no_group)) + + def test_member_user_is_admin(self): + self.assertTrue(is_project_admin(self.member, self.project_no_group)) + + def test_group_admin_is_admin_of_group_project(self): + self.assertTrue(is_project_admin(self.group_admin, self.project_with_group)) + + def test_group_admin_is_not_admin_of_unrelated_project(self): + self.assertFalse(is_project_admin(self.group_admin, self.project_no_group)) + + def test_outsider_is_not_admin(self): + self.assertFalse(is_project_admin(self.outsider, self.project_no_group)) + self.assertFalse(is_project_admin(self.outsider, self.project_with_group)) diff --git a/medcat-trainer/webapp/api/api/tests/test_serializers.py b/medcat-trainer/webapp/api/api/tests/test_serializers.py new file mode 100644 index 000000000..45883bc58 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_serializers.py @@ -0,0 +1,196 @@ +"""Unit tests for api.serializers.""" + +import json +import os +import tempfile + +from django.contrib.auth.models import User +from django.test import TestCase, override_settings + +from ..models import ( + AnnotatedEntity, + ConceptDB, + Dataset, + Document, + Entity, + ProjectAnnotateEntities, + ProjectGroup, + Vocabulary, +) +from ..serializers import ( + AnnotatedEntitySerializer, + DatasetSerializer, + DocumentSerializer, + EntitySerializer, + ProjectAnnotateEntitiesSerializer, + ProjectGroupSerializer, + UserSerializer, +) +from ._helpers import create_dataset + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-serializers') +class UserSerializerTests(TestCase): + def test_serializes_expected_fields(self): + user = User.objects.create_user(username='alice', email='a@x.com', password='pw') + data = UserSerializer(user, context={'request': None}).data + self.assertEqual(data['username'], 'alice') + self.assertEqual(data['email'], 'a@x.com') + self.assertIn('id', data) + self.assertIn('is_staff', data) + self.assertIn('is_superuser', data) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-serializers') +class EntitySerializerTests(TestCase): + def test_serializes_entity_label(self): + ent = Entity.objects.create(label='C123') + data = EntitySerializer(ent).data + self.assertEqual(data['label'], 'C123') + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-serializers') +class DocumentAndAnnotationSerializerTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user(username='annu', password='pw') + cdb = ConceptDB(name='ser_cdb', cdb_file='ser_cdb.dat') + cdb.save(skip_load=True) + vocab = Vocabulary(name='ser_vocab', vocab_file='ser_vocab.dat') + vocab.save(skip_load=True) + cls.dataset = create_dataset(name='ser_ds', file_name='ser_ds.csv') + cls.document = Document.objects.create(name='doc', text='hello', dataset=cls.dataset) + + cls.project = ProjectAnnotateEntities() + cls.project.name = 'ser-proj' + cls.project.dataset = cls.dataset + cls.project.concept_db = cdb + cls.project.vocab = vocab + cls.project.cuis = '' + cls.project.save() + + cls.entity = Entity.objects.create(label='C001') + + def test_document_serializer(self): + data = DocumentSerializer(self.document).data + self.assertEqual(data['name'], 'doc') + self.assertEqual(data['text'], 'hello') + self.assertEqual(data['dataset'], self.dataset.id) + + def test_annotated_entity_serializer(self): + ann = AnnotatedEntity.objects.create( + user=self.user, + project=self.project, + document=self.document, + entity=self.entity, + value='hello', + start_ind=0, + end_ind=5, + acc=0.5, + ) + data = AnnotatedEntitySerializer(ann).data + self.assertEqual(data['value'], 'hello') + self.assertEqual(data['start_ind'], 0) + self.assertEqual(data['end_ind'], 5) + self.assertEqual(data['acc'], 0.5) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-serializers') +class ProjectAnnotateEntitiesSerializerTests(TestCase): + @classmethod + def setUpTestData(cls): + cdb = ConceptDB(name='pas_cdb', cdb_file='pas_cdb.dat') + cdb.save(skip_load=True) + vocab = Vocabulary(name='pas_vocab', vocab_file='pas_vocab.dat') + vocab.save(skip_load=True) + cls.dataset = create_dataset(name='pas_ds', file_name='pas_ds.csv') + + cls.project = ProjectAnnotateEntities() + cls.project.name = 'pas-proj' + cls.project.dataset = cls.dataset + cls.project.concept_db = cdb + cls.project.vocab = vocab + cls.project.cuis = 'A,B,C' + cls.project.save() + + def test_to_representation_includes_inline_cuis_only_when_no_file(self): + data = ProjectAnnotateEntitiesSerializer(self.project).data + # Should contain the original CUIs separated by ',' + self.assertEqual(set(data['cuis'].split(',')), {'A', 'B', 'C'}) + + def test_to_representation_merges_cuis_from_file(self): + media_root = '/tmp/mct-tests-serializers' + os.makedirs(media_root, exist_ok=True) + rel_path = 'pas_cuis_file.json' + abs_path = os.path.join(media_root, rel_path) + with open(abs_path, 'w') as f: + json.dump(['X', 'Y'], f) + try: + self.project.cuis_file.name = rel_path + self.project.save() + + data = ProjectAnnotateEntitiesSerializer(self.project).data + self.assertEqual(set(data['cuis'].split(',')), {'A', 'B', 'C', 'X', 'Y'}) + finally: + if os.path.isfile(abs_path): + os.unlink(abs_path) + + def test_to_representation_handles_missing_cuis_file_gracefully(self): + # Do not save() — post_save would try to read cuis_file on disk. + self.project.cuis_file.name = 'missing_cuis_file.json' + + data = ProjectAnnotateEntitiesSerializer(self.project).data + self.assertEqual(set(data['cuis'].split(',')), {'A', 'B', 'C'}) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-serializers') +class ProjectGroupSerializerTests(TestCase): + @classmethod + def setUpTestData(cls): + cdb = ConceptDB(name='pg_cdb', cdb_file='pg_cdb.dat') + cdb.save(skip_load=True) + vocab = Vocabulary(name='pg_vocab', vocab_file='pg_vocab.dat') + vocab.save(skip_load=True) + cls.dataset = create_dataset(name='pg_ds', file_name='pg_ds.csv') + cls.cdb = cdb + cls.vocab = vocab + + def test_last_modified_is_null_when_group_has_no_projects(self): + group = ProjectGroup.objects.create( + name='empty-group', + dataset=self.dataset, + concept_db=self.cdb, + vocab=self.vocab, + cuis='', + ) + data = ProjectGroupSerializer(group).data + self.assertIsNone(data['last_modified']) + + def test_last_modified_is_set_to_latest_project_in_group(self): + group = ProjectGroup.objects.create( + name='active-group', + dataset=self.dataset, + concept_db=self.cdb, + vocab=self.vocab, + cuis='', + ) + p = ProjectAnnotateEntities() + p.name = 'p-in-group' + p.dataset = self.dataset + p.cuis = '' + p.concept_db = self.cdb + p.vocab = self.vocab + p.group = group + p.save() + + data = ProjectGroupSerializer(group).data + self.assertIsNotNone(data['last_modified']) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-serializers') +class DatasetSerializerTests(TestCase): + def test_serializes_dataset(self): + dataset = create_dataset(name='ds-test', file_name='ds-test.csv') + data = DatasetSerializer(dataset).data + self.assertEqual(data['name'], 'ds-test') + self.assertIn('original_file', data) diff --git a/medcat-trainer/webapp/api/api/tests/test_solr_utils.py b/medcat-trainer/webapp/api/api/tests/test_solr_utils.py new file mode 100644 index 000000000..dc38789a3 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_solr_utils.py @@ -0,0 +1,212 @@ +"""Unit tests for api.solr_utils using mocked HTTP calls.""" + +import json +from unittest.mock import MagicMock, patch + +from django.test import TestCase, override_settings + +from .. import solr_utils +from ..models import ConceptDB +from ._helpers import dataset_signals_disconnected # noqa: F401 (ensures helper module imports) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-solr') +class CollectionsAvailableTests(TestCase): + def setUp(self): + # Clear schema cache to avoid leakage between tests + solr_utils.SOLR_INDEX_SCHEMA.clear() + + @patch('api.solr_utils.requests.get') + def test_returns_imported_map_when_cdbs_provided(self, mock_get): + # First call: list collections; subsequent: schema + def side_effect(url, *args, **kwargs): + if 'admin/collections' in url: + return MagicMock(status_code=200, text=json.dumps({'collections': ['my_id_1']})) + if 'schema' in url: + return MagicMock(text=json.dumps({'schema': {'fields': [{'name': 'cui', 'type': 'string'}]}})) + return MagicMock(status_code=404, text='') + + mock_get.side_effect = side_effect + + response = solr_utils.collections_available(['1', '2']) + self.assertEqual(response.status_code, 200) + body = response.data + self.assertTrue(body['results']['1']) + self.assertFalse(body['results']['2']) + + @patch('api.solr_utils.requests.get') + def test_returns_full_collection_info_when_no_cdbs(self, mock_get): + def side_effect(url, *args, **kwargs): + if 'admin/collections' in url: + return MagicMock(status_code=200, text=json.dumps({'collections': ['my_id_1']})) + return MagicMock(text=json.dumps({'schema': {'fields': [{'name': 'cui', 'type': 'string'}]}})) + + mock_get.side_effect = side_effect + + response = solr_utils.collections_available([]) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['results']['1']['index_name'], 'my_id_1') + + @patch('api.solr_utils.requests.get') + def test_returns_500_when_solr_admin_unavailable(self, mock_get): + mock_get.return_value = MagicMock(status_code=500, text='boom') + response = solr_utils.collections_available(['1']) + self.assertEqual(response.status_code, 500) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-solr') +class SearchCollectionTests(TestCase): + @classmethod + def setUpTestData(cls): + cdb = ConceptDB(name='solrcdb', cdb_file='solrcdb.dat') + cdb.save(skip_load=True) + cls.cdb = cdb + + def setUp(self): + solr_utils.SOLR_INDEX_SCHEMA.clear() + solr_utils.SOLR_INDEX_SCHEMA[f'solrcdb_id_{self.cdb.id}'] = {'cui': 'string'} + + def test_empty_query_returns_empty_results(self): + response = solr_utils.search_collection([self.cdb.id], '') + self.assertEqual(response.data, {'results': []}) + + @patch('api.solr_utils.requests.get') + def test_returns_documents_for_text_query(self, mock_get): + mock_get.return_value = MagicMock( + text=json.dumps({ + 'response': { + 'docs': [ + { + 'cui': ['C001'], + 'pretty_name': ['Concept 1'], + 'type_ids': ['T001'], + 'synonyms': ['c1', 'c-one'], + }, + { + 'cui': ['C002'], + 'pretty_name': ['Concept 2'], + 'type_ids': ['T002'], + 'synonyms': ['c2'], + }, + ] + } + }) + ) + + response = solr_utils.search_collection([self.cdb.id], 'foo') + results = response.data['results'] + cuis = sorted(r['cui'] for r in results) + self.assertEqual(cuis, ['C001', 'C002']) + + @patch('api.solr_utils.requests.get') + def test_falls_back_to_wildcard_when_no_results(self, mock_get): + calls = [] + + def fake_get(url, *args, **kwargs): + calls.append(url) + if len(calls) == 1: + return MagicMock(text=json.dumps({'response': {'docs': []}})) + return MagicMock(text=json.dumps({ + 'response': { + 'docs': [{ + 'cui': ['C999'], + 'pretty_name': ['Wildcard match'], + 'type_ids': [], + 'synonyms': ['wm'], + }] + } + })) + + mock_get.side_effect = fake_get + + response = solr_utils.search_collection([self.cdb.id], 'foo') + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['cui'], 'C999') + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-solr') +class HelperFunctionTests(TestCase): + def test_process_result_response_deduplicates_by_cui(self): + resp = { + 'response': { + 'docs': [ + {'cui': ['C001'], 'pretty_name': ['a'], 'type_ids': ['T1'], 'synonyms': ['x']}, + {'cui': ['C001'], 'pretty_name': ['a-dup'], 'type_ids': ['T1'], 'synonyms': ['x']}, + {'cui': ['C002'], 'pretty_name': ['b'], 'type_ids': [], 'synonyms': []}, + ] + } + } + result_map = solr_utils._process_result_repsonse(resp) + self.assertEqual(set(result_map.keys()), {'C001', 'C002'}) + + def test_concept_dct_uses_pretty_name_when_no_synonyms(self): + cdb = MagicMock() + cdb.get_name.return_value = 'Pretty (qualifier)' + info = {'original_names': [], 'type_ids': ['T1'], 'description': 'desc'} + out = solr_utils._concept_dct('C001', cdb, info) + self.assertEqual(out['cui'], 'C001') + # synonyms fall back to pretty name when original_names is empty + self.assertEqual(out['synonyms'], ['Pretty (qualifier)']) + # parenthesised qualifier removed in `name` + self.assertEqual(out['name'], 'Pretty') + + def test_concept_dct_uses_original_names_as_synonyms(self): + cdb = MagicMock() + cdb.get_name.return_value = 'Hypertension' + info = {'original_names': {'HTN', 'High blood pressure'}, 'type_ids': ['T'], 'description': 'd'} + out = solr_utils._concept_dct('C100', cdb, info) + self.assertSetEqual(set(out['synonyms']), {'HTN', 'High blood pressure'}) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-solr') +class DropCollectionTests(TestCase): + @patch('api.solr_utils.requests.get') + def test_drop_collection_calls_delete_endpoint(self, mock_get): + mock_get.return_value = MagicMock(status_code=200, text='{}') + cdb = ConceptDB(name='drop_cdb', cdb_file='drop_cdb.dat') + cdb.save(skip_load=True) + solr_utils.drop_collection(cdb) + # Should call the DELETE action URL + call_url = mock_get.call_args[0][0] + self.assertIn(f'name=drop_cdb_id_{cdb.id}', call_url) + self.assertIn('action=DELETE', call_url) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-solr') +class EnsureConceptSearchableTests(TestCase): + @patch('api.solr_utils.requests.post') + @patch('api.solr_utils.requests.get') + def test_uploads_concept_when_collection_exists(self, mock_get, mock_post): + cdb = ConceptDB(name='ecs_cdb', cdb_file='ecs_cdb.dat') + cdb.save(skip_load=True) + + mock_get.return_value = MagicMock( + status_code=200, + text=json.dumps({'collections': [f'ecs_cdb_id_{cdb.id}']}), + ) + mock_post.return_value = MagicMock(status_code=200, text='{}') + + mc_cdb = MagicMock() + mc_cdb.get_name.return_value = 'X' + mc_cdb.cui2info = {'C': {'original_names': [], 'type_ids': [], 'description': ''}} + + solr_utils.ensure_concept_searchable('C', mc_cdb, cdb) + mock_post.assert_called_once() + payload = mock_post.call_args.kwargs.get('json') + self.assertEqual(payload[0]['cui'], 'C') + + @patch('api.solr_utils.requests.post') + @patch('api.solr_utils.requests.get') + def test_does_not_upload_when_collection_missing(self, mock_get, mock_post): + cdb = ConceptDB(name='ecs_cdb2', cdb_file='ecs_cdb2.dat') + cdb.save(skip_load=True) + mock_get.return_value = MagicMock( + status_code=200, + text=json.dumps({'collections': []}), + ) + mc_cdb = MagicMock() + mc_cdb.get_name.return_value = 'X' + mc_cdb.cui2info = {'C': {'original_names': [], 'type_ids': [], 'description': ''}} + + solr_utils.ensure_concept_searchable('C', mc_cdb, cdb) + mock_post.assert_not_called() diff --git a/medcat-trainer/webapp/api/api/tests/test_utils.py b/medcat-trainer/webapp/api/api/tests/test_utils.py new file mode 100644 index 000000000..d26a4cfed --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_utils.py @@ -0,0 +1,386 @@ +"""Unit tests for api.utils. + +Covers the pure-Python helpers (RemoteEntity, RemoteSpacyDoc, SimpleFilters, +env_str_to_bool, call_remote_model_service_* and add_annotations/ +remove_annotations/create_annotation) without spinning up MedCAT itself. +""" + +import os +from unittest.mock import patch, MagicMock + +import requests +from django.contrib.auth.models import User +from django.test import TestCase, override_settings + +from .. import utils +from ..models import AnnotatedEntity, Entity + + +class RemoteEntityTests(TestCase): + """Tests for the RemoteEntity helper that mirrors spaCy's entity shape.""" + + def test_constructs_from_full_payload(self): + ent = utils.RemoteEntity( + { + 'cui': 'C001', + 'start': 5, + 'end': 12, + 'detected_name': 'fever', + 'context_similarity': 0.92, + 'meta_anns': {'Presence': {'value': 'True'}}, + }, + 'patient has fever now', + ) + self.assertEqual(ent.cui, 'C001') + self.assertEqual(ent.start_char_index, 5) + self.assertEqual(ent.end_char_index, 12) + self.assertEqual(ent.text, 'fever') + self.assertEqual(ent.context_similarity, 0.92) + self.assertEqual(ent.get_addon_data('meta_cat_meta_anns'), {'Presence': {'value': 'True'}}) + + def test_falls_back_to_source_value_and_acc(self): + ent = utils.RemoteEntity( + {'cui': 'C002', 'source_value': 'cough', 'acc': 0.42}, + 'cough', + ) + self.assertEqual(ent.text, 'cough') + self.assertEqual(ent.context_similarity, 0.42) + self.assertEqual(ent.start_char_index, 0) + self.assertEqual(ent.end_char_index, 0) + + def test_unknown_addon_key_returns_empty_dict(self): + ent = utils.RemoteEntity({'cui': 'C003'}, 'text') + self.assertEqual(ent.get_addon_data('some_other_key'), {}) + + def test_defaults_when_payload_empty(self): + ent = utils.RemoteEntity({}, '') + self.assertEqual(ent.cui, '') + self.assertEqual(ent.text, '') + self.assertEqual(ent.context_similarity, 0.0) + + +class RemoteSpacyDocTests(TestCase): + def test_wraps_linked_ents(self): + ents = [utils.RemoteEntity({'cui': 'A'}, 'text'), utils.RemoteEntity({'cui': 'B'}, 'text')] + doc = utils.RemoteSpacyDoc(ents) + self.assertEqual(doc.linked_ents, ents) + + +class SimpleFiltersTests(TestCase): + def test_default_empty_filters(self): + f = utils.SimpleFilters() + self.assertEqual(f.cuis, set()) + self.assertEqual(f.cuis_exclude, set()) + + def test_custom_filters_preserved(self): + f = utils.SimpleFilters(cuis={'X'}, cuis_exclude={'Y'}) + self.assertEqual(f.cuis, {'X'}) + self.assertEqual(f.cuis_exclude, {'Y'}) + + +class EnvStrToBoolTests(TestCase): + def _set_env(self, value): + os.environ['__MCT_TEST_FLAG__'] = value + self.addCleanup(os.environ.pop, '__MCT_TEST_FLAG__', None) + + def test_truthy_string_values(self): + for v in ('1', 'true', 't', 'y', 'TRUE', 'True'): + self._set_env(v) + self.assertIs(utils.env_str_to_bool('__MCT_TEST_FLAG__', False), True, f'Expected True for {v}') + + def test_falsy_string_values(self): + for v in ('0', 'false', 'f', 'n', 'False'): + self._set_env(v) + self.assertIs(utils.env_str_to_bool('__MCT_TEST_FLAG__', True), False, f'Expected False for {v}') + + def test_unknown_string_returns_value_unchanged(self): + self._set_env('maybe') + self.assertEqual(utils.env_str_to_bool('__MCT_TEST_FLAG__', True), 'maybe') + + def test_uses_default_when_not_set(self): + os.environ.pop('__MCT_TEST_FLAG__', None) + self.assertTrue(utils.env_str_to_bool('__MCT_TEST_FLAG__', True)) + self.assertFalse(utils.env_str_to_bool('__MCT_TEST_FLAG__', False)) + + +class CallRemoteModelServiceSpacyTests(TestCase): + """Tests for utils.call_remote_model_service_spacy with mocked requests.""" + + @patch('api.utils.requests.post') + def test_parses_spacy_response_into_remote_entities(self, mock_post): + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: { + 'entities': { + '0': { + 'cui': 'C001', + 'start': 0, + 'end': 5, + 'detected_name': 'fever', + 'context_similarity': 0.9, + }, + '1': { + 'cui': 'C002', + 'start': 6, + 'end': 12, + 'detected_name': 'cough', + 'context_similarity': 0.8, + }, + } + }, + raise_for_status=lambda: None, + ) + + doc = utils.call_remote_model_service_spacy('http://service:8000/', 'fever cough') + self.assertEqual(len(doc.linked_ents), 2) + cui_set = {e.cui for e in doc.linked_ents} + self.assertEqual(cui_set, {'C001', 'C002'}) + + mock_post.assert_called_once() + call_args, call_kwargs = mock_post.call_args + self.assertEqual(call_args[0], 'http://service:8000/api/process') + self.assertEqual(call_kwargs['json'], {'text': 'fever cough'}) + + @patch('api.utils.requests.post') + def test_request_failure_is_re_raised(self, mock_post): + mock_post.side_effect = requests.exceptions.ConnectionError('boom') + with self.assertRaises(Exception) as ctx: + utils.call_remote_model_service_spacy('http://service:8000', 'text') + self.assertIn('Failed to call remote model service', str(ctx.exception)) + + +class CallRemoteModelServiceMedcatTests(TestCase): + @patch('api.utils.requests.post') + def test_parses_medcat_response_into_remote_entities(self, mock_post): + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: { + 'medcat_info': {'version': '1.0'}, + 'result': { + 'text': 'fever cough', + 'annotations': [ + { + '0': {'cui': 'C001', 'start': 0, 'end': 5, 'detected_name': 'fever', 'context_similarity': 0.9}, + '1': {'cui': 'C002', 'start': 6, 'end': 12, 'detected_name': 'cough', 'context_similarity': 0.8}, + } + ], + }, + }, + raise_for_status=lambda: None, + ) + doc = utils.call_remote_model_service_medcat('http://service:8000', 'fever cough') + self.assertEqual(len(doc.linked_ents), 2) + + @patch('api.utils.requests.post') + def test_raises_when_result_missing(self, mock_post): + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: {'medcat_info': {}}, + raise_for_status=lambda: None, + ) + with self.assertRaises(Exception) as ctx: + utils.call_remote_model_service_medcat('http://service:8000', 'text') + self.assertIn("missing 'result'", str(ctx.exception)) + + @patch('api.utils.requests.post') + def test_raises_when_result_contains_errors(self, mock_post): + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: {'result': {'errors': ['bad input']}}, + raise_for_status=lambda: None, + ) + with self.assertRaises(Exception) as ctx: + utils.call_remote_model_service_medcat('http://service:8000', 'text') + self.assertIn('errors', str(ctx.exception)) + + +class CallRemoteModelServiceDispatchTests(TestCase): + """Top-level dispatcher should route by REMOTE_MODEL_SERVICE_TYPE.""" + + def setUp(self): + # Ensure we don't leak across tests + self._original = os.environ.get('REMOTE_MODEL_SERVICE_TYPE') + + def tearDown(self): + if self._original is None: + os.environ.pop('REMOTE_MODEL_SERVICE_TYPE', None) + else: + os.environ['REMOTE_MODEL_SERVICE_TYPE'] = self._original + + @patch('api.utils.call_remote_model_service_spacy') + def test_dispatches_to_spacy_by_default(self, mock_spacy): + mock_spacy.return_value = 'doc' + os.environ.pop('REMOTE_MODEL_SERVICE_TYPE', None) + result = utils.call_remote_model_service('http://x', 'text') + mock_spacy.assert_called_once_with('http://x', 'text') + self.assertEqual(result, 'doc') + + @patch('api.utils.call_remote_model_service_medcat') + def test_dispatches_to_medcat_when_configured(self, mock_mc): + mock_mc.return_value = 'doc' + os.environ['REMOTE_MODEL_SERVICE_TYPE'] = 'medcat' + result = utils.call_remote_model_service('http://x', 'text') + mock_mc.assert_called_once_with('http://x', 'text') + self.assertEqual(result, 'doc') + + def test_unknown_service_type_raises(self): + os.environ['REMOTE_MODEL_SERVICE_TYPE'] = 'unknown' + with self.assertRaises(ValueError): + utils.call_remote_model_service('http://x', 'text') + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-utils') +class DbHelperTests(TestCase): + """Tests for DB-backed helpers (remove_annotations / create_annotation / + add_annotations) using lightweight in-memory fixtures. + """ + + def setUp(self): + from ._helpers import create_basic_project, create_document, create_entity, create_user + + self.user = create_user(username='ann-user') + self.project = create_basic_project(name='utils-test-project') + self.document = create_document(self.project, name='doc1', text='fever and cough') + self.entity = create_entity(label='C001') + + def _make_ann(self, validated=False): + ann = AnnotatedEntity( + user=self.user, + project=self.project, + document=self.document, + entity=self.entity, + value='fever', + start_ind=0, + end_ind=5, + acc=0.9, + validated=validated, + ) + ann.save() + return ann + + def test_remove_annotations_full_clears_all(self): + self._make_ann(validated=True) + self._make_ann(validated=False) + utils.remove_annotations(self.document, self.project, partial=False) + self.assertFalse( + AnnotatedEntity.objects.filter(project=self.project, document=self.document).exists() + ) + + def test_remove_annotations_partial_keeps_validated(self): + kept = self._make_ann(validated=True) + unvalidated = self._make_ann(validated=False) + + utils.remove_annotations(self.document, self.project, partial=True) + + remaining = AnnotatedEntity.objects.filter(project=self.project, document=self.document) + ids = {a.id for a in remaining} + self.assertIn(kept.id, ids) + self.assertNotIn(unvalidated.id, ids) + + def test_create_annotation_persists_manually_created(self): + ann_id = utils.create_annotation( + source_val='fever', + selection_occurrence_index=0, + cui='C001', + user=self.user, + project=self.project, + document=self.document, + ) + self.assertIsNotNone(ann_id) + ann = AnnotatedEntity.objects.get(id=ann_id) + self.assertEqual(ann.start_ind, 0) + self.assertEqual(ann.end_ind, 5) + self.assertTrue(ann.manually_created) + self.assertTrue(ann.validated) + self.assertTrue(ann.correct) + + def test_create_annotation_creates_new_entity_when_missing(self): + self.assertFalse(Entity.objects.filter(label='C999').exists()) + utils.create_annotation( + source_val='cough', + selection_occurrence_index=0, + cui='C999', + user=self.user, + project=self.project, + document=self.document, + ) + self.assertTrue(Entity.objects.filter(label='C999').exists()) + + def test_create_annotation_returns_none_for_empty_cui(self): + ann_id = utils.create_annotation( + source_val='fever', + selection_occurrence_index=0, + cui='', + user=self.user, + project=self.project, + document=self.document, + ) + self.assertIsNone(ann_id) + + def test_add_annotations_with_simple_filters(self): + spacy_doc = utils.RemoteSpacyDoc([ + utils.RemoteEntity({'cui': 'C001', 'start': 0, 'end': 5, 'detected_name': 'fever', + 'context_similarity': 0.9}, 'fever and cough'), + utils.RemoteEntity({'cui': 'C002', 'start': 10, 'end': 15, 'detected_name': 'cough', + 'context_similarity': 0.85}, 'fever and cough'), + ]) + + utils.add_annotations( + spacy_doc=spacy_doc, + user=self.user, + project=self.project, + document=self.document, + existing_annotations=[], + cat=None, + filters=utils.SimpleFilters(cuis={'C001'}), + similarity_threshold=0.5, + ) + + anns = list(AnnotatedEntity.objects.filter(project=self.project, document=self.document)) + self.assertEqual(len(anns), 1) + self.assertEqual(anns[0].entity.label, 'C001') + + def test_add_annotations_marks_low_acc_as_deleted(self): + spacy_doc = utils.RemoteSpacyDoc([ + utils.RemoteEntity({'cui': 'C100', 'start': 0, 'end': 5, 'detected_name': 'fever', + 'context_similarity': 0.1}, 'fever and cough'), + ]) + + utils.add_annotations( + spacy_doc=spacy_doc, + user=self.user, + project=self.project, + document=self.document, + existing_annotations=[], + cat=None, + filters=utils.SimpleFilters(), + similarity_threshold=0.3, + ) + + ann = AnnotatedEntity.objects.get(entity__label='C100') + self.assertTrue(ann.deleted) + self.assertTrue(ann.validated) + + def test_add_annotations_respects_excludes(self): + spacy_doc = utils.RemoteSpacyDoc([ + utils.RemoteEntity({'cui': 'C200', 'start': 0, 'end': 5, 'detected_name': 'fever', + 'context_similarity': 0.9}, 'fever and cough'), + utils.RemoteEntity({'cui': 'C201', 'start': 10, 'end': 15, 'detected_name': 'cough', + 'context_similarity': 0.9}, 'fever and cough'), + ]) + + utils.add_annotations( + spacy_doc=spacy_doc, + user=self.user, + project=self.project, + document=self.document, + existing_annotations=[], + cat=None, + filters=utils.SimpleFilters(cuis_exclude={'C200'}), + similarity_threshold=0.5, + ) + + labels = {a.entity.label for a in AnnotatedEntity.objects.filter(project=self.project, + document=self.document)} + self.assertIn('C201', labels) + self.assertNotIn('C200', labels) diff --git a/medcat-trainer/webapp/api/api/tests/test_views.py b/medcat-trainer/webapp/api/api/tests/test_views.py new file mode 100644 index 000000000..f2b625ba5 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_views.py @@ -0,0 +1,383 @@ +"""Integration tests for api.views using DRF's APIClient. + +These tests focus on endpoints that don't require MedCAT to be loaded. +""" + +import json +import os +from unittest.mock import patch + +from django.contrib.auth.models import User +from django.test import TestCase, override_settings +from rest_framework.test import APIClient + +from ..models import ( + AnnotatedEntity, + ConceptDB, + Document, + Entity, + ProjectAnnotateEntities, + Vocabulary, +) +from ._helpers import ( + create_basic_project, + create_dataset, + create_document, + create_entity, + create_user, +) + + +def _auth_client(user): + client = APIClient() + client.force_authenticate(user=user) + return client + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class AuthenticationRequiredTests(TestCase): + def test_anonymous_users_cannot_list_projects(self): + client = APIClient() + resp = client.get('/api/project-annotate-entities/') + # 401 (Unauthorized) or 403 (Forbidden) acceptable + self.assertIn(resp.status_code, (401, 403)) + + def test_anonymous_users_cannot_list_users(self): + client = APIClient() + resp = client.get('/api/users/') + self.assertIn(resp.status_code, (401, 403)) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class SimpleInfoEndpointsTests(TestCase): + def setUp(self): + self.user = create_user(username='infouser') + self.client = _auth_client(self.user) + + def test_version_returns_env_value(self): + old = os.environ.get('MCT_VERSION') + os.environ['MCT_VERSION'] = 'v9.9.9-test' + try: + resp = self.client.get('/api/version/') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.data, 'v9.9.9-test') + finally: + if old is None: + os.environ.pop('MCT_VERSION', None) + else: + os.environ['MCT_VERSION'] = old + + def test_behind_reverse_proxy_returns_value(self): + old = os.environ.get('BEHIND_RP') + os.environ['BEHIND_RP'] = '1' + try: + resp = self.client.get('/api/behind-rp/') + self.assertEqual(resp.status_code, 200) + self.assertTrue(resp.data) + finally: + if old is None: + os.environ.pop('BEHIND_RP', None) + else: + os.environ['BEHIND_RP'] = old + + def test_anno_tool_conf_returns_environment_dict(self): + resp = self.client.get('/api/anno-conf/') + self.assertEqual(resp.status_code, 200) + # Just make sure it returns a dict-like JSON object + self.assertIsInstance(resp.json(), dict) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class UserViewSetTests(TestCase): + def setUp(self): + self.user = create_user(username='listuser', password='pw') + self.other = create_user(username='otheruser', password='pw') + + def test_authenticated_user_can_list_users(self): + client = _auth_client(self.user) + resp = client.get('/api/users/') + self.assertEqual(resp.status_code, 200) + usernames = [u['username'] for u in resp.json()['results']] + self.assertIn('listuser', usernames) + self.assertIn('otheruser', usernames) + + def test_filter_by_username(self): + client = _auth_client(self.user) + resp = client.get('/api/users/?username=otheruser') + self.assertEqual(resp.status_code, 200) + results = resp.json()['results'] + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['username'], 'otheruser') + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class ProjectAnnotateEntitiesViewSetTests(TestCase): + def setUp(self): + self.member = create_user(username='m1', password='pw') + self.outsider = create_user(username='o1', password='pw') + self.superuser = User.objects.create_superuser( + username='su1', password='pw', email='su1@x', + ) + self.project = create_basic_project(name='pl-proj') + self.project.members.add(self.member) + + def test_member_sees_only_their_projects(self): + client = _auth_client(self.member) + resp = client.get('/api/project-annotate-entities/') + self.assertEqual(resp.status_code, 200) + names = [p['name'] for p in resp.json()['results']] + self.assertIn('pl-proj', names) + + def test_outsider_sees_no_projects(self): + client = _auth_client(self.outsider) + resp = client.get('/api/project-annotate-entities/') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json()['results'], []) + + def test_superuser_sees_all_projects(self): + client = _auth_client(self.superuser) + resp = client.get('/api/project-annotate-entities/') + self.assertEqual(resp.status_code, 200) + names = [p['name'] for p in resp.json()['results']] + self.assertIn('pl-proj', names) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class GetCreateEntityTests(TestCase): + def setUp(self): + self.user = create_user(username='entuser') + self.client = _auth_client(self.user) + + def test_creates_entity_when_label_does_not_exist(self): + self.assertFalse(Entity.objects.filter(label='C-NEW').exists()) + resp = self.client.post('/api/get-create-entity/', {'label': 'C-NEW'}, format='json') + self.assertEqual(resp.status_code, 200) + self.assertTrue(Entity.objects.filter(label='C-NEW').exists()) + self.assertIn('entity_id', resp.json()) + + def test_returns_existing_entity_id_when_label_exists(self): + ent = create_entity(label='C-EXIST') + resp = self.client.post('/api/get-create-entity/', {'label': 'C-EXIST'}, format='json') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json()['entity_id'], ent.id) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class ProjectProgressTests(TestCase): + def setUp(self): + self.user = create_user(username='ppuser') + self.client = _auth_client(self.user) + self.project = create_basic_project(name='pp-proj') + + # Create 3 documents but no annotations + for i in range(3): + create_document(self.project, name=f'doc-{i}', text=f'text {i}') + + def test_returns_progress_for_project(self): + resp = self.client.get(f'/api/project-progress/?projects={self.project.id}') + self.assertEqual(resp.status_code, 200) + data = resp.json() + # JSON dict keys are strings + key = str(self.project.id) + self.assertIn(key, data) + self.assertEqual(data[key]['validated_count'], 0) + self.assertEqual(data[key]['dataset_count'], 3) + + def test_returns_400_when_no_projects_param(self): + resp = self.client.get('/api/project-progress/') + self.assertEqual(resp.status_code, 400) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class PrepareDocsBgTaskTests(TestCase): + def setUp(self): + self.user = create_user(username='bguser') + self.client = _auth_client(self.user) + self.project = create_basic_project(name='bg-proj') + for i in range(2): + create_document(self.project, name=f'd-{i}', text=f't-{i}') + + def test_returns_400_for_unknown_project(self): + resp = self.client.get('/api/prep-docs-bg-tasks/999999/') + self.assertEqual(resp.status_code, 400) + + def test_returns_doc_counts_for_known_project(self): + resp = self.client.get(f'/api/prep-docs-bg-tasks/{self.project.id}/') + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data['dataset_len'], 2) + self.assertEqual(data['prepd_docs_len'], 0) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class ProjectAdminProjectsTests(TestCase): + def setUp(self): + self.member = create_user(username='admin-mem') + self.outsider = create_user(username='admin-out') + self.project = create_basic_project(name='admin-proj') + self.project.members.add(self.member) + + def test_member_can_list_admin_projects(self): + client = _auth_client(self.member) + resp = client.get('/api/project-admin/projects/') + self.assertEqual(resp.status_code, 200) + names = [p['name'] for p in resp.json()] + self.assertIn('admin-proj', names) + + def test_outsider_has_no_admin_projects(self): + client = _auth_client(self.outsider) + resp = client.get('/api/project-admin/projects/') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json(), []) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class ProjectAdminDetailTests(TestCase): + def setUp(self): + self.member = create_user(username='detail-mem') + self.outsider = create_user(username='detail-out') + self.project = create_basic_project(name='detail-proj') + self.project.members.add(self.member) + + def test_member_can_access_project_detail(self): + client = _auth_client(self.member) + resp = client.get(f'/api/project-admin/projects/{self.project.id}/') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json()['name'], 'detail-proj') + + def test_outsider_gets_403(self): + client = _auth_client(self.outsider) + resp = client.get(f'/api/project-admin/projects/{self.project.id}/') + self.assertEqual(resp.status_code, 403) + + def test_returns_404_for_unknown_project(self): + client = _auth_client(self.member) + resp = client.get('/api/project-admin/projects/9999999/') + self.assertEqual(resp.status_code, 404) + + def test_member_can_delete_project(self): + client = _auth_client(self.member) + resp = client.delete(f'/api/project-admin/projects/{self.project.id}/') + self.assertEqual(resp.status_code, 200) + self.assertFalse( + ProjectAnnotateEntities.objects.filter(id=self.project.id).exists() + ) + + def test_outsider_cannot_delete(self): + client = _auth_client(self.outsider) + resp = client.delete(f'/api/project-admin/projects/{self.project.id}/') + self.assertEqual(resp.status_code, 403) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class ProjectAdminResetTests(TestCase): + def setUp(self): + self.member = create_user(username='reset-mem') + self.outsider = create_user(username='reset-out') + self.project = create_basic_project(name='reset-proj') + self.project.members.add(self.member) + + doc = create_document(self.project, name='doc', text='hello') + ent = create_entity(label='C-RESET') + AnnotatedEntity.objects.create( + user=self.member, project=self.project, document=doc, entity=ent, + value='hello', start_ind=0, end_ind=5, acc=1.0, validated=True, + ) + self.project.validated_documents.add(doc) + + def test_member_resets_annotations(self): + self.assertEqual( + AnnotatedEntity.objects.filter(project=self.project).count(), 1 + ) + client = _auth_client(self.member) + resp = client.post(f'/api/project-admin/projects/{self.project.id}/reset/') + self.assertEqual(resp.status_code, 200) + self.assertEqual( + AnnotatedEntity.objects.filter(project=self.project).count(), 0 + ) + self.project.refresh_from_db() + self.assertEqual(self.project.validated_documents.count(), 0) + + def test_outsider_cannot_reset(self): + client = _auth_client(self.outsider) + resp = client.post(f'/api/project-admin/projects/{self.project.id}/reset/') + self.assertEqual(resp.status_code, 403) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class ProjectAdminCloneTests(TestCase): + def setUp(self): + self.member = create_user(username='clone-mem') + self.outsider = create_user(username='clone-out') + self.project = create_basic_project(name='clone-proj') + self.project.members.add(self.member) + + def test_clone_returns_new_project(self): + client = _auth_client(self.member) + resp = client.post( + f'/api/project-admin/projects/{self.project.id}/clone/', + {'name': 'my-clone'}, + format='json', + ) + self.assertEqual(resp.status_code, 201, msg=resp.content) + self.assertEqual(resp.json()['name'], 'my-clone') + self.assertTrue(ProjectAnnotateEntities.objects.filter(name='my-clone').exists()) + + def test_clone_default_name_when_unspecified(self): + client = _auth_client(self.member) + resp = client.post( + f'/api/project-admin/projects/{self.project.id}/clone/', + {}, + format='json', + ) + self.assertEqual(resp.status_code, 201) + self.assertEqual(resp.json()['name'], 'clone-proj (Clone)') + + def test_clone_returns_404_for_unknown_project(self): + client = _auth_client(self.member) + resp = client.post('/api/project-admin/projects/99999/clone/', {}, format='json') + self.assertEqual(resp.status_code, 404) + + def test_outsider_cannot_clone(self): + client = _auth_client(self.outsider) + resp = client.post( + f'/api/project-admin/projects/{self.project.id}/clone/', + {}, + format='json', + ) + self.assertEqual(resp.status_code, 403) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class ModelLoadedTests(TestCase): + def setUp(self): + self.user = create_user(username='ml-user') + self.client = _auth_client(self.user) + self.project = create_basic_project(name='ml-proj') + + def test_returns_model_states_for_all_projects(self): + resp = self.client.get('/api/model-loaded/') + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn('model_states', data) + self.assertIn(str(self.project.id), {str(k) for k in data['model_states']}) + + +@override_settings(MEDIA_ROOT='/tmp/mct-tests-views') +class DownloadAnnosTests(TestCase): + def setUp(self): + self.regular = create_user(username='reg') + self.superuser = User.objects.create_superuser(username='dl-su', password='pw', email='dl@x') + + def test_non_superuser_cannot_download(self): + client = _auth_client(self.regular) + resp = client.get('/api/download-annos/?project_ids=1') + self.assertEqual(resp.status_code, 400) + + def test_superuser_can_download_for_existing_project(self): + project = create_basic_project(name='dl-proj') + client = _auth_client(self.superuser) + resp = client.get(f'/api/download-annos/?project_ids={project.id}&with_text=true') + self.assertEqual(resp.status_code, 200) + # Response is a streaming JSON document + self.assertIn('Content-Disposition', resp) diff --git a/medcat-trainer/webapp/frontend/env.d.ts b/medcat-trainer/webapp/frontend/env.d.ts index 356c5d67a..2b96ad0bb 100644 --- a/medcat-trainer/webapp/frontend/env.d.ts +++ b/medcat-trainer/webapp/frontend/env.d.ts @@ -5,3 +5,13 @@ declare module '*.vue' { const component: DefineComponent