Decision Trees with Scikit-Learn

Share on:

Table of Contents

Introduction

The police knows the murderer is likely to be female whose height is about 165cm and the murder weapon appears to be either a cricket bat or a rolling pin. 31 suspects are connected to the crime scene. Who should the police arrest?

Let’s play with decision trees using Scikit-Learn to find out!

Learning Objective

Decision trees are typically trained over large data sets with the aim of boiling down complex features into a limited number of decision points (i.e., ‘splits’).

This tutorial’s objective is simply introducing the notion of decision trees as a machine learning tool, using an intuitive example.

Obligatory Imports and Boilerplate

1import numpy as np
2import pandas as pd
3import re
4import matplotlib.pyplot as plt
5from sklearn import tree
6from sklearn.tree import DecisionTreeClassifier

The Suspects

The police have identified 31 suspects. In addition to their gender, and height, the police have searched their possessions and have found one potential murder weapon. Download spreadsheet here.

1df_raw = pd.read_excel('suspects.xlsx')
2df_raw.head()

Suspect Gender Height Weapon
0 Jorge M 183 Bottle
1 Rajesh M 170 Cricket Bat
2 Lucy F 162 Bottle
3 Heidi F 161 Rolling Pin
4 Juliette F 163 Cricket Bat

Decision Tree-Friendly Data

Decision trees understand numerical values. The suspects’ height are continuous values so they are OK, but what about gender and weapon?

Unlike other machine learning methods, we can’t simply replace these with numbers without creating further challenges. For example, let’s say that we were to replace ‘M’ (male) with 1, and ‘F’ (female) with 2, what would be the interpretation of a gender of value 1.5 or 3 in this case?

Variables such as gender and weapon are so-called categorical variables, which are meant to be treated as pseudo-binary values, as far as decision trees are concerned.

This means that rather than something like M=1 and F=2, we need to create two columns: Gender_M, and Gender_F, as follows:

Gender_M Gender_F Meaning
1 0 Male
0 1 Female

In essence, each category becomes a new feature, as far as decision trees are concerned. What if we have many options, like in the case of the murder weapon? Do we require a new column for each murder weapon too? Yes, madam.

But data scientists hate overworking themselves, so Pandas has a nifty method called get_dummies() which does all of this heavy lifting for us. We only need to specify which columns should be subjected to this process:

1df = pd.get_dummies(df_raw,columns=['Gender','Weapon'])
2df.head()

Suspect Height Gender_F Gender_M Weapon_Bottle Weapon_Cricket Bat Weapon_Rolling Pin
0 Jorge 183 0 1 1 0 0
1 Rajesh 170 0 1 0 1 0
2 Lucy 162 1 0 1 0 0
3 Heidi 161 1 0 0 0 1
4 Juliette 163 1 0 0 1 0

Now we have a data set that can be consumed by Scikit-Learn but first we need to split it into features (X), and classes (y)

1X = df.iloc[:, 1:]
2y = df.index.values

Given that we are using numbers for our labels—our suspects—it will come handy later to turn such numbers into values again:

1labels = df['Suspect'] 
2labels.head()
0       Jorge
1      Rajesh
2        Lucy
3       Heidi
4    Juliette
Name: Suspect, dtype: object

Model Training

It would be an exaggeration to say that we are ’training’ the model given that we only have 31 instances. In reality, the model is simply memorising the data set, but the objective in this tutorial is simply to introduce you to the notion of decision trees.

There are a number of arguments that can be provided to DecisionTreeClassifier(). Here, we have just specified a fixed random state.

1model = DecisionTreeClassifier(random_state=12)
2model.fit(X.values,y)
3model.score(X.values,y)
0.9354838709677419

Visualisation

Business people—specially auditors—love decision trees. Why? Because they can understand what the model has actually learned! Let’s see:

1fnames = list(X.columns.values)
2fig = plt.figure(figsize=(45,30))
3_ = tree.plot_tree(model, 
4                   feature_names=fnames,  
5                   class_names=labels,
6                   filled=True,
7                   rounded=True)

You can’t read a thing, can you? No worries, we will show a smaller tree in a moment. For now, consider the below text representation:

