diff --git a/src/landingai_ade/_client.py b/src/landingai_ade/_client.py index fba7be3..562fee1 100644 --- a/src/landingai_ade/_client.py +++ b/src/landingai_ade/_client.py @@ -85,8 +85,12 @@ def _get_input_filename( ) -> str: """Extract base filename (without extension) from file or URL input.""" if file_input is not None and not isinstance(file_input, Omit): - if isinstance(file_input, (Path, str)): + if isinstance(file_input, (Path, os.PathLike)): return Path(file_input).stem + elif isinstance(file_input, str): + # Strings are always treated as raw content, not file paths. + # File inputs should use Path objects, tuples, or IO objects. + pass elif isinstance(file_input, tuple) and len(file_input) > 0: # Tuple format: (filename, content, mime_type) return Path(str(file_input[0])).stem @@ -111,12 +115,25 @@ def _save_response( method_name: str, result: Any, ) -> None: - """Save API response to a JSON file in the specified folder.""" + """Save API response to a JSON file. + + If save_to ends with '.json', it is treated as a full file path and the + response is written there directly. Otherwise it is treated as a directory + and the file is auto-named '{filename}_{method_name}_output.json' + (or '{method_name}_output.json' when filename is 'output'). + """ try: - folder = Path(save_to) - folder.mkdir(parents=True, exist_ok=True) - output_path = folder / f"{filename}_{method_name}_output.json" - output_path.write_text(result.to_json()) + save_path = Path(save_to) + if str(save_to).endswith(".json"): + save_path.parent.mkdir(parents=True, exist_ok=True) + save_path.write_text(result.to_json()) + else: + save_path.mkdir(parents=True, exist_ok=True) + if filename == "output": + output_path = save_path / f"{method_name}_output.json" + else: + output_path = save_path / f"{filename}_{method_name}_output.json" + output_path.write_text(result.to_json()) except OSError as exc: raise LandingAiadeError(f"Failed to save {method_name} response to {save_to}: {exc}") from exc @@ -328,9 +345,10 @@ def extract( strict: If True, reject schemas with unsupported fields (HTTP 422). If False, prune unsupported fields and continue. Only applies to extract versions that support schema validation. - save_to: Optional output folder path. If provided, the response will be saved as - JSON to this folder with the filename format: {input_file}_extract_output.json. - The folder will be created if it doesn't exist. + save_to: Optional output path. If a directory, auto-generates the filename + (e.g. {input_file}_extract_output.json, or extract_output.json when no + input filename is available). If a full path ending in .json, saves there + directly. Parent directories are created automatically. extra_headers: Send extra headers @@ -429,9 +447,10 @@ def parse( parameter. Set the parameter to page to split documents at the page level. The splits object in the API output will contain a set of data for each page. - save_to: Optional output folder path. If provided, the response will be saved as - JSON to this folder with the filename format: {input_file}_parse_output.json. - The folder will be created if it doesn't exist. + save_to: Optional output path. If a directory, auto-generates the filename + (e.g. {input_file}_parse_output.json, or parse_output.json when no + input filename is available). If a full path ending in .json, saves there + directly. Parent directories are created automatically. extra_headers: Send extra headers @@ -518,9 +537,10 @@ def split( model: Model version to use for split classification. Defaults to the latest version. - save_to: Optional output folder path. If provided, the response will be saved as - JSON to this folder with the filename format: {input_file}_split_output.json. - The folder will be created if it doesn't exist. + save_to: Optional output path. If a directory, auto-generates the filename + (e.g. {input_file}_split_output.json, or split_output.json when no + input filename is available). If a full path ending in .json, saves there + directly. Parent directories are created automatically. extra_headers: Send extra headers @@ -768,6 +788,7 @@ async def extract( markdown_url: Optional[str] | Omit = omit, model: Optional[str] | Omit = omit, strict: bool | Omit = omit, + save_to: str | Path | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -801,6 +822,11 @@ async def extract( unsupported fields and continue. Only applies to extract versions that support schema validation. + save_to: Optional output path. If a directory, auto-generates the filename + (e.g. {input_file}_extract_output.json, or extract_output.json when no + input filename is available). If a full path ending in .json, saves there + directly. Parent directories are created automatically. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -809,6 +835,9 @@ async def extract( timeout: Override the client-level default timeout for this request, in seconds """ + # Store original inputs for filename extraction before conversion + original_markdown = markdown + original_markdown_url = markdown_url # Convert local file paths to file parameters markdown, markdown_url = convert_url_to_file_if_local(markdown, markdown_url) @@ -830,7 +859,7 @@ async def extract( "runtime_tag": f"ade-python-v{_LIB_VERSION}", **(extra_headers or {}), } - return await self.post( + result = await self.post( "/v1/ade/extract", body=await async_maybe_transform(body, client_extract_params.ClientExtractParams), files=files, @@ -842,6 +871,10 @@ async def extract( ), cast_to=ExtractResponse, ) + if save_to: + filename = _get_input_filename(original_markdown, original_markdown_url) + _save_response(save_to, filename, "extract", result) + return result async def parse( self, @@ -852,6 +885,7 @@ async def parse( model: Optional[str] | Omit = omit, password: Optional[str] | Omit = omit, split: Optional[Literal["page"]] | Omit = omit, + save_to: str | Path | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -890,6 +924,11 @@ async def parse( parameter. Set the parameter to page to split documents at the page level. The splits object in the API output will contain a set of data for each page. + save_to: Optional output path. If a directory, auto-generates the filename + (e.g. {input_file}_parse_output.json, or parse_output.json when no + input filename is available). If a full path ending in .json, saves there + directly. Parent directories are created automatically. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -898,6 +937,9 @@ async def parse( timeout: Override the client-level default timeout for this request, in seconds """ + # Store original inputs for filename extraction before conversion + original_document = document + original_document_url = document_url # Convert local file paths to file parameters document, document_url = convert_url_to_file_if_local(document, document_url) @@ -920,7 +962,7 @@ async def parse( "runtime_tag": f"ade-python-v{_LIB_VERSION}", **(extra_headers or {}), } - return await self.post( + result = await self.post( "/v1/ade/parse", body=await async_maybe_transform(body, client_parse_params.ClientParseParams), files=files, @@ -932,6 +974,10 @@ async def parse( ), cast_to=ParseResponse, ) + if save_to: + filename = _get_input_filename(original_document, original_document_url) + _save_response(save_to, filename, "parse", result) + return result async def split( self, @@ -940,6 +986,7 @@ async def split( markdown: Union[FileTypes, str, None] | Omit = omit, markdown_url: Optional[str] | Omit = omit, model: Optional[str] | Omit = omit, + save_to: str | Path | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -967,6 +1014,11 @@ async def split( model: Model version to use for split classification. Defaults to the latest version. + save_to: Optional output path. If a directory, auto-generates the filename + (e.g. {input_file}_split_output.json, or split_output.json when no + input filename is available). If a full path ending in .json, saves there + directly. Parent directories are created automatically. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -975,6 +1027,9 @@ async def split( timeout: Override the client-level default timeout for this request, in seconds """ + # Store original inputs for filename extraction + original_markdown = markdown + original_markdown_url = markdown_url body = deepcopy_minimal( { "split_class": split_class, @@ -988,7 +1043,7 @@ async def split( # sent to the server will contain a `boundary` parameter, e.g. # multipart/form-data; boundary=---abc-- extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - return await self.post( + result = await self.post( "/v1/ade/split", body=await async_maybe_transform(body, client_split_params.ClientSplitParams), files=files, @@ -997,6 +1052,10 @@ async def split( ), cast_to=SplitResponse, ) + if save_to: + filename = _get_input_filename(original_markdown, original_markdown_url) + _save_response(save_to, filename, "split", result) + return result @override def _make_status_error( diff --git a/tests/test_save_to.py b/tests/test_save_to.py index af5c00d..7b14259 100644 --- a/tests/test_save_to.py +++ b/tests/test_save_to.py @@ -8,6 +8,7 @@ import pytest +from landingai_ade import AsyncLandingAIADE from landingai_ade._client import _save_response, _get_input_filename from landingai_ade._exceptions import LandingAiadeError @@ -20,10 +21,15 @@ def test_path_input(self) -> None: result = _get_input_filename(Path("/path/to/document.pdf"), None) assert result == "document" - def test_string_path_input(self) -> None: - """Test with string path input.""" + def test_string_input_returns_default(self) -> None: + """Test that string inputs always return 'output' (strings are content, not paths).""" result = _get_input_filename("/path/to/document.pdf", None) - assert result == "document" + assert result == "output" + + def test_string_with_dots_returns_default(self) -> None: + """Test that strings containing periods return 'output'.""" + result = _get_input_filename("Visit example.com for details", None) + assert result == "output" def test_tuple_input(self) -> None: """Test with tuple (filename, content, mime_type) input.""" @@ -73,6 +79,22 @@ def test_file_takes_precedence_over_url(self) -> None: result = _get_input_filename(Path("local.pdf"), "https://example.com/remote.pdf") assert result == "local" + def test_raw_markdown_string_returns_default(self) -> None: + """Test that raw markdown content (not a file path) returns 'output'.""" + result = _get_input_filename("# Hello World\n\nSome content here", None) + assert result == "output" + + def test_multiline_markdown_string_returns_default(self) -> None: + """Test that multi-line markdown content returns 'output'.""" + markdown = "Form completed on September 3, 2025\nReference Number: RT-2025-0847" + result = _get_input_filename(markdown, None) + assert result == "output" + + def test_short_string_without_extension_returns_default(self) -> None: + """Test that a short string without a file extension returns 'output'.""" + result = _get_input_filename("no_extension", None) + assert result == "output" + class TestSaveResponse: """Tests for _save_response helper function.""" @@ -137,3 +159,120 @@ def test_accepts_string_path(self, tmp_path: Path) -> None: expected_file = tmp_path / "strpath_split_output.json" assert expected_file.exists() + + def test_output_filename_skips_redundant_prefix(self, tmp_path: Path) -> None: + """Test that 'output' stem produces '{method}_output.json', not 'output_{method}_output.json'.""" + mock_result = MagicMock() + mock_result.to_json.return_value = "{}" + + for method in ["parse", "extract", "split"]: + _save_response(tmp_path, "output", method, mock_result) + expected = tmp_path / f"{method}_output.json" + assert expected.exists(), f"Expected {expected} to exist" + redundant = tmp_path / f"output_{method}_output.json" + assert not redundant.exists(), f"Redundant {redundant} should NOT exist" + + def test_full_json_path_saves_to_exact_location(self, tmp_path: Path) -> None: + """Test that a path ending in .json is used as the exact output file.""" + output_file = tmp_path / "custom_name.json" + mock_result = MagicMock() + mock_result.to_json.return_value = '{"key": "value"}' + + _save_response(output_file, "ignored_filename", "extract", mock_result) + + assert output_file.exists() + assert output_file.read_text() == '{"key": "value"}' + assert not (tmp_path / "ignored_filename_extract_output.json").exists() + + def test_full_json_path_creates_parent_dirs(self, tmp_path: Path) -> None: + """Test that parent directories are created for full .json path.""" + output_file = tmp_path / "nested" / "deep" / "result.json" + mock_result = MagicMock() + mock_result.to_json.return_value = '{"nested": true}' + + _save_response(output_file, "file", "parse", mock_result) + + assert output_file.exists() + assert output_file.read_text() == '{"nested": true}' + + def test_full_json_path_as_string(self, tmp_path: Path) -> None: + """Test that a string path ending in .json works as full path mode.""" + output_file = str(tmp_path / "my_output.json") + mock_result = MagicMock() + mock_result.to_json.return_value = '{"string": true}' + + _save_response(output_file, "file", "split", mock_result) + + assert Path(output_file).exists() + assert Path(output_file).read_text() == '{"string": true}' + + +class TestAsyncSaveTo: + """Tests that async client methods accept save_to and save correctly.""" + + @pytest.fixture + def mock_response(self) -> MagicMock: + mock = MagicMock() + mock.to_json.return_value = '{"result": "ok"}' + return mock + + @pytest.mark.asyncio + async def test_async_extract_save_to_directory(self, tmp_path: Path, mock_response: MagicMock) -> None: + from unittest.mock import AsyncMock, patch + + client = AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") + with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response): + result = await client.extract( + schema="{}", + markdown=Path("/path/to/doc.pdf"), + save_to=tmp_path, + ) + + assert (tmp_path / "doc_extract_output.json").exists() + assert result is mock_response + + @pytest.mark.asyncio + async def test_async_extract_save_to_json_path(self, tmp_path: Path, mock_response: MagicMock) -> None: + from unittest.mock import AsyncMock, patch + + output_file = tmp_path / "custom.json" + client = AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") + with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response): + await client.extract(schema="{}", markdown=Path("/doc.pdf"), save_to=output_file) + + assert output_file.exists() + + @pytest.mark.asyncio + async def test_async_parse_save_to(self, tmp_path: Path, mock_response: MagicMock) -> None: + from unittest.mock import AsyncMock, patch + + client = AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") + with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response): + await client.parse(document=Path("/path/to/doc.pdf"), save_to=tmp_path) + + assert (tmp_path / "doc_parse_output.json").exists() + + @pytest.mark.asyncio + async def test_async_split_save_to(self, tmp_path: Path, mock_response: MagicMock) -> None: + from unittest.mock import AsyncMock, patch + + client = AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") + with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response): + await client.split( + split_class=[{"name": "type1"}], + markdown=Path("/path/to/doc.md"), + save_to=tmp_path, + ) + + assert (tmp_path / "doc_split_output.json").exists() + + @pytest.mark.asyncio + async def test_async_no_save_when_save_to_none(self, tmp_path: Path, mock_response: MagicMock) -> None: + from unittest.mock import AsyncMock, patch + + client = AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") + with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response): + result = await client.extract(schema="{}") + + assert result is mock_response + assert not list(tmp_path.iterdir())