The world’s leading publication for data science, AI, and ML professionals.

Bank Customer Churn with Tidymodels - Part 2 Decision Threshold Analysis

Decision Threshold and Scenario Analysis with Tidymodels

Welcome back for Part 2 of our exploration of the Bank Customer Churn problem now examining decision threshold analysis.

In Part 1 we developed a candidate workflow that achieved strong results across a variety of classification metrics, and discussed the impacts of different up- and downsampling techniques to manage the 4:1 class imbalance in the bank customer churn dataset (taken from https://www.kaggle.com/shivan118/churn-modeling-dataset (License CC0: Public Domain)).

Part 1 is available here, and I recommend reading prior to better understand context and model development.

This article will focus on explaining the consequences of model output to a non-technical audience. We will complete decision threshold analysis and generate a cost function and present two scenarios – the threshold that either best differentiates between classes OR has the lowest cost. We will use the probably package from tidymodels to complete this analysis. Our aim is to identify and present a costed decision threshold analysis to our stakeholders weighing the cost of customer churn and intervention strategies.

Load Packages

library(tidymodels) #ML Metapackage
library(probably) #Threshold Analysis
library(forcats) #Working with factors
library(patchwork) #ggplot grids
tidymodels_prefer()
options(yardstick.event_first = FALSE)
class_metric <- metric_set(accuracy, f_meas, j_index, kap, precision, sensitivity, specificity, mcc)

Finalize and Fit Model

Picking up from where we left off in Part 1, we’ll identify the best performing workflow and finalise the model.

best_result <- wf_sample_exp %>% 
  extract_workflow_set_result("UPSAMPLE_Boosted_Trees") %>% 
  select_best(metric = 'j_index')
xgb_fit <- wf_sample_exp %>% 
  extract_workflow("UPSAMPLE_Boosted_Trees") %>% 
  finalize_workflow(best_result) %>%
  fit(training(cust_split))

workflowsets::extract_workflow_set_result takes a named workflow to generate a tibble of all trialled hyperparameter combinations and will select the best based on the the metric called in select_best(). workflows::extract_workflow() again takes a named workflow and then updates the hyperparameters based on what is stored in best_result. The resulting workflow is then fit to the training data.

Threshold Analysis

xgb_fit %>% 
  predict(new_data = testing(cust_split), type = 'prob') %>% 
  bind_cols(testing(cust_split)) %>% 
  ggplot(aes(x=.pred_1, fill = Exited, color = Exited)) +
    geom_histogram(bins = 40, alpha = 0.5) +
    theme_minimal() +
    scale_fill_viridis_d(aesthetics = c('color', 'fill'), end = 0.8) +
    labs(title = 'Distribution of Prediction Probabilities by Exited Status', x = 'Probability Prediction', y = 'Count')

By predicting probabilities, we can visualise the respective distribution of churn status. At the default threshold of 0.5, predictions greater then are predicted as churning and vice versa. Threshold analysis identifies an optimal threshold given desired metrics. The probably package enables us to carry out such analysis. probably::threshold_perf() takes the Truth, Estimate and sequentially varies the threshold and calculates sensitivity, specificity and J-Index for each threshold.

#Generate Probability Prediction Dataset
xgb_pred <- xgb_fit %>% 
  predict(new_data = testing(cust_split), type = 'prob') %>% 
  bind_cols(testing(cust_split)) %>% 
  select(Exited, .pred_0, .pred_1)
#Generate Sequential Threshold Tibble
threshold_data <- xgb_pred %>% 
  threshold_perf(truth = Exited, Estimate = .pred_1, thresholds = seq(0.1, 1, by = 0.01))
#Identify Threshold for Maximum J-Index
max_j_index <- threshold_data %>% 
  filter(.metric == 'j_index') %>% 
  filter(.estimate == max(.estimate)) %>% 
  select(.threshold) %>% 
  as_vector()
#Visualise Threshold Analysis
threshold_data %>% 
  filter(.metric != 'distance') %>% 
  ggplot(aes(x=.threshold, y=.estimate, color = .metric)) +
   geom_line(size = 2) +
   geom_vline(xintercept = max_j_index, lty = 5, alpha = .6) +
   theme_minimal() +
   scale_colour_viridis_d(end = 0.8) +
   labs(x='Threshold', 
        y='Estimate', 
        title = 'Balancing Performance by Varying Threshold',
        subtitle = 'Verticle Line = Max J-Index',
        color = 'Metric')

Our analysis indicates that the threshold with the highest J-index is 0.47.

To extend this analysis we can hypothetically tune the threshold to any available metric as below. I couldn’t get this to work using a yardstick::metric_set() and probably::threshold_perf() and include pr_auc and roc_auc, so had to use a purrr-fect tricks.

#Threshold Analysis by Several Classification Metrics
list(pred_df = list(pred_df = xgb_pred), 
     threshold = list(threshold = seq(0.03, 0.99, by = 0.01))) %>% 
cross_df() %>% 
  mutate(pred_data = map2(pred_df, threshold, ~mutate(.x, .prob_class = as_factor(if_else(.pred_1 < .y , 0, 1)))),
         pred_data = map2(pred_data,  threshold, ~mutate(.x, .prob_metric = if_else(.pred_1 < .y , 0, 1))),
         pred_metric = map(pred_data, ~class_metric(.x, truth = Exited, estimate = .prob_class)),
         roc_auc = map(pred_data, ~roc_auc(.x, truth = Exited, estimate = .prob_metric)),
         pr_auc = map(pred_data, ~pr_auc(.x, truth = Exited, estimate = .prob_metric)),
         pred_metric = pmap(list(pred_metric, roc_auc, pr_auc),~bind_rows(..1,..2,..3))) %>%
  select(pred_metric, threshold) %>%                                                            
  unnest(pred_metric) %>%                                                                        
  ggplot(aes(x=threshold, y=.estimate, color = .metric)) +
    geom_line(size = 1) +
    scale_color_viridis_d() +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45)) +
    facet_wrap(~.metric, nrow = 2) +
    labs(title = 'Impact of Decision Threshold on Classification Metrics', x= 'Threshold', y = 'Estimate', color = 'Metric')

The above aside, we have all we need from the output of probably::threshold_perf(). Sensitivity and Specificity enable us to calculate the FPR and FNR for a particular threshold, and hence a cost function.

Cost Function

Here we enter a hypothetical situation not captured within the original dataset. In order to calculate the Total Cost of FN and FP as below we need an approximation for Customer Lifetime Value (CLV) and a cost of intervention. Now let’s suppose the cost of intervention is $99 or the value of an annual fee for a standard account, if we suspect a customer is churning, customer service will honour a years discount for an annual fee.

For annualised CLV we take the sum of account fees and credit card fees. We assume that each product has a $99 annual fee except credit cards which have a $149 fee. So as below we calculate an approximate CLV per customer and then take the median value as CLV. That so happens to be $149.

N.B. I know this is a very basic view on what constitutes CLV, and no doubt a real life scenario would calculate the amount of interest a customer has paid on credit or loans amongst other avenues of bank customer revenue.

train %>% 
  mutate(CreditCardFees = HasCrCard*149,
         AccountFees = (NumOfProducts - HasCrCard)*99,
         CLV = CreditCardFees + AccountFees) %>% 
  ggplot(aes(CLV)) +
   geom_histogram() +
   theme_minimal() +
   labs(title = 'Distribution of Annual CLV', x='CLV', y = 'Count')

In applying this logic to our threshold_data tibble, we can visualise these functions.

threshold_data %>% 
  filter(.metric %in% c('sens', 'spec')) %>% 
  pivot_wider(id_cols = .threshold, values_from = .estimate, names_from = .metric) %>% 
  mutate(Cost_FN = ((1-sens) * 510 * 149), 
         Cost_FP = ((1-spec) * 1991 * 99),
         Total_Cost = Cost_FN + Cost_FP) %>% 
 select(.threshold, Cost_FN, Cost_FP, Total_Cost) %>% 
 pivot_longer(2:4, names_to = 'Cost_Function', values_to = 'Cost') %>% 
  ggplot(aes(x = .threshold, y = Cost, color = Cost_Function)) +
    geom_line(size = 1.5) +
    theme_minimal() +
    scale_colour_viridis_d(end = 0.8) +
    labs(title = 'Threshold Cost Function', x = 'Threshold')

Scenario Analysis – Minimising Cost or Maximising Differentiation

As we have established cost functions, we can then identify a decision threshold that minimises these costs. As noted in the introduction, we can think of two scenarios, as we’ve identified above, the threshold that optimises the J-index or the threshold that minimises cost. This is demonstrated below.

threshold_data %>% 
  filter(.metric %in% c('sens', 'spec')) %>% 
  pivot_wider(id_cols = .threshold, values_from = .estimate, names_from = .metric) %>% 
  mutate(Cost = ((1-sens) * 510 * 149) + ((1-spec) * 1991 * 99),
         j_index = (sens+spec)-1) %>% 
  ggplot(aes(y=Cost, x = .threshold)) +
    geom_line() +
    geom_point(aes(size = j_index, color = j_index)) +
    geom_vline(xintercept = 0.47, lty = 2) +
    annotate(x = 0.36, y=100000, geom = 'text', label = 'Best Class DifferentiationnJ-Index = 0.56,nCost = $57,629,nThreshold = 0.47') +
    geom_vline(xintercept = 0.69, lty = 2) +
    annotate(x = 0.81, y = 100000, geom = 'text', label = 'Lowest Cost ModelnJ-Index = 0.48,nCost = $48,329,nThreshold = 0.69') +    
    theme_minimal() +
    scale_colour_viridis_c() +
    labs(title = 'Decision Threshold Attrition Cost Function', 
         subtitle = 'Where Cost(FN) = $149 & Cost(FP) = $99',
         x = 'Classification Threshold', size = 'J-Index', color = 'J-Index')

Interestingly, given our cost assumptions, the lowest cost threshold is 0.69, unsurprisingly increasing specificity (TPR) at the cost of sensitivity (TNR). We visualise confusion matrices as below.

t1 <- xgb_pred %>% 
  mutate(.pred = make_two_class_pred(.pred_0, levels(Exited), threshold = 0.5)) %>%
  conf_mat(estimate = .pred, Exited) %>% 
  autoplot(type = 'heatmap') + 
  scale_fill_gradient2() +
  labs(title = 'Default Decision Threshold = 0.50')
t2 <- xgb_pred %>% 
  mutate(.pred = make_two_class_pred(.pred_0, levels(Exited), threshold = 0.47)) %>%
  conf_mat(estimate = .pred, Exited) %>% 
  autoplot(type = 'heatmap') + 
  scale_fill_gradient2() +
  labs(title = 'With Adjusted Decision Threshold = 0.47')
t3 <- xgb_pred %>% 
  mutate(.pred = make_two_class_pred(.pred_0, levels(Exited), threshold = 0.69)) %>%
  conf_mat(estimate = .pred, Exited) %>% 
  autoplot(type = 'heatmap') + 
  scale_fill_gradient2() +
  labs(title ='With Adjusted Decision Threshold = 0.69')
t2 / t1 / t3 +
  plot_annotation(title = 'Confusion Matrices for UPSAMPLE_Boosted_Trees')

Concluding Remarks

We’ve completed a decision threshold analysis using the probably package, and a constructed a hypothetical scenario analysis. Given our assumptions, the lowest cost model reduces model performance. This is the trade off the business needs to consider, productionalise an effective model that can differentiate classes moderately well or go with the lower cost one despite the need for more interventions with greater false positive predictions.


Thank you for reading this article, and I hope you enjoyed it. I write these to teach myself something, and I hope you’ve learnt something too. If you’re not a Medium member – use my referral link below and get regular updates on new publications from myself and other fantastic Medium authors.

Join Medium with my referral link – Murray Gillin


Related Articles