Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 58 additions & 49 deletions erfa_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import functools
import re
from pathlib import Path
from typing import Final
from typing import Final, final

DEFAULT_ERFA_LOC = Path(__file__).with_name("liberfa") / "erfa" / "src"
DEFAULT_TEMPLATE_LOC = Path(__file__).with_name("erfa")
Expand Down Expand Up @@ -130,14 +130,66 @@ def __init__(self, doc):

class Variable:
"""Properties shared by Argument, Return and StatusCode."""

def __init__(self, ctype: str, name: str | None = None) -> None:
self.ctype: Final = ctype
self.name: Final = "c_retval" if name is None else name

@final
@property
def npy_type(self):
def npy_type(self) -> str:
"""Predefined type used by numpy ufuncs to indicate a given ctype.

Eg., NPY_DOUBLE for double.
"""
return "NPY_" + self.ctype.upper()

@property
def dtype(self) -> str:
return "dt_" + self.ctype

@property
def signature_shape(self) -> str:
return "()"


class Argument(Variable):
def __init__(self, definition: str, doc: FunctionDoc) -> None:
self.doc = doc
ctype, ptr_name_arr = definition.strip().rsplit(" ", 1)
name_arr = ptr_name_arr.removeprefix("*").removesuffix("[]")
self.is_ptr = name_arr != ptr_name_arr
if "[" in name_arr:
name_arr, arr = name_arr.split("[", 1)
self.shape = tuple([int(size) for size in arr[:-1].split("][")])
else:
self.shape = ()
super().__init__(ctype, name_arr)

@functools.cached_property
def inout_state(self) -> str:
inout_state = ""
for i in self.doc.input:
if self.name in i.name.split(","):
inout_state = "in"
for o in self.doc.output:
if self.name in o.name.split(","):
inout_state = "inout" if inout_state == "in" else "out"
return inout_state

@property
def name_for_call(self) -> str:
"""How the argument should be used in the call to the ERFA function.

This takes care of ensuring that inputs are passed by value,
as well as adding back the number of bodies for any LDBODY argument.
The latter presumes that in the ufunc inner loops, that number is
called 'nb'.
"""
if self.ctype == "eraLDBODY":
return "nb, _" + self.name
return ("_" if self.is_ptr else "*_") + self.name

@property
def dtype(self):
"""Name of dtype corresponding to the ctype.
Expand Down Expand Up @@ -173,7 +225,7 @@ def dtype(self):
case "double", (2,):
return "dt_pvdpv"
case _, ():
return "dt_" + self.ctype
return super().dtype
raise ValueError(f"ctype {self.ctype} with shape {self.shape} not recognized.")

@property
Expand Down Expand Up @@ -215,54 +267,13 @@ def signature_shape(self):
return "(3)"
case "double", (3, 3):
return "(3, 3)"
return "()"


class Argument(Variable):

def __init__(self, definition, doc):
self.doc = doc
self.ctype, ptr_name_arr = definition.strip().rsplit(" ", 1)
name_arr = ptr_name_arr.removeprefix("*").removesuffix("[]")
self.is_ptr = name_arr != ptr_name_arr
if "[" in name_arr:
self.name, arr = name_arr.split("[", 1)
self.shape = tuple([int(size) for size in arr[:-1].split("][")])
else:
self.name = name_arr
self.shape = ()

@functools.cached_property
def inout_state(self):
inout_state = ""
for i in self.doc.input:
if self.name in i.name.split(","):
inout_state = "in"
for o in self.doc.output:
if self.name in o.name.split(","):
inout_state = "inout" if inout_state == "in" else "out"
return inout_state

@property
def name_for_call(self):
"""How the argument should be used in the call to the ERFA function.

This takes care of ensuring that inputs are passed by value,
as well as adding back the number of bodies for any LDBODY argument.
The latter presumes that in the ufunc inner loops, that number is
called 'nb'.
"""
if self.ctype == 'eraLDBODY':
return 'nb, _' + self.name
return ("_" if self.is_ptr else "*_") + self.name
return super().signature_shape


class StatusCode(Variable):
def __init__(self, ctype: str, doc: FunctionDoc, funcname: str) -> None:
self.name = "c_retval"
self.inout_state = "stat"
self.ctype = "int"
self.shape = ()
super().__init__(ctype)

status = re.search(
r"Returned \(function value\):\n\s+\w+\s+status.*?:(.+?)\s+Notes?:",
Expand Down Expand Up @@ -291,10 +302,8 @@ def to_python(self) -> str:
class Return(Variable):

def __init__(self, ctype, doc):
self.name = 'c_retval'
self.inout_state = "ret"
self.ctype = ctype
self.shape = ()
super().__init__(ctype)


class Function:
Expand Down