#-----------------------------------------------#
# Title: 03_other
#-----------------------------------------------#
# Project: Ordered Logit Empirical Applications
# Author: Carlos Gonzalez
# Date: November 2025
#-----------------------------------------------#
# Description:
# - Includes additional robustness checks and 
# complementary results that did not make it into
# the final manuscript
#----------------------------------------------#

#--------------------#
# RISK
#--------------------#

# 1. Seed Selection in Risk Nonparametric Analysis
#-------------------------------------------------#

# We observed some variation in the number of non-rationalizable quantiles across seeds
# To provide a representative view of the number of non-rationalizable quantiles we iterated the Bellman routine across different seeds
# And selected a seed which reports a number of violations close to the 10-fold average (this was the seed = 2, which we fixed for the manuscript)

# Bellman routine for single seed
seed_iteration_bellman = function(seed){
  
  # Generate Choice data
  n20_data = choice_data_gen(menu_data, 20, seed = seed)
  n80_data = choice_data_gen(menu_data, 80, seed = seed)
  n160_data = choice_data_gen(menu_data, 160, seed = seed)
  n480_data = choice_data_gen(menu_data, 480, seed = seed)
  
  # Apply algorithm across datasets
  viols_n20_data = lapply_routine_bell(n20_data, quant_selection = quantile_vector)
  viols_n80_data = lapply_routine_bell(n80_data, quant_selection = quantile_vector)
  viols_n160_data = lapply_routine_bell(n160_data, quant_selection = quantile_vector)
  viols_n480_data = lapply_routine_bell(n480_data, quant_selection = quantile_vector)
  
  # Bind data together
  viols = rbind(viols_n20_data, viols_n80_data, 
                viols_n160_data, viols_n480_data) |>
          mutate(iter = seed)
  
}

# Sapply seed_iteration_bellman 10 times and output
seed_iteration_wrap = function(){
  
  # Iterate across different seeds
  viols_iter = sapply(1:10, seed_iteration_bellman)
  
  # Initialize output tibble
  store_viols = tibble(quantile = NA, iter = NA, dataset = NA)
  
  # Output tibble formatting
  for (i in 1:ncol(viols_iter)){
    viol_inter = do.call(cbind, viols_iter[,i])
    viol_inter = viol_inter |> as_tibble() |> 
                 mutate(dataset = ifelse(quantile < lag(quantile), 1, 0),
                        dataset = ifelse(is.na(dataset), 1, dataset),
                        dataset = cumsum(dataset))
                  store_viols = rbind(store_viols, viol_inter)
  }
  
  store_viols = store_viols[-1,]
  
  # Ungrouped table by dataset
  dis_table =
    store_viols |> mutate(dataset = case_when(dataset == 1 ~ "020 obs",
                                              dataset == 2 ~ "080 obs",
                                              dataset == 3 ~ "160 obs",
                                              dataset == 4 ~ "480 obs",)) |>
    group_by(iter, dataset) |> summarize(n_errors = n()) 
  
  # Summary table
  sum_table = dis_table |> group_by(dataset) |> 
              summarize(mean_errors = mean(n_errors),
                        min_errors = min(n_errors),
                        max_errors = max(n_errors))
  
  list(dis_table = dis_table,
       sum_table = sum_table)
  
}

###########################
# WARNING LONG TIME TO RUN!
###########################
# seed_iteration_wrap() # Uncomment to run

# 2. Weight matrix and graph for reduced set of menus
#----------------------------------------------------#
# Plot a weight matrix for a reduced set of weights
weight_matrix_n160_mini = gen_weight_matrix(n160_data, 
                                            menu_selection = c(3, 8, 11, 16, 20),
                                            plot = T, weights = T)

# 3. CLA property as regression
#------------------------------#

# The CLA property as displayed in Figure 4, can be used to recover an estimator of \tau and \sigma
# via linear regression
cla_regression = function(n, df){
  
  # df: an object (or merge of objects) created via cla_data()
  # n: n_obs to be considered (one dataset at a time)
  
  # Select dataset
  df_n = df |> filter(n_obs == n)
  
  # Run regression
  model = lm(sum_log_odds ~ sum_t, df_n)
  
  # Recover estimates from regression output
  sigma = 1 / model$coefficients[2]
  tau = - model$coefficients[1]*sigma / 2
  
  # Output
  round(c("tau" = tau, "sigma" = sigma), 3)

}

# 
cla_regression_sapply = function(df_list){
  
  # df_list: a list of objects created via cla_data_alt()
  
  # Bind datasets together
  df = do.call(rbind, df_list)
  
  # Regression
  sapply(unique(df$n_obs), cla_regression, df = df) 
}

