diff --git a/dash/_pages.py b/dash/_pages.py index ab97ba80df..74ef219796 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -1,6 +1,7 @@ import collections import importlib import os +import pkgutil import re import sys from fnmatch import fnmatch @@ -426,6 +427,15 @@ def _page_meta_tags(app): ] +def _ensure_layout_is_loaded(module_name, page_module): + if ( + module_name in PAGE_REGISTRY + and not PAGE_REGISTRY[module_name]["supplied_layout"] + ): + _validate.validate_pages_layout(module_name, page_module) + PAGE_REGISTRY[module_name]["layout"] = getattr(page_module, "layout") + + def _import_layouts_from_pages(pages_folder): for root, dirs, files in os.walk(pages_folder): dirs[:] = [d for d in dirs if not d.startswith(".") and not d.startswith("_")] @@ -443,10 +453,13 @@ def _import_layouts_from_pages(pages_folder): page_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(page_module) sys.modules[module_name] = page_module + _ensure_layout_is_loaded(module_name, page_module) - if ( - module_name in PAGE_REGISTRY - and not PAGE_REGISTRY[module_name]["supplied_layout"] - ): - _validate.validate_pages_layout(module_name, page_module) - PAGE_REGISTRY[module_name]["layout"] = getattr(page_module, "layout") + +def _import_layouts_from_package(pages_package): + modules = pkgutil.walk_packages( + pages_package.__path__, prefix=pages_package.__name__ + "." + ) + for module in modules: + page_module = importlib.import_module(module.name) + _ensure_layout_is_loaded(module.name, page_module) diff --git a/dash/dash.py b/dash/dash.py index 4e38059aa1..31a550781c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -73,6 +73,7 @@ _page_meta_tags, _path_to_page, _import_layouts_from_pages, + _import_layouts_from_package, ) from ._jupyter import jupyter_dash, JupyterDisplayMode from .types import RendererHooks @@ -384,6 +385,7 @@ def __init__( # pylint: disable=too-many-statements server=True, assets_folder="assets", pages_folder="pages", + pages_package=None, use_pages=None, assets_url_path="assets", assets_ignore="", @@ -490,6 +492,7 @@ def __init__( # pylint: disable=too-many-statements _get_paths.CONFIG = self.config _pages.CONFIG = self.config + self.pages_package = pages_package self.pages_folder = str(pages_folder) self.use_pages = (pages_folder != "pages") if use_pages is None else use_pages self.routing_callback_inputs = routing_callback_inputs or {} @@ -2164,6 +2167,8 @@ def enable_pages(self): return if self.pages_folder: _import_layouts_from_pages(self.config.pages_folder) + if self.pages_package: + _import_layouts_from_package(self.pages_package) @self.server.before_request def router(): diff --git a/tests/conftest.py b/tests/conftest.py index c5801b3ddc..0eaf715607 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import sys import pytest import dash @@ -21,6 +22,9 @@ def clear_pages_state(): def init_pages_state(): """Clear all global state that is used by pages feature.""" + for page in dash._pages.PAGE_REGISTRY.values(): + if page["module"] in sys.modules: + sys.modules.pop(page["module"]) dash._pages.PAGE_REGISTRY.clear() dash._pages.CONFIG.clear() dash._pages.CONFIG.__dict__.clear() diff --git a/tests/unit/pages/test_pages.py b/tests/unit/pages/test_pages.py index df943c8f9b..1ed94260f6 100644 --- a/tests/unit/pages/test_pages.py +++ b/tests/unit/pages/test_pages.py @@ -76,3 +76,12 @@ def test_import_layouts_from_pages( page_entry = list(dash.page_registry.values())[0] assert page_entry["module"] == expected_module_name + + +def test_import_layouts_from_package(clear_pages_state): + from . import custom_pages + + _ = Dash(__package__, use_pages=True, pages_folder="", pages_package=custom_pages) + page_entries = list(dash.page_registry.values()) + assert len(page_entries) == 1 + assert page_entries[0]["module"] == "pages.custom_pages.page"