From ac3fb0318cbecfb8f7a04af6048df9a9a33178af Mon Sep 17 00:00:00 2001 From: Max Risuhin Date: Thu, 14 Nov 2019 03:26:36 -0800 Subject: [PATCH 1/8] Support GDrive as remote --- dvc/config.py | 15 +- dvc/remote/__init__.py | 2 + dvc/remote/gdrive/__init__.py | 277 ++++++++++++++++++++++++ dvc/remote/gdrive/pydrive.py | 104 +++++++++ dvc/remote/gdrive/utils.py | 32 +++ dvc/scheme.py | 1 + setup.py | 4 +- tests/func/test_data_cloud.py | 76 +++++++ tests/unit/remote/gdrive/__init__.py | 0 tests/unit/remote/gdrive/conftest.py | 9 + tests/unit/remote/gdrive/test_gdrive.py | 9 + 11 files changed, 525 insertions(+), 4 deletions(-) create mode 100644 dvc/remote/gdrive/__init__.py create mode 100644 dvc/remote/gdrive/pydrive.py create mode 100644 dvc/remote/gdrive/utils.py create mode 100644 tests/unit/remote/gdrive/__init__.py create mode 100644 tests/unit/remote/gdrive/conftest.py create mode 100644 tests/unit/remote/gdrive/test_gdrive.py diff --git a/dvc/config.py b/dvc/config.py index bd569864cb..f3621d907b 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -153,6 +153,8 @@ class Config(object): # pylint: disable=too-many-instance-attributes CONFIG = "config" CONFIG_LOCAL = "config.local" + CREDENTIALPATH = "credentialpath" + LEVEL_LOCAL = 0 LEVEL_REPO = 1 LEVEL_GLOBAL = 2 @@ -221,7 +223,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes # backward compatibility SECTION_AWS = "aws" SECTION_AWS_STORAGEPATH = "storagepath" - SECTION_AWS_CREDENTIALPATH = "credentialpath" + SECTION_AWS_CREDENTIALPATH = CREDENTIALPATH SECTION_AWS_ENDPOINT_URL = "endpointurl" SECTION_AWS_LIST_OBJECTS = "listobjects" SECTION_AWS_REGION = "region" @@ -244,7 +246,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes # backward compatibility SECTION_GCP = "gcp" SECTION_GCP_STORAGEPATH = SECTION_AWS_STORAGEPATH - SECTION_GCP_CREDENTIALPATH = SECTION_AWS_CREDENTIALPATH + SECTION_GCP_CREDENTIALPATH = CREDENTIALPATH SECTION_GCP_PROJECTNAME = "projectname" SECTION_GCP_SCHEMA = { SECTION_GCP_STORAGEPATH: str, @@ -261,6 +263,10 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_OSS_ACCESS_KEY_ID = "oss_key_id" SECTION_OSS_ACCESS_KEY_SECRET = "oss_key_secret" SECTION_OSS_ENDPOINT = "oss_endpoint" + # GDrive options + SECTION_GDRIVE_CLIENT_ID = "gdrive_client_id" + SECTION_GDRIVE_CLIENT_SECRET = "gdrive_client_secret" + SECTION_GDRIVE_USER_CREDENTIALS_FILE = "gdrive_user_credentials_file" SECTION_REMOTE_REGEX = r'^\s*remote\s*"(?P.*)"\s*$' SECTION_REMOTE_FMT = 'remote "{}"' @@ -277,7 +283,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_REMOTE_URL: str, Optional(SECTION_AWS_REGION): str, Optional(SECTION_AWS_PROFILE): str, - Optional(SECTION_AWS_CREDENTIALPATH): str, + Optional(CREDENTIALPATH): str, Optional(SECTION_AWS_ENDPOINT_URL): str, Optional(SECTION_AWS_LIST_OBJECTS, default=False): BOOL_SCHEMA, Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA, @@ -297,6 +303,9 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional(SECTION_OSS_ACCESS_KEY_ID): str, Optional(SECTION_OSS_ACCESS_KEY_SECRET): str, Optional(SECTION_OSS_ENDPOINT): str, + Optional(SECTION_GDRIVE_CLIENT_ID): str, + Optional(SECTION_GDRIVE_CLIENT_SECRET): str, + Optional(SECTION_GDRIVE_USER_CREDENTIALS_FILE): str, Optional(PRIVATE_CWD): str, Optional(SECTION_REMOTE_NO_TRAVERSE, default=True): BOOL_SCHEMA, } diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index 3ff90d365b..e2c20a2168 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -2,6 +2,7 @@ from .config import RemoteConfig from dvc.remote.azure import RemoteAZURE +from dvc.remote.gdrive import RemoteGDrive from dvc.remote.gs import RemoteGS from dvc.remote.hdfs import RemoteHDFS from dvc.remote.http import RemoteHTTP @@ -14,6 +15,7 @@ REMOTES = [ RemoteAZURE, + RemoteGDrive, RemoteGS, RemoteHDFS, RemoteHTTP, diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py new file mode 100644 index 0000000000..567c36158e --- /dev/null +++ b/dvc/remote/gdrive/__init__.py @@ -0,0 +1,277 @@ +from __future__ import unicode_literals + +import os +import posixpath +import logging + +from funcy import cached_property +from backoff import on_exception, expo + +from dvc.scheme import Schemes +from dvc.path_info import CloudURLInfo +from dvc.remote.base import RemoteBASE +from dvc.config import Config +from dvc.exceptions import DvcException +from dvc.remote.gdrive.pydrive import ( + RequestListFile, + RequestListFilePaginated, + RequestCreateFolder, + RequestUploadFile, + RequestDownloadFile, +) +from dvc.remote.gdrive.utils import FOLDER_MIME_TYPE + +logger = logging.getLogger(__name__) + + +class GDriveURLInfo(CloudURLInfo): + @property + def netloc(self): + return self.parsed.netloc + + +class RemoteGDrive(RemoteBASE): + scheme = Schemes.GDRIVE + path_cls = GDriveURLInfo + REGEX = r"^gdrive://.*$" + REQUIRES = {"pydrive": "pydrive"} + GDRIVE_USER_CREDENTIALS_DATA = "GDRIVE_USER_CREDENTIALS_DATA" + DEFAULT_USER_CREDENTIALS_FILE = ".dvc/tmp/gdrive-user-credentials.json" + + def __init__(self, repo, config): + super(RemoteGDrive, self).__init__(repo, config) + self.no_traverse = False + self.path_info = self.path_cls(config[Config.SECTION_REMOTE_URL]) + self.config = config + self.init_drive() + + def init_drive(self): + self.gdrive_client_id = self.config.get( + Config.SECTION_GDRIVE_CLIENT_ID, None + ) + self.gdrive_client_secret = self.config.get( + Config.SECTION_GDRIVE_CLIENT_SECRET, None + ) + if not self.gdrive_client_id or not self.gdrive_client_secret: + raise DvcException( + "Please specify Google Drive's client id and " + "secret in DVC's config. Learn more at " + "https://man.dvc.org/remote/add." + ) + self.gdrive_user_credentials_path = self.config.get( + Config.SECTION_GDRIVE_USER_CREDENTIALS_FILE, + self.DEFAULT_USER_CREDENTIALS_FILE, + ) + + self.root_id = self.get_path_id(self.path_info, create=True) + self.cached_dirs, self.cached_ids = self.cache_root_dirs() + + @on_exception(expo, DvcException, max_tries=8) + def execute_request(self, request): + try: + result = request.execute() + except Exception as exception: + retry_codes = ["403", "500", "502", "503", "504"] + if any(code in str(exception) for code in retry_codes): + raise DvcException("Google API request failed") + raise + return result + + def list_drive_item(self, query): + list_request = RequestListFilePaginated(self.drive, query) + page_list = self.execute_request(list_request) + while page_list: + for item in page_list: + yield item + page_list = self.execute_request(list_request) + + def cache_root_dirs(self): + cached_dirs = {} + cached_ids = {} + for dir1 in self.list_drive_item( + "'{}' in parents and trashed=false".format(self.root_id) + ): + cached_dirs.setdefault(dir1["title"], []).append(dir1["id"]) + cached_ids[dir1["id"]] = dir1["title"] + return cached_dirs, cached_ids + + @cached_property + def drive(self): + from pydrive.auth import GoogleAuth + from pydrive.drive import GoogleDrive + + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + with open( + self.gdrive_user_credentials_path, "w" + ) as credentials_file: + credentials_file.write( + os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) + ) + + GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings" + GoogleAuth.DEFAULT_SETTINGS["client_config"] = { + "client_id": self.gdrive_client_id, + "client_secret": self.gdrive_client_secret, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "revoke_uri": "https://oauth2.googleapis.com/revoke", + "redirect_uri": "", + } + GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True + GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file" + GoogleAuth.DEFAULT_SETTINGS[ + "save_credentials_file" + ] = self.gdrive_user_credentials_path + GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True + GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [ + "https://www.googleapis.com/auth/drive", + "https://www.googleapis.com/auth/drive.appdata", + ] + + # Pass non existent settings path to force DEFAULT_SETTINGS loading + gauth = GoogleAuth(settings_file="") + gauth.CommandLineAuth() + + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + os.remove(self.gdrive_user_credentials_path) + + gdrive = GoogleDrive(gauth) + return gdrive + + def create_drive_item(self, parent_id, title): + upload_request = RequestCreateFolder( + {"drive": self.drive, "title": title, "parent_id": parent_id} + ) + result = self.execute_request(upload_request) + return result + + def get_drive_item(self, name, parents_ids): + if not parents_ids: + return None + query = " or ".join( + "'{}' in parents".format(parent_id) for parent_id in parents_ids + ) + + query += " and trashed=false and title='{}'".format(name) + + list_request = RequestListFile(self.drive, query) + item_list = self.execute_request(list_request) + return next(iter(item_list), None) + + def resolve_remote_file(self, parents_ids, path_parts, create): + for path_part in path_parts: + item = self.get_drive_item(path_part, parents_ids) + if not item and create: + item = self.create_drive_item(parents_ids[0], path_part) + elif not item: + return None + parents_ids = [item["id"]] + return item + + def subtract_root_path(self, parts): + if not hasattr(self, "root_id"): + return parts, [self.path_info.netloc] + + for part in self.path_info.path.split("/"): + if parts and parts[0] == part: + parts.pop(0) + else: + break + return parts, [self.root_id] + + def get_path_id_from_cache(self, path_info): + files_ids = [] + parts, parents_ids = self.subtract_root_path(path_info.path.split("/")) + if ( + hasattr(self, "cached_dirs") + and path_info != self.path_info + and parts + and (parts[0] in self.cached_dirs) + ): + parents_ids = self.cached_dirs[parts[0]] + files_ids = self.cached_dirs[parts[0]] + parts.pop(0) + + return files_ids, parents_ids, parts + + def get_path_id(self, path_info, create=False): + files_ids, parents_ids, parts = self.get_path_id_from_cache(path_info) + + if not parts and files_ids: + return files_ids[0] + + file1 = self.resolve_remote_file(parents_ids, parts, create) + return file1["id"] if file1 else "" + + def exists(self, path_info): + return self.get_path_id(path_info) != "" + + def _upload(self, from_file, to_info, name, no_progress_bar): + dirname = to_info.parent + if dirname: + parent_id = self.get_path_id(dirname, True) + else: + parent_id = to_info.netloc + + upload_request = RequestUploadFile( + { + "drive": self.drive, + "title": to_info.name, + "parent_id": parent_id, + }, + no_progress_bar, + from_file, + name, + ) + self.execute_request(upload_request) + + def _download(self, from_info, to_file, name, no_progress_bar): + file_id = self.get_path_id(from_info) + download_request = RequestDownloadFile( + { + "drive": self.drive, + "file_id": file_id, + "to_file": to_file, + "progress_name": name, + "no_progress_bar": no_progress_bar, + } + ) + self.execute_request(download_request) + + def list_cache_paths(self): + file_id = self.get_path_id(self.path_info) + prefix = self.path_info.path + for path in self.list_path(file_id): + yield posixpath.join(prefix, path) + + def list_file_path(self, drive_file): + if drive_file["mimeType"] == FOLDER_MIME_TYPE: + for i in self.list_path(drive_file["id"]): + yield posixpath.join(drive_file["title"], i) + else: + yield drive_file["title"] + + def list_path(self, parent_id): + for file1 in self.list_drive_item( + "'{}' in parents and trashed=false".format(parent_id) + ): + for path in self.list_file_path(file1): + yield path + + def all(self): + if not hasattr(self, "cached_ids") or not self.cached_ids: + return + + query = " or ".join( + "'{}' in parents".format(dir_id) for dir_id in self.cached_ids + ) + + query += " and trashed=false" + for file1 in self.list_drive_item(query): + parent_id = file1["parents"][0]["id"] + path = posixpath.join(self.cached_ids[parent_id], file1["title"]) + try: + yield self.path_to_checksum(path) + except ValueError: + # We ignore all the non-cache looking files + logger.debug('Ignoring path as "non-cache looking"') diff --git a/dvc/remote/gdrive/pydrive.py b/dvc/remote/gdrive/pydrive.py new file mode 100644 index 0000000000..46836d04fc --- /dev/null +++ b/dvc/remote/gdrive/pydrive.py @@ -0,0 +1,104 @@ +import os + +from dvc.remote.gdrive.utils import TrackFileReadProgress, FOLDER_MIME_TYPE + + +class RequestBASE: + def __init__(self, drive): + self.drive = drive + + def execute(self): + raise NotImplementedError + + +class RequestListFile(RequestBASE): + def __init__(self, drive, query): + super(RequestListFile, self).__init__(drive) + self.query = query + + def execute(self): + return self.drive.ListFile( + {"q": self.query, "maxResults": 1000} + ).GetList() + + +class RequestListFilePaginated(RequestBASE): + def __init__(self, drive, query): + super(RequestListFilePaginated, self).__init__(drive) + self.query = query + self.iter = None + + def execute(self): + if not self.iter: + self.iter = iter( + self.drive.ListFile({"q": self.query, "maxResults": 1000}) + ) + return next(self.iter, None) + + +class RequestCreateFolder(RequestBASE): + def __init__(self, args): + super(RequestCreateFolder, self).__init__(args["drive"]) + self.title = args["title"] + self.parent_id = args["parent_id"] + + def execute(self): + item = self.drive.CreateFile( + { + "title": self.title, + "parents": [{"id": self.parent_id}], + "mimeType": FOLDER_MIME_TYPE, + } + ) + item.Upload() + return item + + +class RequestUploadFile(RequestBASE): + def __init__( + self, args, no_progress_bar=True, from_file="", progress_name="" + ): + super(RequestUploadFile, self).__init__(args["drive"]) + self.title = args["title"] + self.parent_id = args["parent_id"] + self.no_progress_bar = no_progress_bar + self.from_file = from_file + self.progress_name = progress_name + + def upload(self, item): + with open(self.from_file, "rb") as from_file: + if not self.no_progress_bar: + from_file = TrackFileReadProgress( + self.progress_name, from_file + ) + if os.stat(self.from_file).st_size: + item.content = from_file + item.Upload() + + def execute(self): + item = self.drive.CreateFile( + {"title": self.title, "parents": [{"id": self.parent_id}]} + ) + self.upload(item) + return item + + +class RequestDownloadFile(RequestBASE): + def __init__(self, args): + super(RequestDownloadFile, self).__init__(args["drive"]) + self.file_id = args["file_id"] + self.to_file = args["to_file"] + self.progress_name = args["progress_name"] + self.no_progress_bar = args["no_progress_bar"] + + def execute(self): + from dvc.progress import Tqdm + + gdrive_file = self.drive.CreateFile({"id": self.file_id}) + if not self.no_progress_bar: + tqdm = Tqdm( + desc=self.progress_name, total=int(gdrive_file["fileSize"]) + ) + gdrive_file.GetContentFile(self.to_file) + if not self.no_progress_bar: + tqdm.close() diff --git a/dvc/remote/gdrive/utils.py b/dvc/remote/gdrive/utils.py new file mode 100644 index 0000000000..0f3cf02cd0 --- /dev/null +++ b/dvc/remote/gdrive/utils.py @@ -0,0 +1,32 @@ +import os + +from dvc.progress import Tqdm + + +FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" + + +class TrackFileReadProgress(object): + UPDATE_AFTER_READ_COUNT = 30 + + def __init__(self, progress_name, fobj): + self.progress_name = progress_name + self.fobj = fobj + self.file_size = os.fstat(fobj.fileno()).st_size + self.tqdm = Tqdm(desc=self.progress_name, total=self.file_size) + self.update_counter = 0 + + def read(self, size): + if self.update_counter == 0: + self.tqdm.update_to(self.fobj.tell()) + self.update_counter = self.UPDATE_AFTER_READ_COUNT + else: + self.update_counter -= 1 + return self.fobj.read(size) + + def close(self): + self.fobj.close() + self.tqdm.close() + + def __getattr__(self, attr): + return getattr(self.fobj, attr) diff --git a/dvc/scheme.py b/dvc/scheme.py index e12b768f58..5f7a8d1a28 100644 --- a/dvc/scheme.py +++ b/dvc/scheme.py @@ -9,5 +9,6 @@ class Schemes: HTTP = "http" HTTPS = "https" GS = "gs" + GDRIVE = "gdrive" LOCAL = "local" OSS = "oss" diff --git a/setup.py b/setup.py index 00bd03bf66..5dddccfc16 100644 --- a/setup.py +++ b/setup.py @@ -90,6 +90,7 @@ def run(self): # Extra dependencies for remote integrations gs = ["google-cloud-storage==1.19.0"] +gdrive = ["pydrive==1.3.1", "backoff>=1.8.1"] s3 = ["boto3==1.9.115"] azure = ["azure-storage-blob==2.1.0"] oss = ["oss2==2.6.1"] @@ -100,7 +101,7 @@ def run(self): # we can start shipping it by default. ssh_gssapi = ["paramiko[gssapi]>=2.5.0"] hdfs = ["pyarrow==0.14.0"] -all_remotes = gs + s3 + azure + ssh + oss +all_remotes = gs + s3 + azure + ssh + oss + gdrive if os.name != "nt" or sys.version_info[0] != 2: # NOTE: there are no pyarrow wheels for python2 on windows @@ -150,6 +151,7 @@ def run(self): extras_require={ "all": all_remotes, "gs": gs, + "gdrive": gdrive, "s3": s3, "azure": azure, "oss": oss, diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 6eb94d7eae..ad0d33020b 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -18,6 +18,7 @@ from dvc.data_cloud import DataCloud from dvc.main import main from dvc.remote import RemoteAZURE +from dvc.remote import RemoteGDrive from dvc.remote import RemoteGS from dvc.remote import RemoteHDFS from dvc.remote import RemoteHTTP @@ -58,6 +59,11 @@ # Ensure that absolute path is used os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = TEST_GCP_CREDS_FILE +TEST_GDRIVE_CLIENT_ID = ( + "719861249063-v4an78j9grdtuuuqg3lnm0sugna6v3lh.apps.googleusercontent.com" +) +TEST_GDRIVE_CLIENT_SECRET = "2fy_HyzSwkxkGzEken7hThXb" + def _should_test_aws(): do_test = env2bool("DVC_TEST_AWS", undefined=None) @@ -70,6 +76,13 @@ def _should_test_aws(): return False +def _should_test_gdrive(): + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + return True + + return False + + def _should_test_gcp(): do_test = env2bool("DVC_TEST_GCP", undefined=None) if do_test is not None: @@ -202,6 +215,10 @@ def get_aws_url(): return "s3://" + get_aws_storagepath() +def get_gdrive_url(): + return "gdrive://root/" + str(uuid.uuid4()) + + def get_gcp_storagepath(): return TEST_GCP_REPO_BUCKET + "/" + str(uuid.uuid4()) @@ -375,6 +392,35 @@ def _get_cloud_class(self): return RemoteS3 +class TestRemoteGDrive(TestDataCloudBase): + def _should_test(self): + return _should_test_gdrive() + + def _setup_cloud(self): + self._ensure_should_run() + + repo = self._get_url() + + config = copy.deepcopy(TEST_CONFIG) + config[TEST_SECTION][Config.SECTION_REMOTE_URL] = repo + config[TEST_SECTION][ + Config.SECTION_GDRIVE_CLIENT_ID + ] = TEST_GDRIVE_CLIENT_ID + config[TEST_SECTION][ + Config.SECTION_GDRIVE_CLIENT_SECRET + ] = TEST_GDRIVE_CLIENT_SECRET + self.dvc.config.config = config + self.cloud = DataCloud(self.dvc) + + self.assertIsInstance(self.cloud.get_remote(), self._get_cloud_class()) + + def _get_url(self): + return get_gdrive_url() + + def _get_cloud_class(self): + return RemoteGDrive + + class TestRemoteGS(TestDataCloudBase): def _should_test(self): return _should_test_gcp() @@ -621,6 +667,36 @@ def _test(self): self._test_cloud(TEST_REMOTE) +class TestRemoteGDriveCLI(TestDataCloudCLIBase): + def _should_test(self): + return _should_test_gdrive() + + def _test(self): + url = get_gdrive_url() + + self.main(["remote", "add", TEST_REMOTE, url]) + self.main( + [ + "remote", + "modify", + TEST_REMOTE, + Config.SECTION_GDRIVE_CLIENT_ID, + TEST_GDRIVE_CLIENT_ID, + ] + ) + self.main( + [ + "remote", + "modify", + TEST_REMOTE, + Config.SECTION_GDRIVE_CLIENT_SECRET, + TEST_GDRIVE_CLIENT_SECRET, + ] + ) + + self._test_cloud(TEST_REMOTE) + + class TestRemoteGSCLI(TestDataCloudCLIBase): def _should_test(self): return _should_test_gcp() diff --git a/tests/unit/remote/gdrive/__init__.py b/tests/unit/remote/gdrive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/remote/gdrive/conftest.py b/tests/unit/remote/gdrive/conftest.py new file mode 100644 index 0000000000..035ca15094 --- /dev/null +++ b/tests/unit/remote/gdrive/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from dvc.remote.gdrive import RemoteGDrive + + +@pytest.fixture +def gdrive(repo): + ret = RemoteGDrive(None, {"url": "gdrive://root/data"}) + return ret diff --git a/tests/unit/remote/gdrive/test_gdrive.py b/tests/unit/remote/gdrive/test_gdrive.py new file mode 100644 index 0000000000..28e003748c --- /dev/null +++ b/tests/unit/remote/gdrive/test_gdrive.py @@ -0,0 +1,9 @@ +import mock +from dvc.remote.gdrive import RemoteGDrive + + +@mock.patch("dvc.remote.gdrive.RemoteGDrive.init_drive") +def test_init_drive(repo): + url = "gdrive://root/data" + gdrive = RemoteGDrive(repo, {"url": url}) + assert str(gdrive.path_info) == url From e8171124ea2bdbc4726e8db5fc10c7bbeeaa1c3a Mon Sep 17 00:00:00 2001 From: Max Risuhin Date: Mon, 18 Nov 2019 06:40:41 -0800 Subject: [PATCH 2/8] Use funcy retry; Simplify upload file progress implementation; Remove GDriveURLInfo; Update Google API errors handling --- dvc/remote/gdrive/__init__.py | 28 ++++++++++++++-------------- dvc/remote/gdrive/utils.py | 13 +++---------- setup.py | 2 +- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py index 567c36158e..a1c611f16f 100644 --- a/dvc/remote/gdrive/__init__.py +++ b/dvc/remote/gdrive/__init__.py @@ -4,8 +4,7 @@ import posixpath import logging -from funcy import cached_property -from backoff import on_exception, expo +from funcy import cached_property, retry from dvc.scheme import Schemes from dvc.path_info import CloudURLInfo @@ -23,16 +22,13 @@ logger = logging.getLogger(__name__) - -class GDriveURLInfo(CloudURLInfo): - @property - def netloc(self): - return self.parsed.netloc - +class GDriveRetriableError(DvcException): + def __init__(self, msg): + super(GDriveRetriableError, self).__init__(msg) class RemoteGDrive(RemoteBASE): scheme = Schemes.GDRIVE - path_cls = GDriveURLInfo + path_cls = CloudURLInfo REGEX = r"^gdrive://.*$" REQUIRES = {"pydrive": "pydrive"} GDRIVE_USER_CREDENTIALS_DATA = "GDRIVE_USER_CREDENTIALS_DATA" @@ -66,14 +62,18 @@ def init_drive(self): self.root_id = self.get_path_id(self.path_info, create=True) self.cached_dirs, self.cached_ids = self.cache_root_dirs() - @on_exception(expo, DvcException, max_tries=8) + # 8 tries, start at 0.5s, multiply by golden ratio, cap at 10s + @retry(8, + errors=(GDriveRetriableError), + timeout=lambda a: min(0.5 * 1.618 ** a, 10)) def execute_request(self, request): + from pydrive.files import ApiRequestError try: result = request.execute() except Exception as exception: retry_codes = ["403", "500", "502", "503", "504"] - if any(code in str(exception) for code in retry_codes): - raise DvcException("Google API request failed") + if any("HttpError {}".format(code) in str(exception) for code in retry_codes): + raise GDriveRetriableError("Google API request failed") raise return result @@ -170,7 +170,7 @@ def resolve_remote_file(self, parents_ids, path_parts, create): def subtract_root_path(self, parts): if not hasattr(self, "root_id"): - return parts, [self.path_info.netloc] + return parts, [self.path_info.bucket] for part in self.path_info.path.split("/"): if parts and parts[0] == part: @@ -211,7 +211,7 @@ def _upload(self, from_file, to_info, name, no_progress_bar): if dirname: parent_id = self.get_path_id(dirname, True) else: - parent_id = to_info.netloc + parent_id = to_info.bucket upload_request = RequestUploadFile( { diff --git a/dvc/remote/gdrive/utils.py b/dvc/remote/gdrive/utils.py index 0f3cf02cd0..781af811a5 100644 --- a/dvc/remote/gdrive/utils.py +++ b/dvc/remote/gdrive/utils.py @@ -7,21 +7,14 @@ class TrackFileReadProgress(object): - UPDATE_AFTER_READ_COUNT = 30 - def __init__(self, progress_name, fobj): self.progress_name = progress_name self.fobj = fobj - self.file_size = os.fstat(fobj.fileno()).st_size - self.tqdm = Tqdm(desc=self.progress_name, total=self.file_size) - self.update_counter = 0 + file_size = os.fstat(fobj.fileno()).st_size + self.tqdm = Tqdm(desc=self.progress_name, total=file_size) def read(self, size): - if self.update_counter == 0: - self.tqdm.update_to(self.fobj.tell()) - self.update_counter = self.UPDATE_AFTER_READ_COUNT - else: - self.update_counter -= 1 + self.tqdm.update(size) return self.fobj.read(size) def close(self): diff --git a/setup.py b/setup.py index 5dddccfc16..d12563401b 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,7 @@ def run(self): # Extra dependencies for remote integrations gs = ["google-cloud-storage==1.19.0"] -gdrive = ["pydrive==1.3.1", "backoff>=1.8.1"] +gdrive = ["pydrive==1.3.1"] s3 = ["boto3==1.9.115"] azure = ["azure-storage-blob==2.1.0"] oss = ["oss2==2.6.1"] From ab5e9712030c56adbca71c8f65663574da391c79 Mon Sep 17 00:00:00 2001 From: Max Risuhin Date: Mon, 18 Nov 2019 16:13:26 -0800 Subject: [PATCH 3/8] Fancy funcy usage; Removed PyDrive request classes. --- dvc/remote/gdrive/__init__.py | 152 +++++++++++++++++++++------------- dvc/remote/gdrive/pydrive.py | 104 ----------------------- 2 files changed, 95 insertions(+), 161 deletions(-) delete mode 100644 dvc/remote/gdrive/pydrive.py diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py index a1c611f16f..6703cf8011 100644 --- a/dvc/remote/gdrive/__init__.py +++ b/dvc/remote/gdrive/__init__.py @@ -4,28 +4,48 @@ import posixpath import logging -from funcy import cached_property, retry +from funcy import cached_property, retry, compose, decorator +from funcy.py3 import cat +from dvc.remote.gdrive.utils import TrackFileReadProgress, FOLDER_MIME_TYPE from dvc.scheme import Schemes from dvc.path_info import CloudURLInfo from dvc.remote.base import RemoteBASE from dvc.config import Config from dvc.exceptions import DvcException -from dvc.remote.gdrive.pydrive import ( - RequestListFile, - RequestListFilePaginated, - RequestCreateFolder, - RequestUploadFile, - RequestDownloadFile, -) -from dvc.remote.gdrive.utils import FOLDER_MIME_TYPE logger = logging.getLogger(__name__) + class GDriveRetriableError(DvcException): def __init__(self, msg): super(GDriveRetriableError, self).__init__(msg) + +@decorator +def _wrap_pydrive_retriable(call): + try: + result = call() + except Exception as exception: + retry_codes = ["403", "500", "502", "503", "504"] + if any( + "HttpError {}".format(code) in str(exception) + for code in retry_codes + ): + raise GDriveRetriableError(msg="Google API request failed") + raise + return result + + +gdrive_retry = compose( + # 8 tries, start at 0.5s, multiply by golden ratio, cap at 10s + retry( + 8, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 10) + ), + _wrap_pydrive_retriable, +) + + class RemoteGDrive(RemoteBASE): scheme = Schemes.GDRIVE path_cls = CloudURLInfo @@ -62,28 +82,57 @@ def init_drive(self): self.root_id = self.get_path_id(self.path_info, create=True) self.cached_dirs, self.cached_ids = self.cache_root_dirs() - # 8 tries, start at 0.5s, multiply by golden ratio, cap at 10s - @retry(8, - errors=(GDriveRetriableError), - timeout=lambda a: min(0.5 * 1.618 ** a, 10)) - def execute_request(self, request): - from pydrive.files import ApiRequestError - try: - result = request.execute() - except Exception as exception: - retry_codes = ["403", "500", "502", "503", "504"] - if any("HttpError {}".format(code) in str(exception) for code in retry_codes): - raise GDriveRetriableError("Google API request failed") - raise - return result + def request_list_file(self, query): + return self.drive.ListFile({"q": query, "maxResults": 1000}).GetList() + + def request_create_folder(self, title, parent_id): + item = self.drive.CreateFile( + { + "title": title, + "parents": [{"id": parent_id}], + "mimeType": FOLDER_MIME_TYPE, + } + ) + item.Upload() + return item + + def request_upload_file( + self, args, no_progress_bar=True, from_file="", progress_name="" + ): + item = self.drive.CreateFile( + {"title": args["title"], "parents": [{"id": args["parent_id"]}]} + ) + self.upload_file(item, no_progress_bar, from_file, progress_name) + return item + + def upload_file(self, item, no_progress_bar, from_file, progress_name): + with open(from_file, "rb") as opened_file: + if not no_progress_bar: + opened_file = TrackFileReadProgress(progress_name, opened_file) + if os.stat(from_file).st_size: + item.content = opened_file + item.Upload() + + def request_download_file( + self, file_id, to_file, progress_name, no_progress_bar + ): + from dvc.progress import Tqdm + + gdrive_file = self.drive.CreateFile({"id": file_id}) + if not no_progress_bar: + tqdm = Tqdm(desc=progress_name, total=int(gdrive_file["fileSize"])) + gdrive_file.GetContentFile(to_file) + if not no_progress_bar: + tqdm.close() def list_drive_item(self, query): - list_request = RequestListFilePaginated(self.drive, query) - page_list = self.execute_request(list_request) - while page_list: - for item in page_list: - yield item - page_list = self.execute_request(list_request) + file_list = self.drive.ListFile({"q": query, "maxResults": 1000}) + + # Isolate and decorate fetching of remote drive items in pages + get_list = gdrive_retry(lambda: next(file_list, None)) + + # Fetch pages until None is received, lazily flatten the thing + return cat(iter(get_list, None)) def cache_root_dirs(self): cached_dirs = {} @@ -139,11 +188,9 @@ def drive(self): return gdrive def create_drive_item(self, parent_id, title): - upload_request = RequestCreateFolder( - {"drive": self.drive, "title": title, "parent_id": parent_id} - ) - result = self.execute_request(upload_request) - return result + return gdrive_retry( + lambda: self.request_create_folder(title, parent_id) + )() def get_drive_item(self, name, parents_ids): if not parents_ids: @@ -154,8 +201,7 @@ def get_drive_item(self, name, parents_ids): query += " and trashed=false and title='{}'".format(name) - list_request = RequestListFile(self.drive, query) - item_list = self.execute_request(list_request) + item_list = gdrive_retry(lambda: self.request_list_file(query))() return next(iter(item_list), None) def resolve_remote_file(self, parents_ids, path_parts, create): @@ -213,30 +259,22 @@ def _upload(self, from_file, to_info, name, no_progress_bar): else: parent_id = to_info.bucket - upload_request = RequestUploadFile( - { - "drive": self.drive, - "title": to_info.name, - "parent_id": parent_id, - }, - no_progress_bar, - from_file, - name, - ) - self.execute_request(upload_request) + gdrive_retry( + lambda: self.request_upload_file( + {"title": to_info.name, "parent_id": parent_id}, + no_progress_bar, + from_file, + name, + ) + )() def _download(self, from_info, to_file, name, no_progress_bar): file_id = self.get_path_id(from_info) - download_request = RequestDownloadFile( - { - "drive": self.drive, - "file_id": file_id, - "to_file": to_file, - "progress_name": name, - "no_progress_bar": no_progress_bar, - } - ) - self.execute_request(download_request) + gdrive_retry( + lambda: self.request_download_file( + file_id, to_file, name, no_progress_bar + ) + )() def list_cache_paths(self): file_id = self.get_path_id(self.path_info) diff --git a/dvc/remote/gdrive/pydrive.py b/dvc/remote/gdrive/pydrive.py deleted file mode 100644 index 46836d04fc..0000000000 --- a/dvc/remote/gdrive/pydrive.py +++ /dev/null @@ -1,104 +0,0 @@ -import os - -from dvc.remote.gdrive.utils import TrackFileReadProgress, FOLDER_MIME_TYPE - - -class RequestBASE: - def __init__(self, drive): - self.drive = drive - - def execute(self): - raise NotImplementedError - - -class RequestListFile(RequestBASE): - def __init__(self, drive, query): - super(RequestListFile, self).__init__(drive) - self.query = query - - def execute(self): - return self.drive.ListFile( - {"q": self.query, "maxResults": 1000} - ).GetList() - - -class RequestListFilePaginated(RequestBASE): - def __init__(self, drive, query): - super(RequestListFilePaginated, self).__init__(drive) - self.query = query - self.iter = None - - def execute(self): - if not self.iter: - self.iter = iter( - self.drive.ListFile({"q": self.query, "maxResults": 1000}) - ) - return next(self.iter, None) - - -class RequestCreateFolder(RequestBASE): - def __init__(self, args): - super(RequestCreateFolder, self).__init__(args["drive"]) - self.title = args["title"] - self.parent_id = args["parent_id"] - - def execute(self): - item = self.drive.CreateFile( - { - "title": self.title, - "parents": [{"id": self.parent_id}], - "mimeType": FOLDER_MIME_TYPE, - } - ) - item.Upload() - return item - - -class RequestUploadFile(RequestBASE): - def __init__( - self, args, no_progress_bar=True, from_file="", progress_name="" - ): - super(RequestUploadFile, self).__init__(args["drive"]) - self.title = args["title"] - self.parent_id = args["parent_id"] - self.no_progress_bar = no_progress_bar - self.from_file = from_file - self.progress_name = progress_name - - def upload(self, item): - with open(self.from_file, "rb") as from_file: - if not self.no_progress_bar: - from_file = TrackFileReadProgress( - self.progress_name, from_file - ) - if os.stat(self.from_file).st_size: - item.content = from_file - item.Upload() - - def execute(self): - item = self.drive.CreateFile( - {"title": self.title, "parents": [{"id": self.parent_id}]} - ) - self.upload(item) - return item - - -class RequestDownloadFile(RequestBASE): - def __init__(self, args): - super(RequestDownloadFile, self).__init__(args["drive"]) - self.file_id = args["file_id"] - self.to_file = args["to_file"] - self.progress_name = args["progress_name"] - self.no_progress_bar = args["no_progress_bar"] - - def execute(self): - from dvc.progress import Tqdm - - gdrive_file = self.drive.CreateFile({"id": self.file_id}) - if not self.no_progress_bar: - tqdm = Tqdm( - desc=self.progress_name, total=int(gdrive_file["fileSize"]) - ) - gdrive_file.GetContentFile(self.to_file) - if not self.no_progress_bar: - tqdm.close() From 151a499866b49c551b1815992d81623e0525dad2 Mon Sep 17 00:00:00 2001 From: Max Risuhin Date: Tue, 19 Nov 2019 16:18:08 -0800 Subject: [PATCH 4/8] Load user creds from tmp file; Set count of requested items to 1 to resolve specific item by path; Rename gdrive query methods. --- dvc/remote/gdrive/__init__.py | 41 +++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py index 6703cf8011..ae95e8e262 100644 --- a/dvc/remote/gdrive/__init__.py +++ b/dvc/remote/gdrive/__init__.py @@ -3,6 +3,7 @@ import os import posixpath import logging +import uuid from funcy import cached_property, retry, compose, decorator from funcy.py3 import cat @@ -46,6 +47,10 @@ def _wrap_pydrive_retriable(call): ) +def get_tmp_filepath(): + return posixpath.join(".dvc", "tmp", str(uuid.uuid4())) + + class RemoteGDrive(RemoteBASE): scheme = Schemes.GDRIVE path_cls = CloudURLInfo @@ -74,18 +79,22 @@ def init_drive(self): "secret in DVC's config. Learn more at " "https://man.dvc.org/remote/add." ) - self.gdrive_user_credentials_path = self.config.get( - Config.SECTION_GDRIVE_USER_CREDENTIALS_FILE, - self.DEFAULT_USER_CREDENTIALS_FILE, + self.gdrive_user_credentials_path = ( + get_tmp_filepath() + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) + else self.config.get( + Config.SECTION_GDRIVE_USER_CREDENTIALS_FILE, + self.DEFAULT_USER_CREDENTIALS_FILE, + ) ) self.root_id = self.get_path_id(self.path_info, create=True) self.cached_dirs, self.cached_ids = self.cache_root_dirs() - def request_list_file(self, query): - return self.drive.ListFile({"q": query, "maxResults": 1000}).GetList() + def gdrive_list_file(self, query): + return self.drive.ListFile({"q": query, "maxResults": 1}).GetList() - def request_create_folder(self, title, parent_id): + def gdrive_create_folder(self, title, parent_id): item = self.drive.CreateFile( { "title": title, @@ -96,7 +105,7 @@ def request_create_folder(self, title, parent_id): item.Upload() return item - def request_upload_file( + def gdrive_upload_file( self, args, no_progress_bar=True, from_file="", progress_name="" ): item = self.drive.CreateFile( @@ -113,7 +122,7 @@ def upload_file(self, item, no_progress_bar, from_file, progress_name): item.content = opened_file item.Upload() - def request_download_file( + def gdrive_download_file( self, file_id, to_file, progress_name, no_progress_bar ): from dvc.progress import Tqdm @@ -125,7 +134,7 @@ def request_download_file( if not no_progress_bar: tqdm.close() - def list_drive_item(self, query): + def gdrive_list_item(self, query): file_list = self.drive.ListFile({"q": query, "maxResults": 1000}) # Isolate and decorate fetching of remote drive items in pages @@ -137,7 +146,7 @@ def list_drive_item(self, query): def cache_root_dirs(self): cached_dirs = {} cached_ids = {} - for dir1 in self.list_drive_item( + for dir1 in self.gdrive_list_item( "'{}' in parents and trashed=false".format(self.root_id) ): cached_dirs.setdefault(dir1["title"], []).append(dir1["id"]) @@ -189,7 +198,7 @@ def drive(self): def create_drive_item(self, parent_id, title): return gdrive_retry( - lambda: self.request_create_folder(title, parent_id) + lambda: self.gdrive_create_folder(title, parent_id) )() def get_drive_item(self, name, parents_ids): @@ -201,7 +210,7 @@ def get_drive_item(self, name, parents_ids): query += " and trashed=false and title='{}'".format(name) - item_list = gdrive_retry(lambda: self.request_list_file(query))() + item_list = gdrive_retry(lambda: self.gdrive_list_file(query))() return next(iter(item_list), None) def resolve_remote_file(self, parents_ids, path_parts, create): @@ -260,7 +269,7 @@ def _upload(self, from_file, to_info, name, no_progress_bar): parent_id = to_info.bucket gdrive_retry( - lambda: self.request_upload_file( + lambda: self.gdrive_upload_file( {"title": to_info.name, "parent_id": parent_id}, no_progress_bar, from_file, @@ -271,7 +280,7 @@ def _upload(self, from_file, to_info, name, no_progress_bar): def _download(self, from_info, to_file, name, no_progress_bar): file_id = self.get_path_id(from_info) gdrive_retry( - lambda: self.request_download_file( + lambda: self.gdrive_download_file( file_id, to_file, name, no_progress_bar ) )() @@ -290,7 +299,7 @@ def list_file_path(self, drive_file): yield drive_file["title"] def list_path(self, parent_id): - for file1 in self.list_drive_item( + for file1 in self.gdrive_list_item( "'{}' in parents and trashed=false".format(parent_id) ): for path in self.list_file_path(file1): @@ -305,7 +314,7 @@ def all(self): ) query += " and trashed=false" - for file1 in self.list_drive_item(query): + for file1 in self.gdrive_list_item(query): parent_id = file1["parents"][0]["id"] path = posixpath.join(self.cached_ids[parent_id], file1["title"]) try: From 5294cff3962103a35d1a4a785b4263d7a5ca27e1 Mon Sep 17 00:00:00 2001 From: Max Risuhin Date: Wed, 20 Nov 2019 02:01:32 -0800 Subject: [PATCH 5/8] Use tmp_fname from utils --- dvc/remote/gdrive/__init__.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py index ae95e8e262..914999aac0 100644 --- a/dvc/remote/gdrive/__init__.py +++ b/dvc/remote/gdrive/__init__.py @@ -3,7 +3,6 @@ import os import posixpath import logging -import uuid from funcy import cached_property, retry, compose, decorator from funcy.py3 import cat @@ -14,6 +13,7 @@ from dvc.remote.base import RemoteBASE from dvc.config import Config from dvc.exceptions import DvcException +from dvc.utils import tmp_fname logger = logging.getLogger(__name__) @@ -47,10 +47,6 @@ def _wrap_pydrive_retriable(call): ) -def get_tmp_filepath(): - return posixpath.join(".dvc", "tmp", str(uuid.uuid4())) - - class RemoteGDrive(RemoteBASE): scheme = Schemes.GDRIVE path_cls = CloudURLInfo @@ -80,7 +76,7 @@ def init_drive(self): "https://man.dvc.org/remote/add." ) self.gdrive_user_credentials_path = ( - get_tmp_filepath() + tmp_fname(".dvc/tmp/") if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) else self.config.get( Config.SECTION_GDRIVE_USER_CREDENTIALS_FILE, From b883bb973dd79b3204b24b54a9a9f471aac97da5 Mon Sep 17 00:00:00 2001 From: Max Risuhin Date: Wed, 20 Nov 2019 15:11:51 -0800 Subject: [PATCH 6/8] Use gdrive_retry as decorator;Use Tqdm with disable param; Rename methods. --- dvc/remote/gdrive/__init__.py | 86 ++++++++++++++++------------------- 1 file changed, 39 insertions(+), 47 deletions(-) diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py index 914999aac0..9bc9cfc2da 100644 --- a/dvc/remote/gdrive/__init__.py +++ b/dvc/remote/gdrive/__init__.py @@ -84,23 +84,9 @@ def init_drive(self): ) ) - self.root_id = self.get_path_id(self.path_info, create=True) + self.root_id = self.get_remote_id(self.path_info, create=True) self.cached_dirs, self.cached_ids = self.cache_root_dirs() - def gdrive_list_file(self, query): - return self.drive.ListFile({"q": query, "maxResults": 1}).GetList() - - def gdrive_create_folder(self, title, parent_id): - item = self.drive.CreateFile( - { - "title": title, - "parents": [{"id": parent_id}], - "mimeType": FOLDER_MIME_TYPE, - } - ) - item.Upload() - return item - def gdrive_upload_file( self, args, no_progress_bar=True, from_file="", progress_name="" ): @@ -124,11 +110,8 @@ def gdrive_download_file( from dvc.progress import Tqdm gdrive_file = self.drive.CreateFile({"id": file_id}) - if not no_progress_bar: - tqdm = Tqdm(desc=progress_name, total=int(gdrive_file["fileSize"])) - gdrive_file.GetContentFile(to_file) - if not no_progress_bar: - tqdm.close() + with Tqdm(desc=progress_name, total=int(gdrive_file["fileSize"]), disable=no_progress_bar): + gdrive_file.GetContentFile(to_file) def gdrive_list_item(self, query): file_list = self.drive.ListFile({"q": query, "maxResults": 1000}) @@ -192,12 +175,20 @@ def drive(self): gdrive = GoogleDrive(gauth) return gdrive - def create_drive_item(self, parent_id, title): - return gdrive_retry( - lambda: self.gdrive_create_folder(title, parent_id) - )() + @gdrive_retry + def create_remote_dir(self, parent_id, title): + item = self.drive.CreateFile( + { + "title": title, + "parents": [{"id": parent_id}], + "mimeType": FOLDER_MIME_TYPE, + } + ) + item.Upload() + return item - def get_drive_item(self, name, parents_ids): + @gdrive_retry + def get_remote_item(self, name, parents_ids): if not parents_ids: return None query = " or ".join( @@ -206,14 +197,15 @@ def get_drive_item(self, name, parents_ids): query += " and trashed=false and title='{}'".format(name) - item_list = gdrive_retry(lambda: self.gdrive_list_file(query))() + # Limit found remote items count to 1 in response + item_list = self.drive.ListFile({"q": query, "maxResults": 1}).GetList() return next(iter(item_list), None) - def resolve_remote_file(self, parents_ids, path_parts, create): + def resolve_remote_item_from_path(self, parents_ids, path_parts, create): for path_part in path_parts: - item = self.get_drive_item(path_part, parents_ids) + item = self.get_remote_item(path_part, parents_ids) if not item and create: - item = self.create_drive_item(parents_ids[0], path_part) + item = self.create_remote_dir(parents_ids[0], path_part) elif not item: return None parents_ids = [item["id"]] @@ -230,7 +222,7 @@ def subtract_root_path(self, parts): break return parts, [self.root_id] - def get_path_id_from_cache(self, path_info): + def get_remote_id_from_cache(self, path_info): files_ids = [] parts, parents_ids = self.subtract_root_path(path_info.path.split("/")) if ( @@ -245,22 +237,22 @@ def get_path_id_from_cache(self, path_info): return files_ids, parents_ids, parts - def get_path_id(self, path_info, create=False): - files_ids, parents_ids, parts = self.get_path_id_from_cache(path_info) + def get_remote_id(self, path_info, create=False): + files_ids, parents_ids, parts = self.get_remote_id_from_cache(path_info) if not parts and files_ids: return files_ids[0] - file1 = self.resolve_remote_file(parents_ids, parts, create) + file1 = self.resolve_remote_item_from_path(parents_ids, parts, create) return file1["id"] if file1 else "" def exists(self, path_info): - return self.get_path_id(path_info) != "" + return self.get_remote_id(path_info) != "" def _upload(self, from_file, to_info, name, no_progress_bar): dirname = to_info.parent if dirname: - parent_id = self.get_path_id(dirname, True) + parent_id = self.get_remote_id(dirname, True) else: parent_id = to_info.bucket @@ -274,7 +266,7 @@ def _upload(self, from_file, to_info, name, no_progress_bar): )() def _download(self, from_info, to_file, name, no_progress_bar): - file_id = self.get_path_id(from_info) + file_id = self.get_remote_id(from_info) gdrive_retry( lambda: self.gdrive_download_file( file_id, to_file, name, no_progress_bar @@ -282,25 +274,25 @@ def _download(self, from_info, to_file, name, no_progress_bar): )() def list_cache_paths(self): - file_id = self.get_path_id(self.path_info) + file_id = self.get_remote_id(self.path_info) prefix = self.path_info.path - for path in self.list_path(file_id): + for path in self.list_children(file_id): yield posixpath.join(prefix, path) - def list_file_path(self, drive_file): - if drive_file["mimeType"] == FOLDER_MIME_TYPE: - for i in self.list_path(drive_file["id"]): - yield posixpath.join(drive_file["title"], i) - else: - yield drive_file["title"] - - def list_path(self, parent_id): + def list_children(self, parent_id): for file1 in self.gdrive_list_item( "'{}' in parents and trashed=false".format(parent_id) ): - for path in self.list_file_path(file1): + for path in self.list_remote_item(file1): yield path + def list_remote_item(self, drive_file): + if drive_file["mimeType"] == FOLDER_MIME_TYPE: + for i in self.list_children(drive_file["id"]): + yield posixpath.join(drive_file["title"], i) + else: + yield drive_file["title"] + def all(self): if not hasattr(self, "cached_ids") or not self.cached_ids: return From 9f6d4696cf6fd298c4ce25ce550b8ca3f1327312 Mon Sep 17 00:00:00 2001 From: Max Risuhin Date: Thu, 21 Nov 2019 15:00:32 -0800 Subject: [PATCH 7/8] gdrive lazy init --- dvc/remote/gdrive/__init__.py | 144 +++++++++++++++++++--------------- 1 file changed, 80 insertions(+), 64 deletions(-) diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py index 9bc9cfc2da..c28b6d4535 100644 --- a/dvc/remote/gdrive/__init__.py +++ b/dvc/remote/gdrive/__init__.py @@ -3,8 +3,9 @@ import os import posixpath import logging +import threading -from funcy import cached_property, retry, compose, decorator +from funcy import cached_property, retry, compose, decorator, wrap_with from funcy.py3 import cat from dvc.remote.gdrive.utils import TrackFileReadProgress, FOLDER_MIME_TYPE @@ -84,9 +85,6 @@ def init_drive(self): ) ) - self.root_id = self.get_remote_id(self.path_info, create=True) - self.cached_dirs, self.cached_ids = self.cache_root_dirs() - def gdrive_upload_file( self, args, no_progress_bar=True, from_file="", progress_name="" ): @@ -132,48 +130,66 @@ def cache_root_dirs(self): cached_ids[dir1["id"]] = dir1["title"] return cached_dirs, cached_ids - @cached_property + @property + def cached_dirs(self): + if not hasattr(self, '_cached_dirs'): + self.drive + return self._cached_dirs + + @property + def cached_ids(self): + if not hasattr(self, '_cached_ids'): + self.drive + return self._cached_ids + + @property + @wrap_with(threading.RLock()) def drive(self): - from pydrive.auth import GoogleAuth - from pydrive.drive import GoogleDrive - - if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): - with open( - self.gdrive_user_credentials_path, "w" - ) as credentials_file: - credentials_file.write( - os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) - ) - - GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings" - GoogleAuth.DEFAULT_SETTINGS["client_config"] = { - "client_id": self.gdrive_client_id, - "client_secret": self.gdrive_client_secret, - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "revoke_uri": "https://oauth2.googleapis.com/revoke", - "redirect_uri": "", - } - GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True - GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file" - GoogleAuth.DEFAULT_SETTINGS[ - "save_credentials_file" - ] = self.gdrive_user_credentials_path - GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True - GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [ - "https://www.googleapis.com/auth/drive", - "https://www.googleapis.com/auth/drive.appdata", - ] - - # Pass non existent settings path to force DEFAULT_SETTINGS loading - gauth = GoogleAuth(settings_file="") - gauth.CommandLineAuth() - - if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): - os.remove(self.gdrive_user_credentials_path) - - gdrive = GoogleDrive(gauth) - return gdrive + if not hasattr(self, '_gdrive'): + from pydrive.auth import GoogleAuth + from pydrive.drive import GoogleDrive + + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + with open( + self.gdrive_user_credentials_path, "w" + ) as credentials_file: + credentials_file.write( + os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) + ) + + GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings" + GoogleAuth.DEFAULT_SETTINGS["client_config"] = { + "client_id": self.gdrive_client_id, + "client_secret": self.gdrive_client_secret, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "revoke_uri": "https://oauth2.googleapis.com/revoke", + "redirect_uri": "", + } + GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True + GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file" + GoogleAuth.DEFAULT_SETTINGS[ + "save_credentials_file" + ] = self.gdrive_user_credentials_path + GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True + GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [ + "https://www.googleapis.com/auth/drive", + "https://www.googleapis.com/auth/drive.appdata", + ] + + # Pass non existent settings path to force DEFAULT_SETTINGS loading + gauth = GoogleAuth(settings_file="") + gauth.CommandLineAuth() + + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + os.remove(self.gdrive_user_credentials_path) + + self._gdrive = GoogleDrive(gauth) + + self.root_id = self.get_remote_id(self.path_info, create=True) + self._cached_dirs, self._cached_ids = self.cache_root_dirs() + + return self._gdrive @gdrive_retry def create_remote_dir(self, parent_id, title): @@ -211,39 +227,39 @@ def resolve_remote_item_from_path(self, parents_ids, path_parts, create): parents_ids = [item["id"]] return item - def subtract_root_path(self, parts): + def subtract_root_path(self, path_parts): if not hasattr(self, "root_id"): - return parts, [self.path_info.bucket] + return path_parts, [self.path_info.bucket] for part in self.path_info.path.split("/"): - if parts and parts[0] == part: - parts.pop(0) + if path_parts and path_parts[0] == part: + path_parts.pop(0) else: break - return parts, [self.root_id] + return path_parts, [self.root_id] def get_remote_id_from_cache(self, path_info): - files_ids = [] - parts, parents_ids = self.subtract_root_path(path_info.path.split("/")) + remote_ids = [] + path_parts, parents_ids = self.subtract_root_path(path_info.path.split("/")) if ( - hasattr(self, "cached_dirs") + hasattr(self, "_cached_dirs") and path_info != self.path_info - and parts - and (parts[0] in self.cached_dirs) + and path_parts + and (path_parts[0] in self.cached_dirs) ): - parents_ids = self.cached_dirs[parts[0]] - files_ids = self.cached_dirs[parts[0]] - parts.pop(0) + parents_ids = self.cached_dirs[path_parts[0]] + remote_ids = self.cached_dirs[path_parts[0]] + path_parts.pop(0) - return files_ids, parents_ids, parts + return remote_ids, parents_ids, path_parts def get_remote_id(self, path_info, create=False): - files_ids, parents_ids, parts = self.get_remote_id_from_cache(path_info) + remote_ids, parents_ids, path_parts = self.get_remote_id_from_cache(path_info) - if not parts and files_ids: - return files_ids[0] + if not path_parts and remote_ids: + return remote_ids[0] - file1 = self.resolve_remote_item_from_path(parents_ids, parts, create) + file1 = self.resolve_remote_item_from_path(parents_ids, path_parts, create) return file1["id"] if file1 else "" def exists(self, path_info): @@ -294,7 +310,7 @@ def list_remote_item(self, drive_file): yield drive_file["title"] def all(self): - if not hasattr(self, "cached_ids") or not self.cached_ids: + if not self.cached_ids: return query = " or ".join( From b4a8377582a27acc16d87f545aa55e1ff6bae1fc Mon Sep 17 00:00:00 2001 From: "Restyled.io" Date: Fri, 22 Nov 2019 00:03:33 +0000 Subject: [PATCH 8/8] Restyled by black --- dvc/remote/gdrive/__init__.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py index c28b6d4535..5c2c27dab5 100644 --- a/dvc/remote/gdrive/__init__.py +++ b/dvc/remote/gdrive/__init__.py @@ -108,7 +108,11 @@ def gdrive_download_file( from dvc.progress import Tqdm gdrive_file = self.drive.CreateFile({"id": file_id}) - with Tqdm(desc=progress_name, total=int(gdrive_file["fileSize"]), disable=no_progress_bar): + with Tqdm( + desc=progress_name, + total=int(gdrive_file["fileSize"]), + disable=no_progress_bar, + ): gdrive_file.GetContentFile(to_file) def gdrive_list_item(self, query): @@ -132,20 +136,20 @@ def cache_root_dirs(self): @property def cached_dirs(self): - if not hasattr(self, '_cached_dirs'): + if not hasattr(self, "_cached_dirs"): self.drive return self._cached_dirs @property def cached_ids(self): - if not hasattr(self, '_cached_ids'): + if not hasattr(self, "_cached_ids"): self.drive return self._cached_ids @property @wrap_with(threading.RLock()) def drive(self): - if not hasattr(self, '_gdrive'): + if not hasattr(self, "_gdrive"): from pydrive.auth import GoogleAuth from pydrive.drive import GoogleDrive @@ -214,7 +218,9 @@ def get_remote_item(self, name, parents_ids): query += " and trashed=false and title='{}'".format(name) # Limit found remote items count to 1 in response - item_list = self.drive.ListFile({"q": query, "maxResults": 1}).GetList() + item_list = self.drive.ListFile( + {"q": query, "maxResults": 1} + ).GetList() return next(iter(item_list), None) def resolve_remote_item_from_path(self, parents_ids, path_parts, create): @@ -240,7 +246,9 @@ def subtract_root_path(self, path_parts): def get_remote_id_from_cache(self, path_info): remote_ids = [] - path_parts, parents_ids = self.subtract_root_path(path_info.path.split("/")) + path_parts, parents_ids = self.subtract_root_path( + path_info.path.split("/") + ) if ( hasattr(self, "_cached_dirs") and path_info != self.path_info @@ -254,12 +262,16 @@ def get_remote_id_from_cache(self, path_info): return remote_ids, parents_ids, path_parts def get_remote_id(self, path_info, create=False): - remote_ids, parents_ids, path_parts = self.get_remote_id_from_cache(path_info) + remote_ids, parents_ids, path_parts = self.get_remote_id_from_cache( + path_info + ) if not path_parts and remote_ids: return remote_ids[0] - file1 = self.resolve_remote_item_from_path(parents_ids, path_parts, create) + file1 = self.resolve_remote_item_from_path( + parents_ids, path_parts, create + ) return file1["id"] if file1 else "" def exists(self, path_info):