/**
 * @file triangular.h
 * @brief Header including MCMC algorithm
 */

#ifndef BVHAR_BAYES_TRIANGULAR_DEPRECATED_TRIANGULAR_H
#define BVHAR_BAYES_TRIANGULAR_DEPRECATED_TRIANGULAR_H

#include "./config.h"
#include <type_traits>

namespace bvhar {

// MCMC algorithms
class McmcTriangular;
class McmcReg;
class McmcSv;
// Shrinkage priors
template <typename BaseMcmc> class McmcMinn;
template <typename BaseMcmc> class McmcHierminn;
template <typename BaseMcmc> class McmcSsvs;
template <typename BaseMcmc, bool isGroup> class McmcHorseshoe;
template <typename BaseMcmc, bool isGroup> class McmcNg;
template <typename BaseMcmc, bool isGroup> class McmcDl;
template <typename BaseMcmc> class McmcGdp;
// Running MCMC
template <typename BaseMcmc, bool isGroup> class CtaRun;

/**
 * @brief Corrected Triangular Algorithm (CTA)
 * 
 * This class is a base class to conduct corrected triangular algorithm.
 * 
 */
class McmcTriangular : public McmcAlgo {
public:
	McmcTriangular(const RegParams& params, const RegInits& inits, unsigned int seed)
	: McmcAlgo(params, seed),
		include_mean(params._mean), x(params._x), y(params._y),
		dim(params._dim), dim_design(params._dim_design), num_design(params._num_design),
		num_lowerchol(params._num_lowerchol), num_coef(params._num_coef), num_alpha(params._num_alpha), nrow_coef(params._nrow),
		own_id(params._own_id), grp_id(params._grp_id), grp_vec(params._grp_mat.reshaped()), num_grp(grp_id.size()),
		// reg_record(std::make_unique<RegRecords>(num_iter, dim, num_design, num_coef, num_lowerchol)),
		sparse_record(num_iter, dim, num_design, num_coef, num_lowerchol),
		coef_vec(Eigen::VectorXd::Zero(num_coef)), contem_coef(inits._contem),
		prior_alpha_mean(Eigen::VectorXd::Zero(num_coef)),
		prior_alpha_prec(Eigen::VectorXd::Zero(num_coef)),
		alpha_penalty(Eigen::VectorXd::Zero(num_alpha)),
		prior_chol_mean(Eigen::VectorXd::Zero(num_lowerchol)),
		prior_chol_prec(Eigen::VectorXd::Ones(num_lowerchol)),
		coef_mat(inits._coef), contem_id(0),
		sparse_coef(Eigen::MatrixXd::Zero(dim_design, dim)), sparse_contem(Eigen::VectorXd::Zero(num_lowerchol)),
		chol_lower(build_inv_lower(dim, contem_coef)),
		latent_innov(y - x * coef_mat),
		response_contem(Eigen::VectorXd::Zero(num_design)),
		sqrt_sv(Eigen::MatrixXd::Zero(num_design, dim)),
		prior_sig_shp(params._sig_shp), prior_sig_scl(params._sig_scl) {
		if (include_mean) {
			prior_alpha_mean.tail(dim) = params._mean_non;
			prior_alpha_prec.tail(dim) = 1 / (params._sd_non * Eigen::VectorXd::Ones(dim)).array().square();
		}
		coef_vec.head(num_alpha) = coef_mat.topRows(nrow_coef).reshaped();
		if (include_mean) {
			coef_vec.tail(dim) = coef_mat.bottomRows(1).transpose();
		}
		// reg_record->assignRecords(0, coef_vec, contem_coef, diag_vec);
		sparse_record.assignRecords(0, sparse_coef, sparse_contem);
	}
	virtual ~McmcTriangular() = default;

	/**
	 * @brief Append each class's additional record to the result `LIST`
	 * 
	 * @param list `LIST` containing MCMC record result
	 */
	virtual void appendRecords(LIST& list) = 0;

	void doWarmUp() override {
		std::lock_guard<std::mutex> lock(mtx);
		updateCoefPrec();
		updatePenalty();
		updateSv();
		updateCoef();
		updateImpactPrec();
		updateLatent();
		updateImpact();
		updateChol();
		updateState();
	}

	void doPosteriorDraws() override {
		std::lock_guard<std::mutex> lock(mtx);
		addStep();
		updateCoefPrec();
		updatePenalty();
		updateSv(); // D before coef
		updateCoef();
		updateImpactPrec();
		updateLatent(); // E_t before a
		updateImpact();
		updateChol(); // L before d_i
		updateState();
		updateRecords();
	}

	LIST returnRecords(int num_burn, int thin) override {
		LIST res = gatherRecords();
		appendRecords(res);
		for (auto& record : res) {
			if (IS_MATRIX(ACCESS_LIST(record, res))) {
				ACCESS_LIST(record, res) = thin_record(CAST<Eigen::MatrixXd>(ACCESS_LIST(record, res)), num_iter, num_burn, thin);
			} else {
				ACCESS_LIST(record, res) = thin_record(CAST<Eigen::VectorXd>(ACCESS_LIST(record, res)), num_iter, num_burn, thin);
			}
		}
		return res;
	}

	/**
	 * @brief Return `LdltRecords`
	 * 
	 * @param num_burn Number of burn-in
	 * @param thin Thinning
	 * @param sparse If `true`, return sparsified draws.
	 * @return LdltRecords `LdltRecords` object
	 */
	LdltRecords returnLdltRecords(int num_burn, int thin, bool sparse = false) const {
		return reg_record->returnLdltRecords(sparse_record, num_iter, num_burn, thin, sparse);
	}

	/**
	 * @brief Return `SvRecords`
	 * 
	 * @param num_burn Number of burn-in
	 * @param thin Thinning
	 * @param sparse If `true`, return sparsified draws.
	 * @return SvRecords `SvRecords` object
	 */
	SvRecords returnSvRecords(int num_burn, int thin, bool sparse = false) const {
		return reg_record->returnSvRecords(sparse_record, num_iter, num_burn, thin, sparse);
	}

