diff --git a/src/stan/mcmc/covar_adaptation.hpp b/src/stan/mcmc/covar_adaptation.hpp index 8c0acebd022..7b0928fafb0 100644 --- a/src/stan/mcmc/covar_adaptation.hpp +++ b/src/stan/mcmc/covar_adaptation.hpp @@ -14,13 +14,11 @@ class covar_adaptation : public windowed_adaptation { explicit covar_adaptation(int n) : windowed_adaptation("covariance"), estimator_(n) {} - bool learn_covariance(Eigen::MatrixXd& covar, const Eigen::VectorXd& q) { - if (adaptation_window()) + void learn_covariance(Eigen::MatrixXd& covar, const Eigen::VectorXd& q) { + if (in_phase2_window()) estimator_.add_sample(q); - if (end_adaptation_window()) { - compute_next_window(); - + if (end_phase2_window()) { estimator_.sample_covariance(covar); double n = static_cast(estimator_.num_samples()); @@ -35,15 +33,7 @@ class covar_adaptation : public windowed_adaptation { "unconstrained space; this may happen when the posterior density " "function is too wide or improper. " "There may be problems with your model specification."); - - estimator_.restart(); - - ++adapt_window_counter_; - return true; } - - ++adapt_window_counter_; - return false; } protected: diff --git a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp index 05b6c80523f..8898d7b968f 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp @@ -29,15 +29,18 @@ class adapt_dense_e_nuts : public dense_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); + if (this->covar_adaptation_.in_phase2_window()) { + this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + } - if (update) { + if (this->covar_adaptation_.end_phase2_window()) { this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + this->covar_adaptation_.compute_next_window(); } + this->covar_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 45e92380f57..f0d5d4b8a29 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -29,15 +29,18 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); + if (this->var_adaptation_.in_phase2_window()) { + this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + this->z_.q); + } - if (update) { + if (this->var_adaptation_.end_phase2_window()) { this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + this->var_adaptation_.compute_next_window(); } + this->var_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/nuts_classic/adapt_dense_e_nuts_classic.hpp b/src/stan/mcmc/hmc/nuts_classic/adapt_dense_e_nuts_classic.hpp index 4838118f8a3..bf0155ba547 100644 --- a/src/stan/mcmc/hmc/nuts_classic/adapt_dense_e_nuts_classic.hpp +++ b/src/stan/mcmc/hmc/nuts_classic/adapt_dense_e_nuts_classic.hpp @@ -29,15 +29,18 @@ class adapt_dense_e_nuts_classic : public dense_e_nuts_classic, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); + if (this->covar_adaptation_.in_phase2_window()) { + this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + } - if (update) { + if (this->covar_adaptation_.end_phase2_window()) { this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + this->covar_adaptation_.compute_next_window(); } + this->covar_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/nuts_classic/adapt_diag_e_nuts_classic.hpp b/src/stan/mcmc/hmc/nuts_classic/adapt_diag_e_nuts_classic.hpp index 77e146119b1..5429eceac44 100644 --- a/src/stan/mcmc/hmc/nuts_classic/adapt_diag_e_nuts_classic.hpp +++ b/src/stan/mcmc/hmc/nuts_classic/adapt_diag_e_nuts_classic.hpp @@ -30,15 +30,18 @@ class adapt_diag_e_nuts_classic : public diag_e_nuts_classic, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); + if (this->var_adaptation_.in_phase2_window()) { + this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + this->z_.q); + } - if (update) { + if (this->var_adaptation_.end_phase2_window()) { this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + this->var_adaptation_.compute_next_window(); } + this->var_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp b/src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp index 5b9c88e7bd4..614d10043e7 100644 --- a/src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp +++ b/src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp @@ -32,16 +32,17 @@ class adapt_dense_e_static_hmc : public dense_e_static_hmc, s.accept_stat()); this->update_L_(); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); + if (this->covar_adaptation_.in_phase2_window()) + this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); - if (update) { + if (this->covar_adaptation_.end_phase2_window()) { this->init_stepsize(logger); this->update_L_(); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); - this->stepsize_adaptation_.restart(); + this->covar_adaptation_.compute_next_window(); } + this->covar_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp b/src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp index a2098ef95f0..2cafa6d7797 100644 --- a/src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp +++ b/src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp @@ -32,16 +32,18 @@ class adapt_diag_e_static_hmc : public diag_e_static_hmc, s.accept_stat()); this->update_L_(); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); + if (this->var_adaptation_.in_phase2_window()) + this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + this->z_.q); - if (update) { + if (this->var_adaptation_.end_phase2_window()) { this->init_stepsize(logger); this->update_L_(); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); - this->stepsize_adaptation_.restart(); + this->var_adaptation_.compute_next_window(); } + this->var_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/static_uniform/adapt_dense_e_static_uniform.hpp b/src/stan/mcmc/hmc/static_uniform/adapt_dense_e_static_uniform.hpp index 716d200cb46..afcf8781a56 100644 --- a/src/stan/mcmc/hmc/static_uniform/adapt_dense_e_static_uniform.hpp +++ b/src/stan/mcmc/hmc/static_uniform/adapt_dense_e_static_uniform.hpp @@ -31,14 +31,18 @@ class adapt_dense_e_static_uniform this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); + if (this->covar_adaptation_.in_phase2_window()) { + this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + } - if (update) { + if (this->covar_adaptation_.end_phase2_window()) { this->init_stepsize(logger); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + this->covar_adaptation_.compute_next_window(); } + this->covar_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/static_uniform/adapt_diag_e_static_uniform.hpp b/src/stan/mcmc/hmc/static_uniform/adapt_diag_e_static_uniform.hpp index 6a068dff830..a3874a60396 100644 --- a/src/stan/mcmc/hmc/static_uniform/adapt_diag_e_static_uniform.hpp +++ b/src/stan/mcmc/hmc/static_uniform/adapt_diag_e_static_uniform.hpp @@ -31,13 +31,18 @@ class adapt_diag_e_static_uniform this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); - if (update) { + if (this->var_adaptation_.in_phase2_window()) { + this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + this->z_.q); + } + + if (this->var_adaptation_.end_phase2_window()) { this->init_stepsize(logger); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + this->var_adaptation_.compute_next_window(); } + this->var_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/xhmc/adapt_dense_e_xhmc.hpp b/src/stan/mcmc/hmc/xhmc/adapt_dense_e_xhmc.hpp index fb473ff7471..e69c8392f34 100644 --- a/src/stan/mcmc/hmc/xhmc/adapt_dense_e_xhmc.hpp +++ b/src/stan/mcmc/hmc/xhmc/adapt_dense_e_xhmc.hpp @@ -29,15 +29,18 @@ class adapt_dense_e_xhmc : public dense_e_xhmc, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); + if (this->covar_adaptation_.in_phase2_window()) { + this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + } - if (update) { + if (this->covar_adaptation_.end_phase2_window()) { this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + this->covar_adaptation_.compute_next_window(); } + this->covar_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/hmc/xhmc/adapt_diag_e_xhmc.hpp b/src/stan/mcmc/hmc/xhmc/adapt_diag_e_xhmc.hpp index 5ddee86725b..5c4153a3c21 100644 --- a/src/stan/mcmc/hmc/xhmc/adapt_diag_e_xhmc.hpp +++ b/src/stan/mcmc/hmc/xhmc/adapt_diag_e_xhmc.hpp @@ -29,15 +29,18 @@ class adapt_diag_e_xhmc : public diag_e_xhmc, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); + if (this->var_adaptation_.in_phase2_window()) { + this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + this->z_.q); + } - if (update) { + if (this->var_adaptation_.end_phase2_window()) { this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + this->var_adaptation_.compute_next_window(); } + this->var_adaptation_.cur_iter_++; } return s; } diff --git a/src/stan/mcmc/stepsize_adaptation.hpp b/src/stan/mcmc/stepsize_adaptation.hpp index 033d84a5a44..ca7d43066d8 100644 --- a/src/stan/mcmc/stepsize_adaptation.hpp +++ b/src/stan/mcmc/stepsize_adaptation.hpp @@ -70,7 +70,10 @@ class stepsize_adaptation : public base_adaptation { epsilon = std::exp(x); } - void complete_adaptation(double& epsilon) { epsilon = std::exp(x_bar_); } + void complete_adaptation(double& epsilon) { + if (x_bar_ > 0) + epsilon = std::exp(x_bar_); + } protected: double counter_; // Adaptation iteration diff --git a/src/stan/mcmc/var_adaptation.hpp b/src/stan/mcmc/var_adaptation.hpp index 840eb77239b..a844c4c22ca 100644 --- a/src/stan/mcmc/var_adaptation.hpp +++ b/src/stan/mcmc/var_adaptation.hpp @@ -14,13 +14,11 @@ class var_adaptation : public windowed_adaptation { explicit var_adaptation(int n) : windowed_adaptation("variance"), estimator_(n) {} - bool learn_variance(Eigen::VectorXd& var, const Eigen::VectorXd& q) { - if (adaptation_window()) + void learn_variance(Eigen::VectorXd& var, const Eigen::VectorXd& q) { + if (in_phase2_window()) estimator_.add_sample(q); - if (end_adaptation_window()) { - compute_next_window(); - + if (end_phase2_window()) { estimator_.sample_variance(var); double n = static_cast(estimator_.num_samples()); @@ -34,15 +32,7 @@ class var_adaptation : public windowed_adaptation { "unconstrained space; this may happen when the posterior density " "function is too wide or improper. " "There may be problems with your model specification."); - - estimator_.restart(); - - ++adapt_window_counter_; - return true; } - - ++adapt_window_counter_; - return false; } protected: diff --git a/src/stan/mcmc/windowed_adaptation.hpp b/src/stan/mcmc/windowed_adaptation.hpp index 36bf1d1d986..75269453328 100644 --- a/src/stan/mcmc/windowed_adaptation.hpp +++ b/src/stan/mcmc/windowed_adaptation.hpp @@ -8,46 +8,55 @@ namespace stan { namespace mcmc { - +/* Warmup schedule for NUTS-HMC windowed adaptation has 3 phases. + * Phases 1 & 3: have a fixed size. + * Phase 2 iterations are divided into "windows" which double in length so that + * phase 2 iterations fill out the total number of warmup iterations. + */ class windowed_adaptation : public base_adaptation { public: explicit windowed_adaptation(std::string name) : estimator_name_(name) { num_warmup_ = 0; - adapt_init_buffer_ = 0; - adapt_term_buffer_ = 0; - adapt_base_window_ = 0; - - restart(); - } - - void restart() { - adapt_window_counter_ = 0; - adapt_window_size_ = adapt_base_window_; - adapt_next_window_ = adapt_init_buffer_ + adapt_window_size_ - 1; + cur_iter_ = 0; + cur_phase2_ = 1; + cur_phase2_end_ = 0; + end_phase1_ = 0; + start_phase3_ = 0; } + /* Record user requested number of warmup iterations and adjust size of + * warmup phases as needed. + * + * @param[in] num_warmup number of warmup draws + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] base_window initial width of slow adaptation interval + * @param[in,out] logger Logger for messages + */ void set_window_params(unsigned int num_warmup, unsigned int init_buffer, unsigned int term_buffer, unsigned int base_window, callbacks::logger& logger) { - if (num_warmup < 20) { - logger.info("WARNING: No " + estimator_name_ + " estimation is"); - logger.info(" performed for num_warmup < 20"); - logger.info(""); - return; - } + num_warmup_ = num_warmup - 1; // count from 0 + end_phase1_ = (init_buffer > 0) ? 0 : init_buffer - 1; + start_phase3_ = num_warmup - term_buffer; + cur_phase2_ = base_window; + cur_phase2_end_ = end_phase1_ + cur_phase2_; if (init_buffer + base_window + term_buffer > num_warmup) { logger.info( "WARNING: There aren't enough warmup " "iterations to fit the"); - logger.info(" three stages of adaptation as currently" - + std::string(" configured.")); + logger.info( + " three stages of adaptation as currently" + " configured."); num_warmup_ = num_warmup; - adapt_init_buffer_ = 0.15 * num_warmup; - adapt_term_buffer_ = 0.10 * num_warmup; - adapt_base_window_ - = num_warmup - (adapt_init_buffer_ + adapt_term_buffer_); + end_phase1_ = 0.15 * num_warmup; // C++ rounds down + start_phase3_ = num_warmup - (0.10 * num_warmup); + cur_phase2_ = (base_window <= 0.75 * num_warmup) + ? base_window + : start_phase3_ - end_phase1_; + cur_phase2_end_ = end_phase1_ + cur_phase2_; logger.info( " Reducing each adaptation stage to " @@ -55,71 +64,51 @@ class windowed_adaptation : public base_adaptation { logger.info(" the given number of warmup iterations:"); std::stringstream init_buffer_msg; - init_buffer_msg << " init_buffer = " << adapt_init_buffer_; + init_buffer_msg << " init_buffer = " << end_phase1_; logger.info(init_buffer_msg); std::stringstream adapt_window_msg; - adapt_window_msg << " adapt_window = " << adapt_base_window_; + adapt_window_msg << " adapt_window = " + << start_phase3_ - end_phase1_; logger.info(adapt_window_msg); std::stringstream term_buffer_msg; - term_buffer_msg << " term_buffer = " << adapt_term_buffer_; + term_buffer_msg << " term_buffer = " + << num_warmup - start_phase3_; logger.info(term_buffer_msg); logger.info(""); - return; } - - num_warmup_ = num_warmup; - adapt_init_buffer_ = init_buffer; - adapt_term_buffer_ = term_buffer; - adapt_base_window_ = base_window; - restart(); } - bool adaptation_window() { - return (adapt_window_counter_ >= adapt_init_buffer_) - && (adapt_window_counter_ < num_warmup_ - adapt_term_buffer_) - && (adapt_window_counter_ != num_warmup_); + bool in_phase2_window() { + return (cur_iter_ > end_phase1_ && cur_iter_ < start_phase3_); } - bool end_adaptation_window() { - return (adapt_window_counter_ == adapt_next_window_) - && (adapt_window_counter_ != num_warmup_); - } + bool end_phase2_window() { return (cur_iter_ == cur_phase2_end_); } + // find next window endpoint + // double window size if possible, else use remaining phase2 iters void compute_next_window() { - if (adapt_next_window_ == num_warmup_ - adapt_term_buffer_ - 1) - return; - - adapt_window_size_ *= 2; - adapt_next_window_ = adapt_window_counter_ + adapt_window_size_; - - if (adapt_next_window_ == num_warmup_ - adapt_term_buffer_ - 1) - return; - - // Boundary of the following window, not the window just computed - unsigned int next_window_boundary - = adapt_next_window_ + 2 * adapt_window_size_; - - // If the following window overtakes the full adaptation window, - // then stretch the current window to the end of the full window - if (next_window_boundary >= num_warmup_ - adapt_term_buffer_) { - adapt_next_window_ = num_warmup_ - adapt_term_buffer_ - 1; + int next_phase2_size = cur_phase2_ * 2; + if (next_phase2_size + cur_iter_ <= start_phase3_) { + cur_phase2_ = next_phase2_size; + cur_phase2_end_ = cur_iter_ + cur_phase2_; + } else { + cur_phase2_end_ = start_phase3_; } } + unsigned int cur_iter_; + protected: std::string estimator_name_; unsigned int num_warmup_; - unsigned int adapt_init_buffer_; - unsigned int adapt_term_buffer_; - unsigned int adapt_base_window_; - - unsigned int adapt_window_counter_; - unsigned int adapt_next_window_; - unsigned int adapt_window_size_; + unsigned int cur_phase2_; + unsigned int cur_phase2_end_; + unsigned int end_phase1_; + unsigned int start_phase3_; }; } // namespace mcmc diff --git a/src/test/unit/mcmc/windowed_adaptation_test.cpp b/src/test/unit/mcmc/windowed_adaptation_test.cpp index c30455b0ff0..7c48455e4a7 100644 --- a/src/test/unit/mcmc/windowed_adaptation_test.cpp +++ b/src/test/unit/mcmc/windowed_adaptation_test.cpp @@ -2,20 +2,49 @@ #include #include -TEST(McmcWindowedAdaptation, set_window_params1) { +TEST(McmcWindowedAdaptation, sampler_defaults) { stan::test::unit::instrumented_logger logger; - stan::mcmc::windowed_adaptation adapter("test"); + adapter.set_window_params(1000, 75, 50, 25, logger); + ASSERT_EQ(0, logger.call_count()); + ASSERT_EQ(0, logger.call_count_info()); + EXPECT_FALSE(adapter.in_phase2_window()); + EXPECT_FALSE(adapter.end_phase2_window()); +} + +TEST(McmcWindowedAdaptation, num_warmup_10) { + stan::test::unit::instrumented_logger logger; + stan::mcmc::windowed_adaptation adapter("test"); adapter.set_window_params(10, 1, 1, 1, logger); + ASSERT_EQ(0, logger.call_count()); + ASSERT_EQ(0, logger.call_count_info()); + + EXPECT_FALSE(adapter.in_phase2_window()); + EXPECT_FALSE(adapter.end_phase2_window()); +} + +TEST(McmcWindowedAdaptation, num_warmup_10_reduce) { + stan::test::unit::instrumented_logger logger; + stan::mcmc::windowed_adaptation adapter("test"); + adapter.set_window_params(10, 75, 25, 50, logger); ASSERT_EQ(logger.call_count(), logger.call_count_info()); - EXPECT_EQ(1, logger.find_info("WARNING: No test estimation is")); - EXPECT_EQ(1, logger.find_info("performed for num_warmup < 20")); + + EXPECT_FALSE(adapter.in_phase2_window()); + EXPECT_FALSE(adapter.end_phase2_window()); } -TEST(McmcWindowedAdaptation, set_window_params2) { +TEST(McmcWindowedAdaptation, num_warmup_1) { stan::test::unit::instrumented_logger logger; + stan::mcmc::windowed_adaptation adapter("test"); + adapter.set_window_params(1, 75, 25, 50, logger); + + EXPECT_FALSE(adapter.in_phase2_window()); + EXPECT_TRUE(adapter.end_phase2_window()); +} +TEST(McmcWindowedAdaptation, sampler_defaults_100) { + stan::test::unit::instrumented_logger logger; stan::mcmc::windowed_adaptation adapter("test"); adapter.set_window_params(100, 75, 50, 25, logger); @@ -35,14 +64,3 @@ TEST(McmcWindowedAdaptation, set_window_params2) { EXPECT_EQ(1, logger.find_info(" adapt_window = 75")); EXPECT_EQ(1, logger.find_info(" term_buffer = 10")); } - -TEST(McmcWindowedAdaptation, set_window_params3) { - stan::test::unit::instrumented_logger logger; - - stan::mcmc::windowed_adaptation adapter("test"); - - adapter.set_window_params(1000, 75, 50, 25, logger); - - ASSERT_EQ(0, logger.call_count()); - ASSERT_EQ(0, logger.call_count_info()); -} diff --git a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_test.cpp b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_test.cpp index 8638966d5dd..ee02dd70232 100644 --- a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_test.cpp @@ -184,3 +184,40 @@ TEST_F(ServicesSampleHmcNutsDenseEAdapt, output_regression) { EXPECT_EQ(1, logger.find_info("seconds (Total)")); EXPECT_EQ(0, logger.call_count_error()); } + +TEST_F(ServicesSampleHmcNutsDenseEAdapt, term_buffer_0) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 150; + int num_samples = 10; + int num_thin = 1; + bool save_warmup = true; + int refresh = 0; + double stepsize = 1.0; + double stepsize_jitter = 0.0; + int max_depth = 10; + double delta = .8; + double gamma = .05; + double kappa = .75; + double t0 = 10; + unsigned int init_buffer = 50; + unsigned int term_buffer = 0; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_dense_e_adapt( + model, context, random_seed, chain, init_radius, num_warmup, num_samples, + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, + logger, init, parameter, diagnostic); + + EXPECT_EQ(0, logger.call_count_error()); + std::vector messages = parameter.string_values(); + for (auto msg : messages) { + if (msg.find("Step size") != std::string::npos) { + EXPECT_NE("Step size = 1", msg); + } + } +} diff --git a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_test.cpp b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_test.cpp index dce2945acf6..bfcdc14d6a7 100644 --- a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_test.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include class ServicesSampleHmcNutsDiagEAdapt : public testing::Test { @@ -82,13 +83,13 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, parameter_checks) { delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init, parameter, diagnostic); - std::vector > parameter_names; + std::vector> parameter_names; parameter_names = parameter.vector_string_values(); - std::vector > parameter_values; + std::vector> parameter_values; parameter_values = parameter.vector_double_values(); - std::vector > diagnostic_names; + std::vector> diagnostic_names; diagnostic_names = diagnostic.vector_string_values(); - std::vector > diagnostic_values; + std::vector> diagnostic_values; diagnostic_values = diagnostic.vector_double_values(); // Expectations of parameter parameter names. @@ -143,13 +144,13 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, output_sizes) { delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init, parameter, diagnostic); - std::vector > parameter_names; + std::vector> parameter_names; parameter_names = parameter.vector_string_values(); - std::vector > parameter_values; + std::vector> parameter_values; parameter_values = parameter.vector_double_values(); - std::vector > diagnostic_names; + std::vector> diagnostic_names; diagnostic_names = diagnostic.vector_string_values(); - std::vector > diagnostic_values; + std::vector> diagnostic_values; diagnostic_values = diagnostic.vector_double_values(); EXPECT_EQ(return_code, 0); @@ -194,3 +195,267 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, output_regression) { EXPECT_EQ(1, logger.find_info("seconds (Total)")); EXPECT_EQ(0, logger.call_count_error()); } + +TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_0) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 150; + int num_samples = 10; + int num_thin = 1; + bool save_warmup = true; + int refresh = 0; + double stepsize = 1.0; + double stepsize_jitter = 0.0; + int max_depth = 10; + double delta = .8; + double gamma = .05; + double kappa = .75; + double t0 = 10; + unsigned int init_buffer = 50; + unsigned int term_buffer = 0; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_diag_e_adapt( + model, context, random_seed, chain, init_radius, num_warmup, num_samples, + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, + logger, init, parameter, diagnostic); + + EXPECT_EQ(0, logger.call_count_error()); + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); + + std::vector messages = parameter.string_values(); + for (auto msg : messages) { + if (msg.find("Step size") != std::string::npos) { + EXPECT_NE("Step size = 1", msg); + } + } +} + +TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_1) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 150; + int num_samples = 10; + int num_thin = 1; + bool save_warmup = true; + int refresh = 0; + double stepsize = 1.0; + double stepsize_jitter = 0.0; + int max_depth = 10; + double delta = .8; + double gamma = .05; + double kappa = .75; + double t0 = 10; + unsigned int init_buffer = 49; + unsigned int term_buffer = 1; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_diag_e_adapt( + model, context, random_seed, chain, init_radius, num_warmup, num_samples, + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, + logger, init, parameter, diagnostic); + + EXPECT_EQ(0, logger.call_count_error()); + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); + + std::vector> draws = parameter.vector_double_values(); + auto draw = draws[draws.size() - 1]; + + std::vector messages = parameter.string_values(); + for (auto msg : messages) { + if (msg.find("Step size") != std::string::npos) { + std::vector toks; + boost::split(toks, msg, boost::is_any_of(" ")); + auto adapted = std::stod(toks[toks.size() - 1]); + EXPECT_NEAR(draw[2], adapted, 1e-5); + } + } +} + +TEST_F(ServicesSampleHmcNutsDiagEAdapt, no_stepsize_adapt) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 150; + int num_samples = 10; + int num_thin = 1; + bool save_warmup = true; + int refresh = 0; + double stepsize = 1.0; + double stepsize_jitter = 0.0; + int max_depth = 10; + double delta = .8; + double gamma = .05; + double kappa = .75; + double t0 = 10; + unsigned int init_buffer = 0; + unsigned int term_buffer = 0; + unsigned int window = 50; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_diag_e_adapt( + model, context, random_seed, chain, init_radius, num_warmup, num_samples, + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, + logger, init, parameter, diagnostic); + + EXPECT_EQ(0, logger.call_count_error()); + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); + + std::vector messages = parameter.string_values(); + for (auto msg : messages) { + if (msg.find("Step size") != std::string::npos) { + EXPECT_NE("Step size = 1", msg); + } + } +} + +TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_a) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 35; + int num_samples = 2; + int num_thin = 1; + bool save_warmup = true; + int refresh = 0; + double stepsize = 1.0; + double stepsize_jitter = 0.0; + int max_depth = 10; + double delta = .8; + double gamma = .05; + double kappa = .75; + double t0 = 10; + unsigned int init_buffer = 5; + unsigned int term_buffer = 0; + unsigned int window = 20; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_diag_e_adapt( + model, context, random_seed, chain, init_radius, num_warmup, num_samples, + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, + logger, init, parameter, diagnostic); + + EXPECT_EQ(0, logger.call_count_error()); + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); + + std::vector> draws = parameter.vector_double_values(); + auto draw = draws[draws.size() - 1]; + + std::vector messages = parameter.string_values(); + for (auto msg : messages) { + if (msg.find("Step size") != std::string::npos) { + std::vector toks; + boost::split(toks, msg, boost::is_any_of(" ")); + auto adapted = std::stod(toks[toks.size() - 1]); + EXPECT_NEAR(draw[2], adapted, 1e-5); + } + } +} + +TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_b) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 36; + int num_samples = 2; + int num_thin = 1; + bool save_warmup = true; + int refresh = 0; + double stepsize = 1.0; + double stepsize_jitter = 0.0; + int max_depth = 10; + double delta = .8; + double gamma = .05; + double kappa = .75; + double t0 = 10; + unsigned int init_buffer = 5; + unsigned int term_buffer = 1; + unsigned int window = 30; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_diag_e_adapt( + model, context, random_seed, chain, init_radius, num_warmup, num_samples, + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, + logger, init, parameter, diagnostic); + + EXPECT_EQ(0, logger.call_count_error()); + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); + + std::vector> draws = parameter.vector_double_values(); + auto draw = draws[draws.size() - 1]; + + std::vector messages = parameter.string_values(); + for (auto msg : messages) { + if (msg.find("Step size") != std::string::npos) { + std::vector toks; + boost::split(toks, msg, boost::is_any_of(" ")); + auto adapted = std::stod(toks[toks.size() - 1]); + EXPECT_NEAR(draw[2], adapted, 1e-5); + } + } +} + +TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_c) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 35; + int num_samples = 2; + int num_thin = 1; + bool save_warmup = true; + int refresh = 0; + double stepsize = 1.0; + double stepsize_jitter = 0.0; + int max_depth = 10; + double delta = .8; + double gamma = .05; + double kappa = .75; + double t0 = 10; + unsigned int init_buffer = 0; + unsigned int term_buffer = 0; + unsigned int window = 25; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_diag_e_adapt( + model, context, random_seed, chain, init_radius, num_warmup, num_samples, + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, + logger, init, parameter, diagnostic); + + EXPECT_EQ(0, logger.call_count_error()); + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); + + std::vector> draws = parameter.vector_double_values(); + auto draw = draws[draws.size() - 1]; + + std::vector messages = parameter.string_values(); + for (auto msg : messages) { + if (msg.find("Step size") != std::string::npos) { + std::vector toks; + boost::split(toks, msg, boost::is_any_of(" ")); + auto adapted = std::stod(toks[toks.size() - 1]); + EXPECT_NEAR(draw[2], adapted, 1e-5); + } + } +}