Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 66 additions & 21 deletions faker/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import functools
import logging
import re

from collections import OrderedDict
Expand All @@ -19,10 +20,13 @@

RetType = TypeVar("RetType")

logger = logging.getLogger(__name__)


class Faker:
"""Proxy class capable of supporting multiple locales"""

cache_attr_name = "_cached_{method_name}_mapping"
cache_pattern: Pattern = re.compile(r"^_cached_\w*_mapping$")
generator_attrs = [
attr for attr in dir(Generator) if not attr.startswith("__") and attr not in ["seed", "seed_instance", "random"]
Expand All @@ -39,8 +43,7 @@ def __init__(
) -> None:
self._factory_map: OrderedDict[str, Generator | Faker] = OrderedDict()
self._weights = None
self._unique_proxy = UniqueProxy(self)
self._optional_proxy = OptionalProxy(self)
self._last_used_factory_map: dict[str, Generator | Faker] = {}

if isinstance(locale, str):
locales = [locale.replace("-", "_")]
Expand Down Expand Up @@ -146,21 +149,27 @@ def __deepcopy__(self, memodict):
result._factory_map = copy.deepcopy(self._factory_map, memodict)
result._factories = list(result._factory_map.values())
result._weights = copy.deepcopy(self._weights, memodict)
result._unique_proxy = UniqueProxy(result)
result._unique_proxy._seen = {k: {result._unique_proxy._sentinel} for k in self._unique_proxy._seen.keys()}
result._optional_proxy = OptionalProxy(result)
result._last_used_factory_map = copy.deepcopy(self._last_used_factory_map)
result.unique._seen = {k: {result.unique._sentinel} for k in self.unique._seen.keys()}
result.preferred_unique._seen = {
k: {result.preferred_unique._sentinel} for k in self.preferred_unique._seen.keys()
}
return result

def __setstate__(self, state: Any) -> None:
self.__dict__.update(state)

@property
@functools.cached_property
def unique(self) -> UniqueProxy:
return self._unique_proxy
return UniqueProxy(self)

@property
@functools.cached_property
def preferred_unique(self) -> UniqueProxy:
return UniqueProxy(self, only_prefer_uniqueness=True)

@functools.cached_property
def optional(self) -> OptionalProxy:
return self._optional_proxy
return OptionalProxy(self)

def _select_factory(self, method_name: str) -> Factory:
"""
Expand All @@ -176,12 +185,16 @@ def _select_factory(self, method_name: str) -> Factory:
msg = f"No generator object has attribute {method_name!r}"
raise AttributeError(msg)
elif len(factories) == 1:
self._last_used_factory_map[method_name] = factories[0]
return factories[0]

if weights:
factory = self._select_factory_distribution(factories, weights)
else:
factory = self._select_factory_choice(factories)

self._last_used_factory_map[method_name] = factory

return factory

def _select_factory_distribution(self, factories, weights):
Expand All @@ -202,7 +215,7 @@ def _map_provider_method(self, method_name: str) -> tuple[list[Factory], list[fl
"""

# Return cached mapping if it exists for given method
attr = f"_cached_{method_name}_mapping"
attr = self.cache_attr_name.format(method_name=method_name)
if hasattr(self, attr):
return getattr(self, attr)

Expand Down Expand Up @@ -299,11 +312,12 @@ def items(self) -> list[tuple[str, Generator | Faker]]:


class UniqueProxy:
def __init__(self, proxy: Faker, excluded_types: tuple[type, ...] = ()):
def __init__(self, proxy: Faker, excluded_types: tuple[type, ...] = (), only_prefer_uniqueness: bool = False):
self._proxy = proxy
self._seen: dict = {}
self._sentinel = object()
self._excluded_types = excluded_types
self._only_prefer_uniqueness = only_prefer_uniqueness

def clear(self) -> None:
self._seen = {}
Expand Down Expand Up @@ -339,9 +353,12 @@ def __getitem__(self, locale: str) -> UniqueProxy:
def __getattr__(self, name: str) -> Any:
obj = getattr(self._proxy, name)
if callable(obj):
return self._wrap(name, obj)
else:
raise TypeError("Accessing non-functions through .unique is not supported.")
if name.startswith("current_"):
self._force_next_current_factory(name)
return obj
elif not name.startswith("__"):
return self._wrap(name, obj)
return obj

def __getstate__(self):
# Copy the object's state from self.__dict__ which contains
Expand All @@ -366,6 +383,9 @@ def _make_hashable(self, value: Any) -> Any:
def _wrap(self, name: str, function: Callable) -> Callable:
@functools.wraps(function)
def wrapper(*args, **kwargs):
key = (name, args, tuple(sorted(kwargs.items())))
generated = self._seen.setdefault(key, {self._sentinel})

# If types are excluded, call function once to check return type
if self._excluded_types:
retval = function(*args, **kwargs)
Expand All @@ -375,8 +395,6 @@ def wrapper(*args, **kwargs):
# If not excluded, continue with normal uniqueness logic
# but we already have a value, so we'll use it if unique
hashable_retval = self._make_hashable(retval)
key = (name, args, tuple(sorted(kwargs.items())))
generated = self._seen.setdefault(key, {self._sentinel})

# Check if this first value is unique
if hashable_retval not in generated:
Expand All @@ -385,26 +403,53 @@ def wrapper(*args, **kwargs):
# Not unique, continue with normal loop below
else:
# No exclusions, use original logic
key = (name, args, tuple(sorted(kwargs.items())))
generated = self._seen.setdefault(key, {self._sentinel})
retval = self._sentinel
hashable_retval = self._make_hashable(retval)

# Original uniqueness logic (with potential first attempt already done)
for i in range(_UNIQUE_ATTEMPTS):
if hashable_retval not in generated:
for _ in range(_UNIQUE_ATTEMPTS):
if hashable_retval is None or hashable_retval not in generated:
break
retval = function(*args, **kwargs)
hashable_retval = self._make_hashable(retval)
else:
raise UniquenessException(f"Got duplicated values after {_UNIQUE_ATTEMPTS:,} iterations.")
if self._only_prefer_uniqueness:
logger.warning(
f'There seem to be no more unique values for generator "{name}". '
"Resetting store of generated values as uniqueness is not being enforced."
)
generated.clear()
else:
raise UniquenessException(f"Got duplicated values after {_UNIQUE_ATTEMPTS:,} iterations.")

generated.add(hashable_retval)

return retval

return wrapper

def _force_next_current_factory(self, name: str) -> None:
"""Shrink and eventually rebuild list of cached factories for generator method.

Ensures that 'current_*' generators go through all possible provider options
to make them at least initially unique.
"""
# No need to re-roll factory list if only one is present
if len(self._proxy.factories) == 1:
return

attr = self._proxy.cache_attr_name.format(method_name=name)
if last_used_factory := self._proxy._last_used_factory_map.get(name, None):
# Delete last used factory to force use of another
mapping, weights = getattr(self._proxy, attr)
last_index = mapping.index(last_used_factory)
del mapping[last_index]
if weights:
del weights[last_index]
# Reset provider mapping if no unique options left
if len(mapping) == 0:
delattr(self._proxy, attr)


class OptionalProxy:
"""
Expand Down
3 changes: 1 addition & 2 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,7 @@ def test_dir_include_all_providers_attribute_in_list(self):
"_locales",
"_factory_map",
"_weights",
"_unique_proxy",
"_optional_proxy",
"_last_used_factory_map",
]
)
for factory in fake.factories:
Expand Down
51 changes: 47 additions & 4 deletions tests/test_unique.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging

import pytest

from faker import Faker
from faker.config import AVAILABLE_LOCALES, DEFAULT_LOCALE
from faker.exceptions import UniquenessException

LOGGER = logging.getLogger(__name__)


class TestUniquenessClass:
def test_uniqueness(self):
Expand Down Expand Up @@ -52,14 +57,13 @@ def test_exclusive_arguments(self):
# this would throw a sanity exception
fake.unique.random_int(min=2, max=10)

def test_functions_only(self):
def test_accessing_non_function(self):
"""Accessing non-functions through the `.unique` attribute
will throw a TypeError."""
is allowed."""

fake = Faker()

with pytest.raises(TypeError, match="Accessing non-functions through .unique is not supported."):
fake.unique.locales
assert fake.unique.locales == [DEFAULT_LOCALE]

def test_complex_return_types_is_supported(self):
"""The unique decorator supports complex return types
Expand Down Expand Up @@ -95,3 +99,42 @@ def test_unique_locale_access(self):

with pytest.raises(UniquenessException, match=r"Got duplicated values after [\d,]+ iterations."):
fake.unique["ja_JP"].random_int(min=1, max=10)

def test_preferred_uniqueness(self, caplog):
fake = Faker()

with caplog.at_level(logging.WARNING):
for i in range(3):
_ = fake.preferred_unique.boolean()
assert (
'There seem to be no more unique values for generator "boolean". '
"Resetting store of generated values as uniqueness is not being enforced."
) in caplog.text

def test_current_values_exempt_from_unique_check(self):
fake = Faker()

country_first_attempt = fake.unique.current_country()
assert country_first_attempt == fake.unique.current_country()

def test_initial_current_values_with_multiple_locales_are_unique(self):
fake = Faker(AVAILABLE_LOCALES)

all_country_codes_with_locales = {Faker(locale).current_country_code() for locale in AVAILABLE_LOCALES}
generated_country_codes = {fake.unique.current_country_code() for _ in range(len(AVAILABLE_LOCALES))}

assert all_country_codes_with_locales == generated_country_codes

def test_current_values_start_repeating_after_locales_exhausted(self):
fake = Faker({"en_US": 1, "fr_FR": 2}, use_weighting=True)

locale_count = len(fake.locales)
generated_countries = {fake.unique.current_country() for _ in range(locale_count)}
assert len(generated_countries) == locale_count
assert fake.unique.current_country() in generated_countries

def test_none_values_exempt_from_unique_check(self):
fake = Faker()

for _ in range(2):
assert fake.unique.seed_locale("en_US", 0) is None
Loading