diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index d554b7d..0fa7f35 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -11,10 +11,10 @@ jobs: surpyval_ci: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set-up Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' @@ -33,7 +33,7 @@ jobs: coverage html - name: Upload coverage html report artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: coverage-html-report path: htmlcov/ diff --git a/requirements.txt b/requirements.txt index f29c293..e176e44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,8 @@ # For development, use requirements_dev.txt instead lifelines==0.27.4 numba==0.56.4 +numdifftools>=0.9.40 numpy-indexed==0.3.5 reliability==0.8.6 -matplotlib==3.6 \ No newline at end of file +matplotlib==3.6 +scipy<1.14.0 \ No newline at end of file diff --git a/surpyval/tests/test_parametric_security.py b/surpyval/tests/test_parametric_security.py new file mode 100644 index 0000000..2b51c2d --- /dev/null +++ b/surpyval/tests/test_parametric_security.py @@ -0,0 +1,28 @@ +import pytest +import surpyval as surv +from surpyval.univariate.parametric.parametric import Parametric + +def test_from_dict_security_validation(): + # Attempt to load a dictionary with a non-distribution attribute + payload = { + "parameterization": "parametric", + "distribution": "__builtins__", + "how": "MLE", + "offset": False, + "lfp": False, + "zi": False, + "params": [1, 2] + } + + with pytest.raises(ValueError, match="Invalid distribution: __builtins__"): + Parametric.from_dict(payload) + +def test_from_dict_valid_distribution(): + # Fit a simple model to get a valid dict + x = [1, 2, 3, 4, 5] + model = surv.Weibull.fit(x) + model_dict = model.to_dict() + + # Should not raise any error + loaded_model = Parametric.from_dict(model_dict) + assert loaded_model.dist.name == "Weibull" diff --git a/surpyval/univariate/parametric/parametric.py b/surpyval/univariate/parametric/parametric.py index f19b51f..8393163 100755 --- a/surpyval/univariate/parametric/parametric.py +++ b/surpyval/univariate/parametric/parametric.py @@ -16,6 +16,31 @@ CB_COLOUR = "#e94c54" +ALLOWED_DISTRIBUTIONS = [ + "Bernoulli", + "Beta", + "ExactEventTime", + "Exponential", + "ExpoWeibull", + "FixedEventProbability", + "Galton", + "Gamma", + "Gauss", + "Gumbel", + "GumbelLEV", + "InstantlyOccurs", + "Logistic", + "LogLogistic", + "LogNormal", + "MixtureModel", + "NeverOccurs", + "Normal", + "Parametric", + "Rayleigh", + "Uniform", + "Weibull", +] + class Parametric(Distribution): """ @@ -80,6 +105,11 @@ def from_dict(cls, model_dict): "Must create parametric model from parametric model dict" ) + if model_dict["distribution"] not in ALLOWED_DISTRIBUTIONS: + raise ValueError( + f"Invalid distribution: {model_dict['distribution']}" + ) + dist = getattr(surv, model_dict["distribution"]) how = model_dict["how"] if "data" in model_dict: