18. Decision Trees#

%matplotlib inline
import pandas as pd
import seaborn as sns
import numpy as np
from sklearn import tree
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn import metrics
sns.set(palette='colorblind') # this improves contrast

from sklearn.metrics import confusion_matrix, classification_report

18.1. Let’s look at a toy dataset#

Using a toy dataset here shows an easy to see challenge for the classifier that we have seen so far. Real datasets will be hard in different ways, and since they’re higher dimensional, it’s harder to visualize the cause.

corner_data = 'https://raw.githubusercontent.com/rhodyprog4ds/06-naive-bayes/f425ba121cc0c4dd8bcaa7ebb2ff0b40b0b03bff/data/dataset6.csv'
df6= pd.read_csv(corner_data,usecols=[1,2,3])
gnb = GaussianNB()
sns.pairplot(data=df6, hue='char',hue_order=['A','B'])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 sns.pairplot(data=df6, hue='char',hue_order=['A','B'])

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/axisgrid.py:2148, in pairplot(data, hue, hue_order, palette, vars, x_vars, y_vars, kind, diag_kind, markers, height, aspect, corner, dropna, plot_kws, diag_kws, grid_kws, size)
   2146     diag_kws.setdefault("fill", True)
   2147     diag_kws.setdefault("warn_singular", False)
-> 2148     grid.map_diag(kdeplot, **diag_kws)
   2150 # Maybe plot on the off-diagonals
   2151 if diag_kind is not None:

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/axisgrid.py:1507, in PairGrid.map_diag(self, func, **kwargs)
   1505     plot_kwargs.setdefault("hue_order", self._hue_order)
   1506     plot_kwargs.setdefault("palette", self._orig_palette)
-> 1507     func(x=vector, **plot_kwargs)
   1508     ax.legend_ = None
   1510 self._add_axis_labels()

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/distributions.py:1717, in kdeplot(data, x, y, hue, weights, palette, hue_order, hue_norm, color, fill, multiple, common_norm, common_grid, cumulative, bw_method, bw_adjust, warn_singular, log_scale, levels, thresh, gridsize, cut, clip, legend, cbar, cbar_ax, cbar_kws, ax, **kwargs)
   1713 if p.univariate:
   1715     plot_kws = kwargs.copy()
-> 1717     p.plot_univariate_density(
   1718         multiple=multiple,
   1719         common_norm=common_norm,
   1720         common_grid=common_grid,
   1721         fill=fill,
   1722         color=color,
   1723         legend=legend,
   1724         warn_singular=warn_singular,
   1725         estimate_kws=estimate_kws,
   1726         **plot_kws,
   1727     )
   1729 else:
   1731     p.plot_bivariate_density(
   1732         common_norm=common_norm,
   1733         fill=fill,
   (...)
   1743         **kwargs,
   1744     )

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/distributions.py:996, in _DistributionPlotter.plot_univariate_density(self, multiple, common_norm, common_grid, warn_singular, fill, color, legend, estimate_kws, **plot_kws)
    993 if "x" in self.variables:
    995     if fill:
--> 996         artist = ax.fill_between(support, fill_from, density, **artist_kws)
    998     else:
    999         artist, = ax.plot(support, density, **artist_kws)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/__init__.py:1423, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1420 @functools.wraps(func)
   1421 def inner(ax, *args, data=None, **kwargs):
   1422     if data is None:
-> 1423         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1425     bound = new_sig.bind(ax, *args, **kwargs)
   1426     auto_label = (bound.arguments.get(label_namer)
   1427                   or bound.kwargs.get(label_namer))

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5367, in Axes.fill_between(self, x, y1, y2, where, interpolate, step, **kwargs)
   5365 def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
   5366                  step=None, **kwargs):
-> 5367     return self._fill_between_x_or_y(
   5368         "x", x, y1, y2,
   5369         where=where, interpolate=interpolate, step=step, **kwargs)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5272, in Axes._fill_between_x_or_y(self, ind_dir, ind, dep1, dep2, where, interpolate, step, **kwargs)
   5268         kwargs["facecolor"] = \
   5269             self._get_patches_for_fill.get_next_color()
   5271 # Handle united data, such as dates
-> 5272 ind, dep1, dep2 = map(
   5273     ma.masked_invalid, self._process_unit_info(
   5274         [(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs))
   5276 for name, array in [
   5277         (ind_dir, ind), (f"{dep_dir}1", dep1), (f"{dep_dir}2", dep2)]:
   5278     if array.ndim > 1:

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/numpy/ma/core.py:2360, in masked_invalid(a, copy)
   2332 def masked_invalid(a, copy=True):
   2333     """
   2334     Mask an array where invalid values occur (NaNs or infs).
   2335 
   (...)
   2357 
   2358     """
-> 2360     return masked_where(~(np.isfinite(getdata(a))), a, copy=copy)

TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
../_images/2022-10-21_6_1.png

As we can see in this dataset, these classes are quite separated.

X_train, X_test, y_train, y_test = train_test_split(df6[['x0','x1']],
                          df6['char'],
                          random_state = 4)
gnb.fit(X_train,y_train)
gnb.score(X_test,y_test)
0.72

But we do not get a very good classification score.

To see why, we can look at what it learned.

gnb.__dict__
{'priors': None,
 'var_smoothing': 1e-09,
 'classes_': array(['A', 'B'], dtype='<U1'),
 'feature_names_in_': array(['x0', 'x1'], dtype=object),
 'n_features_in_': 2,
 'epsilon_': 4.294249888888889e-09,
 'theta_': array([[3.91910256, 3.9624359 ],
        [4.42861111, 3.54222222]]),
 'var_': array([[4.0355774 , 3.99742099],
        [4.43948696, 4.14491451]]),
 'class_count_': array([78., 72.]),
 'class_prior_': array([0.52, 0.48])}
N = 100
gnb_df = pd.DataFrame(np.concatenate([np.random.multivariate_normal(th, sig*np.eye(2),N)
         for th, sig in zip(gnb.theta_,gnb.sigma_)]),
         columns = ['x0','x1'])
gnb_df['char'] = [ci for cl in [[c]*N for c in gnb.classes_] for ci in cl]

sns.pairplot(data =gnb_df, hue='char',hue_order=['A','B'])
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[9], line 3
      1 N = 100
      2 gnb_df = pd.DataFrame(np.concatenate([np.random.multivariate_normal(th, sig*np.eye(2),N)
----> 3          for th, sig in zip(gnb.theta_,gnb.sigma_)]),
      4          columns = ['x0','x1'])
      5 gnb_df['char'] = [ci for cl in [[c]*N for c in gnb.classes_] for ci in cl]
      7 sns.pairplot(data =gnb_df, hue='char',hue_order=['A','B'])

AttributeError: 'GaussianNB' object has no attribute 'sigma_'

This does not look much like the data and it’s hard to tell which is higher at any given point in the 2D space. We know though, that it has missed the mark. We can also look at the actual predictions.

df6pred = X_test.copy()
df6pred['pred'] =gnb.predict(X_test)

sns.pairplot(df6pred, hue = 'pred',hue_order=['A','B'])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 4
      1 df6pred = X_test.copy()
      2 df6pred['pred'] =gnb.predict(X_test)
----> 4 sns.pairplot(df6pred, hue = 'pred',hue_order=['A','B'])

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/axisgrid.py:2148, in pairplot(data, hue, hue_order, palette, vars, x_vars, y_vars, kind, diag_kind, markers, height, aspect, corner, dropna, plot_kws, diag_kws, grid_kws, size)
   2146     diag_kws.setdefault("fill", True)
   2147     diag_kws.setdefault("warn_singular", False)
-> 2148     grid.map_diag(kdeplot, **diag_kws)
   2150 # Maybe plot on the off-diagonals
   2151 if diag_kind is not None:

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/axisgrid.py:1507, in PairGrid.map_diag(self, func, **kwargs)
   1505     plot_kwargs.setdefault("hue_order", self._hue_order)
   1506     plot_kwargs.setdefault("palette", self._orig_palette)
-> 1507     func(x=vector, **plot_kwargs)
   1508     ax.legend_ = None
   1510 self._add_axis_labels()

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/distributions.py:1717, in kdeplot(data, x, y, hue, weights, palette, hue_order, hue_norm, color, fill, multiple, common_norm, common_grid, cumulative, bw_method, bw_adjust, warn_singular, log_scale, levels, thresh, gridsize, cut, clip, legend, cbar, cbar_ax, cbar_kws, ax, **kwargs)
   1713 if p.univariate:
   1715     plot_kws = kwargs.copy()
-> 1717     p.plot_univariate_density(
   1718         multiple=multiple,
   1719         common_norm=common_norm,
   1720         common_grid=common_grid,
   1721         fill=fill,
   1722         color=color,
   1723         legend=legend,
   1724         warn_singular=warn_singular,
   1725         estimate_kws=estimate_kws,
   1726         **plot_kws,
   1727     )
   1729 else:
   1731     p.plot_bivariate_density(
   1732         common_norm=common_norm,
   1733         fill=fill,
   (...)
   1743         **kwargs,
   1744     )

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/distributions.py:996, in _DistributionPlotter.plot_univariate_density(self, multiple, common_norm, common_grid, warn_singular, fill, color, legend, estimate_kws, **plot_kws)
    993 if "x" in self.variables:
    995     if fill:
--> 996         artist = ax.fill_between(support, fill_from, density, **artist_kws)
    998     else:
    999         artist, = ax.plot(support, density, **artist_kws)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/__init__.py:1423, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1420 @functools.wraps(func)
   1421 def inner(ax, *args, data=None, **kwargs):
   1422     if data is None:
-> 1423         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1425     bound = new_sig.bind(ax, *args, **kwargs)
   1426     auto_label = (bound.arguments.get(label_namer)
   1427                   or bound.kwargs.get(label_namer))

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5367, in Axes.fill_between(self, x, y1, y2, where, interpolate, step, **kwargs)
   5365 def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
   5366                  step=None, **kwargs):
-> 5367     return self._fill_between_x_or_y(
   5368         "x", x, y1, y2,
   5369         where=where, interpolate=interpolate, step=step, **kwargs)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5272, in Axes._fill_between_x_or_y(self, ind_dir, ind, dep1, dep2, where, interpolate, step, **kwargs)
   5268         kwargs["facecolor"] = \
   5269             self._get_patches_for_fill.get_next_color()
   5271 # Handle united data, such as dates
-> 5272 ind, dep1, dep2 = map(
   5273     ma.masked_invalid, self._process_unit_info(
   5274         [(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs))
   5276 for name, array in [
   5277         (ind_dir, ind), (f"{dep_dir}1", dep1), (f"{dep_dir}2", dep2)]:
   5278     if array.ndim > 1:

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/numpy/ma/core.py:2360, in masked_invalid(a, copy)
   2332 def masked_invalid(a, copy=True):
   2333     """
   2334     Mask an array where invalid values occur (NaNs or infs).
   2335 
   (...)
   2357 
   2358     """
-> 2360     return masked_where(~(np.isfinite(getdata(a))), a, copy=copy)

TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
../_images/2022-10-21_14_1.png

This makes it more clear. It basically learns one group that covers 3 blobs and only 1 for the second color. If we train again with a different random seed, it makes a different specific error, but basically the same idea.

X_train, X_test, y_train, y_test = train_test_split(df6[['x0','x1']],
                          df6['char'], random_state = 5)
gnb.fit(X_train,y_train)
gnb.score(X_test,y_test)
0.34
df6pred = X_test.copy()
df6pred['pred'] =gnb.predict(X_test)

sns.pairplot(df6pred, hue = 'pred')
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 4
      1 df6pred = X_test.copy()
      2 df6pred['pred'] =gnb.predict(X_test)
----> 4 sns.pairplot(df6pred, hue = 'pred')

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/axisgrid.py:2148, in pairplot(data, hue, hue_order, palette, vars, x_vars, y_vars, kind, diag_kind, markers, height, aspect, corner, dropna, plot_kws, diag_kws, grid_kws, size)
   2146     diag_kws.setdefault("fill", True)
   2147     diag_kws.setdefault("warn_singular", False)
-> 2148     grid.map_diag(kdeplot, **diag_kws)
   2150 # Maybe plot on the off-diagonals
   2151 if diag_kind is not None:

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/axisgrid.py:1507, in PairGrid.map_diag(self, func, **kwargs)
   1505     plot_kwargs.setdefault("hue_order", self._hue_order)
   1506     plot_kwargs.setdefault("palette", self._orig_palette)
-> 1507     func(x=vector, **plot_kwargs)
   1508     ax.legend_ = None
   1510 self._add_axis_labels()

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/distributions.py:1717, in kdeplot(data, x, y, hue, weights, palette, hue_order, hue_norm, color, fill, multiple, common_norm, common_grid, cumulative, bw_method, bw_adjust, warn_singular, log_scale, levels, thresh, gridsize, cut, clip, legend, cbar, cbar_ax, cbar_kws, ax, **kwargs)
   1713 if p.univariate:
   1715     plot_kws = kwargs.copy()
-> 1717     p.plot_univariate_density(
   1718         multiple=multiple,
   1719         common_norm=common_norm,
   1720         common_grid=common_grid,
   1721         fill=fill,
   1722         color=color,
   1723         legend=legend,
   1724         warn_singular=warn_singular,
   1725         estimate_kws=estimate_kws,
   1726         **plot_kws,
   1727     )
   1729 else:
   1731     p.plot_bivariate_density(
   1732         common_norm=common_norm,
   1733         fill=fill,
   (...)
   1743         **kwargs,
   1744     )

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/distributions.py:996, in _DistributionPlotter.plot_univariate_density(self, multiple, common_norm, common_grid, warn_singular, fill, color, legend, estimate_kws, **plot_kws)
    993 if "x" in self.variables:
    995     if fill:
--> 996         artist = ax.fill_between(support, fill_from, density, **artist_kws)
    998     else:
    999         artist, = ax.plot(support, density, **artist_kws)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/__init__.py:1423, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1420 @functools.wraps(func)
   1421 def inner(ax, *args, data=None, **kwargs):
   1422     if data is None:
-> 1423         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1425     bound = new_sig.bind(ax, *args, **kwargs)
   1426     auto_label = (bound.arguments.get(label_namer)
   1427                   or bound.kwargs.get(label_namer))

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5367, in Axes.fill_between(self, x, y1, y2, where, interpolate, step, **kwargs)
   5365 def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
   5366                  step=None, **kwargs):
-> 5367     return self._fill_between_x_or_y(
   5368         "x", x, y1, y2,
   5369         where=where, interpolate=interpolate, step=step, **kwargs)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5272, in Axes._fill_between_x_or_y(self, ind_dir, ind, dep1, dep2, where, interpolate, step, **kwargs)
   5268         kwargs["facecolor"] = \
   5269             self._get_patches_for_fill.get_next_color()
   5271 # Handle united data, such as dates
-> 5272 ind, dep1, dep2 = map(
   5273     ma.masked_invalid, self._process_unit_info(
   5274         [(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs))
   5276 for name, array in [
   5277         (ind_dir, ind), (f"{dep_dir}1", dep1), (f"{dep_dir}2", dep2)]:
   5278     if array.ndim > 1:

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/numpy/ma/core.py:2360, in masked_invalid(a, copy)
   2332 def masked_invalid(a, copy=True):
   2333     """
   2334     Mask an array where invalid values occur (NaNs or infs).
   2335 
   (...)
   2357 
   2358     """
-> 2360     return masked_where(~(np.isfinite(getdata(a))), a, copy=copy)

TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
../_images/2022-10-21_17_1.png
X_train, X_test, y_train, y_test = train_test_split(df6[['x0','x1']],
                          df6['char'], random_state = 7)
gnb.fit(X_train,y_train)
gnb.score(X_test,y_test)
0.58
df6pred = X_test.copy()
df6pred['pred'] =gnb.predict(X_test)

sns.pairplot(df6pred, hue = 'pred')
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[14], line 4
      1 df6pred = X_test.copy()
      2 df6pred['pred'] =gnb.predict(X_test)
----> 4 sns.pairplot(df6pred, hue = 'pred')

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/axisgrid.py:2148, in pairplot(data, hue, hue_order, palette, vars, x_vars, y_vars, kind, diag_kind, markers, height, aspect, corner, dropna, plot_kws, diag_kws, grid_kws, size)
   2146     diag_kws.setdefault("fill", True)
   2147     diag_kws.setdefault("warn_singular", False)
-> 2148     grid.map_diag(kdeplot, **diag_kws)
   2150 # Maybe plot on the off-diagonals
   2151 if diag_kind is not None:

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/axisgrid.py:1507, in PairGrid.map_diag(self, func, **kwargs)
   1505     plot_kwargs.setdefault("hue_order", self._hue_order)
   1506     plot_kwargs.setdefault("palette", self._orig_palette)
-> 1507     func(x=vector, **plot_kwargs)
   1508     ax.legend_ = None
   1510 self._add_axis_labels()

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/distributions.py:1717, in kdeplot(data, x, y, hue, weights, palette, hue_order, hue_norm, color, fill, multiple, common_norm, common_grid, cumulative, bw_method, bw_adjust, warn_singular, log_scale, levels, thresh, gridsize, cut, clip, legend, cbar, cbar_ax, cbar_kws, ax, **kwargs)
   1713 if p.univariate:
   1715     plot_kws = kwargs.copy()
-> 1717     p.plot_univariate_density(
   1718         multiple=multiple,
   1719         common_norm=common_norm,
   1720         common_grid=common_grid,
   1721         fill=fill,
   1722         color=color,
   1723         legend=legend,
   1724         warn_singular=warn_singular,
   1725         estimate_kws=estimate_kws,
   1726         **plot_kws,
   1727     )
   1729 else:
   1731     p.plot_bivariate_density(
   1732         common_norm=common_norm,
   1733         fill=fill,
   (...)
   1743         **kwargs,
   1744     )

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/seaborn/distributions.py:996, in _DistributionPlotter.plot_univariate_density(self, multiple, common_norm, common_grid, warn_singular, fill, color, legend, estimate_kws, **plot_kws)
    993 if "x" in self.variables:
    995     if fill:
--> 996         artist = ax.fill_between(support, fill_from, density, **artist_kws)
    998     else:
    999         artist, = ax.plot(support, density, **artist_kws)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/__init__.py:1423, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1420 @functools.wraps(func)
   1421 def inner(ax, *args, data=None, **kwargs):
   1422     if data is None:
-> 1423         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1425     bound = new_sig.bind(ax, *args, **kwargs)
   1426     auto_label = (bound.arguments.get(label_namer)
   1427                   or bound.kwargs.get(label_namer))

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5367, in Axes.fill_between(self, x, y1, y2, where, interpolate, step, **kwargs)
   5365 def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
   5366                  step=None, **kwargs):
-> 5367     return self._fill_between_x_or_y(
   5368         "x", x, y1, y2,
   5369         where=where, interpolate=interpolate, step=step, **kwargs)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5272, in Axes._fill_between_x_or_y(self, ind_dir, ind, dep1, dep2, where, interpolate, step, **kwargs)
   5268         kwargs["facecolor"] = \
   5269             self._get_patches_for_fill.get_next_color()
   5271 # Handle united data, such as dates
-> 5272 ind, dep1, dep2 = map(
   5273     ma.masked_invalid, self._process_unit_info(
   5274         [(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs))
   5276 for name, array in [
   5277         (ind_dir, ind), (f"{dep_dir}1", dep1), (f"{dep_dir}2", dep2)]:
   5278     if array.ndim > 1:

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/numpy/ma/core.py:2360, in masked_invalid(a, copy)
   2332 def masked_invalid(a, copy=True):
   2333     """
   2334     Mask an array where invalid values occur (NaNs or infs).
   2335 
   (...)
   2357 
   2358     """
-> 2360     return masked_where(~(np.isfinite(getdata(a))), a, copy=copy)

TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
../_images/2022-10-21_19_1.png

If you try this again, split, fit, plot, it will learn different decisions, but always at least about 25% of the data will have to be classified incorrectly.

18.2. Decision Trees#

This data does not fit the assumptions of the Niave Bayes model, but a decision tree has a different rule. It can be more complex, but for the scikit learn one relies on splitting the data at a series of points along one axis at a time.

It is a discriminative model, because it describes how to discriminate (in the sense of differentiate) between the classes.

dt = tree.DecisionTreeClassifier()
dt.fit(X_train,y_train)
DecisionTreeClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
dt.score(X_test,y_test)
1.0

The sklearn estimator objects (that corresond to different models) all have the same API, so the fit, predict, and score methods are the same as above. We will see this also in regression and clustering. What each method does in terms of the specific calculations will vary depending on the model, but they’re always there.

the tree module also allows you to plot the tree to examine it.

plt.figure(figsize=(15,20))
tree.plot_tree(dt, rounded =True, class_names = ['A','B'],
      proportion=True, filled =True, impurity=False,fontsize=10);
../_images/2022-10-21_25_0.png

18.3. Setting Classifier Parameters#

The decision tree we had above has a lot more layers than we would expect. This is really simple data so we still got perfect classification. However, the more complex the model, the more risk that it will learn something noisy about the training data that doesn’t hold up in the test set.

Fortunately, we can control the parameters to make it find a simpler decision boundary.

dt2 = tree.DecisionTreeClassifier(max_depth=2)
dt2.fit(X_train,y_train)
DecisionTreeClassifier(max_depth=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
plt.figure(figsize=(15,20))
tree.plot_tree(dt2, rounded =True, class_names = ['A','B'],
      proportion=True, filled =True, impurity=False,fontsize=10);
../_images/2022-10-21_28_0.png
dt2.score(X_test,y_test)
0.86

We might need to play with different parameters to get it just how we want it. A simpler model is better because it will be more reliable in general.

dt2.score(X_test,y_test)
0.86
dt2_20 = tree.DecisionTreeClassifier(max_depth=2)
dt2_20.fit(X_train,y_train)
DecisionTreeClassifier(max_depth=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
dt2_20.score(X_test, y_test)
0.86
plt.figure(figsize=(15,20))
tree.plot_tree(dt2_20, rounded =True, class_names = ['A','B'],
      proportion=True, filled =True, impurity=False,fontsize=10);
../_images/2022-10-21_34_0.png

18.4. Questions After Class#

18.4.1. What do the dots and lines represent in the graph?#

The scatter plots, each dot is one sample (row from the DataFrame), the different sub plots are different views of the data, determined by different pairs of variables.

The lines are the density, when it’s high it means values are likely , when it’s low it means they are unlikely.

18.4.2. What doe the GNB model do?#

When we fit, it learns a description of the data. When we predict it uses that description to predict the probability that each test sample belongs to each of the classes and returns the most probable one. See the last notes for a visualization of the probabilities.

18.4.3. Is xtrain ytrain, xtest ytest the same every time?#

Generally no, with the random_state variable set on a given data they will be the same subset every time, but without it would get different rows. For example:

X_train, X_test, y_train, y_test = train_test_split(df6[['x0','x1']],
                          df6['char'],
                          random_state = 4)
X_train.head()
x0 x1
110 5.91 2.39
154 6.07 6.47
16 6.84 2.61
19 6.27 1.99
2 2.27 5.44
X_train, X_test, y_train, y_test = train_test_split(df6[['x0','x1']],
                          df6['char'],
                          random_state = 4)
X_train.head()
x0 x1
110 5.91 2.39
154 6.07 6.47
16 6.84 2.61
19 6.27 1.99
2 2.27 5.44

with random state, we get the same samples each time.

X_train, X_test, y_train, y_test = train_test_split(df6[['x0','x1']],
                          df6['char'])
X_train.head()
x0 x1
45 2.23 2.55
93 2.35 2.06
47 2.50 5.99
36 1.85 6.50
109 6.07 2.65
X_train, X_test, y_train, y_test = train_test_split(df6[['x0','x1']],
                          df6['char'])
X_train.head()
x0 x1
174 1.81 2.09
27 5.21 2.11
147 6.23 5.56
157 2.34 6.34
98 2.86 2.73

without, we get different ones.

18.4.4. I still want to know why it is worse to have a model that gets 100% accuracy with more depth to the decision tree flowchart than to have a model with less accuracy and less decisions in the decision tree.#

This accuracy is only on one set of training data, we might not want the big accuracy drop we saw here ,but we could do other things to improve it.