diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 6b110cb17..e104243c9 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -154,12 +154,21 @@ services: JWT_SECRET: DH8kSxcflUVfNRdkEiJJCn2dOOKI3qfw POSTGRES_URI: postgresql+psycopg://zimfarm:zimpass@zimfarm-db:5432/zimfarm ALEMBIC_UPGRADE_HEAD_ON_START: true + DEBUG: true INIT_USERNAME: admin INIT_PASSWORD: admin ALLOWED_ORIGINS: http://localhost:8003 ARTIFACTS_UPLOAD_URI: s3+http://minio:9000/?keyId=minio_key&secretAccessKey=minio_secret&bucketName=org-kiwix-dev-artifacts LOGS_UPLOAD_URI: s3+http://minio:9000/?keyId=minio_key&secretAccessKey=minio_secret&bucketName=org-kiwix-dev-logs ZIM_UPLOAD_URI: s3+http://minio:9000/?keyId=minio_key&secretAccessKey=minio_secret&bucketName=org-kiwix-dev-zims + OAUTH_JWKS_URI: https://ory.login-staging.kiwix.org/.well-known/jwks.json + OAUTH_ISSUER: https://ory.login-staging.kiwix.org + OAUTH_OIDC_CLIENT_ID: 38302485-7f3d-4e88-b1b3-ec02fb92ec8c + OAUTH_OIDC_LOGIN_REQUIRE_2FA: true + OAUTH_SESSION_AUDIENCE_ID: 309693e7-ad5e-4379-bf93-ba89314230fd + OAUTH_SESSION_LOGIN_REQUIRE_2FA: true + CREATE_NEW_OAUTH_ACCOUNT: false + AUTH_MODES: local,oauth-session networks: - wp1bot-dev depends_on: diff --git a/docker/zimfarm/zimfarm_ui_dev/config.json b/docker/zimfarm/zimfarm_ui_dev/config.json index cbea83688..67cba86d7 100644 --- a/docker/zimfarm/zimfarm_ui_dev/config.json +++ b/docker/zimfarm/zimfarm_ui_dev/config.json @@ -1,3 +1,6 @@ { - "ZIMFARM_WEBAPI": "http://localhost:8004/v2" + "ZIMFARM_WEBAPI": "http://localhost:8004/v2", + "OAUTH_BASE_URL": "https://ory.login-staging.kiwix.org", + "OAUTH_MODE": "session", + "LOGIN_MODES": ["local"] } diff --git a/wp1/credentials.py.dev.e2e b/wp1/credentials.py.dev.e2e index 1e5954bc1..b1da60344 100644 --- a/wp1/credentials.py.dev.e2e +++ b/wp1/credentials.py.dev.e2e @@ -40,12 +40,24 @@ CREDENTIALS = { 'secret': '', 'bucket': 'org-kiwix-dev-wp1', }, - + 'MAILGUN': { 'url': 'https://api.eu.mailgun.net/v3/mg.wp1.openzim.org/messages', 'api_key': 'INSERT_YOUR_MAILGUN_API_KEY_HERE', }, - + "ZIMFARM": { + "auth_mode": "local", + "url": "https://fake.farm/v2", + "s3_url": "https://fake.wasabisys.com/org-kiwix-zimit", + "user": "farmuser", + "password": "farmpass", + "hook_token": "hook-token-abc", + # if auth_mode is set to "oauth", then, uncomment these + # "oauth_issuer": "https://ory.login-staging.kiwix.org", + # "oauth_client_id": "oauth-client-id", + # "oauth_client_secret": "oauth-client-secret", + # "oauth_audience_id": "oauth-audience-id", + }, 'FILE_PATH': { # Path where pageviews.bz2 file (~3GB) will be downloaded. 'pageviews': '/tmp/pageviews', diff --git a/wp1/credentials.py.dev.example b/wp1/credentials.py.dev.example index e687af204..d60895fbc 100644 --- a/wp1/credentials.py.dev.example +++ b/wp1/credentials.py.dev.example @@ -101,6 +101,7 @@ CREDENTIALS = { # Server URL and credentials for the Zim Farm that will be used to create # ZIM files from materialized selections. 'ZIMFARM': { + 'auth_mode': 'local', 'url': 'http://zimfarm-api/v2', 's3_url': 'http://localhost:9000/org-kiwix-dev-zims', # if using minio 'user': 'admin', @@ -111,6 +112,11 @@ CREDENTIALS = { # Update this to the latest version at the time of your deployment. 'image': 'ghcr.io/openzim/mwoffliner:1.17.2', 'definition_version': '1.17.2', + # if auth_mode is set to "oauth", then, uncomment these + # "oauth_issuer": "https://ory.login-staging.kiwix.org", + # "oauth_client_id": "oauth-client-id", + # "oauth_client_secret": "oauth-client-secret", + # "oauth_audience_id": "oauth-audience-id", }, # Logging directives. Keys are the names of the loggers, values are dictionaries diff --git a/wp1/credentials.py.e2e b/wp1/credentials.py.e2e index 37cfa6579..5e3023a9e 100644 --- a/wp1/credentials.py.e2e +++ b/wp1/credentials.py.e2e @@ -76,11 +76,17 @@ CREDENTIALS = { 'backend': 'http://test.server.fake' }, 'ZIMFARM': { + 'auth_mode': 'local', 'url': 'https://fake.farm/v2', 's3_url': 'https://fake.wasabisys.com/org-kiwix-zimit', 'user': 'farmuser', 'password': 'farmpass', 'hook_token': 'hook-token-abc', + # if auth_mode is set to "oauth", then, uncomment these + # "oauth_issuer": "https://ory.login-staging.kiwix.org", + # "oauth_client_id": "oauth-client-id", + # "oauth_client_secret": "oauth-client-secret", + # "oauth_audience_id": "oauth-audience-id", }, 'MAILGUN': { 'url': 'https://api.eu.mailgun.net/v3/mg.wp1.openzim.org/messages', diff --git a/wp1/credentials.py.example b/wp1/credentials.py.example index 2119d1f03..f5aeb90f1 100644 --- a/wp1/credentials.py.example +++ b/wp1/credentials.py.example @@ -129,6 +129,7 @@ CREDENTIALS = { # Server URL and credentials for the Zim Farm that will be used to create # ZIM files from materialized selections. 'ZIMFARM': { + 'auth_mode': 'local', 'url': 'http://zimfarm-api/v2', 's3_url': 'https://localhost:9000/org-kiwix-dev-zims', # if using minio 'user': 'admin', @@ -139,6 +140,11 @@ CREDENTIALS = { # Update this to the latest version at the time of your deployment. 'image': 'ghcr.io/openzim/mwoffliner:1.17.2', 'definition_version': '1.17.2', + # if auth_mode is set to "oauth", then, uncomment these + # "oauth_issuer": "https://ory.login-staging.kiwix.org", + # "oauth_client_id": "oauth-client-id", + # "oauth_client_secret": "oauth-client-secret", + # "oauth_audience_id": "oauth-audience-id", }, # Credentials for the Mailgun service, used to send emails. This is not diff --git a/wp1/logic/builder_test.py b/wp1/logic/builder_test.py index 962e6f93b..141096255 100644 --- a/wp1/logic/builder_test.py +++ b/wp1/logic/builder_test.py @@ -1115,7 +1115,7 @@ def test_handle_zim_generation( "wp1.logic.builder.utcnow", return_value=datetime.datetime(2022, 12, 25, 0, 1, 2), ) - @patch("wp1.logic.builder.zimfarm.get_zimfarm_token", return_value="test_token") + @patch("wp1.logic.builder.zimfarm.token_provider", return_value="test_token") def test_handle_zim_generation_long_title( self, mock_utcnow, mock_get_zimfarm_token ): diff --git a/wp1/timestamp.py b/wp1/timestamp.py index 2313d4849..5f68ba339 100644 --- a/wp1/timestamp.py +++ b/wp1/timestamp.py @@ -1,5 +1,10 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC def utcnow() -> datetime: return datetime.now(timezone.utc) + + +def naive_utcnow(): + """naive UTC now""" + return datetime.now(UTC).replace(tzinfo=None) diff --git a/wp1/web/builders_test.py b/wp1/web/builders_test.py index 8b628a57a..1c367c5b5 100644 --- a/wp1/web/builders_test.py +++ b/wp1/web/builders_test.py @@ -587,7 +587,7 @@ def test_create_zim_file_for_builder( patched_request_zimfarm_task.assert_called_once() with self.wp10db.cursor() as cursor: cursor.execute( - "SELECT z_task_id, z_status FROM zim_tasks " "WHERE z_selection_id = 3" + "SELECT z_task_id, z_status FROM zim_tasks WHERE z_selection_id = 3" ) data = cursor.fetchone() @@ -665,7 +665,7 @@ def test_create_zim_file_for_builder_400(self, patched_request_zimfarm_task): @patch("wp1.zimfarm.request_zimfarm_task") # Mock requests to avoid actual HTTP calls @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider.get_access_token") def test_create_zim_file_for_builder_no_title( self, mock_get_token, mock_requests, mock_request_zimfarm_task ): @@ -685,7 +685,7 @@ def test_create_zim_file_for_builder_no_title( @patch("wp1.zimfarm.request_zimfarm_task") # Mock requests to avoid actual HTTP calls @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider.get_access_token") def test_create_zim_file_for_builder_too_long_title( self, mock_get_token, mock_requests, mock_request_zimfarm_task ): @@ -706,7 +706,7 @@ def test_create_zim_file_for_builder_too_long_title( @patch("wp1.zimfarm.request_zimfarm_task") # Mock requests to avoid actual HTTP calls @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider.get_access_token") def test_create_zim_file_for_builder_too_long_description( self, mock_get_token, mock_requests, mock_request_zimfarm_task ): @@ -727,7 +727,7 @@ def test_create_zim_file_for_builder_too_long_description( @patch("wp1.zimfarm.request_zimfarm_task") # Mock requests to avoid actual HTTP calls @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider.get_access_token") def test_create_zim_file_for_builder_too_long_long_description( self, mock_get_token, mock_requests, mock_request_zimfarm_task ): @@ -752,7 +752,7 @@ def test_create_zim_file_for_builder_too_long_long_description( @patch("wp1.zimfarm.request_zimfarm_task") # Mock requests to avoid actual HTTP calls @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider.get_access_token") def test_create_zim_file_for_builder_too_short_long_description( self, mock_get_token, mock_requests, mock_request_zimfarm_task ): @@ -775,7 +775,7 @@ def test_create_zim_file_for_builder_too_short_long_description( # Mock requests to avoid actual HTTP calls @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider.get_access_token") def test_create_zim_file_for_builder_too_many_articles( self, token_mock, requests_mock ): diff --git a/wp1/zimfarm.py b/wp1/zimfarm.py index 5b0d82a08..6f1f12ce6 100644 --- a/wp1/zimfarm.py +++ b/wp1/zimfarm.py @@ -1,10 +1,12 @@ import logging import urllib.parse import uuid -from datetime import UTC, datetime +from typing import cast, Any +from datetime import datetime, UTC, timedelta import regex import requests +from requests.auth import HTTPBasicAuth import wp1.logic.builder as logic_builder import wp1.logic.selection as logic_selection @@ -24,7 +26,7 @@ from wp1.models.wp10.builder import Builder from wp1.models.wp10.selection import Selection from wp1.models.wp10.zim_schedule import ZimSchedule -from wp1.time import get_current_datetime +from wp1.timestamp import naive_utcnow REDIS_AUTH_KEY = "zimfarm.auth" @@ -45,87 +47,153 @@ def store_zimfarm_token(redis, data): redis.hset(REDIS_AUTH_KEY, mapping=data) -def request_zimfarm_token(redis): - user = CREDENTIALS[ENV].get("ZIMFARM", {}).get("user") - password = CREDENTIALS[ENV].get("ZIMFARM", {}).get("password") - - if user is None or password is None: - raise ZimFarmError( - "Could not log into zimfarm, user/password not found in " "site credentials" - ) - - logger.debug( - "Requesting auth token from %s with username/password", get_zimfarm_url() - ) - r = requests.post( - "%s/auth/authorize" % get_zimfarm_url(), - headers={"User-Agent": WP1_USER_AGENT}, - json={"username": user, "password": password}, - ) - try: - r.raise_for_status() - except requests.exceptions.HTTPError as e: - logger.exception(r.text) - raise ZimFarmError("Error getting authentication token for Zimfarm") from e +class ZimfarmClientTokenProvider: + """Client to generate access tokens to authenticate with Zimfarm API""" + + def __init__(self): + self._access_token: str | None = None + self._refresh_token: str | None = None + self._expires_at: datetime = datetime.fromtimestamp(0, UTC).replace(tzinfo=None) + self._zimfarm_creds: dict[str, Any] = CREDENTIALS[ENV].get("ZIMFARM", {}) + + def _validate_creds(self): + if self._zimfarm_creds.get("auth_mode", "local") == "local": + if not ( + self._zimfarm_creds.get("user") and self._zimfarm_creds.get("password") + ): + raise ZimFarmError( + "user and password must be set in Zimfarm site credentials " + "when auth mode is 'local' " + ) + elif self._zimfarm_creds.get("auth_mode") == "oauth": + if not ( + self._zimfarm_creds.get("oauth_issuer") + and self._zimfarm_creds.get("oauth_client_id") + and self._zimfarm_creds.get("oauth_client_secret") + and self._zimfarm_creds.get("oauth_audience_id") + ): + raise ZimFarmError( + "oauth_client_secret, oauth_client_id and oauth_audience_id must be set " + "in Zimfarm site credentials when auth mode is 'oauth'" + ) + else: + raise ZimFarmError( + f"Unknown auth mode {self._zimfarm_creds.get('auth_mode')}. " + "Allowed values are 'local' and 'oauth'." + ) - data = r.json() - store_zimfarm_token(redis, data) + def _generate_oauth_access_token(self) -> None: + """Generate oauth access token and update expires_at.""" - access_token = data.get("access_token") - if access_token is None: - logger.warning( - "Access token from zimfarm API was None, full response: %s", data + logger.debug( + "Requesting auth token from %s with oauth credentials", + self._zimfarm_creds.get("oauth_issuer"), ) - return access_token - - -def refresh_zimfarm_token(redis, refresh_token): - logger.debug( - "Requesting access_token from %s using refresh_token", get_zimfarm_url() - ) - r = requests.post( - "%s/auth/refresh" % get_zimfarm_url(), - headers={ - "User-Agent": WP1_USER_AGENT, - }, - json={ - "refresh_token": refresh_token, - }, - ) - try: - r.raise_for_status() - except requests.exceptions.HTTPError as e: - logger.exception(r.text) - raise ZimFarmError("Error getting authentication token for Zimfarm") from e - - data = r.json() - access_token = data.get("access_token") - if access_token is None: - logger.warning( - "Access token from zimfarm API was None, full response: %s", data + response = requests.post( + f"{self._zimfarm_creds.get('oauth_issuer')}/oauth2/token", + data={ + "grant_type": "client_credentials", + "audience": self._zimfarm_creds.get("oauth_audience_id"), + }, + auth=HTTPBasicAuth( + self._zimfarm_creds.get("oauth_client_id"), + self._zimfarm_creds.get("oauth_client_secret"), + ), + timeout=self._zimfarm_creds.get("requests_timeout", 30), + headers={"User-Agent": WP1_USER_AGENT}, ) - return access_token + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + logger.exception(response.text) + raise ZimFarmError("Error getting authentication token for Zimfarm") from e + payload = response.json() + self._access_token = cast(str, payload["access_token"]) + self._expires_at = naive_utcnow() + timedelta(seconds=payload["expires_in"]) -def get_zimfarm_token(redis): - data = redis.hgetall(REDIS_AUTH_KEY) - if data is None or data.get("refresh_token") is None: - logger.debug("No saved zimfarm refresh_token, requesting") - return request_zimfarm_token(redis) + def _generate_local_access_token(self) -> None: + if self._refresh_token: + logger.debug( + "Requesting access_token from %s using refresh_token", get_zimfarm_url() + ) + response = requests.post( + f"{get_zimfarm_url()}/auth/refresh", + json={ + "refresh_token": self._refresh_token, + }, + timeout=self._zimfarm_creds.get("requests_timeout", 30), + headers={"User-Agent": WP1_USER_AGENT}, + ) + else: + logger.debug( + "Requesting auth token from %s with username/password", + get_zimfarm_url(), + ) + response = requests.post( + f"{get_zimfarm_url()}/auth/authorize", + json={ + "username": self._zimfarm_creds.get("user"), + "password": self._zimfarm_creds.get("password"), + }, + timeout=self._zimfarm_creds.get("requests_timeout", 30), + headers={"User-Agent": WP1_USER_AGENT}, + ) - access_expired = ( - datetime.strptime( - data.get("expires_time", "1970-01-01T00:00:00Z"), "%Y-%m-%dT%H:%M:%SZ" + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + logger.exception(response.text) + raise ZimFarmError("Error getting authentication token for Zimfarm") from e + + payload = response.json() + self._access_token = cast(str, payload["access_token"]) + self._refresh_token = cast(str, payload["refresh_token"]) + self._expires_at = datetime.fromisoformat(payload["expires_time"]).replace( + tzinfo=None ) - < get_current_datetime() - ) - - if access_expired: - logger.debug("Zimfarm access_token is expired, refreshing") - return refresh_zimfarm_token(redis, data["refresh_token"]) - return data.get("access_token") + def get_access_token(self, redis) -> str: + """Retrieve or generate access token depending on if token has expired.""" + self._validate_creds() + + data = redis.hgetall(REDIS_AUTH_KEY) + if data is not None: + self._access_token = data.get("access_token") + self._refresh_token = data.get("refresh_token") + if data.get("expires_at"): + self._expires_at = datetime.fromisoformat(data["expires_at"]).replace( + tzinfo=None + ) + + now = naive_utcnow() + if self._access_token is None or now >= ( + self._expires_at + - timedelta(seconds=self._zimfarm_creds.get("token_renewal_window", 300)) + ): + logger.debug("Refreshing Zimfarm acess token") + if self._zimfarm_creds.get("auth_mode") == "oauth": + self._generate_oauth_access_token() + elif self._zimfarm_creds.get("auth_mode") == "local": + self._generate_local_access_token() + + if self._access_token: + store_zimfarm_token( + redis, + { + "access_token": self._access_token, + "refresh_token": self._refresh_token or "", + "expires_at": self._expires_at.isoformat(), + }, + ) + + if self._access_token is None: + raise ZimFarmError("Failed to generate access token.") + return self._access_token + + +token_provider = ZimfarmClientTokenProvider() def get_zimfarm_url(): @@ -297,9 +365,7 @@ def _get_zimfarm_headers(token): def zimfarm_schedule_exists(redis, builder_id: str) -> bool: """Checks if a ZimSchedule exists in the zimfarm""" - token = get_zimfarm_token(redis) - if token is None: - raise ZimFarmError("Error retrieving auth token for request") + token = token_provider.get_access_token(redis) base_url = get_zimfarm_url() headers = _get_zimfarm_headers(token) @@ -339,9 +405,7 @@ def create_or_update_zimfarm_schedule( """ Requests a ZIM file schedule from the Zimfarm for the given builder. """ - token = get_zimfarm_token(redis) - if token is None: - raise ZimFarmError("Error retrieving auth token for request") + token = token_provider.get_access_token(redis) if builder is None: raise ObjectNotFoundError("Cannot schedule for None builder") @@ -425,10 +489,7 @@ def request_zimfarm_task(redis, wp10db, builder): """ Requests a ZIM file task from the Zimfarm for the given builder. """ - token = get_zimfarm_token(redis) - if token is None: - raise ZimFarmError("Error retrieving auth token for request") - + token = token_provider.get_access_token(redis) if builder is None: raise ObjectNotFoundError("Cannot schedule for None builder") @@ -524,9 +585,7 @@ def cancel_zim_by_task_id(redis, task_id): if isinstance(task_id, bytes): task_id = task_id.decode("utf-8") - token = get_zimfarm_token(redis) - if token is None: - raise ZimFarmError("Error retrieving auth token for request") + token = token_provider.get_access_token(redis) base_url = get_zimfarm_url() headers = _get_zimfarm_headers(token) @@ -547,7 +606,7 @@ def cancel_zim_by_task_id(redis, task_id): try: r.raise_for_status() - except requests.exceptions.HTTPError as e: + except requests.exceptions.HTTPError: raise ZimFarmError("Task could not be deleted/canceled (task_id=%s)" % task_id) @@ -558,9 +617,7 @@ def delete_zimfarm_schedule_by_builder_id(redis, builder_id): if isinstance(builder_id, bytes): builder_id = builder_id.decode("utf-8") - token = get_zimfarm_token(redis) - if token is None: - raise ZimFarmError("Error retrieving auth token for request") + token = token_provider.get_access_token(redis) base_url = get_zimfarm_url() headers = _get_zimfarm_headers(token) diff --git a/wp1/zimfarm_test.py b/wp1/zimfarm_test.py index 78e6b9870..2f283049c 100644 --- a/wp1/zimfarm_test.py +++ b/wp1/zimfarm_test.py @@ -21,6 +21,7 @@ ZIM_DESCRIPTION_MAX_LENGTH, ZIM_LONG_DESCRIPTION_MAX_LENGTH, ZIM_TITLE_MAX_LENGTH, + ZimfarmClientTokenProvider, ) @@ -238,157 +239,458 @@ def test_get_params_missing_builder(self): with self.assertRaises(ValueError): zimfarm._get_params(None, self.selection, "Tile", "Desc", "Long Desc") - @patch("wp1.zimfarm.requests") - def test_request_zimfarm_token(self, mock_requests): - redis = MagicMock() - mock_response = MagicMock() - mock_response.json.return_value = {"access_token": "abcdef"} - mock_requests.post.return_value = mock_response + @patch("wp1.zimfarm.CREDENTIALS") + def test_token_provider_init_oauth_valid(self, mock_credentials): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "oauth", + "oauth_client_id": "test_client_id", + "oauth_client_secret": "test_client_secret", + "oauth_audience_id": "test_audience", + "oauth_issuer": "https://oauth.example.com", + } + } + + provider = ZimfarmClientTokenProvider() + provider._validate_creds() + + self.assertIsNone(provider._access_token) + self.assertIsNone(provider._refresh_token) + + @patch("wp1.zimfarm.CREDENTIALS") + def test_token_provider_init_local_valid(self, mock_credentials): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "local", + "user": "test_user", + "password": "test_pass", + } + } + + provider = ZimfarmClientTokenProvider() + provider._validate_creds() + + self.assertIsNone(provider._access_token) + self.assertIsNone(provider._refresh_token) + + @patch("wp1.zimfarm.CREDENTIALS") + def test_token_provider_init_oauth_missing_credentials(self, mock_credentials): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "oauth", + "oauth_client_id": "test_client_id", + } + } + + with self.assertRaises(ZimFarmError): + ZimfarmClientTokenProvider()._validate_creds() + + @patch("wp1.zimfarm.CREDENTIALS") + def test_token_provider_init_local_missing_credentials(self, mock_credentials): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "local", + "user": "test_user", + } + } + + with self.assertRaises(ZimFarmError): + ZimfarmClientTokenProvider()._validate_creds() - actual = zimfarm.request_zimfarm_token(redis) + @patch("wp1.zimfarm.CREDENTIALS") + def test_token_provider_init_unknown_auth_mode(self, mock_credentials): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": {"auth_mode": "unknown"} + } - self.assertEqual("abcdef", actual) + with self.assertRaises(ZimFarmError): + ZimfarmClientTokenProvider()._validate_creds() + @patch("wp1.zimfarm.CREDENTIALS") @patch("wp1.zimfarm.requests") - def test_request_zimfarm_token_posts_with_correct_data(self, mock_requests): - redis = MagicMock() + @patch("wp1.zimfarm.naive_utcnow") + def test_token_provider_generate_oauth_access_token_success( + self, mock_naive_utcnow, mock_requests, mock_credentials + ): + """Test _generate_oauth_access_token successfully generates token""" + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "oauth", + "oauth_client_id": "test_client_id", + "oauth_client_secret": "test_client_secret", + "oauth_audience_id": "test_audience", + "oauth_issuer": "https://oauth.example.com", + "requests_timeout": 30, + } + } + + mock_naive_utcnow.return_value = datetime.datetime( + 2023, 1, 1, 0, 0, 0, tzinfo=None + ) mock_response = MagicMock() - mock_response.json.return_value = {"access_token": "abcdef"} + mock_response.json.return_value = { + "access_token": "oauth_token_123", + "expires_in": 3600, + } mock_requests.post.return_value = mock_response - zimfarm.request_zimfarm_token(redis) + provider = ZimfarmClientTokenProvider() + provider._generate_oauth_access_token() - mock_requests.post.assert_called_once_with( - "https://fake.farm/v2/auth/authorize", - headers={"User-Agent": "WP 1.0 bot 1.0.0/Audiodude "}, - json={"username": "farmuser", "password": "farmpass"}, - ) + self.assertEqual(provider._access_token, "oauth_token_123") + self.assertEqual(provider._expires_at, datetime.datetime(2023, 1, 1, 1, 0, 0)) + @patch("wp1.zimfarm.CREDENTIALS") @patch("wp1.zimfarm.requests") - def test_request_zimfarm_token_raises_for_status(self, mock_requests): - redis = MagicMock() + def test_token_provider_generate_oauth_access_token_http_error( + self, mock_requests, mock_credentials + ): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "oauth", + "oauth_client_id": "test_client_id", + "oauth_client_secret": "test_client_secret", + "oauth_audience_id": "test_audience", + "oauth_issuer": "https://oauth.example.com", + } + } + mock_response = MagicMock() mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError - mock_requests.exceptions.HTTPError = requests.exceptions.HTTPError mock_requests.post.return_value = mock_response + mock_requests.exceptions.HTTPError = requests.exceptions.HTTPError - with self.assertRaises(ZimFarmError): - zimfarm.request_zimfarm_token(redis) - - @patch("wp1.zimfarm.CREDENTIALS", {Environment.TEST: {}}) - def test_request_zimfarm_token_no_creds(self): - redis = MagicMock() + provider = ZimfarmClientTokenProvider() with self.assertRaises(ZimFarmError): - zimfarm.request_zimfarm_token(redis) + provider._generate_oauth_access_token() + @patch("wp1.zimfarm.CREDENTIALS") @patch("wp1.zimfarm.requests") - def test_refresh_zimfarm_token(self, mock_requests): - redis = MagicMock() + @patch("wp1.zimfarm.get_zimfarm_url") + def test_token_provider_generate_local_access_token_no_refresh( + self, mock_get_url, mock_requests, mock_credentials + ): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "local", + "user": "test_user", + "password": "test_pass", + "requests_timeout": 30, + } + } + mock_get_url.return_value = "https://fake.farm/v2" + mock_response = MagicMock() - mock_response.json.return_value = {"access_token": "abcdef"} + mock_response.json.return_value = { + "access_token": "local_token_123", + "refresh_token": "refresh_token_123", + "expires_time": "2023-01-01T12:00:00Z", + } mock_requests.post.return_value = mock_response - refresh_token = "12345" - actual = zimfarm.refresh_zimfarm_token(redis, refresh_token) + provider = ZimfarmClientTokenProvider() + provider._generate_local_access_token() - self.assertEqual("abcdef", actual) + self.assertEqual(provider._access_token, "local_token_123") + self.assertEqual(provider._refresh_token, "refresh_token_123") + self.assertEqual(provider._expires_at, datetime.datetime(2023, 1, 1, 12, 0, 0)) + @patch("wp1.zimfarm.CREDENTIALS") @patch("wp1.zimfarm.requests") - def test_refresh_zimfarm_token_posts_with_correct_data(self, mock_requests): - redis = MagicMock() + @patch("wp1.zimfarm.get_zimfarm_url") + def test_token_provider_generate_local_access_token_with_refresh( + self, mock_get_url, mock_requests, mock_credentials + ): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "local", + "user": "test_user", + "password": "test_pass", + "requests_timeout": 30, + } + } + mock_get_url.return_value = "https://fake.farm/v2" + mock_response = MagicMock() - mock_response.json.return_value = {"access_token": "abcdef"} + mock_response.json.return_value = { + "access_token": "new_token_456", + "refresh_token": "new_refresh_456", + "expires_time": "2023-01-01T13:00:00Z", + } mock_requests.post.return_value = mock_response - refresh_token = "12345" - zimfarm.refresh_zimfarm_token(redis, refresh_token) + provider = ZimfarmClientTokenProvider() + provider._refresh_token = "old_refresh_token" + provider._generate_local_access_token() + self.assertEqual(provider._access_token, "new_token_456") + self.assertEqual(provider._refresh_token, "new_refresh_456") mock_requests.post.assert_called_once_with( "https://fake.farm/v2/auth/refresh", - headers={ - "User-Agent": "WP 1.0 bot 1.0.0/Audiodude ", - }, - json={"refresh_token": "12345"}, + json={"refresh_token": "old_refresh_token"}, + timeout=30, + headers={"User-Agent": "WP 1.0 bot 1.0.0/Audiodude "}, ) + @patch("wp1.zimfarm.CREDENTIALS") @patch("wp1.zimfarm.requests") - def test_refresh_zimfarm_token_raises_for_status(self, mock_requests): - redis = MagicMock() + @patch("wp1.zimfarm.get_zimfarm_url") + def test_token_provider_generate_local_access_token_http_error( + self, mock_get_url, mock_requests, mock_credentials + ): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "local", + "user": "test_user", + "password": "test_pass", + } + } + mock_get_url.return_value = "https://fake.farm/v2" + mock_response = MagicMock() mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError - mock_requests.exceptions.HTTPError = requests.exceptions.HTTPError mock_requests.post.return_value = mock_response + mock_requests.exceptions.HTTPError = requests.exceptions.HTTPError + + provider = ZimfarmClientTokenProvider() with self.assertRaises(ZimFarmError): - zimfarm.refresh_zimfarm_token(redis, "12345") + provider._generate_local_access_token() - @patch("wp1.zimfarm.CREDENTIALS", {Environment.TEST: {}}) - def test_refresh_zimfarm_token_no_creds(self): - redis = MagicMock() + @patch("wp1.zimfarm.CREDENTIALS") + @patch("wp1.zimfarm.naive_utcnow") + def test_token_provider_get_access_token_not_expired( + self, mock_naive_utcnow, mock_credentials + ): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "local", + "user": "test_user", + "password": "test_pass", + "token_renewal_window": 300, + } + } - with self.assertRaises(ZimFarmError): - zimfarm.refresh_zimfarm_token(redis, "12345") + mock_naive_utcnow.return_value = datetime.datetime( + 2023, 1, 1, 11, 50, 0, tzinfo=None + ) - @patch("wp1.zimfarm.request_zimfarm_token") - def test_get_zimfarm_token_no_data(self, request_token_mock): redis = MagicMock() - redis.hgetall.return_value = None + redis.hgetall.return_value = { + "access_token": "existing_token", + "refresh_token": "existing_refresh", + "expires_at": "2023-01-01T12:00:00Z", + } + + provider = ZimfarmClientTokenProvider() + token = provider.get_access_token(redis) - request_token_mock.return_value = "bcdefg" - actual = zimfarm.get_zimfarm_token(redis) - self.assertEqual(actual, "bcdefg") + self.assertEqual(token, "existing_token") + redis.hset.assert_not_called() + + @patch("wp1.zimfarm.CREDENTIALS") + @patch("wp1.zimfarm.naive_utcnow") + @patch("wp1.zimfarm.requests") + def test_token_provider_get_access_token_expired_oauth( + self, mock_requests, mock_naive_utcnow, mock_credentials + ): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "oauth", + "oauth_client_id": "test_client_id", + "oauth_client_secret": "test_client_secret", + "oauth_audience_id": "test_audience", + "oauth_issuer": "https://oauth.example.com", + "token_renewal_window": 300, + } + } + + mock_naive_utcnow.return_value = datetime.datetime( + 2023, 1, 1, 12, 0, 0, tzinfo=None + ) - @patch("wp1.zimfarm.request_zimfarm_token") - def test_get_zimfarm_token_no_refresh_token(self, request_token_mock): redis = MagicMock() redis.hgetall.return_value = { - "expires_time": "2023-01-01T00:00:01Z", - "access_token": "abcdef", + "access_token": "old_token", + "refresh_token": "", + "expires_at": "2023-01-01T11:50:00Z", # Expired } - request_token_mock.return_value = "bcdefg" - actual = zimfarm.get_zimfarm_token(redis) - self.assertEqual(actual, "bcdefg") + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "new_oauth_token", + "expires_in": 3600, + } + mock_requests.post.return_value = mock_response + + provider = ZimfarmClientTokenProvider() + token = provider.get_access_token(redis) + + self.assertEqual(token, "new_oauth_token") + redis.hset.assert_called_once() + + @patch("wp1.zimfarm.CREDENTIALS") + @patch("wp1.zimfarm.naive_utcnow") + @patch("wp1.zimfarm.requests") + @patch("wp1.zimfarm.get_zimfarm_url") + def test_token_provider_get_access_token_expired_local( + self, mock_get_url, mock_requests, mock_naive_utcnow, mock_credentials + ): + """Test get_access_token refreshes token when expired (local mode)""" + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "local", + "user": "test_user", + "password": "test_pass", + "token_renewal_window": 300, + } + } + mock_get_url.return_value = "https://fake.farm/v2" - @patch("wp1.zimfarm.get_current_datetime") - def test_get_zimfarm_token_access_token_not_expired(self, current_datetime_mock): - current_datetime_mock.return_value = datetime.datetime(2022, 12, 25, 5, 5, 55) + mock_naive_utcnow.return_value = datetime.datetime( + 2023, 1, 1, 12, 0, 0, tzinfo=None + ) redis = MagicMock() redis.hgetall.return_value = { - "expires_time": "2023-01-01T00:00:01Z", - "refresh_token": "12345", - "access_token": "abcdef", + "access_token": "old_token", + "refresh_token": "old_refresh", + "expires_at": "2023-01-01T11:50:00Z", # Expired } - actual = zimfarm.get_zimfarm_token(redis) + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "new_local_token", + "refresh_token": "new_refresh", + "expires_time": "2023-01-01T13:00:00Z", + } + mock_requests.post.return_value = mock_response + + provider = ZimfarmClientTokenProvider() + token = provider.get_access_token(redis) - self.assertEqual(actual, "abcdef") + self.assertEqual(token, "new_local_token") + redis.hset.assert_called_once() - @patch("wp1.zimfarm.get_current_datetime") - @patch("wp1.zimfarm.refresh_zimfarm_token") - def test_get_zimfarm_token_access_token_expired( - self, refresh_token_mock, current_datetime_mock + @patch("wp1.zimfarm.CREDENTIALS") + @patch("wp1.zimfarm.naive_utcnow") + @patch("wp1.zimfarm.requests") + def test_token_provider_get_access_token_no_redis_data_oauth( + self, mock_requests, mock_naive_utcnow, mock_credentials ): - current_datetime_mock.return_value = datetime.datetime(2022, 12, 25, 5, 5, 55) - refresh_token_mock.return_value = "bcdefg" + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "oauth", + "oauth_client_id": "test_client_id", + "oauth_client_secret": "test_client_secret", + "oauth_audience_id": "test_audience", + "oauth_issuer": "https://oauth.example.com", + } + } + + mock_naive_utcnow.return_value = datetime.datetime( + 2023, 1, 1, 12, 0, 0, tzinfo=None + ) redis = MagicMock() - redis.hgetall.return_value = { - "expires_time": "2022-12-01T00:00:01Z", - "refresh_token": "12345", - "access_token": "abcdef", + redis.hgetall.return_value = None + + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "fresh_oauth_token", + "expires_in": 3600, + } + mock_requests.post.return_value = mock_response + + provider = ZimfarmClientTokenProvider() + token = provider.get_access_token(redis) + + self.assertEqual(token, "fresh_oauth_token") + redis.hset.assert_called_once() + + @patch("wp1.zimfarm.CREDENTIALS") + @patch("wp1.zimfarm.naive_utcnow") + @patch("wp1.zimfarm.requests") + @patch("wp1.zimfarm.get_zimfarm_url") + def test_token_provider_get_access_token_no_redis_data_local( + self, mock_get_url, mock_requests, mock_naive_utcnow, mock_credentials + ): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "local", + "user": "test_user", + "password": "test_pass", + } + } + mock_get_url.return_value = "https://fake.farm/v2" + + mock_naive_utcnow.return_value = datetime.datetime( + 2023, 1, 1, 12, 0, 0, tzinfo=None + ) + + redis = MagicMock() + redis.hgetall.return_value = None + + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "fresh_local_token", + "refresh_token": "fresh_refresh", + "expires_time": "2023-01-01T13:00:00Z", + } + mock_requests.post.return_value = mock_response + + provider = ZimfarmClientTokenProvider() + token = provider.get_access_token(redis) + + self.assertEqual(token, "fresh_local_token") + redis.hset.assert_called_once() + + @patch("wp1.zimfarm.CREDENTIALS") + @patch("wp1.zimfarm.naive_utcnow") + def test_token_provider_get_access_token_stores_in_redis( + self, mock_naive_utcnow, mock_credentials + ): + mock_credentials.__getitem__.return_value = { + "ZIMFARM": { + "auth_mode": "oauth", + "oauth_client_id": "test_client_id", + "oauth_client_secret": "test_client_secret", + "oauth_audience_id": "test_audience", + "oauth_issuer": "https://oauth.example.com", + } } - actual = zimfarm.get_zimfarm_token(redis) + mock_naive_utcnow.return_value = datetime.datetime( + 2023, 1, 1, 12, 0, 0, tzinfo=None + ) - self.assertEqual(actual, "bcdefg") + redis = MagicMock() + redis.hgetall.return_value = None - @patch("wp1.zimfarm.get_zimfarm_token") - def test_create_or_update_zimfarm_schedule_missing_token(self, get_token_mock): + with patch("wp1.zimfarm.requests") as mock_requests: + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "token_to_store", + "expires_in": 3600, + } + mock_requests.post.return_value = mock_response + + provider = ZimfarmClientTokenProvider() + provider.get_access_token(redis) + + redis.hset.assert_called_once() + call_args = redis.hset.call_args + self.assertEqual(call_args[0][0], "zimfarm.auth") + self.assertIn("access_token", call_args[1]["mapping"]) + self.assertEqual(call_args[1]["mapping"]["access_token"], "token_to_store") + + @patch("wp1.zimfarm.token_provider") + def test_create_or_update_zimfarm_schedule_missing_token(self, mock_token_provider): redis = MagicMock() - get_token_mock.return_value = None + mock_token_provider.get_access_token.side_effect = ZimFarmError( + "Failed to generate access token." + ) with self.assertRaises(ZimFarmError): zimfarm.create_or_update_zimfarm_schedule( @@ -400,10 +702,12 @@ def test_create_or_update_zimfarm_schedule_missing_token(self, get_token_mock): long_description=None, ) - @patch("wp1.zimfarm.get_zimfarm_token") - def test_create_or_update_zimfarm_schedule_missing_builder(self, get_token_mock): + @patch("wp1.zimfarm.token_provider") + def test_create_or_update_zimfarm_schedule_missing_builder( + self, mock_token_provider + ): redis = MagicMock() - get_token_mock.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" with self.assertRaises(ObjectNotFoundError): zimfarm.create_or_update_zimfarm_schedule( @@ -416,15 +720,15 @@ def test_create_or_update_zimfarm_schedule_missing_builder(self, get_token_mock) ) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm._get_params") def test_create_or_update_zimfarm_schedule_creates( - self, get_params_mock, get_token_mock, mock_requests + self, get_params_mock, mock_token_provider, mock_requests ): """Test creating a new schedule when no existing schedule is found""" redis = MagicMock() get_params_mock.return_value = {"name": "bar"} - get_token_mock.return_value = "abcdef" + mock_token_provider.get_access_token.return_value = "abcdef" mock_response = MagicMock() mock_response.json.return_value = {"requested": ["9876"]} mock_requests.post.side_effect = (MagicMock(), mock_response) @@ -459,16 +763,20 @@ def test_create_or_update_zimfarm_schedule_creates( self.assertEqual(long_desc.encode("utf-8"), result["s_long_description"]) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm._get_params") @patch("wp1.zimfarm.zimfarm_schedule_exists", return_value=True) def test_create_or_update_zimfarm_schedule_updates( - self, get_params_mock, get_token_mock, mock_requests, zimfarm_schedule_exists + self, + zimfarm_schedule_exists, + get_params_mock, + mock_token_provider, + mock_requests, ): """Test that an existing schedule is updated and persisted in the DB.""" redis = MagicMock() get_params_mock.return_value = {"name": "bar"} - get_token_mock.return_value = "abcdef" + mock_token_provider.get_access_token.return_value = "abcdef" # Insert an existing schedule schedule_id = self._insert_zim_schedule( schedule_id=b"existing-123", @@ -481,7 +789,8 @@ def test_create_or_update_zimfarm_schedule_updates( mock_response = MagicMock() mock_response.json.return_value = {"requested": ["9876"]} - mock_requests.post.side_effect = (MagicMock(), mock_response) + mock_requests.patch.side_effect = (MagicMock(),) + mock_requests.post.side_effect = (mock_response,) # Call the function zimfarm.create_or_update_zimfarm_schedule( @@ -506,14 +815,14 @@ def test_create_or_update_zimfarm_schedule_updates( self.assertIsNone(result["s_remaining_generations"]) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm._get_params") def test_create_or_update_zimfarm_schedule_http_error( - self, get_params_mock, get_token_mock, mock_requests + self, get_params_mock, mock_token_provider, mock_requests ): redis = MagicMock() get_params_mock.return_value = {"name": "bar"} - get_token_mock.return_value = "abcdef" + mock_token_provider.get_access_token.return_value = "abcdef" create_schedule_response = MagicMock() create_schedule_response.raise_for_status.side_effect = ( requests.exceptions.HTTPError @@ -531,10 +840,12 @@ def test_create_or_update_zimfarm_schedule_http_error( None, ) - @patch("wp1.zimfarm.get_zimfarm_token") - def test_create_or_update_zimfarm_schedule_too_long_title(self, get_token_mock): + @patch("wp1.zimfarm.token_provider") + def test_create_or_update_zimfarm_schedule_too_long_title( + self, mock_token_provider + ): redis = MagicMock() - get_token_mock.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" wrong_title = "a" * (ZIM_TITLE_MAX_LENGTH + 1) with self.assertRaises(InvalidZimTitleError): @@ -542,12 +853,12 @@ def test_create_or_update_zimfarm_schedule_too_long_title(self, get_token_mock): redis, self.wp10db, self.builder, wrong_title, "Test Description", None ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") def test_create_or_update_zimfarm_schedule_too_long_description( - self, get_token_mock + self, mock_token_provider ): redis = MagicMock() - get_token_mock.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" too_long_description = "z" * (ZIM_DESCRIPTION_MAX_LENGTH + 1) with self.assertRaises(InvalidZimDescriptionError): @@ -560,12 +871,12 @@ def test_create_or_update_zimfarm_schedule_too_long_description( None, ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") def test_create_or_update_zimfarm_schedule_too_long_long_description( - self, get_token_mock + self, mock_token_provider ): redis = MagicMock() - get_token_mock.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" too_long_long_description = "z" * (ZIM_LONG_DESCRIPTION_MAX_LENGTH + 1) with self.assertRaises(InvalidZimLongDescriptionError): @@ -578,12 +889,12 @@ def test_create_or_update_zimfarm_schedule_too_long_long_description( long_description=too_long_long_description, ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") def test_create_or_update_zimfarm_schedule_too_short_long_description( - self, get_token_mock + self, mock_token_provider ): redis = MagicMock() - get_token_mock.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" with self.assertRaises(InvalidZimLongDescriptionError): zimfarm.create_or_update_zimfarm_schedule( @@ -596,15 +907,15 @@ def test_create_or_update_zimfarm_schedule_too_short_long_description( ) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm._get_params") def test_create_or_update_zimfarm_schedule_create_empty_long_desc_ok( - self, get_params_mock, get_token_mock, mock_requests + self, get_params_mock, mock_token_provider, mock_requests ): """Test creating a new schedule when no existing schedule is found""" redis = MagicMock() get_params_mock.return_value = {"name": "bar"} - get_token_mock.return_value = "abcdef" + mock_token_provider.get_access_token.return_value = "abcdef" zimfarm.create_or_update_zimfarm_schedule( redis, @@ -626,15 +937,15 @@ def test_create_or_update_zimfarm_schedule_create_empty_long_desc_ok( self.assertEqual(None, result["s_long_description"]) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm._get_params") def test_create_or_update_zimfarm_schedule_create_missing_long_desc_ok( - self, get_params_mock, get_token_mock, mock_requests + self, get_params_mock, mock_token_provider, mock_requests ): """Test creating a new schedule when no existing schedule is found""" redis = MagicMock() get_params_mock.return_value = {"name": "bar"} - get_token_mock.return_value = "abcdef" + mock_token_provider.get_access_token.return_value = "abcdef" zimfarm.create_or_update_zimfarm_schedule( redis, @@ -655,10 +966,12 @@ def test_create_or_update_zimfarm_schedule_create_missing_long_desc_ok( self.assertEqual(b"Test Description", result["s_description"]) self.assertEqual(None, result["s_long_description"]) - @patch("wp1.zimfarm.get_zimfarm_token") - def test_create_or_update_zimfarm_schedule_equal_descriptions(self, get_token_mock): + @patch("wp1.zimfarm.token_provider") + def test_create_or_update_zimfarm_schedule_equal_descriptions( + self, mock_token_provider + ): redis = MagicMock() - get_token_mock.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" with self.assertRaises(InvalidZimLongDescriptionError): zimfarm.create_or_update_zimfarm_schedule( @@ -670,13 +983,13 @@ def test_create_or_update_zimfarm_schedule_equal_descriptions(self, get_token_mo long_description="Same description", ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.logic.builder.latest_selection_for") def test_create_or_update_zimfarm_schedule_too_many_articles( - self, mock_latest_selection, get_token_mock + self, mock_latest_selection, mock_token_provider ): redis = MagicMock() - get_token_mock.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" mock_selection = MagicMock() mock_selection.s_article_count = 60000 # Above MAX_ZIMFARM_ARTICLE_COUNT @@ -692,13 +1005,13 @@ def test_create_or_update_zimfarm_schedule_too_many_articles( None, ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.logic.builder.latest_selection_for") def test_create_or_update_zimfarm_schedule_none_article_count( - self, mock_latest_selection, get_token_mock + self, mock_latest_selection, mock_token_provider ): redis = MagicMock() - get_token_mock.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" mock_selection = MagicMock() mock_selection.s_article_count = None @@ -715,14 +1028,14 @@ def test_create_or_update_zimfarm_schedule_none_article_count( ) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm._get_params") def test_create_or_update_zimfarm_schedule_valid_graphemes( - self, get_params_mock, get_token_mock, mock_requests + self, get_params_mock, mock_token_provider, mock_requests ): redis = MagicMock() get_params_mock.return_value = {"name": "bar"} - get_token_mock.return_value = "abcdef" + mock_token_provider.get_access_token.return_value = "abcdef" mock_response = MagicMock() mock_response.json.return_value = {"requested": ["9876"]} mock_requests.post.side_effect = (MagicMock(), mock_response, MagicMock()) @@ -738,23 +1051,25 @@ def test_create_or_update_zimfarm_schedule_valid_graphemes( ) mock_requests.post.assert_called_once() - @patch("wp1.zimfarm.get_zimfarm_token") - def test_request_zimfarm_task_missing_token(self, get_token_mock): + @patch("wp1.zimfarm.token_provider") + def test_request_zimfarm_task_missing_token(self, mock_token_provider): redis = MagicMock() - get_token_mock.return_value = None + mock_token_provider.get_access_token.side_effect = ZimFarmError( + "Failed to generate access token." + ) with self.assertRaises(ZimFarmError): zimfarm.request_zimfarm_task(redis, self.wp10db, self.builder) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.get_zimfarm_schedule_name") def test_request_zimfarm_task( - self, get_zimfarm_schedule_name_mock, get_token_mock, mock_requests + self, get_zimfarm_schedule_name_mock, mock_token_provider, mock_requests ): redis = MagicMock() get_zimfarm_schedule_name_mock.return_value = "bar" - get_token_mock.return_value = "abcdef" + mock_token_provider.get_access_token.return_value = "abcdef" mock_response = MagicMock() mock_response.json.return_value = {"requested": ["9876"]} mock_requests.post.return_value = mock_response @@ -764,22 +1079,24 @@ def test_request_zimfarm_task( self.assertEqual("9876", actual) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") - def test_request_zimfarm_task_missing_builder(self, get_token_mock, mock_requests): + @patch("wp1.zimfarm.token_provider") + def test_request_zimfarm_task_missing_builder( + self, mock_token_provider, mock_requests + ): redis = MagicMock() with self.assertRaises(ObjectNotFoundError): zimfarm.request_zimfarm_task(redis, self.wp10db, None) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.get_zimfarm_schedule_name") def test_request_zimfarm_task_post_requests( - self, get_zimfarm_schedule_name_mock, get_token_mock, mock_requests + self, get_zimfarm_schedule_name_mock, mock_token_provider, mock_requests ): redis = MagicMock() get_zimfarm_schedule_name_mock.return_value = "bar" - get_token_mock.return_value = "abcdef" + mock_token_provider.get_access_token.return_value = "abcdef" mock_response = MagicMock() mock_response.json.return_value = {"requested": ["9876"]} mock_requests.post.side_effect = (MagicMock(), mock_response) @@ -796,10 +1113,10 @@ def test_request_zimfarm_task_post_requests( ) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.get_zimfarm_schedule_name") def test_request_zimfarm_task_missing_task_id( - self, get_zimfarm_schedule_name_mock, get_token_mock, mock_requests + self, get_zimfarm_schedule_name_mock, mock_token_provider, mock_requests ): redis = MagicMock() get_zimfarm_schedule_name_mock.return_value = "bar" @@ -812,10 +1129,10 @@ def test_request_zimfarm_task_missing_task_id( zimfarm.request_zimfarm_task(redis, self.wp10db, self.builder) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.get_zimfarm_schedule_name") def test_request_zimfarm_task_missing_article_count( - self, get_zimfarm_schedule_name_mock, get_token_mock, mock_requests + self, get_zimfarm_schedule_name_mock, mock_token_provider, mock_requests ): redis = MagicMock() get_zimfarm_schedule_name_mock.return_value = "bar" @@ -828,10 +1145,10 @@ def test_request_zimfarm_task_missing_article_count( zimfarm.request_zimfarm_task(redis, self.wp10db, self.builder) @patch("wp1.zimfarm.requests") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.get_zimfarm_schedule_name") def test_request_zimfarm_task_too_many_articles( - self, get_zimfarm_schedule_name_mock, get_zimfarm_token_mock, mock_requests + self, get_zimfarm_schedule_name_mock, mock_token_provider, mock_requests ): redis = MagicMock() get_zimfarm_schedule_name_mock.return_value = "bar" @@ -901,10 +1218,10 @@ def test_zim_file_url_for_task_id_missing_s3_url(self, patched_get): with self.assertRaises(ZimFarmError): zimfarm.zim_file_url_for_task_id("foo-bar") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests.delete") - def test_cancel_zim_by_task_id(self, patched_delete, patched_get_zimfarm_token): - patched_get_zimfarm_token.return_value = "foo-token" + def test_cancel_zim_by_task_id(self, patched_delete, mock_token_provider): + mock_token_provider.get_access_token.return_value = "foo-token" redis = MagicMock() zimfarm.cancel_zim_by_task_id(redis, "task-abc-123") @@ -916,13 +1233,13 @@ def test_cancel_zim_by_task_id(self, patched_delete, patched_get_zimfarm_token): }, ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests.delete") @patch("wp1.zimfarm.requests.post") def test_cancel_zim_by_task_id_first_delete_404( - self, patched_post, patched_delete, patched_get_zimfarm_token + self, patched_post, patched_delete, mock_token_provider ): - patched_get_zimfarm_token.return_value = "foo-token" + mock_token_provider.get_access_token.return_value = "foo-token" redis = MagicMock() response_404 = MagicMock() response_404.status_code = 404 @@ -945,13 +1262,13 @@ def test_cancel_zim_by_task_id_first_delete_404( }, ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests.delete") @patch("wp1.zimfarm.requests.post") def test_cancel_zim_by_task_id_first_delete_other_error( - self, patched_post, patched_delete, patched_get_zimfarm_token + self, patched_post, patched_delete, mock_token_provider ): - patched_get_zimfarm_token.return_value = "foo-token" + mock_token_provider.get_access_token.return_value = "foo-token" redis = MagicMock() response_500 = MagicMock() response_500.status_code = 500 @@ -961,13 +1278,13 @@ def test_cancel_zim_by_task_id_first_delete_other_error( with self.assertRaises(ZimFarmError): zimfarm.cancel_zim_by_task_id(redis, "task-abc-123") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests.delete") @patch("wp1.zimfarm.requests.post") def test_cancel_zim_by_task_id_second_delete_error( - self, patched_post, patched_delete, patched_get_zimfarm_token + self, patched_post, patched_delete, mock_token_provider ): - patched_get_zimfarm_token.return_value = "foo-token" + mock_token_provider.get_access_token.return_value = "foo-token" redis = MagicMock() response_404 = MagicMock() response_404.status_code = 404 @@ -985,11 +1302,11 @@ def test_cancel_zim_by_task_id_second_delete_error( with self.assertRaises(ZimFarmError): zimfarm.cancel_zim_by_task_id(redis, "task-abc-123") - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests") - def test_zimfarm_schedule_exists_true(self, mock_requests, mock_get_token): + def test_zimfarm_schedule_exists_true(self, mock_requests, mock_token_provider): """Test zimfarm_schedule_exists returns True when schedule exists (200 status)""" - mock_get_token.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" mock_response = MagicMock() mock_response.status_code = 200 mock_requests.get.return_value = mock_response @@ -1008,11 +1325,11 @@ def test_zimfarm_schedule_exists_true(self, mock_requests, mock_get_token): }, ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests") - def test_zimfarm_schedule_exists_false(self, mock_requests, mock_get_token): + def test_zimfarm_schedule_exists_false(self, mock_requests, mock_token_provider): """Test zimfarm_schedule_exists returns False when schedule doesn't exist (404 status)""" - mock_get_token.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" mock_response = MagicMock() mock_response.status_code = 404 mock_requests.get.return_value = mock_response @@ -1024,24 +1341,26 @@ def test_zimfarm_schedule_exists_false(self, mock_requests, mock_get_token): self.assertFalse(result) - @patch("wp1.zimfarm.get_zimfarm_token") - def test_zimfarm_schedule_exists_no_token(self, mock_get_token): + @patch("wp1.zimfarm.token_provider") + def test_zimfarm_schedule_exists_no_token(self, mock_token_provider): """Test zimfarm_schedule_exists raises error when token is None""" - mock_get_token.return_value = None + mock_token_provider.get_access_token.side_effect = ZimFarmError( + "Failed to generate access token." + ) redis = MagicMock() builder_id = "1a-2b-3c-4d" - with self.assertRaises(ZimFarmError) as cm: + with self.assertRaises(ZimFarmError): zimfarm.zimfarm_schedule_exists(redis, builder_id) - self.assertEqual(str(cm.exception), "Error retrieving auth token for request") - - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests") - def test_zimfarm_schedule_exists_http_error(self, mock_requests, mock_get_token): + def test_zimfarm_schedule_exists_http_error( + self, mock_requests, mock_token_provider + ): """Test zimfarm_schedule_exists raises error when HTTP request fails""" - mock_get_token.return_value = "test-token" + mock_token_provider.get_access_token.return_value = "test-token" mock_response = MagicMock() mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError mock_requests.get.return_value = mock_response @@ -1111,12 +1430,12 @@ def test_find_existing_schedule_in_db_no_schedules(self): result = zimfarm.find_existing_schedule_in_db(self.wp10db, self.builder.b_id) self.assertIsNone(result) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests.delete") def test_delete_zimfarm_schedule_by_builder_id_success( - self, patched_delete, patched_get_zimfarm_token + self, patched_delete, mock_token_provider ): - patched_get_zimfarm_token.return_value = "foo-token" + mock_token_provider.get_access_token.return_value = "foo-token" redis = MagicMock() response = MagicMock() response.status_code = 204 @@ -1134,12 +1453,12 @@ def test_delete_zimfarm_schedule_by_builder_id_success( }, ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests.delete") def test_delete_zimfarm_schedule_by_builder_id_not_found( - self, patched_delete, patched_get_zimfarm_token + self, patched_delete, mock_token_provider ): - patched_get_zimfarm_token.return_value = "foo-token" + mock_token_provider.get_access_token.return_value = "foo-token" redis = MagicMock() response = MagicMock() response.status_code = 404 @@ -1157,12 +1476,12 @@ def test_delete_zimfarm_schedule_by_builder_id_not_found( }, ) - @patch("wp1.zimfarm.get_zimfarm_token") + @patch("wp1.zimfarm.token_provider") @patch("wp1.zimfarm.requests.delete") def test_delete_zimfarm_schedule_by_builder_id_error( - self, patched_delete, patched_get_zimfarm_token + self, patched_delete, mock_token_provider ): - patched_get_zimfarm_token.return_value = "foo-token" + mock_token_provider.get_access_token.return_value = "foo-token" redis = MagicMock() response = MagicMock() response.status_code = 500 @@ -1172,11 +1491,11 @@ def test_delete_zimfarm_schedule_by_builder_id_error( with self.assertRaises(ZimFarmError): zimfarm.delete_zimfarm_schedule_by_builder_id(redis, "1a-2b-3c-4d") - @patch("wp1.zimfarm.get_zimfarm_token") - def test_delete_zimfarm_schedule_by_builder_id_no_token( - self, patched_get_zimfarm_token - ): - patched_get_zimfarm_token.return_value = None + @patch("wp1.zimfarm.token_provider") + def test_delete_zimfarm_schedule_by_builder_id_no_token(self, mock_token_provider): + mock_token_provider.get_access_token.side_effect = ZimFarmError( + "Failed to generate access token." + ) redis = MagicMock() with self.assertRaises(ZimFarmError):