Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
surpyval_ci:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set-up Python 3.x
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.10'

Expand All @@ -33,7 +33,7 @@ jobs:
coverage html

- name: Upload coverage html report artifact
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: coverage-html-report
path: htmlcov/
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# For development, use requirements_dev.txt instead
lifelines==0.27.4
numba==0.56.4
numdifftools>=0.9.40
numpy-indexed==0.3.5
reliability==0.8.6
matplotlib==3.6
matplotlib==3.6
scipy<1.14.0
28 changes: 28 additions & 0 deletions surpyval/tests/test_parametric_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import surpyval as surv
from surpyval.univariate.parametric.parametric import Parametric

def test_from_dict_security_validation():
# Attempt to load a dictionary with a non-distribution attribute
payload = {
"parameterization": "parametric",
"distribution": "__builtins__",
"how": "MLE",
"offset": False,
"lfp": False,
"zi": False,
"params": [1, 2]
}

with pytest.raises(ValueError, match="Invalid distribution: __builtins__"):
Parametric.from_dict(payload)

def test_from_dict_valid_distribution():
# Fit a simple model to get a valid dict
x = [1, 2, 3, 4, 5]
model = surv.Weibull.fit(x)
model_dict = model.to_dict()

# Should not raise any error
loaded_model = Parametric.from_dict(model_dict)
assert loaded_model.dist.name == "Weibull"
30 changes: 30 additions & 0 deletions surpyval/univariate/parametric/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,31 @@

CB_COLOUR = "#e94c54"

ALLOWED_DISTRIBUTIONS = [
"Bernoulli",
"Beta",
"ExactEventTime",
"Exponential",
"ExpoWeibull",
"FixedEventProbability",
"Galton",
"Gamma",
"Gauss",
"Gumbel",
"GumbelLEV",
"InstantlyOccurs",
"Logistic",
"LogLogistic",
"LogNormal",
"MixtureModel",
"NeverOccurs",
"Normal",
"Parametric",
Comment on lines +35 to +38
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Exclude non-fitters from allowed distribution whitelist

ALLOWED_DISTRIBUTIONS currently includes symbols like MixtureModel, NeverOccurs, and Parametric that are not valid Parametric fitters. In from_dict, these names pass the new whitelist check and then flow into out = cls(dist, ...), where Parametric.__init__ expects dist.k/dist.bounds/dist.param_map; this raises an internal AttributeError instead of the intended ValueError for unrecognized input. A crafted JSON payload using one of these entries can still crash deserialization, so the new validation is not fully enforced for untrusted input.

Useful? React with πŸ‘Β / πŸ‘Ž.

"Rayleigh",
"Uniform",
"Weibull",
]


class Parametric(Distribution):
"""
Expand Down Expand Up @@ -80,6 +105,11 @@ def from_dict(cls, model_dict):
"Must create parametric model from parametric model dict"
)

if model_dict["distribution"] not in ALLOWED_DISTRIBUTIONS:
raise ValueError(
f"Invalid distribution: {model_dict['distribution']}"
)

dist = getattr(surv, model_dict["distribution"])
how = model_dict["how"]
if "data" in model_dict:
Expand Down
Loading