diff --git a/src/fileLoaders.py b/src/fileLoaders.py index 366ab7c6..b6c6bacd 100644 --- a/src/fileLoaders.py +++ b/src/fileLoaders.py @@ -23,10 +23,36 @@ # util variables glogger = glogging.getGrlcLogger(__name__) - class BaseLoader: """Base class for File Loaders""" + def _getLicenseFileFromPath(self, path): + """Returns the URL of a license file in the specified path, or None if not found.""" + try: + files = self._fetchFilesFromPath(path) + for f in files: + if f["name"].lower() == "license" or f["name"].lower() == "licence": + return f["download_url"] + except Exception: + pass + return None + + def _fetchFilesFromPath(self, path): + """To be implemented by sub-classes. Returns a list of file items from the specified path.""" + raise NotImplementedError("Subclasses must override _fetchFilesFromPath()!") + + def getLicenceURL(self): + """Returns the URL of the license file in this repository if one exists. + Default implementation for loaders that support subdirectories.""" + # Check subdirectory first (if subdir is set) + if hasattr(self, 'subdir') and self.subdir: + licence_url = self._getLicenseFileFromPath(self.subdir.strip("/")) + if licence_url: + return licence_url + # If no license found in subdirectory, check root folder + return self._getLicenseFileFromPath("") + return None + def getTextForName(self, query_name): """Return the query text and query type for the given query name. Note that file extention is not part of the query name. For example, @@ -92,7 +118,12 @@ def __init__(self, user, repo, subdir=None, sha=None, prov=None): def fetchFiles(self): """Returns a list of file items contained on the github repo.""" - contents = self.gh_repo.get_contents(self.subdir.strip("/"), ref=self.sha) + return self._fetchFilesFromPath(self.subdir) + + def _fetchFilesFromPath(self, path): + """Returns a list of file items from the specified path in the github repo.""" + path = path.strip("/") + contents = self.gh_repo.get_contents(path, ref=self.sha) files = [] for content_file in contents: if content_file.type == "file": @@ -161,13 +192,6 @@ def getEndpointText(self): """Return content of endpoint file (endpoint.txt)""" return self._getText("endpoint.txt") - def getLicenceURL(self): - """Returns the URL of the license file in this repository if one exists.""" - for f in self.fetchFiles(): - if f["name"].lower() == "license" or f["name"].lower() == "licence": - return f["download_url"] - return None - def getRepoDescription(self): """Return the description of the repository""" return self.gh_repo.description @@ -206,9 +230,14 @@ def __init__(self, user, repo, subdir=None, sha=None, prov=None, branch=None): raise Exception("Repo not found: " + user + "/" + repo) def fetchFiles(self): - """Returns a list of file items contained on the github repo.""" + """Returns a list of file items contained on the gitlab repo.""" + return self._fetchFilesFromPath(self.subdir) + + def _fetchFilesFromPath(self, path): + """Returns a list of file items from the specified path in the gitlab repo.""" + path = path.strip("/") gitlab_files = self.gl_repo.repository_tree( - path=self.subdir.strip("/"), ref=self.branch, all=True + path=path, ref=self.branch, all=True ) files = [] for gitlab_file in gitlab_files: @@ -217,7 +246,7 @@ def fetchFiles(self): files.append( { "download_url": path.join( - self.getRawRepoUri(), self.subdir, name + self.getRawRepoUri(), path, name ), "name": name, "decoded_content": str.encode( @@ -281,13 +310,6 @@ def getEndpointText(self): """Return content of endpoint file (endpoint.txt)""" return self._getText("endpoint.txt") - def getLicenceURL(self): - """Returns the URL of the license file in this repository if one exists.""" - for f in self.fetchFiles(): - if f["name"].lower() == "license" or f["name"].lower() == "licence": - return f["download_url"] - return None - def getRepoDescription(self): """Return the description of the repository""" return self.gl_repo.description