# Run regression for each dataset
cla_regression_sapply(list(cla_20, cla_80, cla_160, cla_480))

#--------------------#
# ALTRUISM
#--------------------#


# 4. A valid non-parametric test
#------------------------------#

# Function to identify WARP crossings given menu and choice data
# (first part of altr_warp())
crossings_matrix = function(menu_df, big_df){
  
  # menu_df: An object created via data_gen_alt()
  # big_df: An object created via choice_data_alt_gen()
  
  # Number of menus
  J = nrow(menu_df)
  
  # Find crossing points of budget line to identify x_{j j'}
  menu_df = menu_df |> mutate(y_point = 1 / pi_2,
                              x_point = 1 / pi_1)
  
  # Initialize store matrices
  inter_matrix = matrix(NA, J, J) # Intersection matrix
  viol_matrix = matrix(NA, J, J)  # Violations matrix
  
  # Find x_{j, j'} (NA when budgets lines do not intersect)
  for (i in 1:J){
    for (j in i:J){
      if((menu_df$y_point[i] > menu_df$y_point[j] & 
          menu_df$x_point[i] > menu_df$x_point[j]) |
         (menu_df$y_point[i] < menu_df$y_point[j] &
          menu_df$x_point[i] < menu_df$x_point[j]) | (i == j)) # Also do not compare a menu to itself
      {
      } else{
        
        # Coordinates of x_{j j'}
        c_1_cross = (menu_df$y_point[i] - menu_df$y_point[j]) / (1 / menu_df$pi_j[i] - 1 / menu_df$pi_j[j])
        c_2_cross = menu_df$y_point[i] - 1 / menu_df$pi_j[i] * c_1_cross
        
        # Find v_j(x_{j j'})
        v_j = 1 - exp(-c_2_cross / c_1_cross)
        
        if(menu_df$pi_j[i] > menu_df$pi_j[j]){ # ONLY a potential violation when F_j < F_j' given pi_j > pi_j'
          inter_matrix[i,j] = v_j              # Hence, we only assign v(x_{j j'}) when pi_j > pi_j'
        } else{inter_matrix[j,i] = v_j} # One must be bigger than the other provided they intersect
      }
    }
  }
  
  inter_matrix
}

# Function to identify violations of WARP for a single coordinate
# (second part of altr_warp())
test_warp_fast = function(j, j_prime, inter_matrix, df){
  
  # j: row_coordinate
  # j_prime: column coordinate
  # inter_matrix: a matrix created via crossings_matrix()
  # df: an object created via choice_data_alt_gen()
  
  # Check if F_j(v(x_{j,j'})) < F_j'(v(x_{j,j'})) [given pi_j < pi_j']
  if (!is.na(inter_matrix[j,j_prime])){ # Only check comparable menus
    
    # F_j(v(x_{j,j'}))
    F_j = df[(df$menu == j & df$v_j <= inter_matrix[j,j_prime]), ]$quantile |>
      max()
    # F_j'(v(x_{j,j'}))
    F_j_prime = df[(df$menu == j_prime & df$v_j <= inter_matrix[j,j_prime]), ]$quantile |>
      max()
    
    F_j - F_j_prime
  } else{NA}
  
}

# Permutes v_j across observations and mapply test_warp_fast() in the permuted dataset 
test_warp_mapply = function(inter_matrix, df, permute = F){
  
  # inter_matrix: a matrix created via crossings_matrix()
  # df: an object created via choice_data_alt_gen()
  # permute: logical indicating if data is to be permuted prior WARP test
  
  # Locals
  J = length(unique(df$menu))
  N = unique(df$n_obs)
  
  # Permute original df
  if (permute){
    permuted_df = 
      tibble(menu = rep(1:J, each = N),
             v_j = sample(df$v_j, length(df$v_j))) |>
      group_by(menu) |>
      mutate(quantile = signif(row_number(v_j) / N)) |> ungroup()
  } else{permuted_df = df |> mutate(quantile = quantile / 100)}
  
  # Mapply in the permuted df
  grid = expand.grid(j = 1:J, j_prime = 1:J)
  T_stat = mapply(test_warp_fast, j = grid$j, j_prime = grid$j_prime,
                  MoreArgs = list(inter_matrix = inter_matrix,
                                  df = permuted_df)) |> min(na.rm = T)
  # min F_j - F_{j_prime}
  T_stat
}

