Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions src/layup/orbitfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
Observation,
gauss,
get_ephem,
run_bk_native_fit,
run_from_vector_with_initial_guess,
)

from layup.convert import convert

from layup.utilities.astrometric_uncertainty import data_weight_Veres2017
Expand Down Expand Up @@ -65,6 +67,25 @@
AU_M = 149597870700
SPEED_OF_LIGHT = 2.99792458e8 * 86400.0 / AU_M

# Heliocentric GM in AU^3 / day^2 (k^2, k = Gaussian gravitational constant).
# Used by the BK-native fit for the bound-orbit energy prior on gdot.
_MU_SUN = 0.00029591220828559104


def _run_fit(assist_ephem, initial_guess, observations, engine):
"""Dispatch a single LM fit step to the configured engine.

Centralizing the dispatch here keeps do_fit's IOD-then-fit pipeline
parameterization-agnostic and lets us add new engines (e.g., a
future distance-dispatched 'auto') with a single edit instead of
threading the choice through every call site.
"""
if engine == "cartesian":
return run_from_vector_with_initial_guess(assist_ephem, initial_guess, observations)
if engine == "bk_native":
return run_bk_native_fit(assist_ephem, initial_guess, observations, _MU_SUN)
raise ValueError(f"Unknown engine {engine!r}; expected one of 'cartesian', 'bk_native'.")


def _get_result_dtypes(primary_id_column_name: str):
"""Helper function to create the result dtype with the correct primary ID column name."""
Expand Down Expand Up @@ -349,7 +370,7 @@ def do_gauss_iod(observations, seq):
return solns


def do_fit(observations, seq, cache_dir, iod="gauss"):
def do_fit(observations, seq, cache_dir, iod="gauss", engine="cartesian"):
"""Carry out an orbit fit to the observations in a
series of steps. A list of lists of observation indices
specifies the order in which the fit proceeds.
Expand Down Expand Up @@ -378,6 +399,12 @@ def do_fit(observations, seq, cache_dir, iod="gauss"):
iod : str
The IOD used to generate an initial guess orbit. Currently supports ['gauss'].
Default is 'gauss'.
engine : str
Which LM fitter to dispatch to. Supported:
- 'cartesian' (default): the existing 6D Cartesian-state fit.
- 'bk_native': the universal Bernstein-Khushalani fit
(run_bk_native_fit), with a fixed bound-orbit energy prior
on gdot. Recovers the Cartesian state at the same epoch.

Returns
-------
Expand All @@ -403,24 +430,24 @@ def do_fit(observations, seq, cache_dir, iod="gauss"):
# Fit primary interval, starting with gauss solution
x = solns[0]
obs = [observations[i] for i in seq[0]]
x = run_from_vector_with_initial_guess(assist_ephem, x, obs)
x = _run_fit(assist_ephem, x, obs, engine)

if (x.flag != 0) and len(solns) > 1:
x = solns[1]
obs = [observations[i] for i in seq[0]]
x = run_from_vector_with_initial_guess(assist_ephem, x, obs)
x = _run_fit(assist_ephem, x, obs, engine)
elif (x.flag != 0) and len(solns) > 2:
x = solns[2]
obs = [observations[i] for i in seq[0]]
x = run_from_vector_with_initial_guess(assist_ephem, x, obs)
x = _run_fit(assist_ephem, x, obs, engine)
if x.flag != 0:
logger.debug(f"Primary interval failed. Total observations: {len(obs)}")
x.flag = 3 # caution
return x

# Attempt to fit all the data, given the fit of the primary interval
obs = observations
x = run_from_vector_with_initial_guess(assist_ephem, x, obs)
x = _run_fit(assist_ephem, x, obs, engine)

# If that failed, build up the solution slowly
if x.flag != 0:
Expand All @@ -429,7 +456,7 @@ def do_fit(observations, seq, cache_dir, iod="gauss"):
for i, sq in enumerate(seq):
obs += [observations[i] for i in sq]
print(i, "of", len(seq), obs[0], sq)
x = run_from_vector_with_initial_guess(assist_ephem, x, obs)
x = _run_fit(assist_ephem, x, obs, engine)
print("flag:", x.flag)
if x.flag != 0:
x.flag = 4
Expand Down Expand Up @@ -458,6 +485,7 @@ def _orbitfit(
sort_array: bool = True,
weight_data: bool = False,
iod: str = "gauss",
engine: str = "cartesian",
):
"""This function will contain all of the calls to the c++ code that will
calculate an orbit given a set of observations. Note that all observations
Expand Down Expand Up @@ -590,7 +618,13 @@ def _orbitfit(
# Perform the orbit fitting
if initial_guess is None or initial_guess["flag"] != 0:
if iod.lower() in ["gauss"]:
res = do_fit(observations=observations, seq=sequence, cache_dir=kernels_loc, iod=iod.lower())
res = do_fit(
observations=observations,
seq=sequence,
cache_dir=kernels_loc,
iod=iod.lower(),
engine=engine,
)
else:
res = do_other_fit(iod=iod.lower())
else:
Expand Down Expand Up @@ -631,6 +665,7 @@ def orbitfit(
debias=False,
weight_data=False,
iod="gauss",
engine="cartesian",
):
"""This is the function that you would call interactively. i.e. from a notebook

Expand Down Expand Up @@ -680,6 +715,7 @@ def orbitfit(
bias_dict=bias_dict,
weight_data=weight_data,
iod=iod,
engine=engine,
)


Expand Down Expand Up @@ -719,13 +755,15 @@ def orbitfit_cli(
weight_data = cli_args.weight_data
output_orbit_format = cli_args.output_orbit_format
iod = cli_args.iod
engine = getattr(cli_args, "engine", "cartesian")
else:
cache_dir = None
debias = False
guess_file = None
weight_data = False
output_orbit_format = "COM" # Default output orbit format.
iod = "gauss"
engine = "cartesian"

_primary_id_column_name = cli_args.primary_id_column_name

Expand Down Expand Up @@ -836,6 +874,7 @@ def orbitfit_cli(
debias=debias,
weight_data=weight_data,
iod=iod,
engine=engine,
)

# Convert the fit_orbits to the preferred output format
Expand Down
13 changes: 13 additions & 0 deletions src/layup_cmdline/orbitfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ def main():
default="gauss",
required=False,
)
optional.add_argument(
"--engine",
help=(
"LM fitter to use after IOD: 'cartesian' (default; classic "
"barycentric-Cartesian LM) or 'bk_native' (universal "
"Bernstein-Khushalani fit with energy prior, better-conditioned "
"for distant short-arc targets and at least as good elsewhere)."
),
dest="engine",
choices=["cartesian", "bk_native"],
default="cartesian",
required=False,
)
optional.add_argument(
"-o",
"--output",
Expand Down
2 changes: 2 additions & 0 deletions src/lib/detection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <cmath>

#include <pybind11/pybind11.h>
#include <pybind11/eigen.h>
#include <pybind11/stl.h>
namespace py = pybind11;

// --- Observation Variant Types ---
Expand Down
Loading
Loading