1print(tree.export_text(model,feature_names=fnames))
|--- Height <= 155.00
|   |--- class: 28
|--- Height >  155.00
|   |--- Weapon_Rolling Pin <= 0.50
|   |   |--- Gender_F <= 0.50
|   |   |   |--- Weapon_Bottle <= 0.50
|   |   |   |   |--- Height <= 162.50
|   |   |   |   |   |--- class: 5
|   |   |   |   |--- Height >  162.50
|   |   |   |   |   |--- Height <= 168.50
|   |   |   |   |   |   |--- class: 11
|   |   |   |   |   |--- Height >  168.50
|   |   |   |   |   |   |--- Height <= 174.00
|   |   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |   |--- Height >  174.00
|   |   |   |   |   |   |   |--- Height <= 180.50
|   |   |   |   |   |   |   |   |--- class: 17
|   |   |   |   |   |   |   |--- Height >  180.50
|   |   |   |   |   |   |   |   |--- Height <= 184.50
|   |   |   |   |   |   |   |   |   |--- class: 14
|   |   |   |   |   |   |   |   |--- Height >  184.50
|   |   |   |   |   |   |   |   |   |--- Height <= 186.50
|   |   |   |   |   |   |   |   |   |   |--- class: 25
|   |   |   |   |   |   |   |   |   |--- Height >  186.50
|   |   |   |   |   |   |   |   |   |   |--- class: 19
|   |   |   |--- Weapon_Bottle >  0.50
|   |   |   |   |--- Height <= 164.00
|   |   |   |   |   |--- class: 8
|   |   |   |   |--- Height >  164.00
|   |   |   |   |   |--- Height <= 166.00
|   |   |   |   |   |   |--- class: 30
|   |   |   |   |   |--- Height >  166.00
|   |   |   |   |   |   |--- Height <= 171.50
|   |   |   |   |   |   |   |--- class: 16
|   |   |   |   |   |   |--- Height >  171.50
|   |   |   |   |   |   |   |--- Height <= 179.50
|   |   |   |   |   |   |   |   |--- class: 23
|   |   |   |   |   |   |   |--- Height >  179.50
|   |   |   |   |   |   |   |   |--- Height <= 183.50
|   |   |   |   |   |   |   |   |   |--- class: 0
|   |   |   |   |   |   |   |   |--- Height >  183.50
|   |   |   |   |   |   |   |   |   |--- class: 21
|   |   |--- Gender_F >  0.50
|   |   |   |--- Weapon_Cricket Bat <= 0.50
|   |   |   |   |--- Height <= 164.50
|   |   |   |   |   |--- class: 2
|   |   |   |   |--- Height >  164.50
|   |   |   |   |   |--- Height <= 167.50
|   |   |   |   |   |   |--- class: 13
|   |   |   |   |   |--- Height >  167.50
|   |   |   |   |   |   |--- Height <= 176.50
|   |   |   |   |   |   |   |--- class: 18
|   |   |   |   |   |   |--- Height >  176.50
|   |   |   |   |   |   |   |--- class: 29
|   |   |   |--- Weapon_Cricket Bat >  0.50
|   |   |   |   |--- Height <= 166.00
|   |   |   |   |   |--- class: 4
|   |   |   |   |--- Height >  166.00
|   |   |   |   |   |--- Height <= 171.00
|   |   |   |   |   |   |--- class: 22
|   |   |   |   |   |--- Height >  171.00
|   |   |   |   |   |   |--- class: 9
|   |--- Weapon_Rolling Pin >  0.50
|   |   |--- Gender_F <= 0.50
|   |   |   |--- Height <= 157.50
|   |   |   |   |--- class: 20
|   |   |   |--- Height >  157.50
|   |   |   |   |--- Height <= 162.00
|   |   |   |   |   |--- class: 24
|   |   |   |   |--- Height >  162.00
|   |   |   |   |   |--- Height <= 173.50
|   |   |   |   |   |   |--- class: 6
|   |   |   |   |   |--- Height >  173.50
|   |   |   |   |   |   |--- class: 12
|   |   |--- Gender_F >  0.50
|   |   |   |--- Height <= 158.50
|   |   |   |   |--- class: 10
|   |   |   |--- Height >  158.50
|   |   |   |   |--- Height <= 160.00
|   |   |   |   |   |--- class: 15
|   |   |   |   |--- Height >  160.00
|   |   |   |   |   |--- Height <= 168.00
|   |   |   |   |   |   |--- class: 3
|   |   |   |   |   |--- Height >  168.00
|   |   |   |   |   |   |--- class: 7

Model Predictions (Arresting Suspects)

In the beginning we said that the police believes that the murderer is likely to be female whose height is about 165m and the murder weapon was either a cricket bat or a rolling pin.

How do we provide this data to our decision tree? Let’s define each dimension as a few properties:

1base_suspect = { 'Height'   : 165,
2                 'Gender'   : 'F',
3                 'Weapon'   : None}

Now we’ll define an utility function to translate the above properties into features, make a prediction, and name the suspect that should be arrested:

 1def predict(m, props):
 2    features = { 'Height' : props['Height'],
 3                 'Gender_F' : 1 if props['Gender'] == 'F' else 0,
 4                 'Gender_M' : 1 if props['Gender'] == 'M' else 0,
 5                 'Weapon_Bottle' : 1 if props['Weapon'] == 'Bottle' else 0,
 6                 'Weapon_Cricket Bat' : 1 if props['Weapon'] == 'Cricket Bat' else 0,
 7                 'Weapon_Rolling Pin' : 1 if props['Weapon'] == 'Rolling Pin' else 0 }
 8    print("Features: {}".format(features))
 9    p = m.predict([list(features.values())])
10    suspect = labels[p].values[0]
11    print("Suspect: {}".format(suspect))

