class: title-slide, center, bottom # Tune better models ## Tidymodels, virtually — Session 03 ### Alison Hill --- class: middle, frame, center # Decision Trees To predict the outcome of a new data point: Uses rules learned from splits Each split maximizes information gain --- background-image: url(images/aus-standard-animals.png) background-size: cover .footnote[[Australian Computing Academy](https://aca.edu.au/resources/decision-trees-classifying-animals/)] --- background-image: url(images/aus-standard-tree.png) background-size: cover .footnote[[Australian Computing Academy](https://aca.edu.au/resources/decision-trees-classifying-animals/)] --- background-image: url(images/annotated-tree/annotated-tree.001.png) background-size: cover --- background-image: url(images/annotated-tree/annotated-tree.002.png) background-size: cover --- background-image: url(images/annotated-tree/annotated-tree.003.png) background-size: cover --- background-image: url(images/annotated-tree/annotated-tree.004.png) background-size: cover --- background-image: url(images/annotated-tree/annotated-tree.005.png) background-size: cover --- class: middle, frame # .center[To specify a model with parsnip] .right-column[ 1\. Pick a .display[model] 2\. Set the .display[engine] 3\. Set the .display[mode] (if needed) ] --- class: middle, frame # .center[To specify a decision tree model with parsnip] ```r tree_mod <- decision_tree() %>% set_engine(engine = "rpart") %>% set_mode("classification") ``` --- class: middle, center <img src="figs/rmed03-tune/alz-tree-01-1.png" width="40%" style="display: block; margin: auto;" /> ``` # nn Class Imp Con cover # 4 Impaired [.82 .18] when tau >= 5.9 & VEGF < 17 19% # 10 Impaired [.75 .25] when tau >= 6.7 & VEGF >= 17 4% # 11 Control [.16 .84] when tau is 5.9 to 6.7 & VEGF >= 17 19% # 3 Control [.10 .90] when tau < 5.9 58% ``` --- .pull-left[ <img src="figs/rmed03-tune/unnamed-chunk-2-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ <img src="figs/rmed03-tune/unnamed-chunk-3-1.png" width="504" style="display: block; margin: auto;" /> ] --- class: your-turn # Your turn 1 Here is our very-vanilla parsnip model specification for a decision tree (also in your Rmd)... ```r tree_mod <- decision_tree() %>% set_engine(engine = "rpart") %>% set_mode("classification") ``` And a workflow: ```r tree_wf <- workflow() %>% add_formula(Class ~ .) %>% add_model(tree_mod) ``` For decision trees, no recipe really required š --- class: your-turn # Your turn 1 Fill in the blanks to return the accuracy and ROC AUC for this model using 10-fold cross-validation.
02
:
00
--- ```r set.seed(100) tree_wf %>% fit_resamples(resamples = alz_folds) %>% collect_metrics() # # A tibble: 2 x 5 # .metric .estimator mean n std_err # <chr> <chr> <dbl> <int> <dbl> # 1 accuracy binary 0.756 10 0.0245 # 2 roc_auc binary 0.770 10 0.0255 ``` --- class: middle, center # `args()` Print the arguments for a **parsnip** model specification. ```r args(decision_tree) ``` --- class: middle, center # `decision_tree()` Specifies a decision tree model ```r decision_tree(tree_depth = 30, min_n = 20, cost_complexity = .01) ``` -- *either* mode works! --- class: middle .center[ # `decision_tree()` Specifies a decision tree model ] ```r decision_tree( tree_depth = 30, # max tree depth min_n = 20, # smallest node allowed cost_complexity = .01 # 0 > cp > 0.1 ) ``` --- class: middle, center # `set_args()` Change the arguments for a **parsnip** model specification. ```r _mod %>% set_args(tree_depth = 3) ``` --- class: middle ```r decision_tree() %>% set_engine("rpart") %>% set_mode("classification") %>% * set_args(tree_depth = 3) # Decision Tree Model Specification (classification) # # Main Arguments: # tree_depth = 3 # # Computational engine: rpart ``` --- class: middle ```r *decision_tree(tree_depth = 3) %>% set_engine("rpart") %>% set_mode("classification") # Decision Tree Model Specification (classification) # # Main Arguments: # tree_depth = 3 # # Computational engine: rpart ``` --- class: middle, center # `tree_depth` Cap the maximum tree depth. A method to stop the tree early. Used to prevent overfitting. ```r tree_mod %>% set_args(tree_depth = 30) ``` --- class: middle, center exclude: true --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-15-1.png" width="864" style="display: block; margin: auto;" /> --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-16-1.png" width="864" style="display: block; margin: auto;" /> --- class: middle, center # `min_n` Set minimum `n` to split at any node. Another early stopping method. Used to prevent overfitting. ```r tree_mod %>% set_args(min_n = 20) ``` --- class: middle, center # Quiz What value of `min_n` would lead to the *most overfit* tree? -- `min_n` = 1 --- class: middle, center, frame # Recap: early stopping | `parsnip` arg | `rpart` arg | default | overfit? | |---------------|-------------|:-------:|:--------:| | `tree_depth` | `maxdepth` | 30 |ā¬ļø| | `min_n` | `minsplit` | 20 |ā¬ļø| --- class: middle, center # `cost_complexity` Adds a cost or penalty to error rates of more complex trees. A way to prune a tree. Used to prevent overfitting. ```r tree_mod %>% set_args(cost_complexity = .01) ``` -- Closer to zero ā”ļø larger trees. Higher penalty ā”ļø smaller trees. --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-19-1.png" width="720" style="display: block; margin: auto;" /> --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-20-1.png" width="864" style="display: block; margin: auto;" /> --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-21-1.png" width="864" style="display: block; margin: auto;" /> --- name: bonsai background-image: url(images/kari-shea-AVqh83jStMA-unsplash.jpg) background-position: left background-size: contain class: middle --- template: bonsai .pull-right[ # Consider the bonsai 1. Small pot 1. Strong shears ] --- template: bonsai .pull-right[ # Consider the bonsai 1. ~~Small pot~~ .display[Early stopping] 1. ~~Strong shears~~ .display[Pruning] ] --- class: middle, center, frame # Recap: early stopping & pruning | `parsnip` arg | `rpart` arg | default | overfit? | |---------------|-------------|:-------:|:--------:| | `tree_depth` | `maxdepth` | 30 |ā¬ļø| | `min_n` | `minsplit` | 20 |ā¬ļø| | `cost_complexity` | `cp` | .01 |ā¬ļø| --- class: middle, center <table> <thead> <tr> <th style="text-align:left;"> engine </th> <th style="text-align:left;"> parsnip </th> <th style="text-align:left;"> original </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> rpart </td> <td style="text-align:left;"> tree_depth </td> <td style="text-align:left;"> maxdepth </td> </tr> <tr> <td style="text-align:left;"> rpart </td> <td style="text-align:left;"> min_n </td> <td style="text-align:left;"> minsplit </td> </tr> <tr> <td style="text-align:left;"> rpart </td> <td style="text-align:left;"> cost_complexity </td> <td style="text-align:left;"> cp </td> </tr> </tbody> </table> <https://rdrr.io/cran/rpart/man/rpart.control.html> --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-23-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="figs/rmed03-tune/unnamed-chunk-24-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="figs/rmed03-tune/unnamed-chunk-25-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="figs/rmed03-tune/unnamed-chunk-26-1.png" width="504" style="display: block; margin: auto;" /> --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-27-1.png" width="504" style="display: block; margin: auto;" /> --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-28-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="figs/rmed02-workflows/big-alz-tree-1.png" width="672" style="display: block; margin: auto;" /> --- class: middle, frame, center # Axiom There is an inverse relationship between model *accuracy* and model *interpretability*. --- class: middle, center # `rand_forest()` Specifies a random forest model ```r rand_forest(mtry = 4, trees = 500, min_n = 1) ``` -- *either* mode works! --- class: middle .center[ # `rand_forest()` Specifies a random forest model ] ```r rand_forest( mtry = 4, # predictors seen at each node trees = 500, # trees per forest min_n = 1 # smallest node allowed ) ``` --- class: your-turn # Your turn 2 Create a new parsnip model called `rf_mod`, which will learn an ensemble of classification trees from our training data using the **ranger** package. Update your `tree_wf` with this new model. Fit your workflow with 10-fold cross-validation and compare the ROC AUC of the random forest to your single decision tree model --- which predicts the test set better? *Hint: you'll need https://www.tidymodels.org/find/parsnip/*
04
:
00
--- ```r rf_mod <- rand_forest() %>% set_engine("ranger") %>% set_mode("classification") rf_wf <- tree_wf %>% update_model(rf_mod) set.seed(100) rf_wf %>% fit_resamples(resamples = alz_folds) %>% collect_metrics() # # A tibble: 2 x 5 # .metric .estimator mean n std_err # <chr> <chr> <dbl> <int> <dbl> # 1 accuracy binary 0.827 10 0.0151 # 2 roc_auc binary 0.887 10 0.0157 ``` --- class: middle, center # `mtry` The number of predictors that will be randomly sampled at each split when creating the tree models. ```r rand_forest(mtry = 11) ``` **ranger** default = `floor(sqrt(num_predictors))` --- class: your-turn # Your turn 3 Challenge: Fit 3 more random forest models, each using 3, 8, and 30 variables at each split. Update your `rf_wf` with each new model. Which value maximizes the area under the ROC curve?
03
:
00
--- ```r rf3_mod <- rf_mod %>% * set_args(mtry = 3) rf8_mod <- rf_mod %>% * set_args(mtry = 8) rf30_mod <- rf_mod %>% * set_args(mtry = 30) ``` --- ```r rf3_wf <- rf_wf %>% update_model(rf3_mod) set.seed(100) rf3_wf %>% fit_resamples(resamples = alz_folds) %>% collect_metrics() # # A tibble: 2 x 5 # .metric .estimator mean n std_err # <chr> <chr> <dbl> <int> <dbl> # 1 accuracy binary 0.780 10 0.0115 # 2 roc_auc binary 0.862 10 0.0173 ``` --- ```r rf8_wf <- rf_wf %>% update_model(rf8_mod) set.seed(100) rf8_wf %>% fit_resamples(resamples = alz_folds) %>% collect_metrics() # # A tibble: 2 x 5 # .metric .estimator mean n std_err # <chr> <chr> <dbl> <int> <dbl> # 1 accuracy binary 0.813 10 0.0139 # 2 roc_auc binary 0.880 10 0.0180 ``` --- ```r rf30_wf <- rf_wf %>% update_model(rf30_mod) set.seed(100) rf30_wf %>% fit_resamples(resamples = alz_folds) %>% collect_metrics() # # A tibble: 2 x 5 # .metric .estimator mean n std_err # <chr> <chr> <dbl> <int> <dbl> # 1 accuracy binary 0.837 10 0.0181 # 2 roc_auc binary 0.897 10 0.0137 ``` --- class: middle, center, frame # tune Functions for fitting and tuning models <https://tune.tidymodels.org> <iframe src="https://tune.tidymodels.org" width="100%" height="400px"></iframe> --- class: middle, center # `tune()` A placeholder for hyper-parameters to be "tuned" ```r nearest_neighbor(neighbors = tune()) ``` --- .center[ # `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. ] .pull-left[ ```r tune_grid( object, resamples, ..., grid = 10, metrics = NULL, control = control_grid() ) ``` ] --- .center[ # `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. ] .pull-left[ ```r tune_grid( * object, resamples, ..., grid = 10, metrics = NULL, control = control_grid() ) ``` ] -- .pull-right[ One of: + A parsnip `model` object + A `workflow` ] --- .center[ # `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. ] .pull-left[ ```r tune_grid( * object, * preprocessor, resamples, ..., grid = 10, metrics = NULL, control = control_grid() ) ``` ] .pull-right[ A `model` + `recipe` ] --- .center[ # `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. ] .pull-left[ ```r tune_grid( object, resamples, ..., * grid = 10, metrics = NULL, control = control_grid() ) ``` ] .pull-right[ One of: + A positive integer. + A data frame of tuning combinations. ] --- .center[ # `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. ] .pull-left[ ```r tune_grid( object, resamples, ..., * grid = 10, metrics = NULL, control = control_grid() ) ``` ] .pull-right[ Number of candidate parameter sets to be created automatically; `10` is the default. ] --- ```r data("ad_data") alz <- ad_data # data splitting set.seed(100) # Important! alz_split <- initial_split(alz, strata = Class, prop = .9) alz_train <- training(alz_split) alz_test <- testing(alz_split) # data resampling set.seed(100) alz_folds <- vfold_cv(alz_train, v = 10, strata = Class) ``` --- class: your-turn # Your Turn 4 Here's our random forest model plus workflow to work with. ```r rf_mod <- rand_forest() %>% set_engine("ranger") %>% set_mode("classification") rf_wf <- workflow() %>% add_formula(Class ~ .) %>% add_model(rf_mod) ``` --- class: your-turn # Your Turn 4 Here is the output from `fit_resamples()`... ```r set.seed(100) # Important! rf_results <- rf_wf %>% fit_resamples(resamples = alz_folds, metrics = metric_set(roc_auc)) rf_results %>% collect_metrics() # # A tibble: 1 x 5 # .metric .estimator mean n std_err # <chr> <chr> <dbl> <int> <dbl> # 1 roc_auc binary 0.884 10 0.0157 ``` --- class: your-turn # Your Turn 4 Edit the random forest model to tune the `mtry` and `min_n` hyperparameters. Update your workflow to use the tuned model. Then use `tune_grid()` to find the best combination of hyper-parameters to maximize `roc_auc`; let tune set up the grid for you. How does it compare to the average ROC AUC across folds from `fit_resamples()`?
05
:
00
--- ```r rf_tuner <- rand_forest(mtry = tune(), min_n = tune()) %>% set_engine("ranger") %>% set_mode("classification") rf_wf <- rf_wf %>% update_model(rf_tuner) set.seed(100) # Important! rf_results <- rf_wf %>% tune_grid(resamples = alz_folds, metrics = metric_set(roc_auc)) ``` --- ```r rf_results %>% collect_metrics() # # A tibble: 10 x 8 # mtry min_n .metric .estimator mean n std_err # <int> <int> <chr> <chr> <dbl> <int> <dbl> # 1 53 28 roc_auc binary 0.904 10 0.0111 # 2 37 36 roc_auc binary 0.902 10 0.0134 # 3 74 21 roc_auc binary 0.907 10 0.0108 # 4 10 13 roc_auc binary 0.885 10 0.0171 # 5 80 8 roc_auc binary 0.905 10 0.0101 # 6 100 31 roc_auc binary 0.906 10 0.0118 # 7 16 15 roc_auc binary 0.892 10 0.0162 # 8 113 5 roc_auc binary 0.909 10 0.0111 # 9 127 38 roc_auc binary 0.898 10 0.0121 # 10 43 20 roc_auc binary 0.901 10 0.0114 # # ā¦ with 1 more variable: .config <fct> ``` --- ```r rf_results %>% collect_metrics(summarize = FALSE) # # A tibble: 100 x 7 # id mtry min_n .metric .estimator .estimate .config # <chr> <int> <int> <chr> <chr> <dbl> <fct> # 1 Fold01 53 28 roc_auc binary 0.894 Model01 # 2 Fold01 37 36 roc_auc binary 0.859 Model02 # 3 Fold01 74 21 roc_auc binary 0.914 Model03 # 4 Fold01 10 13 roc_auc binary 0.869 Model04 # 5 Fold01 80 8 roc_auc binary 0.914 Model05 # 6 Fold01 100 31 roc_auc binary 0.919 Model06 # 7 Fold01 16 15 roc_auc binary 0.848 Model07 # 8 Fold01 113 5 roc_auc binary 0.914 Model08 # 9 Fold01 127 38 roc_auc binary 0.909 Model09 # 10 Fold01 43 20 roc_auc binary 0.894 Model10 # # ā¦ with 90 more rows ``` --- .center[ # `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. ] .pull-left[ ```r tune_grid( object, resamples, ..., * grid = df, metrics = NULL, control = control_grid() ) ``` ] .pull-right[ A data frame of tuning combinations. ] --- class: middle, center # `expand_grid()` Takes one or more vectors, and returns a data frame holding all combinations of their values. ```r expand_grid(mtry = c(1, 5), min_n = 1:3) # # A tibble: 6 x 2 # mtry min_n # <dbl> <int> # 1 1 1 # 2 1 2 # 3 1 3 # 4 5 1 # 5 5 2 # 6 5 3 ``` -- .footnote[tidyr package; see also base `expand.grid()`] --- class: middle name: show-best .center[ # `show_best()` Shows the .display[n] most optimum combinations of hyper-parameters ] ```r rf_results %>% show_best(metric = "roc_auc", n = 5) ``` --- template: show-best ``` # # A tibble: 5 x 8 # mtry min_n .metric .estimator mean n std_err .config # <int> <int> <chr> <chr> <dbl> <int> <dbl> <fct> # 1 113 5 roc_auc binary 0.909 10 0.0111 Model08 # 2 74 21 roc_auc binary 0.907 10 0.0108 Model03 # 3 100 31 roc_auc binary 0.906 10 0.0118 Model06 # 4 80 8 roc_auc binary 0.905 10 0.0101 Model05 # 5 53 28 roc_auc binary 0.904 10 0.0111 Model01 ``` --- class: middle, center # `autoplot()` Quickly visualize tuning results ```r rf_results %>% autoplot() ``` <img src="figs/rmed03-tune/rf-plot-1.png" width="504" style="display: block; margin: auto;" /> --- class: middle, center <img src="figs/rmed03-tune/unnamed-chunk-56-1.png" width="504" style="display: block; margin: auto;" /> --- class: middle name: select-best .center[ # `select_best()` Shows the .display[top] combination of hyper-parameters. ] ```r alz_best <- rf_results %>% select_best(metric = "roc_auc") alz_best ``` --- template: select-best ``` # # A tibble: 1 x 3 # mtry min_n .config # <int> <int> <fct> # 1 113 5 Model08 ``` --- class: middle .center[ # `finalize_workflow()` Replaces `tune()` placeholders in a model/recipe/workflow with a set of hyper-parameter values. ] ```r last_rf_workflow <- rf_wf %>% finalize_workflow(alz_best) ``` --- background-image: url(images/diamonds.jpg) background-size: contain background-position: left class: middle, center background-color: #f5f5f5 .pull-right[ ## We are ready to touch the jewels... ## The .display[testing set]! ] --- class: middle .center[ # `last_fit()` ] ```r last_rf_fit <- last_rf_workflow %>% last_fit(split = alz_split) ``` --- ```r last_rf_fit # # Resampling results # # Monte Carlo cross-validation (0.9/0.099) with 1 resamples # # A tibble: 1 x 6 # splits id .metrics .notes .predictions .workflow # <list> <chr> <list> <list> <list> <list> # 1 <split ā¦ train/ā¦ <tibble [ā¦ <tibblā¦ <tibble [33ā¦ <workfloā¦ ``` --- class: your-turn # Your Turn 5 Use `select_best()`, `finalize_workflow()`, and `last_fit()` to take the best combination of hyper-parameters from `rf_results` and use them to predict the test set. How does our actual test ROC AUC compare to our cross-validated estimate?
05
:
00
--- ```r alz_best <- rf_results %>% select_best(metric = "roc_auc") last_rf_workflow <- rf_wf%>% finalize_workflow(alz_best) last_rf_fit <- last_rf_workflow %>% last_fit(split = alz_split) last_rf_fit %>% collect_metrics() ``` --- class: middle, frame .center[ # Final metrics ] ```r last_rf_fit %>% collect_metrics() # # A tibble: 2 x 3 # .metric .estimator .estimate # <chr> <chr> <dbl> # 1 accuracy binary 0.818 # 2 roc_auc binary 0.894 ``` --- class: middle .center[ # Final test predictions ] ```r last_rf_fit %>% collect_predictions() # # A tibble: 33 x 6 # id .pred_Impaired .pred_Control .row .pred_class # <chr> <dbl> <dbl> <int> <fct> # 1 traiā¦ 0.0908 0.909 4 Control # 2 traiā¦ 0.582 0.418 8 Impaired # 3 traiā¦ 0.802 0.198 10 Impaired # 4 traiā¦ 0.325 0.675 34 Control # 5 traiā¦ 0.285 0.715 48 Control # 6 traiā¦ 0.0654 0.935 71 Control # 7 traiā¦ 0.662 0.338 82 Impaired # 8 traiā¦ 0.0143 0.986 83 Control # 9 traiā¦ 0.530 0.470 91 Impaired # 10 traiā¦ 0.0295 0.970 108 Control # # ā¦ with 23 more rows, and 1 more variable: Class <fct> ``` --- ```r roc_values <- last_rf_fit %>% collect_predictions() %>% roc_curve(truth = Class, estimate = .pred_Impaired) autoplot(roc_values) ``` <img src="figs/rmed03-tune/unnamed-chunk-65-1.png" width="50%" style="display: block; margin: auto;" />