diff --git a/open_mastr/mastr.py b/open_mastr/mastr.py index 23af6f5b..918b375e 100644 --- a/open_mastr/mastr.py +++ b/open_mastr/mastr.py @@ -30,7 +30,6 @@ create_data_dir, get_data_version_dir, get_project_home_dir, - get_output_dir, setup_logger, ) import open_mastr.utils.orm as orm @@ -74,9 +73,8 @@ class Mastr: def __init__(self, engine="sqlite", connect_to_translated_db=False) -> None: validate_parameter_format_for_mastr_init(engine) - self.output_dir = get_output_dir() self.home_directory = get_project_home_dir() - self._sqlite_folder_path = os.path.join(self.output_dir, "data", "sqlite") + self._sqlite_folder_path = os.path.join(self.home_directory, "data", "sqlite") os.makedirs(self._sqlite_folder_path, exist_ok=True) self.is_translated = connect_to_translated_db @@ -194,7 +192,7 @@ def download( # Find the name of the zipped xml folder bulk_download_date = parse_date_string(date) - xml_folder_path = os.path.join(self.output_dir, "data", "xml_download") + xml_folder_path = os.path.join(self.home_directory, "data", "xml_download") os.makedirs(xml_folder_path, exist_ok=True) zipped_xml_file_path = os.path.join( xml_folder_path, diff --git a/open_mastr/utils/config.py b/open_mastr/utils/config.py index 40f67ec8..c1d3a6dc 100644 --- a/open_mastr/utils/config.py +++ b/open_mastr/utils/config.py @@ -20,7 +20,7 @@ import os import yaml import shutil -import pathlib +from pathlib import Path from datetime import date import logging @@ -35,7 +35,7 @@ log = logging.getLogger(__name__) -def get_project_home_dir(): +def get_project_home_dir() -> Path: """Get root dir of project data On linux this path equals `$HOME/.open-MaStR/`, respectively `~/.open-MaStR/` @@ -46,23 +46,18 @@ def get_project_home_dir(): path-like object Absolute path to root dir of open-MaStR project home """ - - return os.path.join(os.path.expanduser("~"), ".open-MaStR") - - -def get_output_dir(): - """Get output directory for csv data, xml file and database. Defaults to get_project_home_dir() - - Returns - ------- - path-like object - Absolute path to output path - """ - - if "OUTPUT_PATH" in os.environ: - return os.environ.get("OUTPUT_PATH") - - return get_project_home_dir() + default_dir = Path.home() / ".open-MaStR" + env_path = os.getenv("OUTPUT_PATH") + if env_path: + try: + candidate = Path(env_path).expanduser().resolve(strict=False) + return candidate + except Exception as exc: + print( + f"Warning: could not use OUTPUT_PATH='{env_path}' ({exc}). " + f"Falling back to default '{default_dir}'." + ) + return default_dir def get_data_version_dir(): @@ -78,10 +73,7 @@ def get_data_version_dir(): """ data_version = get_data_config() - if "OUTPUT_PATH" in os.environ: - return os.path.join(os.environ.get("OUTPUT_PATH"), "data", data_version) - - return os.path.join(get_project_home_dir(), "data", data_version) + return get_project_home_dir() / "data" / data_version def get_filenames(): @@ -93,9 +85,7 @@ def get_filenames(): dict File names used in open-MaStR """ - with open( - os.path.join(get_project_home_dir(), "config", "filenames.yml") - ) as filename_fh: + with open(get_project_home_dir() / "config" / "filenames.yml") as filename_fh: filenames = yaml.safe_load(filename_fh) return filenames @@ -113,7 +103,7 @@ def get_data_config(): today = date.today() - data_config = f'dataversion-{today.strftime("%Y-%m-%d")}' + data_config = f"dataversion-{today.strftime('%Y-%m-%d')}" return data_config @@ -123,7 +113,7 @@ def create_project_home_dir(): project_home = get_project_home_dir() # Create root project home path - if not os.path.isdir(project_home): + if not project_home.is_dir(): # Create project home log.info(f"Create {project_home} used for config, parameters and data.") os.mkdir(project_home) @@ -135,12 +125,10 @@ def create_project_home_dir(): os.mkdir(subdir) # copy default config files - config_path = os.path.join(get_project_home_dir(), "config") + config_path = get_project_home_dir() / "config" log.info(f"I will create a default set of config files in {config_path}") - internal_config_dir = os.path.join( - pathlib.Path(__file__).parent.absolute(), "config" - ) + internal_config_dir = os.path.join(Path(__file__).parent.absolute(), "config") files = ["logging.yml"] for file in files: @@ -165,7 +153,7 @@ def create_data_dir(): def _filenames_generator(): """Write default file names .yml to project home dir""" - filenames_file = os.path.join(get_project_home_dir(), "config", "filenames.yml") + filenames_file = get_project_home_dir() / "config" / "filenames.yml" # How files are prefixed prefix = "bnetza_mastr" @@ -207,7 +195,6 @@ def _filenames_generator(): for section, section_filenames in filenames_template.items(): filenames[section] = {} for tech in TECHNOLOGIES: - # Files for all technologies files = ["joined", "basic", "extended", "extended_fail"] @@ -282,14 +269,12 @@ def setup_logger(): """ # Read logging config - with open( - os.path.join(get_project_home_dir(), "config", "logging.yml") - ) as filename_fh: + with open(get_project_home_dir() / "config" / "logging.yml") as filename_fh: logging_config = yaml.safe_load(filename_fh) # Add logfile location - logging_config["handlers"]["file"]["filename"] = os.path.join( - get_project_home_dir(), "logs", "open_mastr.log" + logging_config["handlers"]["file"]["filename"] = ( + get_project_home_dir() / "logs" / "open_mastr.log" ) logging.config.dictConfig(logging_config)