#include "funs.h"
#include "gen_models.h"

// [[Rcpp::export(".single_ensm_predict")]]
Rcpp::NumericMatrix single_predict(Rcpp::List tree_draws,
                                   Rcpp::NumericMatrix tX_cont,
                                   Rcpp::IntegerMatrix tX_cat,
                                   int M,
                                   std::string family,
                                   std::string link,
                                   bool verbose, int print_every)
{
  set_str_conversion set_str; // for converting sets of integers into strings
  int n = 0;
  int p_cont = 0;
  int p_cat = 0;
  
  if(tX_cont.size() > 0) p_cont = tX_cont.rows();
  if(tX_cat.size() > 0) p_cat = tX_cat.rows();
  
  if(p_cont > 0 && p_cat == 0) n = tX_cont.cols();
  else if(p_cont == 0 && p_cat > 0) n = tX_cat.cols();
  else if(p_cont > 0 && p_cat > 0) n = tX_cont.cols();
  else Rcpp::stop("Need to supply tX_cont or tX_cat!");
  
  int p = p_cont + p_cat;
  data_info di;
  di.n = n;
  di.p_cont = p_cont;
  di.p_cat = p_cat;
  di.p = p;
  if(p_cont > 0) di.x_cont = tX_cont.begin();
  if(p_cat > 0) di.x_cat = tX_cat.begin();
  
  int nd = tree_draws.size();
  //Rcpp::CharacterVector first_tree_vec = tree_draws[0];
  //int M = first_tree_vec.size();

  std::vector<double> allfit(n);
  Rcpp::NumericMatrix pred_out(nd,n);
  
  if(family == "binomial" && link == "probit"){
    for(int iter = 0; iter < nd; ++iter){
      if(iter % print_every == 0){
        Rcpp::checkUserInterrupt();
        if(verbose){
          if(iter == 0 || iter == nd-1) Rcpp::Rcout << "  Iteration: " << iter+1 << " of " << nd << std::endl;
          else Rcpp::Rcout << "  Iteration: " << iter << " of " << nd << std::endl;
        }
      }
      
      Rcpp::CharacterVector tmp_string_vec = tree_draws[iter];
      if(tmp_string_vec.size() != M){
        Rcpp::Rcout << "iter = " << iter << " # tree strings = " << tmp_string_vec.size() << std::endl;
        Rcpp::stop("Unexpected number of tree strings!");
      } else{
        std::vector<tree> t_vec(M);
        for(int m = 0; m < M; ++m){
          // tmp_string_vec is an Rcpp::CharacterVector
          // let's extract a single element from the CharacterVector and turn it into a std::string
          // that can be passed to read_tree
          std::string tmp_string = Rcpp::as<std::string>(tmp_string_vec[m]); // convert content of the
          read_tree(t_vec[m], tmp_string, set_str);
        } // closes loop populating vector of trees
        fit_ensemble(allfit, t_vec, di);
        for(int i = 0; i < n; ++i) pred_out(iter,i) = R::pnorm(allfit[i], 0.0, 1.0, true, false);
      } // closes if/else checking that we have string for every tree
    } // closes loop over tree samples
  } else if(family == "binomial" && link == "logit"){
    GenModel* gmp = new Logit(); // generalized model pointer
    for(int iter = 0; iter < nd; ++iter){
      if(iter % print_every == 0){
        Rcpp::checkUserInterrupt();
        if(verbose){
          if(iter == 0 || iter == nd-1) Rcpp::Rcout << "  Iteration: " << iter+1 << " of " << nd << std::endl;
          else Rcpp::Rcout << "  Iteration: " << iter << " of " << nd << std::endl;
        }
      }

      Rcpp::CharacterVector tmp_string_vec = tree_draws[iter];
      if(tmp_string_vec.size() != M){
        Rcpp::Rcout << "iter = " << iter << " # tree strings = " << tmp_string_vec.size() << std::endl;
        Rcpp::stop("Unexpected number of tree strings!");
      } else{
        std::vector<tree> t_vec(M);
        for(int m = 0; m < M; ++m){
          // tmp_string_vec is an Rcpp::CharacterVector
          // let's extract a single element from the CharacterVector and turn it into a std::string
          // that can be passed to read_tree
          std::string tmp_string = Rcpp::as<std::string>(tmp_string_vec[m]); // convert content of the
          read_tree(t_vec[m], tmp_string, set_str);
        } // closes loop populating vector of trees
        fit_ensemble(allfit, t_vec, di);
        for(int i = 0; i < n; ++i) pred_out(iter,i) = gmp->inv_link(allfit[i]);
      } // closes if/else checking that we have string for every tree
    } // closes loop over tree samples
  } else if (family == "poisson" && link == "log"){
    GenModel* gmp = new Poisson(); // generalized model pointer
    for(int iter = 0; iter < nd; ++iter){
      if(iter % print_every == 0){
        Rcpp::checkUserInterrupt();
        if(verbose){
          if(iter == 0 || iter == nd-1) Rcpp::Rcout << "  Iteration: " << iter+1 << " of " << nd << std::endl;
          else Rcpp::Rcout << "  Iteration: " << iter << " of " << nd << std::endl;
        }
      }

      Rcpp::CharacterVector tmp_string_vec = tree_draws[iter];
      if(tmp_string_vec.size() != M){
        Rcpp::Rcout << "iter = " << iter << " # tree strings = " << tmp_string_vec.size() << std::endl;
        Rcpp::stop("Unexpected number of tree strings!");
      } else{
        std::vector<tree> t_vec(M);
        for(int m = 0; m < M; ++m){
          // tmp_string_vec is an Rcpp::CharacterVector
          // let's extract a single element from the CharacterVector and turn it into a std::string
          // that can be passed to read_tree
          std::string tmp_string = Rcpp::as<std::string>(tmp_string_vec[m]); // convert content of the
          read_tree(t_vec[m], tmp_string, set_str);
        } // closes loop populating vector of trees
        fit_ensemble(allfit, t_vec, di);
        for(int i = 0; i < n; ++i) pred_out(iter,i) = gmp->inv_link(allfit[i]);
      } // closes if/else checking that we have string for every tree
    } // closes loop over tree samples
  } else if(family == "gaussian" && link == "identity"){
    for(int iter = 0; iter < nd; ++iter){
      if(iter % print_every == 0){
        Rcpp::checkUserInterrupt();
        if(verbose){
          if(iter == 0 || iter == nd-1) Rcpp::Rcout << "  Iteration: " << iter+1 << " of " << nd << std::endl;
          else Rcpp::Rcout << "  Iteration: " << iter << " of " << nd << std::endl;
        }
      }
      Rcpp::CharacterVector tmp_string_vec = tree_draws[iter];
      if(tmp_string_vec.size() != M){
        Rcpp::Rcout << "iter = " << iter << " # tree strings = " << tmp_string_vec.size() << std::endl;
        Rcpp::stop("Unexpected number of tree strings!");
      } else{
        std::vector<tree> t_vec(M);
        for(int m = 0; m < M; ++m){
          // tmp_string_vec is an Rcpp::CharacterVector
          // let's extract a single element from the CharacterVector and turn it into a std::string
          // that can be passed to read_tree
          std::string tmp_string = Rcpp::as<std::string>(tmp_string_vec[m]); // convert content of the
          read_tree(t_vec[m], tmp_string, set_str);
        } // closes loop populating vector of trees
        fit_ensemble(allfit, t_vec, di);
        for(int i = 0; i < n; ++i) pred_out(iter,i) = allfit[i];
      } // closes if/else checking that we have string for every tree
    } // closes loop over tree samples
  } // closes if/else checking whether we're doing probit
  else{
    Rcpp::Rcout << "Model fit object cointains family = " << family << " and link = " << link << std::endl;
    Rcpp::stop("Unsupported family and link combination!");
  }
  return pred_out;
}



