Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ classifiers = [
]

[project.optional-dependencies]
test = ["pytest"]
test = ["pytest", "numpy"]

[project.urls]
Documentation = "https://pals-project.readthedocs.io"
Expand Down
67 changes: 65 additions & 2 deletions src/pals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ def load_file_to_dict(filename: str) -> dict:
return pals_data


def _numpy_to_native(obj):
"""Convert a numpy scalar/array to its Python-native equivalent.

Returns ``None`` when the object is not a numpy type or when numpy is not
installed; callers use that to decide whether to fall back to the default
serializer behavior. numpy is an optional dependency.
"""
try:
import numpy as np
except ImportError:
return None

if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.generic):
return obj.item()
return None


def store_dict_to_file(filename: str, pals_dict: dict):
file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions(
filename
Expand All @@ -63,14 +82,58 @@ def store_dict_to_file(filename: str, pals_dict: dict):
if extension == ".json":
import json

json_data = json.dumps(pals_dict, sort_keys=False, indent=2)
def _json_default(obj):
native = _numpy_to_native(obj)
if native is not None:
return native
raise TypeError(
f"Object of type {type(obj).__name__} is not JSON serializable"
)

json_data = json.dumps(
pals_dict, sort_keys=False, indent=2, default=_json_default
)
with open(filename, "w") as file:
file.write(json_data)

elif extension == ".yaml":
import yaml

yaml_data = yaml.dump(pals_dict, default_flow_style=False, sort_keys=False)
# Subclass the safe dumper so numpy representers are scoped to PALS
# serialization and do not leak into the global pyyaml state used by
# other code in the same process.
class _PALSDumper(yaml.SafeDumper):
pass

try:
import numpy as np
except ImportError:
np = None

if np is not None:

def _represent_numpy_scalar(dumper, value):
native = value.item()
if isinstance(native, bool):
return dumper.represent_bool(native)
if isinstance(native, int):
return dumper.represent_int(native)
if isinstance(native, float):
return dumper.represent_float(native)
return dumper.represent_data(native)

def _represent_numpy_array(dumper, value):
return dumper.represent_list(value.tolist())

_PALSDumper.add_multi_representer(np.generic, _represent_numpy_scalar)
_PALSDumper.add_representer(np.ndarray, _represent_numpy_array)

yaml_data = yaml.dump(
pals_dict,
Dumper=_PALSDumper,
default_flow_style=False,
sort_keys=False,
)
with open(filename, "w") as file:
file.write(yaml_data)

Expand Down
74 changes: 74 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import pytest

import pals


Expand Down Expand Up @@ -332,3 +334,75 @@ def test_comprehensive_lattice():
# Clean up temporary files
os.remove(yaml_file)
os.remove(json_file)


def _build_numpy_lattice(np):
"""Build a small lattice using numpy-typed scalar values throughout."""
quad = pals.Quadrupole(
name="q_np",
length=np.float64(0.061),
MagneticMultipoleP=pals.MagneticMultipoleParameters(
Bn1=np.float64(-26.0), Bs1=np.float32(0.5), Kn0=np.int64(-1)
),
)
oct_ = pals.Octupole(
name="o_np",
length=np.float64(0.25),
ElectricMultipoleP=pals.ElectricMultipoleParameters(
En3=np.float64(0.75), Es3=np.float32(0.125)
),
)
return pals.BeamLine(name="line_np", line=[quad, oct_])


def test_yaml_roundtrip_with_numpy():
"""Regression test for issue #67: writing YAML with numpy-typed values
must not produce !!python/object tags, and round-tripping must yield
Python-native floats with the correct numeric values."""
np = pytest.importorskip("numpy")

line = _build_numpy_lattice(np)
yaml_file = "numpy_roundtrip.pals.yaml"
line.to_file(yaml_file)
try:
with open(yaml_file, "r") as f:
text = f.read()

# The bug symptom: YAML contains opaque numpy object tags.
assert "!!python/object" not in text, (
f"YAML output still contains unsafe numpy object tags:\n{text}"
)
assert "numpy" not in text, f"YAML output still references numpy:\n{text}"

loaded = pals.BeamLine.from_file(yaml_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5
assert loaded_quad.MagneticMultipoleP.Kn0 == -1

loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
assert type(loaded_oct.ElectricMultipoleP.En3) is float
finally:
if os.path.exists(yaml_file):
os.remove(yaml_file)
Comment thread
EZoni marked this conversation as resolved.
Outdated


def test_json_roundtrip_with_numpy():
"""JSON path also needs to handle numpy values cleanly (defense-in-depth)."""
np = pytest.importorskip("numpy")

line = _build_numpy_lattice(np)
json_file = "numpy_roundtrip.pals.json"
line.to_file(json_file)
try:
loaded = pals.BeamLine.from_file(json_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
finally:
if os.path.exists(json_file):
os.remove(json_file)
Comment thread
EZoni marked this conversation as resolved.
Outdated
Loading