	/**
	 * @brief Return `LdltRecords` or `SvRecords`
	 * 
	 * @tparam RecordType `LdltRecords` or `SvRecords` 
	 * @param num_burn Number of burn-in
	 * @param thin Thinning
	 * @param sparse If `true`, return sparsified draws.
	 * @return RecordType `LdltRecords` or `SvRecords` 
	 */
	template <typename RecordType>
	RecordType returnStructRecords(int num_burn, int thin, bool sparse = false) const {
		return reg_record->returnRecords<RecordType>(sparse_record, num_iter, num_burn, thin, sparse);
	}

protected:
	bool include_mean;
	Eigen::MatrixXd x;
	Eigen::MatrixXd y;
	int dim; // k
  int dim_design; // kp(+1)
  int num_design; // n = T - p
  int num_lowerchol;
  int num_coef;
	int num_alpha;
	int nrow_coef;
	std::set<int> own_id;
	Eigen::VectorXi grp_id;
	Eigen::VectorXi grp_vec;
	int num_grp;
	std::unique_ptr<RegRecords> reg_record;
	SparseRecords sparse_record;
	Eigen::VectorXd coef_vec;
	Eigen::VectorXd contem_coef;
	Eigen::VectorXd prior_alpha_mean; // prior mean vector of alpha
	Eigen::VectorXd prior_alpha_prec; // Diagonal of alpha prior precision
	Eigen::VectorXd alpha_penalty; // SAVS penalty vector
	Eigen::VectorXd prior_chol_mean; // prior mean vector of a = 0
	Eigen::VectorXd prior_chol_prec; // Diagonal of prior precision of a = I
	Eigen::MatrixXd coef_mat;
	int contem_id;
	Eigen::MatrixXd sparse_coef;
	Eigen::VectorXd sparse_contem;
	Eigen::MatrixXd chol_lower; // L in Sig_t^(-1) = L D_t^(-1) LT
	Eigen::MatrixXd latent_innov; // Z0 = Y0 - X0 A = (eps_p+1, eps_p+2, ..., eps_n+p)^T
	Eigen::VectorXd response_contem; // j-th column of Z0 = Y0 - X0 * A: n-dim
	Eigen::MatrixXd sqrt_sv; // stack sqrt of exp(h_t) = (exp(-h_1t / 2), ..., exp(-h_kt / 2)), t = 1, ..., n => n x k
	Eigen::VectorXd prior_sig_shp;
	Eigen::VectorXd prior_sig_scl;

	/**
	 * @brief Draw state vector
	 * 
	 */
	virtual void updateState() = 0;

	/**
	 * @brief Compute D
	 * 
	 */
	virtual void updateSv() = 0;

	/**
	 * @brief Save coefficient records
	 * 
	 */
	virtual void updateCoefRecords() = 0;

	/**
	 * @brief Draw precision of coefficient based on each shrinkage priors
	 * 
	 */
	virtual void updateCoefPrec() = 0;

	/**
	 * @brief Update SAVS penalty
	 * 
	 */
	void updatePenalty() {
		for (int i = 0; i < num_alpha; ++i) {
			if (own_id.find(grp_vec[i]) != own_id.end()) {
				alpha_penalty[i] = 0;
			} else {
				alpha_penalty[i] = 1;
			}
		}
	}
	// virtual void updatePenalty() = 0;

	/**
	 * @brief Draw precision of contemporaneous coefficient based on each shrinkage priors
	 * 
	 */
	virtual void updateImpactPrec() = 0;

	/**
	 * @brief Save MCMC records
	 * 
	 */
	virtual void updateRecords() = 0;

	/**
	 * @brief Draw coefficients
	 * 
	 */
	void updateCoef() {
		for (int j = 0; j < dim; ++j) {
			coef_mat.col(j).setZero(); // j-th column of A = 0
			Eigen::MatrixXd chol_lower_j = chol_lower.bottomRows(dim - j); // L_(j:k) = a_jt to a_kt for t = 1, ..., j - 1
			Eigen::MatrixXd sqrt_sv_j = sqrt_sv.rightCols(dim - j); // use h_jt to h_kt for t = 1, .. n => (k - j + 1) x k
			Eigen::MatrixXd design_coef = kronecker_eigen(chol_lower_j.col(j), x).array().colwise() / sqrt_sv_j.reshaped().array(); // L_(j:k, j) otimes X0 scaled by D_(1:n, j:k): n(k - j + 1) x kp
			Eigen::VectorXd prior_mean_j(dim_design);
			Eigen::VectorXd prior_prec_j(dim_design);
			Eigen::VectorXd penalty_j = Eigen::VectorXd::Zero(dim_design);
			if (include_mean) {
				prior_mean_j << prior_alpha_mean.segment(j * nrow_coef, nrow_coef), prior_alpha_mean.tail(dim)[j];
				prior_prec_j << prior_alpha_prec.segment(j * nrow_coef, nrow_coef), prior_alpha_prec.tail(dim)[j];
				// penalty_j << alpha_penalty.segment(j * nrow_coef, nrow_coef), alpha_penalty.tail(dim)[j];
				penalty_j.head(nrow_coef) = alpha_penalty.segment(j * nrow_coef, nrow_coef);
				draw_coef(
					coef_mat.col(j), design_coef,
					(((y - x * coef_mat) * chol_lower_j.transpose()).array() / sqrt_sv_j.array()).reshaped(), // Hadamard product between: (Y - X0 A(-j))L_(j:k)^T and D_(1:n, j:k)
					prior_mean_j, prior_prec_j, rng
				);
				coef_vec.head(num_alpha) = coef_mat.topRows(nrow_coef).reshaped();
				coef_vec.tail(dim) = coef_mat.bottomRows(1).transpose();
			} else {
				prior_mean_j = prior_alpha_mean.segment(dim_design * j, dim_design);
				prior_prec_j = prior_alpha_prec.segment(dim_design * j, dim_design);
				penalty_j = alpha_penalty.segment(dim_design * j, dim_design);
				draw_coef(
					coef_mat.col(j),
					design_coef,
					(((y - x * coef_mat) * chol_lower_j.transpose()).array() / sqrt_sv_j.array()).reshaped(),
					prior_mean_j, prior_prec_j, rng
				);
				coef_vec = coef_mat.reshaped();
			}
			draw_mn_savs(sparse_coef.col(j), coef_mat.col(j), x, penalty_j);
		}
	}

	/**
	 * @brief Draw contemporaneous coefficients
	 * 
	 */
	void updateImpact() {
		for (int j = 1; j < dim; ++j) {
			response_contem = latent_innov.col(j).array() / sqrt_sv.col(j).array(); // n-dim
			Eigen::MatrixXd design_contem = latent_innov.leftCols(j).array().colwise() / sqrt_sv.col(j).reshaped().array(); // n x (j - 1)
			contem_id = j * (j - 1) / 2;
			draw_coef(
				contem_coef.segment(contem_id, j),
				design_contem, response_contem,
				prior_chol_mean.segment(contem_id, j),
				prior_chol_prec.segment(contem_id, j),
				rng
			);
			draw_savs(sparse_contem.segment(contem_id, j), contem_coef.segment(contem_id, j), latent_innov.leftCols(j));
		}
	}

	/**
	 * @brief Compute residual matrix for orthogonalization
	 * 
	 */
	void updateLatent() { latent_innov = y - x * coef_mat; }

	/**
	 * @brief Compute L
	 * 
	 */
	void updateChol() { chol_lower = build_inv_lower(dim, contem_coef); }