Assuming the suspect’s height to be 165cm and her gender to be female, the model suggests that the police should arrest Lada:

1predict(model, base_suspect)
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 0, 'Weapon_Rolling Pin': 0}
Suspect: Lada

However, when we specify a murder weapon, the model selects different suspects: Heidi for the rolling pin, and Juliette for the cricket bat.

1suspect_1 = base_suspect
2suspect_1['Weapon']='Rolling Pin'
3predict(model, suspect_1)
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 0, 'Weapon_Rolling Pin': 1}
Suspect: Heidi
1suspect_2 = base_suspect
2suspect_2['Weapon']='Cricket Bat'
3predict(model, suspect_2)
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 1, 'Weapon_Rolling Pin': 0}
Suspect: Juliette

Analysis

The police is only certain that the murderer is a female. Her height is ‘about’ 165cm. The murder weapon is believed not to be a bottle. Let’s see what other suspects fit the description:

1closest_suspects = df_raw[ (df_raw['Gender'] == 'F') & 
2                           (df_raw['Height'] <= 175) & 
3                           (df_raw['Weapon'] != 'Bottle') ]
4closest_suspects

Suspect Gender Height Weapon
3 Heidi F 161 Rolling Pin
4 Juliette F 163 Cricket Bat
7 Radka F 175 Rolling Pin
9 Sunita F 173 Cricket Bat
10 Claudia F 158 Rolling Pin
15 Lori F 159 Rolling Pin
22 Isidora F 169 Cricket Bat

The decision tree suggested the arrest of Heidi and Juliette, assuming that the weapon was either a rolling pin, or a cricket bat, respectively.

The reason as to why the model discarded the other suspects is because of their height difference. Here we order suspects by their difference in height—from 165cm.

1by_height = closest_suspects.copy(deep=True)
2by_height['Height Difference'] = (closest_suspects['Height'] - 165).abs()
3by_height.sort_values(by=['Height Difference'])

Suspect Gender Height Weapon Height Difference
4 Juliette F 163 Cricket Bat 2
3 Heidi F 161 Rolling Pin 4
22 Isidora F 169 Cricket Bat 4
15 Lori F 159 Rolling Pin 6
10 Claudia F 158 Rolling Pin 7
9 Sunita F 173 Cricket Bat 8
7 Radka F 175 Rolling Pin 10

In the above table we see that the model’s choice of Juliette and Heidi was the right call, given their proximity to the target height but we can see that Isidora might as well be arrested given than a 4cm difference can easily be the result of, say, wearing high heels.

Limitations

Our model is atypical in that there are no repeated labels. In most conventional models, like the Iris data set, there are several instances pointing to a few labels.

In addition, our decision tree is actually an “expert system” (an old term from the 80s) in the sense that it has simply memorised the data as a collection of rules, rather than making a statistical inference. For larger data sets, decision trees cannot be allowed to grow out of control; they require pruning. Pruning, of course, is a trade-off between scale and accuracy.

For example, let’s set the depth to three:

1model_pre_pruned = DecisionTreeClassifier(random_state=12,max_depth=3)
2model_pre_pruned.fit(X.values,y)
3print("Scores")
4print("Depth = None: {}".format(model.score(X.values,y)))
5print("Depth = 3:    {}".format(model_pre_pruned.score(X.values,y)))
Scores
Depth = None: 0.9354838709677419
Depth = 5:    0.16129032258064516

Let’s visualise the resulting tree:

1fig = plt.figure(figsize=(25,20))
2_ = tree.plot_tree(model_pre_pruned, 
3                   feature_names=fnames,  
4                   class_names=labels,
5                   filled=True,
6                   rounded=True)

As expected, for our particular case, pre-pruning leads to wrong results, and the arrest of an innocent lady!

1predict(model_pre_pruned, suspect_1)
2predict(model_pre_pruned, suspect_2)
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 1, 'Weapon_Rolling Pin': 0}
Suspect: Lucy
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 1, 'Weapon_Rolling Pin': 0}
Suspect: Lucy

Conclusion

Decision trees are an effective model to data sets in which labels (like the suspects in our example) can be ultimately selected by splitting features into decision branches.

For example, if the suspect’s height were known to be, say, 151cm, the murdered could only be Ximeno (class = 28), and the gender and murder weapon wouldn’t even matter, as far as the model is concerned:

1|--- Height <= 155.00
2|   |--- class: 28
3|--- Height >  155.00
4|   |--- Weapon_Rolling Pin <= 0.50
5...

Decision trees that require categorical values, such as murder weapons, require said categories to be translated to numerical values, for which a ‘get dummies’ process is applied as shown in this tutorial.

While our example is one in which the decision tree memorises the data set, you must bear in mind that their typical application is to draw inferences upon large data sets, which require pre-pruning the decision tree in advance. In this case, the decision tree would feel more like a logistic regression model, rather than our ‘binary choice’ one.

Before You Leave

🤘 Subscribe to my 100% spam-free newsletter!

website counters