diff --git a/DESCRIPTION b/DESCRIPTION index 12987cab..65f2c240 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: dymiumCore -Version: 0.1.9.9000 +Version: 0.1.9.9002 Title: A Toolkit for Building a Dynamic Microsimulation Model for Integrated Urban Modelling Description: A modular microsimulation modelling framework for integrated urban modelling. Authors@R: c( @@ -32,7 +32,7 @@ Imports: tryCatchLog (>= 1.1.0) Suggests: furrr (>= 0.1.0), - testthat (>= 2.1.0), + testthat (>= 3.0.0), fastmatch (>= 1.1.0), mlogit (>= 1.1.0), caret (>= 6.0.0), @@ -118,3 +118,5 @@ Collate: 'utils.R' 'validate.R' 'zzz.R' +Config/testthat/parallel: true +Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index 69b8daa9..0207e406 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -18,6 +18,7 @@ S3method(predict,ModelMultinomialLogit) S3method(simulate_choice,Model) S3method(simulate_choice,WrappedModel) S3method(simulate_choice,data.frame) +S3method(simulate_choice,dymium.choice_table) S3method(simulate_choice,glm) S3method(simulate_choice,list) S3method(simulate_choice,train) diff --git a/NEWS.md b/NEWS.md index afacc55f..b6fec862 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,8 @@ ## New features - `ModelMultinomialLogit` and `ModelBinaryChoice` now have S3 `predict` and `summary` methods. Note that `ModelMultinomialLogit` requires `newdata` to be in the same format that is required by `mlogit`. +- `Population$household_type()` now returns the number of members in household and remove individuals not belong to any household. +- When calling `World$add()` and `name` is missing, it will see if the object has a `name` field, if it is a `Target` or a `Model`. ## Changes @@ -10,6 +12,7 @@ - add a `name` argument to `Model`'s constructor function and expose it as an R6 active field. - `World$add()` can now be used to add a named `Model` without providing the `name` argument. - `World$add()` gained a `replace` argument with `TRUE` as its default value. +- `Generic` now has an active `name` field which will equal to `NULL` if no name is given. ## Bug fixes diff --git a/R/Entity-functions.R b/R/Entity-functions.R index 9ab56687..06abfc89 100644 --- a/R/Entity-functions.R +++ b/R/Entity-functions.R @@ -137,7 +137,7 @@ impute_history <- function(entity, ids, event = NULL) { #' add_history(Ind, ids = sample(Ind$get_ids(), 10), event = "event1", time = 1) #' combine_histories(world) combine_histories <- function(x) { - checkmate::expect_r6(x, classes = "Container") + checkmate::assert_r6(x, classes = "Container") get_history(x) %>% purrr::keep(., ~ !is.null(.x)) %>% purrr::map2(.x = ., .y = names(.), diff --git a/R/Generic.R b/R/Generic.R index b0aff8f2..e6fe478a 100644 --- a/R/Generic.R +++ b/R/Generic.R @@ -48,8 +48,11 @@ Generic <- R6Class( classname = "Generic", public = list( - initialize = function(...) { - + initialize = function(name) { + if (!missing(name)) { + self$name <- name + } + invisible(self) }, debug = function() { @@ -126,6 +129,17 @@ Generic <- R6Class( } ), + active = list( + name = function(x) { + if (missing(x)) { + private$.name + } else { + checkmate::assert_string(x, null.ok = T, na.ok = FALSE) + private$.name <- x + } + } + ), + private = list( .abstract = function(msg) { # this is a method for abstract methods @@ -150,7 +164,9 @@ Generic <- R6Class( tag = character(), desc = character(), value = list() - ) + ), + + .name = NULL ) ) diff --git a/R/Individual.R b/R/Individual.R index 94b3f069..88599574 100644 --- a/R/Individual.R +++ b/R/Individual.R @@ -295,8 +295,7 @@ Individual <- R6::R6Class( result <- private$get_relationship(ids, type = "children") %>% dt_group_and_sort(x = ., groupby_col = pid_col, group_col = "child_id", sort_order = ids) - checkmate::expect_set_equal(ids, result[['sort_col']], ordered = T, - info = "`ids` and the result are not equal.") + checkmate::assert_set_equal(ids, result[['sort_col']], ordered = TRUE) result[["group_col"]] }, @@ -307,8 +306,7 @@ Individual <- R6::R6Class( .[, living_together := self$living_together(self_ids = get(pid_col), target_ids = child_id)] %>% .[living_together == TRUE] %>% dt_group_and_sort(x = ., groupby_col = pid_col, group_col = "child_id", sort_order = ids) - checkmate::expect_set_equal(ids, result[['sort_col']], ordered = T, - info = "`ids` and the result are not equal.") + checkmate::assert_set_equal(ids, result[['sort_col']], ordered = TRUE) result[["group_col"]] }, diff --git a/R/MatchingMarket.R b/R/MatchingMarket.R index f100a774..ee12f713 100644 --- a/R/MatchingMarket.R +++ b/R/MatchingMarket.R @@ -99,8 +99,8 @@ MatchingMarket <- R6::R6Class( grouping_vars = NULL, max_market_size = 5000 ^ 2) { # CHECK INPUTS - checkmate::expect_data_table(agentset_A) - checkmate::expect_data_table(agentset_B) + checkmate::assert_data_table(agentset_A) + checkmate::assert_data_table(agentset_B) if (missing(id_col_A)) { stopifnot(uniqueN(agentset_A[, 1]) == nrow(agentset_A)) diff --git a/R/MatchingMarketOptimal.R b/R/MatchingMarketOptimal.R index 5f65fe7d..3cb4ffa4 100644 --- a/R/MatchingMarketOptimal.R +++ b/R/MatchingMarketOptimal.R @@ -52,7 +52,7 @@ MatchingMarketOptimal <- R6::R6Class( parallel_wrapper <- function(...) { if (parallel) { stopifnot(requireNamespace('furrr')) - furrr::future_map_dfr(..., .options = furrr::future_options(globals = "self")) + furrr::future_map_dfr(..., .options = furrr::furrr_options(globals = "self")) } else { purrr::map_dfr(...) } diff --git a/R/MatchingMarketStochastic.R b/R/MatchingMarketStochastic.R index d11b53f2..8f1b44d0 100644 --- a/R/MatchingMarketStochastic.R +++ b/R/MatchingMarketStochastic.R @@ -64,7 +64,7 @@ MatchingMarketStochastic <- R6::R6Class( parallel_wrapper <- function(...) { if (parallel) { stopifnot(requireNamespace('furrr')) - furrr::future_map_dfr(..., .options = furrr::future_options(globals = "self")) + furrr::future_map_dfr(..., .options = furrr::furrr_options(globals = "self")) } else { purrr::map_dfr(...) } diff --git a/R/Model.R b/R/Model.R index a420f682..cfdc004b 100644 --- a/R/Model.R +++ b/R/Model.R @@ -100,8 +100,8 @@ Model <- checkmate::assert_function(preprocessing_fn, nargs = 1, null.ok = TRUE) self$preprocessing_fn <- preprocessing_fn self$set(x) - self$name <- name - invisible() + super$initialize(name = name) + invisible(self) }, get = function() { private$.model @@ -133,19 +133,10 @@ Model <- return(data.table::copy(private$.model)) } get(".model", envir = private) - }, - name = function(value) { - if (missing(value)) { - private$.name - } else { - checkmate::assert_string(value, null.ok = T, na.ok = FALSE) - private$.name <- value - } } ), private = list( - .model = NULL, - .name = NULL + .model = NULL ) ) diff --git a/R/ModelMultinomialLogit.R b/R/ModelMultinomialLogit.R index 1f4dc861..7736d395 100644 --- a/R/ModelMultinomialLogit.R +++ b/R/ModelMultinomialLogit.R @@ -65,7 +65,7 @@ ModelMultinomialLogit <- R6::R6Class( #' choice_id (`integer()`), linear_comb (`numeric()`), prob (`numeric()`). Note #' that, 'linear_comb' stands for linear combination (i.e. $$B1 * x1 + B2 * x2$$). predict = function(newdata, chooser_id_col, choice_id_col) { - checkmate::expect_data_frame(newdata) + checkmate::assert_data_frame(newdata) data.table(chooser_id = newdata[[chooser_id_col]], choice_id = newdata[[choice_id_col]], linear_comb = private$.compute_linear_combination(newdata, chooser_id_col, choice_id_col)) %>% @@ -76,7 +76,7 @@ ModelMultinomialLogit <- R6::R6Class( private = list( .compute_linear_combination = function(newdata, chooser_id_col, choice_id_col) { if (inherits(newdata, "dfidx")) { - checkmate::expect_names(x = names(newdata$idx), + checkmate::assert_names(x = names(newdata$idx), identical.to = c(chooser_id_col, choice_id_col)) } else { newdata <- diff --git a/R/Population.R b/R/Population.R index a19d4f32..03024c33 100644 --- a/R/Population.R +++ b/R/Population.R @@ -242,7 +242,9 @@ Population <- R6Class( leave_household = function(ind_ids) { # check that ids in ind_ids and their household ids exist stopifnot(self$get("Individual")$ids_exist(ids = ind_ids)) - stopifnot(self$get("Household")$ids_exist(ids = self$get("Individual")$get_household_ids(ids = ind_ids))) + stopifnot( + self$get("Household")$ids_exist( + ids = self$get("Individual")$get_household_ids(ids = ind_ids))) # leave household self$get("Individual")$remove_household_id(ids = ind_ids) add_history(entity = self$get("Individual"), @@ -286,6 +288,7 @@ Population <- R6Class( ), by = c(Ind$get_hid_col())] %>% # identify relationships .[, `:=`( + n_members = sapply(members, length), couple_hh = purrr::map2_lgl(members, partners, ~ {any(.y %in% .x)}), with_children = purrr::map2_lgl(members, parents, ~ {any(.y %in% .x)}) )] %>% @@ -304,7 +307,12 @@ Population <- R6Class( by.y = Ind$get_hid_col(), sort = FALSE, allow.cartesian = FALSE - ) + ) %>% + # if there are individuals that don't belong to any household they would all + # be added into id:NA, so i thin + .[!is.na(id), ] + + checkmate::assert_character(household_type[["household_type"]], any.missing = FALSE) diff --git a/R/Target.R b/R/Target.R index a4381a75..af891846 100644 --- a/R/Target.R +++ b/R/Target.R @@ -13,12 +13,16 @@ #' @section Construction: #' #' ``` -#' Target$new(x) +#' Target$new(x, name) #' ``` #' #' * `x` :: any object that passes `check_target()`\cr #' A target object or `NULL`. #' +#' * `name` :: `character(1)`\cr +#' Name/Alias of the Target object. This will be used as the [Target] name when +#' it gets added to a [World]. +#' #' @section Active Field (read-only): #' #' * `data`:: a target object\cr @@ -61,7 +65,7 @@ Target <- R6::R6Class( classname = "Target", inherit = dymiumCore::Generic, public = list( - initialize = function(x) { + initialize = function(x, name) { assert_target(x, null.ok = TRUE) if (is.data.frame(x)) { if (!"time" %in% names(x)) { @@ -80,6 +84,7 @@ Target <- R6::R6Class( } else { private$.data <- x } + super$initialize(name = name) invisible(self) }, diff --git a/R/World.R b/R/World.R index 6e858f14..3b5d158f 100644 --- a/R/World.R +++ b/R/World.R @@ -137,10 +137,11 @@ World <- R6::R6Class( ) if (checkmate::test_r6(x, "World")) { - stop("Adding a World object is not permitted.") + stop("Adding a World object to another World object is not permitted.") } - if ((inherits(x, "Entity") | inherits(x, "Container")) & !inherits(x, "Model") & !inherits(x, "Target")) { + if ((inherits(x, "Entity") | inherits(x, "Container")) & + !inherits(x, "Model") & !inherits(x, "Target")) { stopifnot(x$is_dymium_class()) if (!missing(name)) { lg$warn("The given `name` will be ignored since the object in x \\ @@ -150,11 +151,12 @@ World <- R6::R6Class( name <- class(x)[[1]] } - if (inherits(x, "Model") && !is.null(x$name)) { - name = x$name + if (missing(name) && !is.null(x$name) && + (inherits(x, "Model") | inherits(x, "Target"))) { + name <- x$name } - # only allows letters and underscores + # only allows letters and underscores\ checkmate::assert_string(name, pattern = "^[a-zA-Z_]*$", na.ok = FALSE, @@ -197,10 +199,14 @@ World <- R6::R6Class( if (name_object_exists) { if (replace) { - lg$warn("Replacing the object named `{name}` of class `{.class_old}` \\ + warn_msg = + glue::glue( + "Replacing the object named `{name}` of class `{.class_old}` \\ with `{.class_new}`.", - .class_old = self$get(x = name)$class()[[1]], - .class_new = class(x)[[1]]) + .class_old = self$get(x = name)$class()[[1]], + .class_new = class(x)[[1]] + ) + warning(warn_msg) self$remove(name) } else { stop(glue::glue("{name} already exists in {.listname}. Only one instance \\ diff --git a/R/simulate_choice.R b/R/simulate_choice.R index d7b293d6..e1ad9be6 100644 --- a/R/simulate_choice.R +++ b/R/simulate_choice.R @@ -22,7 +22,7 @@ simulate_choice.train <- function(model, newdata, target = NULL, ...) { checkmate::assert_true(model$modelType == "Classification") probs <- predict(model, newdata, type = "prob") - simulate_choice(probs, target) + simulate_choice(create_choice_table(probs), target) } #' @rdname simulate_choice @@ -56,7 +56,7 @@ simulate_choice.glm <- function(model, newdata, target = NULL, ...) { {data.table::data.table(x1 = ., x2 = 1 - .)} %>% data.table::setnames(choices) - simulate_choice(probs, target) + simulate_choice(create_choice_table(probs), target) } #' @rdname simulate_choice @@ -96,30 +96,91 @@ simulate_choice.WrappedModel <- function(model, newdata, target = NULL, ...) { } else { probs <- mlr::getPredictionProbabilities(pred) } - simulate_choice(probs, target) + simulate_choice(create_choice_table(probs), target) } #' @rdname simulate_choice #' @export -simulate_choice.data.frame <- function(model, target = NULL, ...) { - probs <- model - checkmate::assert_data_frame( +simulate_choice.data.frame <- function(model, newdata, target = NULL, ...) { + + # convert to data.table + if (!is.data.table(model)) { + checkmate::assert_data_frame(model, min.rows = 1) + model <- as.data.table(model) + } + + if (!is.data.table(newdata)) { + checkmate::assert_data_frame(newdata, min.rows = 1) + newdata <- as.data.table(newdata) + } + + if (!xor("prob" %in% names(model), "probs" %in% names(model))) { + stop("`model` should contains a numeric probability column named `prob` in a binary", + " choice case or `probs` in a multiple choice case.") + } + + match_vars <- + names(model)[!names(model) %in% c("prob", "probs", "choices")] + + checkmate::assert_names(names(newdata), must.include = match_vars) + + if ("prob" %in% names(model)) { + probs <- + merge(newdata, model, match_vars, sort = FALSE) %>% + .[, .(yes = prob, no = 1 - prob)] + } + + if ("probs" %in% names(model)) { + if (!"choices" %in% names(model)) { + stop("`model` is missing a list column named `choices`.") + } + stop("`model` with multiple choices has not been developed yet :(.") + } + + # check cases + checkmate::assert_data_table( probs, types = 'double', min.cols = 2, + nrows = nrow(newdata), any.missing = FALSE, null.ok = FALSE, col.names = 'unique' ) - if (!is.data.table(probs)) { - setDT(probs) + + simulate_choice(create_choice_table(probs), target) +} + +#' @rdname simulate_choice +#' @param choice_table a `choice_table` object, created by `create_choice_table()`. +#' @export +simulate_choice.dymium.choice_table <- function(choice_table, target = NULL, ...) { + checkmate::assert_data_frame( + choice_table, + types = 'double', + min.cols = 2, + any.missing = FALSE, + null.ok = FALSE, + col.names = 'unique' + ) + if (!is.data.table(choice_table)) { + setDT(choice_table) } - choices <- names(probs) + choices <- names(choice_table) # random draw choices if (!is.null(target)) { - alignment(probs, target) + alignment(choice_table, target) } else { - purrr::pmap_chr(probs, ~ sample_choice(choices, 1, prob = (list(...)))) + purrr::pmap_chr(choice_table, ~ sample_choice(choices, 1, prob = (list(...)))) } } +#' prepend dymium.choice_table +#' +#' @param x any object. +#' +#' @return `x` +create_choice_table = function(x) { + class(x) <- c("dymium.choice_table", class(x)) + x +} diff --git a/R/transition-fnc.R b/R/transition-fnc.R index a76e2e56..aa8929cf 100644 --- a/R/transition-fnc.R +++ b/R/transition-fnc.R @@ -221,7 +221,6 @@ transition <- #' @export get_transition <- function(world, entity, model, target = NULL, targeted_ids = NULL, preprocessing_fn = NULL) { - checkmate::assert_r6(world, classes = "World") if(!checkmate::test_string(entity, na.ok = FALSE)) { @@ -254,6 +253,7 @@ get_transition <- function(world, entity, model, target = NULL, targeted_ids = N model$preprocessing_fn(e_data) } e_data <- dymiumCore::normalise_derived_vars(e_data) + # early return if no data if (nrow(e_data) == 0) { return(data.table(id = integer(), response = character())) } diff --git a/R/utils.R b/R/utils.R index 38c1c0c3..d7270c4c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -348,7 +348,7 @@ dt_group_and_sort <- function(x, groupby_col, group_col, sort_order) { checkmate::assert_data_table(x) stopifnot(groupby_col %in% names(x)) stopifnot(group_col %in% names(x)) - checkmate::expect_integerish(sort_order, unique = T, lower = 0, min.len = 1) + checkmate::assert_integerish(sort_order, unique = TRUE, lower = 0, min.len = 1) # group and sort x_new <- diff --git a/man/Target.Rd b/man/Target.Rd index 512cdf7b..b419ff48 100644 --- a/man/Target.Rd +++ b/man/Target.Rd @@ -13,11 +13,14 @@ functions. If the target is dynamic then its \code{get} will return its target value at the current time or its closest time to the current time. } \section{Construction}{ -\preformatted{Target$new(x) +\preformatted{Target$new(x, name) } \itemize{ \item \code{x} :: any object that passes \code{check_target()}\cr A target object or \code{NULL}. +\item \code{name} :: \code{character(1)}\cr +Name/Alias of the Target object. This will be used as the \link{Target} name when +it gets added to a \link{World}. } } diff --git a/man/create_choice_table.Rd b/man/create_choice_table.Rd new file mode 100644 index 00000000..b49b89b2 --- /dev/null +++ b/man/create_choice_table.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/simulate_choice.R +\name{create_choice_table} +\alias{create_choice_table} +\title{prepend dymium.choice_table} +\usage{ +create_choice_table(x) +} +\arguments{ +\item{x}{any object.} +} +\value{ +\code{x} +} +\description{ +prepend dymium.choice_table +} diff --git a/man/sample_choice_table.Rd b/man/sample_choice_table.Rd new file mode 100644 index 00000000..d25a8b24 --- /dev/null +++ b/man/sample_choice_table.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/simulate_choice.R +\name{sample_choice_table} +\alias{sample_choice_table} +\title{Sample a choice table} +\usage{ +sample_choice_table(x, target = NULL, ...) +} +\arguments{ +\item{x}{a \code{choice_table} object.} + +\item{target}{a \link{Target} object.} + +\item{...dots}{} +} +\value{ +a character vector. +} +\description{ +Sample a choice table +} diff --git a/man/simulate_choice.Rd b/man/simulate_choice.Rd index dbf2150d..864960e3 100644 --- a/man/simulate_choice.Rd +++ b/man/simulate_choice.Rd @@ -8,6 +8,7 @@ \alias{simulate_choice.Model} \alias{simulate_choice.WrappedModel} \alias{simulate_choice.data.frame} +\alias{simulate_choice.dymium.choice_table} \title{Simulate a choice situation} \usage{ simulate_choice(model, ...) @@ -22,7 +23,9 @@ simulate_choice(model, ...) \method{simulate_choice}{WrappedModel}(model, newdata, target = NULL, ...) -\method{simulate_choice}{data.frame}(model, target = NULL, ...) +\method{simulate_choice}{data.frame}(model, newdata, target = NULL, ...) + +\method{simulate_choice}{dymium.choice_table}(choice_table, target = NULL, ...) } \arguments{ \item{model}{a \link{Model} object or an object in \code{\link[=SupportedTransitionModels]{SupportedTransitionModels()}}.} @@ -33,6 +36,8 @@ simulate_choice(model, ...) \item{target}{a \link{Target} object or a named list this is for aligning the simulation outcome to an external target.} + +\item{choice_table}{a \code{choice_table} object, created by \code{create_choice_table()}.} } \value{ a character vector diff --git a/tests/testthat/test-Model.R b/tests/testthat/test-Model.R index 81b944c5..728966ca 100644 --- a/tests/testthat/test-Model.R +++ b/tests/testthat/test-Model.R @@ -1,7 +1,6 @@ test_that("Model initialisation", { m <- Model$new(list(x = 1), name = "model") expect_true(m$name == "model") - m <- Model$new(list(x = 1)) expect_null(m$null) }) diff --git a/tests/testthat/test-Population.R b/tests/testthat/test-Population.R index f915d682..cdc0b388 100644 --- a/tests/testthat/test-Population.R +++ b/tests/testthat/test-Population.R @@ -180,3 +180,15 @@ test_that("`household_type` of two random hid vectors of the same set be equipva all(table(Pop$household_type(hid = sample(1:100))) == table(Pop$household_type(hid = sample(1:100)))) ) }) + +# $leave_household ------------- +test_that("Population$leave_household()", { + create_toy_world() + Pop <- world$get("Population") + Ind <- world$get("Individual") + ind_ids <- sample(Ind$get_ids(), 10) + Pop$leave_household(ind_ids) + # Once individuals left households they cannot be leaving households again + expect_error(Pop$leave_household(ind_ids), + regexp = "Contains missing values") +}) diff --git a/tests/testthat/test-World.R b/tests/testthat/test-World.R index 2224bfe1..b15552b9 100644 --- a/tests/testthat/test-World.R +++ b/tests/testthat/test-World.R @@ -1,4 +1,7 @@ test_that("add", { + + lg$set_threshold("warn") + w <- World$new() # add container @@ -10,7 +13,7 @@ test_that("add", { # change to capture the warning messages lg$set_threshold("warn") - expect_output(w$add( + expect_warning(w$add( Population$new( ind_data = toy_individuals, hh_data = toy_households, @@ -23,9 +26,9 @@ test_that("add", { w$add(Agent$new(toy_individuals, "pid")) w$add(Firm$new(toy_individuals, "pid")) expect_length(w$entities, 4) - expect_output(w$add(Individual$new(toy_individuals, "pid")), + expect_warning(w$add(Individual$new(toy_individuals, "pid")), "Replacing") - expect_output(w$add(Household$new(toy_households, "hid")), + expect_warning(w$add(Household$new(toy_households, "hid")), "Replacing") lg$set_threshold("fatal") @@ -38,8 +41,11 @@ test_that("add", { w$add(Model$new(list(x = 1), "namedModel")) w$add(list(x = 1), "namedModel") - # add world ? - expect_error(w$add(w), regexp = "Adding a World object is not permitted.") + # a world within another world is not permitted + expect_error( + w$add(w), + regexp = "Adding a World object to another World object is not permitted." + ) }) @@ -124,9 +130,22 @@ test_that("active fields", { }) test_that("add target", { - t <- Target$new(x = list(yes = 10, no = 20)) w <- World$new() + + # name using the name arg + t <- Target$new(x = list(yes = 10, no = 20)) w$add(x = t, name = "a_target") + + expect_target(w$targets[["a_target"]], null.ok = FALSE) + expect_warning(w$add(x = t, name = "a_target"), regexp = "Replacing the object named") + + # unnamed target + expect_error(w$add(x = t), regexp = "argument \"name\" is missing, with no default") + + # named target + t <- Target$new(x = list(yes = 10, no = 20), name = "a_target") + expect_warning(w$add(x = t), regexp = "Replacing the object named") + expect_target(w$targets[["a_target"]], null.ok = FALSE) checkmate::expect_r6(w$targets$a_target, "Target") }) diff --git a/tests/testthat/test-makeModel.R b/tests/testthat/test-makeModel.R index 6b3d8eb0..cad50ed2 100644 --- a/tests/testthat/test-makeModel.R +++ b/tests/testthat/test-makeModel.R @@ -7,6 +7,10 @@ test_that("makeModel", { checkmate::expect_numeric(predict(Mod, newdata = toy_individuals), finite = T, any.missing = FALSE) checkmate::expect_numeric(summary(Mod), names = "named") + # binary choice model without the dependent variable in newdata + # predict(Mod, newdata = toy_individuals[, -"sex"]) + # predict(Mod) + # mlogit model mlogit_model <- create_mlogit_model() Mod <- makeModel(mlogit_model) diff --git a/tests/testthat/test-simulate-choice.R b/tests/testthat/test-simulate-choice.R index 5a54babf..af9b5a87 100644 --- a/tests/testthat/test-simulate-choice.R +++ b/tests/testthat/test-simulate-choice.R @@ -27,11 +27,29 @@ test_that("simulate_choice.train multilabels works", { test_that("simulate_choice.data.frame works", { n_rows <- 10 + + # this used to work in commit:0602258ed2bdffb3827c0ac3870dbdd62b575d56, but now is defunct probs <- data.frame(yes = runif(n_rows), no = runif(n_rows), maybe = runif(n_rows)) - checkmate::expect_character(simulate_choice(probs), - pattern = "yes|no|maybe", - any.missing = FALSE, - len = n_rows) + expect_error(simulate_choice(probs)) + + model <- data.frame(age = c(0:99, 0:99), + sex = c(rep('male', 100), rep('female', 100)), + prob = 0.05) + checkmate::expect_character( + simulate_choice(model, newdata = toy_individuals), + pattern = "yes|no", len = nrow(toy_individuals) + ) + checkmate::expect_character( + simulate_choice(as.data.table(model), newdata = toy_individuals), + pattern = "yes|no", len = nrow(toy_individuals) + ) + + model <- data.frame(age = 1, + sex = "female", + prob = 0.05) + expect_error(simulate_choice(model, newdata = toy_individuals), + regexp = "Assertion on 'probs' failed") + }) test_that("simulate_choice.WrappedModel from mlr works", {