// [[Rcpp::export(".multi_ensm_predict")]]
Rcpp::List multi_predict(Rcpp::List tree_draws,
                         Rcpp::NumericMatrix tZ,
                         Rcpp::NumericMatrix tX_cont,
                         Rcpp::IntegerMatrix tX_cat,
                         Rcpp::IntegerVector M_vec,
                         std::string family,
                         std::string link,
                         bool heteroskedastic,
                         bool verbose, int print_every)
{
  set_str_conversion set_str; // for converting sets of integers into strings
  int n = 0;
  int p_cont = 0;
  int p_cat = 0;
  int R = tZ.rows(); // how many ensembles

  if(tX_cont.size() > 0) p_cont = tX_cont.rows();
  if(tX_cat.size() > 0) p_cat = tX_cat.rows();
  
  if(p_cont > 0 && p_cat == 0) n = tX_cont.cols();
  else if(p_cont == 0 && p_cat > 0) n = tX_cat.cols();
  else if(p_cont > 0 && p_cat > 0) n = tX_cont.cols();
  else Rcpp::stop("Need to supply tX_cont or tX_cat!");
  
  int p = p_cont + p_cat;
  data_info di;
  di.n = n;
  di.p_cont = p_cont;
  di.p_cat = p_cat;
  di.p = p;
  di.R = R;
  di.z = tZ.begin();
  if(p_cont > 0) di.x_cont = tX_cont.begin();
  if(p_cat > 0) di.x_cat = tX_cat.begin();
  
  int nd = tree_draws.size();
  std::vector<double> allfit(n);
  arma::cube raw_beta_out = arma::zeros<arma::cube>(nd, n, R);
  arma::mat pred_out = arma::zeros<arma::mat>(nd,n);
  arma::mat sigma_out = arma::zeros<arma::mat>(1,1);
  if(heteroskedastic) sigma_out.resize(nd,n);
  
  if(family == "binomial" && link == "logit"){
    GenModel* gmp = new Logit(); // generalized model pointer
    for(int iter = 0; iter < nd; ++iter){
      if(iter % print_every == 0){
        Rcpp::checkUserInterrupt();
        if(verbose){
          if(iter == 0 || iter == nd-1) Rcpp::Rcout << "  Iteration: " << iter+1 << " of " << nd << std::endl;
          else Rcpp::Rcout << "  Iteration: " << iter << " of " << nd << std::endl;
        }
      }
      Rcpp::List tmp_draw = tree_draws[iter];
      if(tmp_draw.size() != R){
        Rcpp::Rcout << "iter = " << iter << " found " << tmp_draw.size() << " ensembles." << std::endl;
        Rcpp::Rcout << " Expected " << R << " ensembles" << std::endl;
        Rcpp::stop("Unexpected number of tree ensembles detected");
      } else{
        std::vector<double> tmp_lambda(n);
        for(int r = 0; r < R; ++r){
          Rcpp::CharacterVector tmp_string_vec = tmp_draw[r];
          if(tmp_string_vec.size() != M_vec[r]){
            Rcpp::Rcout << "iter = " << iter << " # tree strings = " << tmp_string_vec.size() << std::endl;
            Rcpp::stop("Unexpected number of tree strings!");
          } else{
            std::vector<tree> t_vec(M_vec[r]);
            for(int m = 0; m < M_vec[r]; ++m){
              std::string tmp_string = Rcpp::as<std::string>(tmp_string_vec[m]);
              read_tree(t_vec[m], tmp_string, set_str);
            } // closes loop populating vector of trees
            fit_ensemble(allfit, t_vec, di);
            for(int i = 0; i < n; ++i){
              raw_beta_out(iter, i, r) = allfit[i];
              tmp_lambda[i] += di.z[r + i*R] * allfit[i];
            }
          } // closes if/else checking that we have a string for every tree
        } // closes loop over ensembles
        for(int i = 0; i < n; ++i) pred_out(iter,i) = gmp->inv_link(tmp_lambda[i]);
      } // closes if/else checking that we have enough tree ensemble draws
    } // closes loop over tree ensemble samples
  } else if(family == "poisson" && link == "log"){
    GenModel* gmp = new Poisson(); // generalized model pointer
    for(int iter = 0; iter < nd; ++iter){
      if(iter % print_every == 0){
        Rcpp::checkUserInterrupt();
        if(verbose){
          if(iter == 0 || iter == nd-1) Rcpp::Rcout << "  Iteration: " << iter+1 << " of " << nd << std::endl;
          else Rcpp::Rcout << "  Iteration: " << iter << " of " << nd << std::endl;
        }
      }
      Rcpp::List tmp_draw = tree_draws[iter];
      if(tmp_draw.size() != R){
        Rcpp::Rcout << "iter = " << iter << " found " << tmp_draw.size() << " ensembles." << std::endl;
        Rcpp::Rcout << " Expected " << R << " ensembles" << std::endl;
        Rcpp::stop("Unexpected number of tree ensembles detected");
      } else{
        std::vector<double> tmp_lambda(n);
        for(int r = 0; r < R; ++r){
          Rcpp::CharacterVector tmp_string_vec = tmp_draw[r];
          if(tmp_string_vec.size() != M_vec[r]){
            Rcpp::Rcout << "iter = " << iter << " # tree strings = " << tmp_string_vec.size() << std::endl;
            Rcpp::stop("Unexpected number of tree strings!");
          } else{
            std::vector<tree> t_vec(M_vec[r]);
            for(int m = 0; m < M_vec[r]; ++m){
              std::string tmp_string = Rcpp::as<std::string>(tmp_string_vec[m]);
              read_tree(t_vec[m], tmp_string, set_str);
            } // closes loop populating vector of trees
            fit_ensemble(allfit, t_vec, di);
            for(int i = 0; i < n; ++i){
              tmp_lambda[i] += allfit[i];
            }
          } // closes if/else checking that we have a string for every tree
        } // closes loop over ensembles
        for(int i = 0; i < n; ++i) pred_out(iter,i) = gmp->inv_link(tmp_lambda[i]);
      } // closes if/else checking that we have enough tree ensemble draws
    } // closes loop over tree ensemble samples
  } else if(family == "gaussian" && link == "identity"){
    for(int iter = 0; iter < nd; ++iter){
      if(iter % print_every == 0){
        Rcpp::checkUserInterrupt();
        if(verbose){
          if(iter == 0 || iter == nd-1) Rcpp::Rcout << "  Iteration: " << iter+1 << " of " << nd << std::endl;
          else Rcpp::Rcout << "  Iteration: " << iter << " of " << nd << std::endl;
        }
      }
      Rcpp::List tmp_draw = tree_draws[iter];
      if(tmp_draw.size() != R && (tmp_draw.size() != R + 1 && heteroskedastic)){
        Rcpp::Rcout << "iter = " << iter << " found " << tmp_draw.size() << " ensembles." << std::endl;
        Rcpp::Rcout << " Expected " << R << " ensembles" << std::endl;
        Rcpp::stop("Unexpected number of tree ensembles detected");
      } else{
        for(int r = 0; r < R; ++r){
          Rcpp::CharacterVector tmp_string_vec = tmp_draw[r];
          if(tmp_string_vec.size() != M_vec[r]){
            Rcpp::Rcout << "iter = " << iter << " # tree strings = " << tmp_string_vec.size() << std::endl;
            Rcpp::stop("Unexpected number of tree strings!");
          } else{
            std::vector<tree> t_vec(M_vec[r]);
            for(int m = 0; m < M_vec[r]; ++m){
              std::string tmp_string = Rcpp::as<std::string>(tmp_string_vec[m]);
              read_tree(t_vec[m], tmp_string, set_str);
            } // closes loop populating vector of trees
            fit_ensemble(allfit, t_vec, di);
            for(int i = 0; i < n; ++i){
              raw_beta_out(iter, i, r) = allfit[i];
              pred_out(iter,i) += di.z[r + i*R] * allfit[i];
            }
          } // closes if/else checking that we have a string for every tree
        } // closes loop over ensembles
        if(heteroskedastic){
          // fit the sigma ensemble
          GenModel* gmp = new Sigma(); // generalized model pointer
          Rcpp::CharacterVector tmp_string_vec = tmp_draw[R];
          if(tmp_string_vec.size() != M_vec[R]){
            Rcpp::Rcout << "iter = " << iter << " # tree strings = " << tmp_string_vec.size() << std::endl;
            Rcpp::stop("Unexpected number of tree strings!");
          } else{
            std::vector<tree> t_vec(M_vec[R]);
            for(int m = 0; m < M_vec[R]; ++m){
              std::string tmp_string = Rcpp::as<std::string>(tmp_string_vec[m]);
              read_tree(t_vec[m], tmp_string, set_str);
            } // closes loop populating vector of trees
            fit_ensemble(allfit, t_vec, di);
            for(int i = 0; i < n; ++i) sigma_out(iter,i) = sqrt(gmp->inv_link(allfit[i]));
          } // closes if/else checking that we have a string for every tree
        }
      } // closes if/else checking that we have enough tree ensemble draws
    } // closes loop over tree ensemble samples
  } else{
    Rcpp::Rcout << "Model fit object cointains family = " << family << " and link = " << link << std::endl;
    Rcpp::stop("Unsupported family and link combination!");
  }

  Rcpp::List results;
  results["fit"] = pred_out;
  results["raw_beta"] = raw_beta_out;
  if(heteroskedastic) results["sigma"] = sigma_out;
  return results;
}