	/**
	 * @brief Gather MCMC records
	 * 
	 * @return LIST 
	 */
	LIST gatherRecords() {
		LIST res = reg_record->returnListRecords(dim, num_alpha, include_mean);
		reg_record->appendRecords(res);
		sparse_record.appendRecords(res, dim, num_alpha, include_mean);
		return res;
	}
};

/**
 * @brief MCMC for homoskedastic LDLT parameterization
 * 
 */
class McmcReg : public McmcTriangular {
public:
	McmcReg(const RegParams& params, const LdltInits& inits, unsigned int seed)
	: McmcTriangular(params, inits, seed), diag_vec(inits._diag) {
		reg_record = std::make_unique<LdltRecords>(num_iter, dim, num_design, num_coef, num_lowerchol);
		reg_record->assignRecords(0, coef_vec, contem_coef, diag_vec);
	}
	virtual ~McmcReg() = default;

protected:
	void updateState() override { reg_ldlt_diag(diag_vec, prior_sig_shp, prior_sig_scl, latent_innov * chol_lower.transpose(), rng); }
	void updateSv() override { sqrt_sv = diag_vec.cwiseSqrt().transpose().replicate(num_design, 1); }
	void updateCoefRecords() override {
		reg_record->assignRecords(mcmc_step, coef_vec, contem_coef, diag_vec);
		sparse_record.assignRecords(mcmc_step, num_alpha, dim, nrow_coef, sparse_coef, sparse_contem);
	}

private:
	Eigen::VectorXd diag_vec; // inverse of d_i
};

/**
 * @brief MCMC for stochastic volatility
 * 
 */
class McmcSv : public McmcTriangular {
public:
	McmcSv(const SvParams& params, const SvInits& inits, unsigned int seed)
	: McmcTriangular(params, inits, seed),
		ortho_latent(Eigen::MatrixXd::Zero(num_design, dim)),
		lvol_draw(inits._lvol), lvol_init(inits._lvol_init), lvol_sig(inits._lvol_sig),
		prior_init_mean(params._init_mean), prior_init_prec(params._init_prec) {
		reg_record = std::make_unique<SvRecords>(num_iter, dim, num_design, num_coef, num_lowerchol);
		reg_record->assignRecords(0, coef_vec, contem_coef, lvol_draw, lvol_sig, lvol_init);
		sparse_record.assignRecords(0, sparse_coef, sparse_contem);
	}
	virtual ~McmcSv() = default;

protected:
	void updateState() override {
		ortho_latent = latent_innov * chol_lower.transpose(); // L eps_t <=> Z0 U
		ortho_latent = (ortho_latent.array().square() + .0001).array().log(); // adjustment log(e^2 + c) for some c = 10^(-4) against numerical problems
		for (int t = 0; t < dim; t++) {
			varsv_ht(lvol_draw.col(t), lvol_init[t], lvol_sig[t], ortho_latent.col(t), rng);
		}
		varsv_sigh(lvol_sig, prior_sig_shp, prior_sig_scl, lvol_init, lvol_draw, rng);
		varsv_h0(lvol_init, prior_init_mean, prior_init_prec, lvol_draw.row(0), lvol_sig, rng);
	}
	void updateSv() override { sqrt_sv = (lvol_draw / 2).array().exp(); }
	void updateCoefRecords() override {
		reg_record->assignRecords(mcmc_step, coef_vec, contem_coef, lvol_draw, lvol_sig, lvol_init);
		sparse_record.assignRecords(mcmc_step, num_alpha, dim, nrow_coef, sparse_coef, sparse_contem);
	}

private:
	Eigen::MatrixXd ortho_latent; // orthogonalized Z0
	Eigen::MatrixXd lvol_draw; // h_j = (h_j1, ..., h_jn)
	Eigen::VectorXd lvol_init;
	Eigen::VectorXd lvol_sig;
	Eigen::VectorXd prior_init_mean;
	Eigen::MatrixXd prior_init_prec;
};

/**
 * @brief Minnesota prior
 * 
 * @tparam BaseMcmc `McmcReg` or `McmcSv`
 */
template <typename BaseMcmc = McmcReg>
class McmcMinn : public BaseMcmc {
public:
	McmcMinn(
		const MinnParams<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, RegParams, SvParams>::type>& params,
		const typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, LdltInits, SvInits>::type& inits,
		unsigned int seed
	)
	: BaseMcmc(params, inits, seed) {
		prior_alpha_mean.head(num_alpha) = params._prior_mean.reshaped();
		prior_alpha_prec.head(num_alpha) = kronecker_eigen(params._prec_diag, params._prior_prec).diagonal();
		if (include_mean) {
			prior_alpha_mean.tail(dim) = params._mean_non;
		}
	}
	virtual ~McmcMinn() = default;
	void appendRecords(LIST& list) override {}

protected:
	using BaseMcmc::dim;
	using BaseMcmc::num_alpha;
	using BaseMcmc::include_mean;
	using BaseMcmc::prior_alpha_mean;
	using BaseMcmc::prior_alpha_prec;
	// using BaseMcmc::alpha_penalty;
	using BaseMcmc::updateCoefRecords;
	void updateCoefPrec() override {};
	// void updatePenalty() override {};
	void updateImpactPrec() override {};
	void updateRecords() override { updateCoefRecords(); }
};

/**
 * @brief Hierarchical Minnesota prior
 * 
 * @tparam BaseMcmc `McmcReg` or `McmcSv`
 */
template <typename BaseMcmc = McmcReg>
class McmcHierminn : public BaseMcmc {
public:
	McmcHierminn(
		const HierminnParams<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, RegParams, SvParams>::type>& params,
		const HierminnInits<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, LdltInits, SvInits>::type>& inits,
		unsigned int seed
	)
	: BaseMcmc(params, inits, seed),
		cross_id(params._cross_id),
		coef_minnesota(params._minnesota), grp_mat(params._grp_mat),
		grid_size(params._grid_size),
		own_lambda(inits._own_lambda), cross_lambda(inits._cross_lambda), contem_lambda(inits._contem_lambda),
		own_shape(params.shape), own_rate(params.rate),
		cross_shape(params.shape), cross_rate(params.rate),
		contem_shape(params.shape), contem_rate(params.rate) {
		prior_alpha_mean.head(num_alpha) = params._prior_mean.reshaped();
		prior_alpha_prec.head(num_alpha) = kronecker_eigen(params._prec_diag, params._prior_prec).diagonal();
		prior_alpha_prec.head(num_alpha).array() /= own_lambda;
		for (int i = 0; i < num_alpha; ++i) {
			if (cross_id.find(grp_vec[i]) != cross_id.end()) {
				prior_alpha_prec[i] /= cross_lambda; // nu
			}
		}
		if (include_mean) {
			prior_alpha_mean.tail(dim) = params._mean_non;
		}
		prior_chol_prec.array() /= contem_lambda; // divide because it is precision
	}
	virtual ~McmcHierminn() = default;
	void appendRecords(LIST& list) override {}

protected:
	using BaseMcmc::own_id;
	using BaseMcmc::grp_vec;
	using BaseMcmc::dim;
	using BaseMcmc::num_alpha;
	using BaseMcmc::include_mean;
	using BaseMcmc::rng;
	using BaseMcmc::prior_alpha_mean;
	using BaseMcmc::prior_alpha_prec;
	// using BaseMcmc::alpha_penalty;
	using BaseMcmc::prior_chol_mean;
	using BaseMcmc::prior_chol_prec;
	using BaseMcmc::coef_vec;
	using BaseMcmc::contem_coef;
	using BaseMcmc::updateCoefRecords;
	void updateCoefPrec() override {
		minnesota_lambda(
			own_lambda, own_shape, own_rate,
			coef_vec.head(num_alpha), prior_alpha_mean.head(num_alpha), prior_alpha_prec.head(num_alpha),
			rng
		);
		minnesota_nu_griddy(
			cross_lambda, grid_size,
			coef_vec.head(num_alpha), prior_alpha_mean.head(num_alpha), prior_alpha_prec.head(num_alpha),
			grp_vec, cross_id, rng
		);
	}
	void updateImpactPrec() override {
		minnesota_lambda(
			contem_lambda, contem_shape, contem_rate,
			contem_coef, prior_chol_mean, prior_chol_prec,
			rng
		);
	};
	void updateRecords() override { updateCoefRecords(); }

private:
	std::set<int> cross_id;
	bool coef_minnesota;
	Eigen::MatrixXi grp_mat;
	int grid_size;
	double own_lambda;
	double cross_lambda;
	double contem_lambda;
	double own_shape;
	double own_rate;
	double cross_shape;
	double cross_rate;
	double contem_shape;
	double contem_rate;
};

/**
 * @brief Stochastic Search Variable Selection (SSVS) prior
 * 
 * @tparam BaseMcmc `McmcReg` or `McmcSv`
 */
template <typename BaseMcmc = McmcReg>
class McmcSsvs : public BaseMcmc {
public:
	McmcSsvs(
		const SsvsParams<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, RegParams, SvParams>::type>& params,
		const SsvsInits<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, LdltInits, SvInits>::type>& inits,
		unsigned int seed
	)
	: BaseMcmc(params, inits, seed),
		ssvs_record(num_iter, num_alpha, num_grp, num_lowerchol),
		coef_grid(params._coef_grid), contem_grid(params._contem_grid),
		coef_dummy(inits._coef_dummy), coef_weight(inits._coef_weight),
		contem_dummy(Eigen::VectorXd::Ones(num_lowerchol)), contem_weight(inits._contem_weight),
		coef_slab(inits._coef_slab),
		spike_scl(inits._coef_spike_scl), contem_spike_scl(inits._coef_spike_scl),
		ig_shape(params._coef_slab_shape), ig_scl(params._coef_slab_scl),
		contem_ig_shape(params._contem_slab_shape), contem_ig_scl(params._contem_slab_scl),
		contem_slab(inits._contem_slab),
		coef_s1(params._coef_s1), coef_s2(params._coef_s2),
		contem_s1(params._contem_s1), contem_s2(params._contem_s2),
		slab_weight(Eigen::VectorXd::Ones(num_alpha)) {
		ssvs_record.assignRecords(0, coef_dummy, coef_weight, contem_dummy, contem_weight);
	}
	virtual ~McmcSsvs() = default;
	void appendRecords(LIST& list) override {
		list["gamma_record"] = ssvs_record.coef_dummy_record;
	}

protected:
	using BaseMcmc::own_id;
	using BaseMcmc::grp_id;
	using BaseMcmc::grp_vec;
	using BaseMcmc::num_grp;
	using BaseMcmc::num_iter;
	using BaseMcmc::num_alpha;
	using BaseMcmc::num_lowerchol;
	using BaseMcmc::mcmc_step;
	using BaseMcmc::rng;
	using BaseMcmc::prior_alpha_prec;
	// using BaseMcmc::alpha_penalty;
	using BaseMcmc::prior_chol_prec;
	using BaseMcmc::coef_vec;
	using BaseMcmc::contem_coef;
	using BaseMcmc::updateCoefRecords;
	void updateCoefPrec() override {
		ssvs_local_slab(coef_slab, coef_dummy, coef_vec.head(num_alpha), ig_shape, ig_scl, spike_scl, rng);
		for (int j = 0; j < num_grp; j++) {
			slab_weight = (grp_vec.array() == grp_id[j]).select(
				coef_weight[j],
				slab_weight
			);
		}
		ssvs_scl_griddy(spike_scl, coef_grid, coef_vec.head(num_alpha), coef_slab, rng);
		ssvs_dummy(
			coef_dummy,
			coef_vec.head(num_alpha),
			coef_slab, spike_scl * coef_slab, slab_weight,
			rng
		);
		ssvs_mn_weight(coef_weight, grp_vec, grp_id, coef_dummy, coef_s1, coef_s2, rng);
		prior_alpha_prec.head(num_alpha).array() = 1 / (spike_scl * (1 - coef_dummy.array()) * coef_slab.array() + coef_dummy.array() * coef_slab.array());
	}
	void updateImpactPrec() override {
		ssvs_local_slab(contem_slab, contem_dummy, contem_coef, contem_ig_shape, contem_ig_scl, contem_spike_scl, rng);
		ssvs_scl_griddy(contem_spike_scl, contem_grid, contem_coef, contem_slab, rng);
		ssvs_dummy(contem_dummy, contem_coef, contem_slab, contem_spike_scl * contem_slab, contem_weight, rng);
		ssvs_weight(contem_weight, contem_dummy, contem_s1, contem_s2, rng);
		prior_chol_prec = 1 / build_ssvs_sd(contem_spike_scl * contem_slab, contem_slab, contem_dummy).array().square();
	}
	void updateRecords() override {
		updateCoefRecords();
		ssvs_record.assignRecords(mcmc_step, coef_dummy, coef_weight, contem_dummy, contem_weight);
	}

private:
	SsvsRecords ssvs_record;
	int coef_grid, contem_grid;
	Eigen::VectorXd coef_dummy;
	Eigen::VectorXd coef_weight;
	Eigen::VectorXd contem_dummy;
	Eigen::VectorXd contem_weight;
	Eigen::VectorXd coef_slab;
	double spike_scl, contem_spike_scl; // scaling factor between 0 and 1: spike_sd = c * slab_sd
	double ig_shape, ig_scl, contem_ig_shape, contem_ig_scl; // IG hyperparameter for spike sd
	Eigen::VectorXd contem_slab;
	Eigen::VectorXd coef_s1, coef_s2;
	double contem_s1, contem_s2;
	Eigen::VectorXd slab_weight; // pij vector
};

/**
 * @brief Horseshoe prior
 * 
 * @tparam BaseMcmc McmcReg or McmcSv
 * @tparam isGroup If `true`, use group shrinkage parameter
 */
template <typename BaseMcmc = McmcReg, bool isGroup = true>
class McmcHorseshoe : public BaseMcmc {
public:
	McmcHorseshoe(
		const HorseshoeParams<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, RegParams, SvParams>::type>& params,
		const HsInits<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, LdltInits, SvInits>::type>& inits,
		unsigned int seed
	)
	: BaseMcmc(params, inits, seed),
		hs_record(num_iter, num_alpha, num_grp),
		local_lev(inits._init_local), group_lev(inits._init_group), global_lev(isGroup ? inits._init_global : 1.0),
		shrink_fac(Eigen::VectorXd::Zero(num_alpha)),
		latent_local(Eigen::VectorXd::Zero(num_alpha)), latent_group(Eigen::VectorXd::Zero(num_grp)), latent_global(0.0),
		coef_var(Eigen::VectorXd::Zero(num_alpha)),
		contem_local_lev(inits._init_contem_local), contem_global_lev(inits._init_conetm_global),
		contem_var(Eigen::VectorXd::Zero(num_lowerchol)),
		latent_contem_local(Eigen::VectorXd::Zero(num_lowerchol)), latent_contem_global(Eigen::VectorXd::Zero(1)) {
		hs_record.assignRecords(0, shrink_fac, local_lev, group_lev, global_lev);
	}
	virtual ~McmcHorseshoe() = default;
	void appendRecords(LIST& list) override {
		list["lambda_record"] = hs_record.local_record;
		list["eta_record"] = hs_record.group_record;
		list["tau_record"] = hs_record.global_record;
		list["kappa_record"] = hs_record.shrink_record;
	}

protected:
	using BaseMcmc::own_id;
	using BaseMcmc::grp_id;
	using BaseMcmc::grp_vec;
	using BaseMcmc::num_grp;
	using BaseMcmc::num_iter;
	using BaseMcmc::num_alpha;
	using BaseMcmc::num_lowerchol;
	using BaseMcmc::mcmc_step;
	using BaseMcmc::rng;
	using BaseMcmc::prior_alpha_prec;
	using BaseMcmc::prior_chol_prec;
	using BaseMcmc::coef_vec;
	using BaseMcmc::contem_coef;
	using BaseMcmc::updateCoefRecords;
	void updateCoefPrec() override {
		horseshoe_latent(latent_group, group_lev, rng);
		horseshoe_mn_sparsity(group_lev, grp_vec, grp_id, latent_group, global_lev, local_lev, coef_vec.head(num_alpha), 1, rng);
		for (int j = 0; j < num_grp; j++) {
			coef_var = (grp_vec.array() == grp_id[j]).select(
				group_lev[j],
				coef_var
			);
		}
		horseshoe_latent(latent_local, local_lev, rng);
		using is_group = std::integral_constant<bool, isGroup>;
		if (is_group::value) {
			horseshoe_latent(latent_global, global_lev, rng);
			global_lev = horseshoe_global_sparsity(latent_global, coef_var.array() * local_lev.array(), coef_vec.head(num_alpha), 1, rng);
		}
		horseshoe_local_sparsity(local_lev, latent_local, coef_var, coef_vec.head(num_alpha), global_lev * global_lev, rng);
		prior_alpha_prec.head(num_alpha) = 1 / (global_lev * coef_var.array() * local_lev.array()).square();
		shrink_fac = 1 / (1 + prior_alpha_prec.head(num_alpha).array());
	}
	void updateImpactPrec() override {
		horseshoe_latent(latent_contem_local, contem_local_lev, rng);
		horseshoe_latent(latent_contem_global, contem_global_lev, rng);
		contem_var = contem_global_lev.replicate(1, num_lowerchol).reshaped();
		horseshoe_local_sparsity(contem_local_lev, latent_contem_local, contem_var, contem_coef, 1, rng);
		contem_global_lev[0] = horseshoe_global_sparsity(latent_contem_global[0], latent_contem_local, contem_coef, 1, rng);
		prior_chol_prec.setZero();
		prior_chol_prec = 1 / (contem_var.array() * contem_local_lev.array()).square();
	}
	void updateRecords() override {
		updateCoefRecords();
		hs_record.assignRecords(mcmc_step, shrink_fac, local_lev, group_lev, global_lev);
	}

private:
	HorseshoeRecords hs_record;
	Eigen::VectorXd local_lev;
	Eigen::VectorXd group_lev;
	double global_lev;
	Eigen::VectorXd shrink_fac;
	Eigen::VectorXd latent_local;
	Eigen::VectorXd latent_group;
	double latent_global;
	Eigen::VectorXd coef_var;
	Eigen::VectorXd contem_local_lev;
	Eigen::VectorXd contem_global_lev; // -> double
	Eigen::VectorXd contem_var;
	Eigen::VectorXd latent_contem_local;
	Eigen::VectorXd latent_contem_global; // -> double
};

/**
 * @brief Normal-Gamma Prior
 * 
 * @tparam BaseMcmc `McmcReg` or `McmcSv`
 * @tparam isGroup If `true`, use group shrinkage parameter
 */
template <typename BaseMcmc = McmcReg, bool isGroup = true>
class McmcNg : public BaseMcmc {
public:
	McmcNg(
		const NgParams<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, RegParams, SvParams>::type>& params,
		const NgInits<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, LdltInits, SvInits>::type>& inits,
		unsigned int seed
	)
	: BaseMcmc(params, inits, seed),
		ng_record(num_iter, num_alpha, num_grp),
		mh_sd(params._mh_sd),
		local_shape(inits._init_local_shape), local_shape_fac(Eigen::VectorXd::Ones(num_alpha)),
		contem_shape(inits._init_contem_shape),
		group_shape(params._group_shape), group_scl(params._global_scl),
		global_shape(params._global_shape), global_scl(params._global_scl),
		contem_global_shape(params._contem_global_shape), contem_global_scl(params._contem_global_scl),
		local_lev(inits._init_local), group_lev(inits._init_group), global_lev(isGroup ? inits._init_global : 1.0),
		coef_var(Eigen::VectorXd::Zero(num_alpha)),
		contem_global_lev(inits._init_conetm_global),
		contem_fac(contem_global_lev[0] * inits._init_contem_local) {
		ng_record.assignRecords(0, local_lev, group_lev, global_lev);
	}
	virtual ~McmcNg() = default;
	void appendRecords(LIST& list) override {
		list["lambda_record"] = ng_record.local_record;
		list["eta_record"] = ng_record.group_record;
		list["tau_record"] = ng_record.global_record;
	}

protected:
	using BaseMcmc::own_id;
	using BaseMcmc::grp_id;
	using BaseMcmc::grp_vec;
	using BaseMcmc::num_grp;
	using BaseMcmc::num_iter;
	using BaseMcmc::num_alpha;
	using BaseMcmc::num_lowerchol;
	using BaseMcmc::mcmc_step;
	using BaseMcmc::rng;
	using BaseMcmc::prior_alpha_prec;
	using BaseMcmc::prior_chol_prec;
	using BaseMcmc::coef_vec;
	using BaseMcmc::contem_coef;
	using BaseMcmc::updateCoefRecords;
	void updateCoefPrec() override {
		ng_mn_shape_jump(local_shape, local_lev, group_lev, grp_vec, grp_id, global_lev, mh_sd, rng);
		ng_mn_sparsity(group_lev, grp_vec, grp_id, local_shape, global_lev, local_lev, group_shape, group_scl, rng);
		for (int j = 0; j < num_grp; j++) {
			coef_var = (grp_vec.array() == grp_id[j]).select(
				group_lev[j],
				coef_var
			);
			local_shape_fac = (grp_vec.array() == grp_id[j]).select(
				local_shape[j],
				local_shape_fac
			);
		}
		using is_group = std::integral_constant<bool, isGroup>;
		if (is_group::value) {
			global_lev = ng_global_sparsity(local_lev.array() / coef_var.array(), local_shape_fac, global_shape, global_scl, rng);
		}
		ng_local_sparsity(local_lev, local_shape_fac, coef_vec.head(num_alpha), global_lev * coef_var, rng);
		prior_alpha_prec.head(num_alpha) = 1 / local_lev.array().square();
	}
	void updateImpactPrec() override {
		contem_shape = ng_shape_jump(contem_shape, contem_fac, contem_global_lev[0], mh_sd, rng);
		contem_global_lev[0] = ng_global_sparsity(contem_fac, contem_shape, contem_global_shape, contem_global_scl, rng);
		ng_local_sparsity(contem_fac, contem_shape, contem_coef, contem_global_lev.replicate(1, num_lowerchol).reshaped(), rng);
		prior_chol_prec = 1 / contem_fac.array().square();
	}
	void updateRecords() override {
		updateCoefRecords();
		ng_record.assignRecords(mcmc_step, local_lev, group_lev, global_lev);
	}

private:
	NgRecords ng_record;
	double mh_sd;
	Eigen::VectorXd local_shape, local_shape_fac;
	double contem_shape;
	double group_shape, group_scl, global_shape, global_scl, contem_global_shape, contem_global_scl;
	Eigen::VectorXd local_lev;
	Eigen::VectorXd group_lev;
	double global_lev;
	Eigen::VectorXd coef_var;
	Eigen::VectorXd contem_global_lev;
	Eigen::VectorXd contem_fac;
};

/**
 * @brief Dirichlet-Laplace prior
 * 
 * @tparam BaseMcmc `McmcReg` or `McmcSv`
 * @tparam isGroup If `true`, use group shrinkage parameter
 */
template <typename BaseMcmc = McmcReg, bool isGroup = true>
class McmcDl : public BaseMcmc {
public:
	McmcDl(
		const DlParams<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, RegParams, SvParams>::type>& params,
		const GlInits<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, LdltInits, SvInits>::type>& inits,
		unsigned int seed
	)
	: BaseMcmc(params, inits, seed),
		dl_record(num_iter, num_alpha),
		dir_concen(0.0), contem_dir_concen(0.0),
		shape(params._shape), scl(params._scl),
		grid_size(params._grid_size),
		local_lev(inits._init_local), group_lev(Eigen::VectorXd::Zero(num_grp)), global_lev(isGroup ? inits._init_global : 1.0),
		latent_local(Eigen::VectorXd::Zero(num_alpha)),
		coef_var(Eigen::VectorXd::Zero(num_alpha)),
		contem_local_lev(inits._init_contem_local), contem_global_lev(inits._init_conetm_global),
		latent_contem_local(Eigen::VectorXd::Zero(num_lowerchol)) {
		dl_record.assignRecords(0, local_lev, global_lev);
	}
	virtual ~McmcDl() = default;
	void appendRecords(LIST& list) override {
		list["lambda_record"] = dl_record.local_record;
		list["tau_record"] = dl_record.global_record;
	}

protected:
	using BaseMcmc::own_id;
	using BaseMcmc::grp_id;
	using BaseMcmc::grp_vec;
	using BaseMcmc::num_grp;
	using BaseMcmc::num_iter;
	using BaseMcmc::num_alpha;
	using BaseMcmc::num_lowerchol;
	using BaseMcmc::mcmc_step;
	using BaseMcmc::rng;
	using BaseMcmc::prior_alpha_prec;
	using BaseMcmc::prior_chol_prec;
	using BaseMcmc::coef_vec;
	using BaseMcmc::contem_coef;
	using BaseMcmc::updateCoefRecords;
	void updateCoefPrec() override {
		dl_mn_sparsity(group_lev, grp_vec, grp_id, global_lev, local_lev, shape, scl, coef_vec.head(num_alpha), rng);
		for (int j = 0; j < num_grp; j++) {
			coef_var = (grp_vec.array() == grp_id[j]).select(
				group_lev[j],
				coef_var
			);
		}
		dl_dir_griddy(dir_concen, grid_size, local_lev, global_lev, rng);
		dl_local_sparsity(local_lev, dir_concen, coef_vec.head(num_alpha).array() / coef_var.array(), rng);
		using is_group = std::integral_constant<bool, isGroup>;
		if (is_group::value) {
			global_lev = dl_global_sparsity(local_lev.array() * coef_var.array(), dir_concen, coef_vec.head(num_alpha), rng);
		}
		dl_latent(latent_local, global_lev * local_lev.array() * coef_var.array(), coef_vec.head(num_alpha), rng);
		prior_alpha_prec.head(num_alpha) = 1 / ((global_lev * local_lev.array() * coef_var.array()).square() * latent_local.array());
	}
	void updateImpactPrec() override {
		dl_dir_griddy(contem_dir_concen, grid_size, contem_local_lev, contem_global_lev[0], rng);
		dl_local_sparsity(contem_local_lev, contem_dir_concen, contem_coef, rng);
		contem_global_lev[0] = dl_global_sparsity(contem_local_lev, contem_dir_concen, contem_coef, rng);
		dl_latent(latent_contem_local, contem_local_lev, contem_coef, rng);
		prior_chol_prec = 1 / ((contem_global_lev[0] * contem_local_lev.array()).square() * latent_contem_local.array());
	}
	void updateRecords() override {
		updateCoefRecords();
		dl_record.assignRecords(mcmc_step, local_lev, global_lev);
	}

private:
	GlobalLocalRecords dl_record;
	double dir_concen, contem_dir_concen, shape, scl;
	int grid_size;
	Eigen::VectorXd local_lev;
	Eigen::VectorXd group_lev;
	double global_lev;
	Eigen::VectorXd latent_local;
	Eigen::VectorXd coef_var;
	Eigen::VectorXd contem_local_lev;
	Eigen::VectorXd contem_global_lev;
	Eigen::VectorXd latent_contem_local;
};

/**
 * @brief Generalized Double Pareto (GDP) prior
 * 
 * @tparam BaseMcmc `McmcReg` or `McmcSv`
 */
template <typename BaseMcmc = McmcReg>
class McmcGdp : public BaseMcmc {
public:
	McmcGdp(
		const GdpParams<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, RegParams, SvParams>::type>& params,
		const GdpInits<typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, LdltInits, SvInits>::type>& inits,
		unsigned int seed
	)
	: BaseMcmc(params, inits, seed),
		// ng_record(num_iter, num_alpha, num_grp),
		group_rate(inits._init_group_rate), group_rate_fac(Eigen::VectorXd::Ones(num_alpha)),
		coef_gamma_shape(inits._init_gamma_shape), coef_gamma_rate(inits._init_gamma_rate),
		shape_grid(params._grid_shape), rate_grid(params._grid_rate),
		local_lev(inits._init_local),
		contem_rate(inits._init_contem_rate),
		contem_gamma_shape(inits._init_contem_gamma_shape), contem_gamma_rate(inits._init_contem_gamma_rate),
		contem_fac(inits._init_contem_local) {
		// ng_record.assignRecords(0, local_lev, group_lev, global_lev);
	}
	virtual ~McmcGdp() = default;
	void appendRecords(LIST& list) override {
		// list["lambda_record"] = ng_record.local_record;
		// list["eta_record"] = ng_record.group_record;
		// list["tau_record"] = ng_record.global_record;
	}

protected:
	using BaseMcmc::own_id;
	using BaseMcmc::grp_id;
	using BaseMcmc::grp_vec;
	using BaseMcmc::num_grp;
	using BaseMcmc::num_iter;
	using BaseMcmc::num_alpha;
	using BaseMcmc::num_lowerchol;
	using BaseMcmc::mcmc_step;
	using BaseMcmc::rng;
	using BaseMcmc::prior_alpha_prec;
	using BaseMcmc::prior_chol_prec;
	using BaseMcmc::coef_vec;
	using BaseMcmc::contem_coef;
	using BaseMcmc::updateCoefRecords;
	void updateCoefPrec() override {
		gdp_shape_griddy(coef_gamma_shape, coef_gamma_rate, shape_grid, coef_vec.head(num_alpha), rng);
		gdp_rate_griddy(coef_gamma_rate, coef_gamma_shape, rate_grid, coef_vec.head(num_alpha), rng);
		gdp_exp_rate(group_rate, coef_gamma_shape, coef_gamma_rate, coef_vec.head(num_alpha), grp_vec, grp_id, rng);
		for (int j = 0; j < num_grp; ++j) {
			group_rate_fac = (grp_vec.array() == grp_id[j]).select(
				group_rate[j],
				group_rate_fac
			);
		}
		gdp_local_sparsity(local_lev, group_rate_fac, coef_vec.head(num_alpha), rng);
		prior_alpha_prec.head(num_alpha) = 1 / local_lev.array();
	}
	void updateImpactPrec() override {
		gdp_shape_griddy(contem_gamma_shape, contem_gamma_rate, shape_grid, contem_coef, rng);
		gdp_rate_griddy(contem_gamma_rate, contem_gamma_shape, rate_grid, contem_coef, rng);
		gdp_exp_rate(contem_rate, contem_gamma_shape, contem_gamma_rate, contem_coef, rng);
		gdp_local_sparsity(contem_fac, contem_rate, contem_coef, rng);
		prior_chol_prec = 1 / contem_fac.array();
	}
	void updateRecords() override {
		updateCoefRecords();
		// ng_record.assignRecords(mcmc_step, local_lev, group_lev, global_lev);
	}

private:
	// NgRecords ng_record;
	Eigen::VectorXd group_rate, group_rate_fac;
	double coef_gamma_shape, coef_gamma_rate;
	int shape_grid, rate_grid;
	Eigen::VectorXd local_lev;
	Eigen::VectorXd contem_rate;
	double contem_gamma_shape, contem_gamma_rate;
	Eigen::VectorXd contem_fac;
};

/**
 * @brief Function to initialize `McmcReg` or `McmcSv`
 * 
 * @tparam BaseMcmc `McmcReg` or `McmcSv`
 * @tparam isGroup If `true`, use group shrinkage parameter
 * @param num_chains Number of MCMC chains
 * @param num_iter MCMC iteration
 * @param x Design matrix in multivariate regression form
 * @param y Response matrix in multivariat regression form
 * @param param_reg Covariance configuration
 * @param param_prior Shrinkage prior configuration
 * @param param_intercept Constant term configuration
 * @param param_init MCMC initial values
 * @param prior_type Prior number to use
 * @param grp_id Minnesota group unique ids
 * @param own_id Own-lag id
 * @param cross_id Cross-lag id
 * @param grp_mat Minnesota group matrix
 * @param include_mean If `true`, include constant term
 * @param seed_chain Seed for each chain
 * @param num_design Number of samples
 * @return std::vector<std::unique_ptr<BaseMcmc>> 
 */
template <typename BaseMcmc = McmcReg, bool isGroup = true>
inline std::vector<std::unique_ptr<BaseMcmc>> initialize_mcmc(
	int num_chains, int num_iter, const Eigen::MatrixXd& x, const Eigen::MatrixXd& y,
	LIST& param_reg, LIST& param_prior, LIST& param_intercept, LIST_OF_LIST& param_init, int prior_type,
  const Eigen::VectorXi& grp_id, const Eigen::VectorXi& own_id, const Eigen::VectorXi& cross_id, const Eigen::MatrixXi& grp_mat,
  bool include_mean, Eigen::Ref<const Eigen::VectorXi> seed_chain, Optional<int> num_design = NULLOPT
) {
	using PARAMS = typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, RegParams, SvParams>::type;
	using INITS = typename std::conditional<std::is_same<BaseMcmc, McmcReg>::value, LdltInits, SvInits>::type;
	std::vector<std::unique_ptr<BaseMcmc>> mcmc_ptr(num_chains);
	switch (prior_type) {
		case 1: {
			MinnParams<PARAMS> minn_params(
				num_iter, x, y,
				param_reg,
				own_id, cross_id,
				grp_id, grp_mat,
				param_prior,
				param_intercept, include_mean
			);
			for (int i = 0; i < num_chains; ++i) {
				LIST init_spec = param_init[i];
				// INITS ldlt_inits(init_spec);
				INITS ldlt_inits = num_design ? INITS(init_spec, *num_design) : INITS(init_spec);
				mcmc_ptr[i] = std::make_unique<McmcMinn<BaseMcmc>>(minn_params, ldlt_inits, static_cast<unsigned int>(seed_chain[i]));
			}
			return mcmc_ptr;
		}
		case 2: {
			SsvsParams<PARAMS> ssvs_params(
				num_iter, x, y,
				param_reg,
				own_id, cross_id,
				grp_id, grp_mat,
				param_prior,
				param_intercept,
				include_mean
			);
			for (int i = 0; i < num_chains; ++i) {
				LIST init_spec = param_init[i];
				// SsvsInits<INITS> ssvs_inits(init_spec);
				SsvsInits<INITS> ssvs_inits = num_design ? SsvsInits<INITS>(init_spec, *num_design) : SsvsInits<INITS>(init_spec);
				mcmc_ptr[i] = std::make_unique<McmcSsvs<BaseMcmc>>(ssvs_params, ssvs_inits, static_cast<unsigned int>(seed_chain[i]));
			}
			return mcmc_ptr;
		}
		case 3: {
			HorseshoeParams<PARAMS> hs_params(
				num_iter, x, y,
				param_reg,
				own_id, cross_id,
				grp_id, grp_mat,
				param_intercept, include_mean
			);
			for (int i = 0; i < num_chains; ++i) {
				LIST init_spec = param_init[i];
				// HsInits<INITS> hs_inits(init_spec);
				HsInits<INITS> hs_inits = num_design ? HsInits<INITS>(init_spec, *num_design) : HsInits<INITS>(init_spec);
				mcmc_ptr[i] = std::make_unique<McmcHorseshoe<BaseMcmc, isGroup>>(hs_params, hs_inits, static_cast<unsigned int>(seed_chain[i]));
			}
			return mcmc_ptr;
		}
		case 4: {
			HierminnParams<PARAMS> minn_params(
				num_iter, x, y,
				param_reg,
				own_id, cross_id,
				grp_id, grp_mat,
				param_prior,
				param_intercept, include_mean
			);
			for (int i = 0; i < num_chains; ++i) {
				LIST init_spec = param_init[i];
				// HierminnInits<INITS> minn_inits(init_spec);
				HierminnInits<INITS> minn_inits = num_design ? HierminnInits<INITS>(init_spec, *num_design) : HierminnInits<INITS>(init_spec);
				mcmc_ptr[i] = std::make_unique<McmcHierminn<BaseMcmc>>(minn_params, minn_inits, static_cast<unsigned int>(seed_chain[i]));
			}
			return mcmc_ptr;
		}
		case 5: {
			NgParams<PARAMS> ng_params(
				num_iter, x, y,
				param_reg,
				own_id, cross_id,
				grp_id, grp_mat,
				param_prior,
				param_intercept,
				include_mean
			);
			for (int i = 0; i < num_chains; ++i) {
				LIST init_spec = param_init[i];
				// NgInits<INITS> ng_inits(init_spec);
				NgInits<INITS> ng_inits = num_design ? NgInits<INITS>(init_spec, *num_design) : NgInits<INITS>(init_spec);
				mcmc_ptr[i] = std::make_unique<McmcNg<BaseMcmc, isGroup>>(ng_params, ng_inits, static_cast<unsigned int>(seed_chain[i]));
			}
			return mcmc_ptr;
		}
		case 6: {
			DlParams<PARAMS> dl_params(
				num_iter, x, y,
				param_reg,
				own_id, cross_id,
				grp_id, grp_mat,
				param_prior,
				param_intercept,
				include_mean
			);
			for (int i = 0; i < num_chains; ++i) {
				LIST init_spec = param_init[i];
				// GlInits<INITS> dl_inits(init_spec);
				GlInits<INITS> dl_inits = num_design ? GlInits<INITS>(init_spec, *num_design) : GlInits<INITS>(init_spec);
				mcmc_ptr[i] = std::make_unique<McmcDl<BaseMcmc, isGroup>>(dl_params, dl_inits, static_cast<unsigned int>(seed_chain[i]));
			}
			return mcmc_ptr;
		}
		case 7: {
			GdpParams<PARAMS> gdp_params(
				num_iter, x, y,
				param_reg,
				own_id, cross_id,
				grp_id, grp_mat,
				param_prior,
				param_intercept,
				include_mean
			);
			for (int i = 0; i < num_chains; ++i) {
				LIST init_spec = param_init[i];
				GdpInits<INITS> gdp_inits = num_design ? GdpInits<INITS>(init_spec, *num_design) : GdpInits<INITS>(init_spec);
				mcmc_ptr[i] = std::make_unique<McmcGdp<BaseMcmc>>(gdp_params, gdp_inits, static_cast<unsigned int>(seed_chain[i]));
			}
			return mcmc_ptr;
		}
	}
	return mcmc_ptr;
}

/**
 * @brief Class that conducts MCMC using CTA
 * 
 * @tparam BaseMcmc `McmcReg` or `McmcSv`
 * @tparam isGroup If `true`, use group shrinkage parameter
 */
template <typename BaseMcmc = McmcReg, bool isGroup = true>
class CtaRun : public McmcRun {
public:
	CtaRun(
		int num_chains, int num_iter, int num_burn, int thin,
    const Eigen::MatrixXd& x, const Eigen::MatrixXd& y,
		LIST& param_cov, LIST& param_prior, LIST& param_intercept,
		LIST_OF_LIST& param_init, int prior_type,
    const Eigen::VectorXi& grp_id, const Eigen::VectorXi& own_id, const Eigen::VectorXi& cross_id, const Eigen::MatrixXi& grp_mat,
    bool include_mean, const Eigen::VectorXi& seed_chain, bool display_progress, int nthreads
	)
	: McmcRun(num_chains, num_iter, num_burn, thin, display_progress, nthreads) {
		auto temp_mcmc = initialize_mcmc<BaseMcmc, isGroup>(
			num_chains, num_iter - num_burn, x, y,
			param_cov, param_prior, param_intercept, param_init, prior_type,
			grp_id, own_id, cross_id, grp_mat,
			include_mean, seed_chain
		);
		for (int i = 0; i < num_chains; ++i) {
			mcmc_ptr[i] = std::move(temp_mcmc[i]);
		}
	}
	virtual ~CtaRun() = default;
};

} // namespace bvhar

#endif // BVHAR_BAYES_TRIANGULAR_DEPRECATED_TRIANGULAR_H