# Bootstrap test_warp_mapply()
test_warp_wrap = function(menu_df, big_df, B){
  
  # menu_df: An object created via data_gen_alt()
  # big_df: An object created via choice_data_alt_gen()
  # B: number of bootstrap samples
  
  # Matrix of crossings (constant across menus)
  inter_matrix = crossings_matrix(menu_df = menu_df, 
                                  big_df = big_df)
  
  # True T_stat
  true_T_stat = test_warp_mapply(inter_matrix = inter_matrix,
                                 df = big_df, permute = F)
  
  # Bootstrap permuted T_stats
  T_stats = replicate(B, test_warp_mapply(inter_matrix = inter_matrix,
                                          df = big_df, permute = T))
  
  list(true_T_stat = true_T_stat,
       T_stats = T_stats,
       N_obs = unique(big_df$n_obs))
  
}

# Analysis and plot
analysis_test_warp = function(test_warp_results){
  
  # test_warp_results: an object created via test_warp_wrap()
  
  # Recover locals from list object
  T_stats = test_warp_results$T_stats
  true_T_stat = test_warp_results$true_T_stat
  N_obs = test_warp_results$N_obs
  
  # Side information
  print(paste("The T_stat for N =", N_obs, "is", true_T_stat))
  
  # Calculate the 5th percentile
  p5 = quantile(T_stats, 0.05)
  print(paste("The p5 for N =", N_obs, "is", p5))
  
  # Calculate the p-value
  sorted_T_stats = sort(c(T_stats, true_T_stat))
  p_value = min(which(sorted_T_stats == true_T_stat)) / length(sorted_T_stats)
  print(paste("The p-value for N =", N_obs, "is", p_value))
  
  
  # Plot
  df = tibble(T_stats = T_stats)
  
  test_plot =
    df |> ggplot(aes(x = T_stats)) +
          stat_ecdf(geom = "step", color = "blue", linewidth = 1) +
          geom_vline(xintercept = p5, linetype = "dashed", color = "black") +
          geom_vline(xintercept = true_T_stat, linetype = "solid", color = "red") +
          labs(title = paste("N = ", N_obs),
               x = "T_N^b", y = "ECDF(T_N^b)") +
          theme_minimal()
  
  plot(test_plot)
  ggsave(paste0("output_final/nonparamtest_altruism", N_obs,".png"),
         test_plot, height = 5 , width = 5)
  
  
}

# Run test for each dataset
test_warp_results20 = test_warp_wrap(menu_alt_data, n20_alt_data, B = 1000)
test_warp_results80 = test_warp_wrap(menu_alt_data, n80_alt_data, B = 1000)
test_warp_results160 = test_warp_wrap(menu_alt_data, n160_alt_data, B = 1000)
test_warp_results480 = test_warp_wrap(menu_alt_data, n480_alt_data, B = 1000)

# Plot results and recover key statistics
analysis_test_warp(test_warp_results20)
analysis_test_warp(test_warp_results80)
analysis_test_warp(test_warp_results160)
analysis_test_warp(test_warp_results480)


# 5. Distribution of alpha
#-------------------------#
# Boxplot for inputted alpha distribution
plot_alpha = function(df_list, true_alpha){
  
  # df_list: A list of one or more objects created via recover_alpha()
  
  # Bind datasets together
  df = do.call(rbind, df_list) |> filter(alpha_inputed < 3, alpha_inputed > -3) # For display pursposes
  
  alpha_boxplot = 
    df |> ggplot() +
    geom_boxplot(aes(x = as.factor(n_obs), y = alpha_inputed)) +
    geom_hline(aes(yintercept = true_alpha), 
               linetype = "dashed", color = "red", linewidth = 1) +
    labs(#title = TeX(r"(Boxplot of $\alpha$ by Number of Observations)"), 
      x = "Number of Observations", 
      y = TeX(r"($\alpha$)"),
      linetype = "", color = "") +
    theme_minimal() +
    theme(axis.text.x = element_text(size = 16),
          axis.title.y = element_text(size = 16),
          axis.title.x = element_text(size = 16, 
                                      margin = margin(t = 10, r = 0, b = 0, l = 0)))
  
  plot(alpha_boxplot)
  ggsave("output_final/altruism_step4_alpha_boxplot.png", alpha_boxplot, height = 5 , width = 5)
  
  # Output median alpha
  df |> group_by(n_obs) |> summarize(alpha = median(alpha_inputed))
  
}

plot_alpha(list(alpha20, alpha80, alpha160, alpha480), 
           true_alpha = .4)

# 6. CLA Regression
#------------------#
# Run regression for each dataset
cla_regression_sapply(list(cla_20_alt, cla_80_alt, cla_160_alt, cla_480_alt))