import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from sklearn import datasets
from sklearn import cluster
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn import treeTrain, test and Validation data¶
We will work with the iris data again.
iris_df = sns.load_dataset('iris')
iris_X = iris_df.drop(columns=['species'])
iris_y = iris_df['species']We will still use the test train split to keep our test data separate from the data that we use to find our preferred parameters.
iris_X_train, iris_X_test, iris_y_train, iris_y_test = train_test_split(iris_X,iris_y, random_state=0)We will be doing cross validation late, but we still use train_test_split at the start that we have the true test data
Think Ahead¶
What would you need to be able to find the best parameter settings?
what would be the inputs to an algorithm to optimize a model
what might the steps include?
Setting up model optmization¶
Today we will optimize a decision tree over three parameters.
One is the criterion, which is how it decides where to create thresholds in parameters. Gini is the default and it computes how concentrated each class is at that node, another is entropy, which is a measure of how random something is. Intuitively these do similar things, which makes sense because they are two ways to make the same choice, but they have slightly different calculations.
The other two parameters relate to the structure of the decision tree that is produced and their values are numbers.
max_depthis the height of the treemin_samples_leafper leaf makes it keeps the leaf sizes small.
dt = tree.DecisionTreeClassifier()
params_dt = {'criterion':['gini','entropy'],
'max_depth':[2,3,4],
'min_samples_leaf':list(range(2,20,2))}Grid Search¶
We will first to an exhaustive optimization on that parameter grid, params_dt.
The dictionary is called a parameter grid because it will be used to create a “grid” of different values, by taking every possible combination.
The GridSearchCV object will then also do cross validation, with the same default values we saw for cross_val_score of 5 fold cross validation (Kfold with K=5).
Solution to Exercise 1 #
GridSearchCV will cross validate the model for every combination of parameter values from the parameter grid.
To compute the number of fits this means this by first getting the lengths of each list of values:
num_param_values = {k:len(v) for k,v in params_dt.items()}
num_param_values{'criterion': 2, 'max_depth': 3, 'min_samples_leaf': 9}so we have 9 for min_samples_leaf because the range is inclusive of the start and exclusive of the stop, or in math it is .
Then multiplying to get the total number of combination
combos = np.prod([v for v in num_param_values.values()])
combosnp.int64(54)We have a total of np.int64(54) combinations that will be tested and since cv=5 it will fit each of those 5 times so the total number of fit models is np.int64(270)
We will instantiate it it with default CV settings.
dt_opt = GridSearchCV(dt,params_dt)The GridSearchCV keeps the same basic interface of estimator objects, we run it with the fit method.
dt_opt.fit(iris_X_train,iris_y_train)We can also get predictions, from the model with the highest score out of all of the combinations:
y_pred = dt_opt.predict(iris_X_test)we can also score it as normal.
test_score = dt_opt.score(iris_X_test,iris_y_test)
test_score0.9473684210526315This is our true test accuracy because this data iris_X_test,iris_y_test was not used at all for training or for optimizing the parameters.
we can also see the best parameters.
dt_opt.best_params_{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 2}Grid Search Results¶
The optimizer saves a lot of details of its process in a dictionary
Long output
dt_opt.cv_results_{'mean_fit_time': array([0.00556393, 0.01108193, 0.00836782, 0.01082835, 0.0096602 ,
0.01146016, 0.01239662, 0.01423249, 0.00842075, 0.01091595,
0.00727549, 0.00824037, 0.0103353 , 0.00863438, 0.01010695,
0.00721569, 0.00706072, 0.00727439, 0.00805449, 0.00800195,
0.00790009, 0.00913134, 0.00789819, 0.00690184, 0.00748725,
0.00939136, 0.00676408, 0.00967827, 0.00656519, 0.00700526,
0.01039691, 0.0067945 , 0.0084537 , 0.00980482, 0.0112596 ,
0.00657005, 0.01148438, 0.00993586, 0.00764174, 0.00959158,
0.01428456, 0.01063943, 0.01292725, 0.00584931, 0.01096969,
0.01036873, 0.00924635, 0.00686235, 0.00730443, 0.00719438,
0.00711536, 0.00826945, 0.00717797, 0.00845666]),
'std_fit_time': array([0.00104859, 0.00467209, 0.00488872, 0.00449627, 0.0025553 ,
0.00512982, 0.00574487, 0.00637621, 0.00526116, 0.00480431,
0.00343601, 0.00521062, 0.00418091, 0.00485772, 0.00437501,
0.00475051, 0.00263535, 0.00486059, 0.00372457, 0.00515958,
0.00416323, 0.00354444, 0.00505693, 0.0040882 , 0.0043604 ,
0.00209445, 0.00506572, 0.0038188 , 0.00389477, 0.0013015 ,
0.00143236, 0.00465869, 0.00143521, 0.00468834, 0.00058274,
0.00612317, 0.00477859, 0.00702165, 0.00561581, 0.00653371,
0.00760997, 0.00475086, 0.00589428, 0.00500827, 0.00434682,
0.00230181, 0.00358727, 0.00413464, 0.00442031, 0.00451731,
0.00323603, 0.00364638, 0.00418752, 0.00074529]),
'mean_score_time': array([0.0043541 , 0.01018205, 0.00706873, 0.00464911, 0.00691705,
0.0107286 , 0.00819516, 0.01061416, 0.00832911, 0.01017704,
0.00723572, 0.00968351, 0.00533085, 0.00961628, 0.00727997,
0.00476322, 0.00469165, 0.00870957, 0.01043324, 0.00805259,
0.00962353, 0.00835204, 0.00510597, 0.00705853, 0.00712686,
0.00754223, 0.00738592, 0.00577888, 0.00624752, 0.0057147 ,
0.0049799 , 0.00699844, 0.00663114, 0.00604944, 0.00806751,
0.01167359, 0.00686994, 0.00585818, 0.00935736, 0.00756249,
0.01157517, 0.00588923, 0.00992284, 0.00907855, 0.00782967,
0.09356718, 0.00698056, 0.00640044, 0.00694876, 0.00690894,
0.00702868, 0.00427184, 0.00446315, 0.00555096]),
'std_score_time': array([0.00355889, 0.00768632, 0.00460646, 0.00414295, 0.00453055,
0.00648163, 0.0063 , 0.00588742, 0.00564748, 0.00464921,
0.00495385, 0.00428572, 0.00433756, 0.00396552, 0.00458429,
0.00409567, 0.00263075, 0.00141644, 0.00292771, 0.00537506,
0.00333061, 0.00594733, 0.00434678, 0.00462183, 0.00453016,
0.00390081, 0.00302669, 0.00514065, 0.0041283 , 0.00230777,
0.00407663, 0.0044185 , 0.00484107, 0.00401512, 0.00505211,
0.00527431, 0.00704981, 0.00557497, 0.00665956, 0.00650267,
0.00531911, 0.00549329, 0.00477628, 0.00394251, 0.00512711,
0.17606516, 0.00448506, 0.00367257, 0.00446495, 0.00434137,
0.00288783, 0.00381217, 0.00247172, 0.00327762]),
'param_criterion': masked_array(data=['gini', 'gini', 'gini', 'gini', 'gini', 'gini', 'gini',
'gini', 'gini', 'gini', 'gini', 'gini', 'gini', 'gini',
'gini', 'gini', 'gini', 'gini', 'gini', 'gini', 'gini',
'gini', 'gini', 'gini', 'gini', 'gini', 'gini',
'entropy', 'entropy', 'entropy', 'entropy', 'entropy',
'entropy', 'entropy', 'entropy', 'entropy', 'entropy',
'entropy', 'entropy', 'entropy', 'entropy', 'entropy',
'entropy', 'entropy', 'entropy', 'entropy', 'entropy',
'entropy', 'entropy', 'entropy', 'entropy', 'entropy',
'entropy', 'entropy'],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False],
fill_value=np.str_('?'),
dtype=object),
'param_max_depth': masked_array(data=[2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False],
fill_value=999999),
'param_min_samples_leaf': masked_array(data=[2, 4, 6, 8, 10, 12, 14, 16, 18, 2, 4, 6, 8, 10, 12, 14,
16, 18, 2, 4, 6, 8, 10, 12, 14, 16, 18, 2, 4, 6, 8, 10,
12, 14, 16, 18, 2, 4, 6, 8, 10, 12, 14, 16, 18, 2, 4,
6, 8, 10, 12, 14, 16, 18],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False],
fill_value=999999),
'params': [{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 2},
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 4},
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 6},
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 8},
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 10},
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 12},
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 14},
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 16},
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 18},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 2},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 4},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 6},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 8},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 10},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 12},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 14},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 16},
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 18},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 2},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 4},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 6},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 8},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 10},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 12},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 14},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 16},
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 18},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 2},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 4},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 6},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 8},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 10},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 12},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 14},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 16},
{'criterion': 'entropy', 'max_depth': 2, 'min_samples_leaf': 18},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 2},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 4},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 6},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 8},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 10},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 12},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 14},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 16},
{'criterion': 'entropy', 'max_depth': 3, 'min_samples_leaf': 18},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 2},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 4},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 6},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 8},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 10},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 12},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 14},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 16},
{'criterion': 'entropy', 'max_depth': 4, 'min_samples_leaf': 18}],
'split0_test_score': array([0.95652174, 0.95652174, 0.95652174, 0.95652174, 0.95652174,
0.95652174, 0.95652174, 0.95652174, 0.95652174, 1. ,
0.95652174, 0.95652174, 0.95652174, 0.95652174, 0.95652174,
0.95652174, 0.95652174, 0.95652174, 1. , 0.95652174,
0.95652174, 0.95652174, 0.95652174, 0.95652174, 0.95652174,
0.95652174, 0.95652174, 0.95652174, 0.95652174, 0.95652174,
0.95652174, 0.95652174, 0.95652174, 0.95652174, 0.95652174,
0.95652174, 1. , 0.95652174, 0.95652174, 0.95652174,
0.95652174, 0.95652174, 0.95652174, 0.95652174, 0.95652174,
1. , 0.95652174, 0.95652174, 0.95652174, 0.95652174,
0.95652174, 0.95652174, 0.95652174, 0.95652174]),
'split1_test_score': array([0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.95652174, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348, 0.91304348,
0.91304348, 0.91304348, 0.91304348, 0.91304348]),
'split2_test_score': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1.]),
'split3_test_score': array([0.90909091, 0.90909091, 0.90909091, 0.90909091, 0.90909091,
0.90909091, 0.90909091, 0.90909091, 0.90909091, 0.95454545,
0.90909091, 0.90909091, 0.90909091, 0.90909091, 0.90909091,
0.90909091, 0.90909091, 0.90909091, 0.95454545, 0.90909091,
0.90909091, 0.90909091, 0.90909091, 0.90909091, 0.90909091,
0.90909091, 0.90909091, 0.90909091, 0.90909091, 0.90909091,
0.90909091, 0.90909091, 0.90909091, 0.90909091, 0.90909091,
0.90909091, 0.95454545, 0.90909091, 0.90909091, 0.90909091,
0.90909091, 0.90909091, 0.90909091, 0.90909091, 0.90909091,
0.95454545, 0.90909091, 0.90909091, 0.90909091, 0.90909091,
0.90909091, 0.90909091, 0.90909091, 0.90909091]),
'split4_test_score': array([0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545, 0.95454545,
0.95454545, 0.95454545, 0.95454545, 0.95454545]),
'mean_test_score': array([0.94664032, 0.94664032, 0.94664032, 0.94664032, 0.94664032,
0.94664032, 0.94664032, 0.94664032, 0.94664032, 0.96442688,
0.94664032, 0.94664032, 0.94664032, 0.94664032, 0.94664032,
0.94664032, 0.94664032, 0.94664032, 0.97312253, 0.94664032,
0.94664032, 0.94664032, 0.94664032, 0.94664032, 0.94664032,
0.94664032, 0.94664032, 0.94664032, 0.94664032, 0.94664032,
0.94664032, 0.94664032, 0.94664032, 0.94664032, 0.94664032,
0.94664032, 0.96442688, 0.94664032, 0.94664032, 0.94664032,
0.94664032, 0.94664032, 0.94664032, 0.94664032, 0.94664032,
0.96442688, 0.94664032, 0.94664032, 0.94664032, 0.94664032,
0.94664032, 0.94664032, 0.94664032, 0.94664032]),
'std_test_score': array([0.03330494, 0.03330494, 0.03330494, 0.03330494, 0.03330494,
0.03330494, 0.03330494, 0.03330494, 0.03330494, 0.03276105,
0.03330494, 0.03330494, 0.03330494, 0.03330494, 0.03330494,
0.03330494, 0.03330494, 0.03330494, 0.02195722, 0.03330494,
0.03330494, 0.03330494, 0.03330494, 0.03330494, 0.03330494,
0.03330494, 0.03330494, 0.03330494, 0.03330494, 0.03330494,
0.03330494, 0.03330494, 0.03330494, 0.03330494, 0.03330494,
0.03330494, 0.03276105, 0.03330494, 0.03330494, 0.03330494,
0.03330494, 0.03330494, 0.03330494, 0.03330494, 0.03330494,
0.03276105, 0.03330494, 0.03330494, 0.03330494, 0.03330494,
0.03330494, 0.03330494, 0.03330494, 0.03330494]),
'rank_test_score': array([5, 5, 5, 5, 5, 5, 5, 5, 5, 2, 5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 2, 5, 5, 5, 5, 5, 5, 5,
5, 2, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)}It is easier to work with if we use a DataFrame:
dt_5cv_df = pd.DataFrame(dt_opt.cv_results_)First let’s inspect its shape:
dt_5cv_df.shape(54, 16)Notice that it has one row for each of the np.int64(54) combinations we computed above).
It has a lot of columns, we can use the head to see them
dt_5cv_df.head()the
fit_timeis the time it takes for the.fitmethod to run for a single setting of hyperparameter values. it computes the mean and std of these over the Kfold cross validation (so here, over 5 separate times)the
score_timeis the time to make predictions. it computes the mean and std of these over the Kfold cross validation (so here, over 5 separate times)there is one
param_column for each key in the parameter grid (<function dict.keys>)the
paramscolumn contains a dictionary of all of the parameters used for that rowthe
test_scores are better termed the validation accuracy because is truly the score for the validation data it is the “test” set from the splits in the cross validation loop. It records the value for each fold, the mean and the std for them.rank_test_scoreis the rank for that hyperparameter setting’smean_test_score
Since we used a classifier here, the score is accuracy if it was regression it would be the score, if Kmeans it would be the opposite of the Kmeans objective.
We can also plot the dta and look at the performance.
sns.catplot(data=dt_5cv_df,x='param_min_samples_leaf',y='mean_test_score',
col='param_criterion', row= 'param_max_depth', kind='bar',)<seaborn.axisgrid.FacetGrid at 0x7f194a4f8ec0>
this makes it clear that none of these stick out much in terms of performance.
The best model here is not much better than the others, but for less simple tasks there are more things to choose from.
Impact of CV parameters¶
Let’s fit again with cv=10 to see with 10-fold cross validation.
dt_opt10 = GridSearchCV(dt,params_dt,cv=10)
dt_opt10.fit(iris_X_train,iris_y_train)and get the dataframe for the results
dt_10cv_df = pd.DataFrame(dt_opt10.cv_results_)We can stack the columns we want from the two results together with a new indicator column cv:
plot_cols = ['param_min_samples_leaf','std_test_score','mean_test_score',
'param_criterion','param_max_depth','cv']
dt_10cv_df['cv'] = 10
dt_5cv_df['cv'] = 5
dt_cv_df = pd.concat([dt_5cv_df[plot_cols],dt_10cv_df[plot_cols]])
dt_cv_df.head()this can be used to plot.
sns.catplot(data=dt_cv_df,x='param_min_samples_leaf',y='mean_test_score',
col='param_criterion', row= 'param_max_depth', kind='bar',
hue = 'cv')<seaborn.axisgrid.FacetGrid at 0x7f1948924530>
we see that the mean scores are not very different, but that 10 is a little higher in some cases. This makes sense, it has more data to learn from, so it found something that applied better, on average, to the test set.
sns.catplot(data=dt_cv_df,x='param_min_samples_leaf',y='std_test_score',
col='param_criterion', row= 'param_max_depth', kind='bar',
hue = 'cv')<seaborn.axisgrid.FacetGrid at 0x7f194895ec90>
However here we see that the variabilty in those scores is much higher, so maybe the 5 is better.
There were a really small number of samples used to compute each of those scores so some of them will vary a lot more.
.75*150112.5112/522.4We can compare to see if it finds the same model as best:
dt_opt.best_params_{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 2}dt_opt10.best_params_{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 2}In some cases they will and others they will not.
dt_opt.score(iris_X_test,iris_y_test)0.9473684210526315dt_opt10.score(iris_X_test,iris_y_test)0.9473684210526315In some cases they will find the same model and score the same, but it other time they will not.
The takeaway is that the cross validation parameters impact our ability to measure the score and possibly how close that cross validation mean score will match the true test score. Mostly it will change the variability in the estimate of the score. It does not change necessarily which model is best, that is up to the data iteself (the original test/train split would impact this).
Other searches¶
from sklearn import model_selection
from sklearn.model_selection import LeaveOneOutrand_opt = model_selection.RandomizedSearchCV(dt,params_dt,)
rand_opt.fit(iris_X_train, iris_y_train)rand_opt.score(iris_X_test,iris_y_test)0.8947368421052632It might find the same solution, but it also might not. If you do some and see that the parameters overall do not impact the scores much, then you can trust whichever one, or consider other criteria to choose the best model to use.
Choosing a model to use¶
The Grid search finds the hyperparameter values that result in the best mean score. But what if more than one does that?
dt_5cv_df['rank_test_score'].value_counts()rank_test_score
5 50
2 3
1 1
Name: count, dtype: int64Lets look at the ones sharing a rank of 1:
dt_5cv_df[dt_5cv_df['rank_test_score']==1]We can compare on other aspects, like the time. In particular a lower or more consistent score_time could impact how expensive it is to run your model in production.
dt_5cv_df[['mean_fit_time', 'std_fit_time', 'mean_score_time', 'std_score_time']].mean()mean_fit_time 0.008934
std_fit_time 0.004200
mean_score_time 0.009026
std_score_time 0.007721
dtype: float64dt_5cv_df[['mean_fit_time', 'std_fit_time', 'mean_score_time', 'std_score_time']].head(3)