diff --git a/erfa_generator.py b/erfa_generator.py index 32ec0a7..31159e1 100644 --- a/erfa_generator.py +++ b/erfa_generator.py @@ -13,7 +13,7 @@ import functools import re from pathlib import Path -from typing import Final +from typing import Final, final DEFAULT_ERFA_LOC = Path(__file__).with_name("liberfa") / "erfa" / "src" DEFAULT_TEMPLATE_LOC = Path(__file__).with_name("erfa") @@ -130,14 +130,66 @@ def __init__(self, doc): class Variable: """Properties shared by Argument, Return and StatusCode.""" + + def __init__(self, ctype: str, name: str | None = None) -> None: + self.ctype: Final = ctype + self.name: Final = "c_retval" if name is None else name + + @final @property - def npy_type(self): + def npy_type(self) -> str: """Predefined type used by numpy ufuncs to indicate a given ctype. Eg., NPY_DOUBLE for double. """ return "NPY_" + self.ctype.upper() + @property + def dtype(self) -> str: + return "dt_" + self.ctype + + @property + def signature_shape(self) -> str: + return "()" + + +class Argument(Variable): + def __init__(self, definition: str, doc: FunctionDoc) -> None: + self.doc = doc + ctype, ptr_name_arr = definition.strip().rsplit(" ", 1) + name_arr = ptr_name_arr.removeprefix("*").removesuffix("[]") + self.is_ptr = name_arr != ptr_name_arr + if "[" in name_arr: + name_arr, arr = name_arr.split("[", 1) + self.shape = tuple([int(size) for size in arr[:-1].split("][")]) + else: + self.shape = () + super().__init__(ctype, name_arr) + + @functools.cached_property + def inout_state(self) -> str: + inout_state = "" + for i in self.doc.input: + if self.name in i.name.split(","): + inout_state = "in" + for o in self.doc.output: + if self.name in o.name.split(","): + inout_state = "inout" if inout_state == "in" else "out" + return inout_state + + @property + def name_for_call(self) -> str: + """How the argument should be used in the call to the ERFA function. + + This takes care of ensuring that inputs are passed by value, + as well as adding back the number of bodies for any LDBODY argument. + The latter presumes that in the ufunc inner loops, that number is + called 'nb'. + """ + if self.ctype == "eraLDBODY": + return "nb, _" + self.name + return ("_" if self.is_ptr else "*_") + self.name + @property def dtype(self): """Name of dtype corresponding to the ctype. @@ -173,7 +225,7 @@ def dtype(self): case "double", (2,): return "dt_pvdpv" case _, (): - return "dt_" + self.ctype + return super().dtype raise ValueError(f"ctype {self.ctype} with shape {self.shape} not recognized.") @property @@ -215,54 +267,13 @@ def signature_shape(self): return "(3)" case "double", (3, 3): return "(3, 3)" - return "()" - - -class Argument(Variable): - - def __init__(self, definition, doc): - self.doc = doc - self.ctype, ptr_name_arr = definition.strip().rsplit(" ", 1) - name_arr = ptr_name_arr.removeprefix("*").removesuffix("[]") - self.is_ptr = name_arr != ptr_name_arr - if "[" in name_arr: - self.name, arr = name_arr.split("[", 1) - self.shape = tuple([int(size) for size in arr[:-1].split("][")]) - else: - self.name = name_arr - self.shape = () - - @functools.cached_property - def inout_state(self): - inout_state = "" - for i in self.doc.input: - if self.name in i.name.split(","): - inout_state = "in" - for o in self.doc.output: - if self.name in o.name.split(","): - inout_state = "inout" if inout_state == "in" else "out" - return inout_state - - @property - def name_for_call(self): - """How the argument should be used in the call to the ERFA function. - - This takes care of ensuring that inputs are passed by value, - as well as adding back the number of bodies for any LDBODY argument. - The latter presumes that in the ufunc inner loops, that number is - called 'nb'. - """ - if self.ctype == 'eraLDBODY': - return 'nb, _' + self.name - return ("_" if self.is_ptr else "*_") + self.name + return super().signature_shape class StatusCode(Variable): def __init__(self, ctype: str, doc: FunctionDoc, funcname: str) -> None: - self.name = "c_retval" self.inout_state = "stat" - self.ctype = "int" - self.shape = () + super().__init__(ctype) status = re.search( r"Returned \(function value\):\n\s+\w+\s+status.*?:(.+?)\s+Notes?:", @@ -291,10 +302,8 @@ def to_python(self) -> str: class Return(Variable): def __init__(self, ctype, doc): - self.name = 'c_retval' self.inout_state = "ret" - self.ctype = ctype - self.shape = () + super().__init__(ctype) class Function: