Skip to content

Refactor trans_models_t to use mlr3: replace fit_fun/gof_fun with Learner/AutoTuner interface#24

Draft
Copilot wants to merge 3 commits intomainfrom
copilot/integrate-mlr3-library
Draft

Refactor trans_models_t to use mlr3: replace fit_fun/gof_fun with Learner/AutoTuner interface#24
Copilot wants to merge 3 commits intomainfrom
copilot/integrate-mlr3-library

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 16, 2026

Replaces the ad-hoc fit_fun/gof_fun function-passing interface with first-class mlr3 integration. Learner identity, hyperparameters, and a serialized untrained spec are stored natively in DuckDB; cross-validation uses mlr3 tasks and measures throughout.

Schema (trans_models_t)

Old New Type Notes
model_family learner_id VARCHAR mlr3 learner key
model_params learner_params MAP(VARCHAR,VARCHAR) Atomic scalar params only
fit_call learner_spec BLOB Serialized untrained Learner
goodness_of_fit crossval_measures MAP(VARCHAR,DOUBLE) prediction$score(measures)
model_obj_part crossval_predictions BLOB Serialized PredictionClassif
model_obj_full learner_full BLOB Serialized trained Learner

Primary key: (id_run, id_trans, fit_call)(id_run, id_trans, learner_id)

API

# New signatures — no backward compatibility
fit_partial_models(self, learner, measures, sample_frac = 0.7, seed = NULL, cluster = NULL)
fit_full_models(self, learner, measures, gof_criterion, gof_maximize, cluster = NULL)

# Example
db$trans_models_t <- db$fit_partial_models(
  learner  = mlr3::lrn("classif.ranger", num.trees = 500, predict_type = "prob"),
  measures = list(mlr3::msr("classif.auc")),
  seed     = 42
)
db$trans_models_t <- db$fit_full_models(
  learner       = mlr3::lrn("classif.ranger", predict_type = "prob"),
  measures      = list(mlr3::msr("classif.auc")),
  gof_criterion = "classif.auc",
  gof_maximize  = TRUE
)

Worker logic

  • fit_partial_model_worker: Builds as_task_classif(..., positive = "TRUE"), deep-clones and trains the learner, scores held-out split via prediction$score(measures). For AutoTuner, extracts the optimal inner learner for learner_id/learner_params/learner_spec.
  • fit_full_model_worker: Reconstructs from learner_spec BLOB; falls back to do.call(mlr3::lrn, c(list(learner_id), as.list(learner_params))) on deserialization failure.
  • predict_trans_pot: Deserializes learner_full and calls learner$predict_newdata(pred_data)$prob[, "TRUE"]; removes family-specific dispatch.

New method

get_crossval_plots(id_run, id_trans) deserializes all crossval_predictions BLOBs and returns mlr3viz::autoplot() results for visual GoF inspection.

Dependencies

mlr3 and mlr3viz added to Suggests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants