Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aim/storage/artifacts/artifact_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from urllib.parse import urlparse

from .filesystem_storage import FilesystemArtifactStorage
from .gc_storage import GCArtifactStorage
from .s3_storage import S3ArtifactStorage


Expand Down Expand Up @@ -30,4 +31,5 @@ def get_storage(self, url: str) -> 'AbstractArtifactStorage':

registry = ArtifactStorageRegistry()
registry.register('s3', S3ArtifactStorage)
registry.register('gs', GCArtifactStorage)
registry.register('file', FilesystemArtifactStorage)
126 changes: 126 additions & 0 deletions aim/storage/artifacts/gc_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pathlib
import tempfile

from concurrent.futures import Future, ThreadPoolExecutor
from typing import Optional
from urllib.parse import urlparse

from aim.ext.cleanup import AutoClean
from aim.storage.artifacts.artifact_storage import AbstractArtifactStorage


def _import_gcs():
try:
from google.cloud import storage

return storage
except ImportError:
raise ImportError("Please install 'aim[gcs]' to use Google Cloud Storage.")


class GCArtifactStorageAutoClean(AutoClean['GCArtifactStorage']):
"""Makes sure all upload threads finish before the storage is destroyed."""

def __init__(self, instance: 'GCArtifactStorage') -> None:
super().__init__(instance)
self._thread_pool = instance._thread_pool

def _close(self) -> None:
# wait for all threads to finish
self._thread_pool.shutdown(wait=True)


class GCArtifactStorage(AbstractArtifactStorage):
def __init__(self, url: str):
"""
Args:
url: GCS bucket URL with optional path prefix (e.g., 'gs://my-bucket/path/prefix').
"""

super().__init__(url)

parsed_url = urlparse(self.url) # e.g. 'gc://my-bucket/path/prefix'
bucket_name = parsed_url.netloc # my-bucket
self._path_prefix = parsed_url.path.lstrip('/') # path/prefix

self._bucket = self._get_gcs_client().bucket(bucket_name)

# thread pool for file upload workers
self._thread_pool = ThreadPoolExecutor(max_workers=4, thread_name_prefix='gcs-upload')

# make sure all threads finish uploading before shutting down
self._resources = GCArtifactStorageAutoClean(self)

def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False):
"""
Upload a local file to Google Cloud Storage, in either blocking or non-blocking mode.

Args:
file_path: Local path to the file to upload.
artifact_path: Bucket destination suffix.
block: If True, the method will block until the upload is complete.
"""

dest_path = pathlib.Path(self._path_prefix) / artifact_path
blob = self._bucket.blob(dest_path.as_posix())
future = self._thread_pool.submit(blob.upload_from_filename, file_path)
future.add_done_callback(self._upload_complete)

if block:
future.result()

def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None) -> str:
"""
Download an artifact from Google Cloud Storage to a local directory in a blocking manner.
Args:
artifact_path: Bucket destination suffix.
dest_dir: Local directory to download the artifact to. If None, a temporary directory is created.
Returns:
The local path to the downloaded artifact.
"""

if dest_dir is None:
dest_dir = pathlib.Path(tempfile.mkdtemp())
else:
dest_dir = pathlib.Path(dest_dir)
dest_dir.mkdir(parents=True, exist_ok=True)

source_path = pathlib.Path(self._path_prefix) / artifact_path
dest_path = dest_dir / source_path.name

blob = self._bucket.blob(source_path.as_posix())
blob.download_to_filename(dest_path.as_posix())

return dest_path.as_posix()

def delete_artifact(self, artifact_path: str):
path = pathlib.Path(self._path_prefix) / artifact_path
blob = self._bucket.blob(path.as_posix())
blob.delete()

def _get_gcs_client(self):
gcs = _import_gcs()
return gcs.Client()

def _upload_complete(self, future: Future):
pass


def GCArtifactStorage_factory(**gcs_client_kwargs):
class GCArtifactStorageCustom(GCArtifactStorage):
def _get_gcs_client(self):
gcs = _import_gcs()
return gcs.Client(**gcs_client_kwargs)

return GCArtifactStorageCustom


def GCArtifactStorage_clientconfig(**gcs_client_kwargs):
"""
Register GCArtifactStorage with custom GCS client configuration.
Args:
gcs_client_kwargs: Keyword arguments to pass to the google.cloud.storage.Client constructor.
"""
from aim.storage.artifacts import registry

registry.registry['gs'] = GCArtifactStorage_factory(**gcs_client_kwargs)
20 changes: 20 additions & 0 deletions docs/source/using/artifacts.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ run = aim.Run()
# Use S3 as artifacts storage
run.set_artifacts_uri('s3://aim/artifacts/')

# Use GCS as artifacts storage
run.set_artifacts_uri('gs://aim/artifacts/')

# Use file-system as artifacts storage
run.set_artifacts_uri('file:///home/user/aim/artifacts/')
```
Expand All @@ -44,6 +47,7 @@ When the artifacts URI is set, Aim will detect storage backend based on the URI
Currently supported backends for artifacts storage are.
- S3
- File System
- GCS

#### S3 Artifacts Storage Backend

Expand All @@ -62,6 +66,22 @@ run = aim.Run(...)
run.set_artifacts_uri('s3://...')
run.log_artifact(..., name=...)
```
#### GCS Artifacts Storage Backend
In order to use Google Cloud Store with Aim, install `aim[gcs]`.

Aim uses `google-cloud-storage` Python package for accessing GCS resources. Connection and credential validation is handled by `google-cloud-storage`. More information is available [here](https://docs.cloud.google.com/docs/authentication/client-libraries#python_1).

If you require direct control of how the `google-cloud-storage` Client is created, you may use `aim.storage.artifacts.gc_storage.GCArtifactStorage_clientconfig(...)`. This method accepts any keyword-arguments that the `google.cloud.storage.Client` constructor accepts.

```python
import aim
from aim.storage.artifacts.gc_storage import GCArtifactStorage_clientconfig

GCArtifactStorage_clientconfig(project=..., credentials=...)
run = aim.Run(...)
run.set_artifacts_uri('gs://...')
run.log_artifact(..., name=...)
```

#### File-system Artifacts Storage Backend

Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def package_files(directory):
'boto3',
]

EXTRAS_REQUIRE = {
'gcs': ['google-cloud-storage>=3.9.0'],
}

if sys.version_info.minor < 9:
REQUIRED += ['astunparse']

Expand Down