Decision Trees with Scikit-Learn

Table of Contents
Introduction
The police knows the murderer is likely to be female whose height is about 165cm and the murder weapon appears to be either a cricket bat or a rolling pin. 31 suspects are connected to the crime scene. Who should the police arrest?
Let’s play with decision trees using Scikit-Learn to find out!
Learning Objective
Decision trees are typically trained over large data sets with the aim of boiling down complex features into a limited number of decision points (i.e., ‘splits’).
This tutorial’s objective is simply introducing the notion of decision trees as a machine learning tool, using an intuitive example.
Obligatory Imports and Boilerplate
1import numpy as np
2import pandas as pd
3import re
4import matplotlib.pyplot as plt
5from sklearn import tree
6from sklearn.tree import DecisionTreeClassifier
The Suspects
The police have identified 31 suspects. In addition to their gender, and height, the police have searched their possessions and have found one potential murder weapon. Download spreadsheet here.
1df_raw = pd.read_excel('suspects.xlsx')
2df_raw.head()
Suspect | Gender | Height | Weapon | |
---|---|---|---|---|
0 | Jorge | M | 183 | Bottle |
1 | Rajesh | M | 170 | Cricket Bat |
2 | Lucy | F | 162 | Bottle |
3 | Heidi | F | 161 | Rolling Pin |
4 | Juliette | F | 163 | Cricket Bat |
Decision Tree-Friendly Data
Decision trees understand numerical values. The suspects’ height are continuous values so they are OK, but what about gender and weapon?
Unlike other machine learning methods, we can’t simply replace these with numbers without creating further challenges. For example, let’s say that we were to replace ‘M’ (male) with 1, and ‘F’ (female) with 2, what would be the interpretation of a gender of value 1.5 or 3 in this case?
Variables such as gender and weapon are so-called categorical variables, which are meant to be treated as pseudo-binary values, as far as decision trees are concerned.
This means that rather than something like M=1
and F=2
, we need to create two columns: Gender_M, and Gender_F, as follows:
Gender_M | Gender_F | Meaning |
---|---|---|
1 | 0 | Male |
0 | 1 | Female |
In essence, each category becomes a new feature, as far as decision trees are concerned. What if we have many options, like in the case of the murder weapon? Do we require a new column for each murder weapon too? Yes, madam.
But data scientists hate overworking themselves, so Pandas has a nifty method called get_dummies()
which does all of this heavy lifting for us. We only need to specify which columns should be subjected to this process:
1df = pd.get_dummies(df_raw,columns=['Gender','Weapon'])
2df.head()
Suspect | Height | Gender_F | Gender_M | Weapon_Bottle | Weapon_Cricket Bat | Weapon_Rolling Pin | |
---|---|---|---|---|---|---|---|
0 | Jorge | 183 | 0 | 1 | 1 | 0 | 0 |
1 | Rajesh | 170 | 0 | 1 | 0 | 1 | 0 |
2 | Lucy | 162 | 1 | 0 | 1 | 0 | 0 |
3 | Heidi | 161 | 1 | 0 | 0 | 0 | 1 |
4 | Juliette | 163 | 1 | 0 | 0 | 1 | 0 |
Now we have a data set that can be consumed by Scikit-Learn but first we need to split it into features (X
), and classes (y)
1X = df.iloc[:, 1:]
2y = df.index.values
Given that we are using numbers for our labels—our suspects—it will come handy later to turn such numbers into values again:
1labels = df['Suspect']
2labels.head()
0 Jorge
1 Rajesh
2 Lucy
3 Heidi
4 Juliette
Name: Suspect, dtype: object
Model Training
It would be an exaggeration to say that we are ’training’ the model given that we only have 31 instances. In reality, the model is simply memorising the data set, but the objective in this tutorial is simply to introduce you to the notion of decision trees.
There are a number of arguments that can be provided to DecisionTreeClassifier()
. Here, we have just specified a fixed random state.
1model = DecisionTreeClassifier(random_state=12)
2model.fit(X.values,y)
3model.score(X.values,y)
0.9354838709677419
Visualisation
Business people—specially auditors—love decision trees. Why? Because they can understand what the model has actually learned! Let’s see:
1fnames = list(X.columns.values)
2fig = plt.figure(figsize=(45,30))
3_ = tree.plot_tree(model,
4 feature_names=fnames,
5 class_names=labels,
6 filled=True,
7 rounded=True)
You can’t read a thing, can you? No worries, we will show a smaller tree in a moment. For now, consider the below text representation:
1print(tree.export_text(model,feature_names=fnames))
|--- Height <= 155.00
| |--- class: 28
|--- Height > 155.00
| |--- Weapon_Rolling Pin <= 0.50
| | |--- Gender_F <= 0.50
| | | |--- Weapon_Bottle <= 0.50
| | | | |--- Height <= 162.50
| | | | | |--- class: 5
| | | | |--- Height > 162.50
| | | | | |--- Height <= 168.50
| | | | | | |--- class: 11
| | | | | |--- Height > 168.50
| | | | | | |--- Height <= 174.00
| | | | | | | |--- class: 1
| | | | | | |--- Height > 174.00
| | | | | | | |--- Height <= 180.50
| | | | | | | | |--- class: 17
| | | | | | | |--- Height > 180.50
| | | | | | | | |--- Height <= 184.50
| | | | | | | | | |--- class: 14
| | | | | | | | |--- Height > 184.50
| | | | | | | | | |--- Height <= 186.50
| | | | | | | | | | |--- class: 25
| | | | | | | | | |--- Height > 186.50
| | | | | | | | | | |--- class: 19
| | | |--- Weapon_Bottle > 0.50
| | | | |--- Height <= 164.00
| | | | | |--- class: 8
| | | | |--- Height > 164.00
| | | | | |--- Height <= 166.00
| | | | | | |--- class: 30
| | | | | |--- Height > 166.00
| | | | | | |--- Height <= 171.50
| | | | | | | |--- class: 16
| | | | | | |--- Height > 171.50
| | | | | | | |--- Height <= 179.50
| | | | | | | | |--- class: 23
| | | | | | | |--- Height > 179.50
| | | | | | | | |--- Height <= 183.50
| | | | | | | | | |--- class: 0
| | | | | | | | |--- Height > 183.50
| | | | | | | | | |--- class: 21
| | |--- Gender_F > 0.50
| | | |--- Weapon_Cricket Bat <= 0.50
| | | | |--- Height <= 164.50
| | | | | |--- class: 2
| | | | |--- Height > 164.50
| | | | | |--- Height <= 167.50
| | | | | | |--- class: 13
| | | | | |--- Height > 167.50
| | | | | | |--- Height <= 176.50
| | | | | | | |--- class: 18
| | | | | | |--- Height > 176.50
| | | | | | | |--- class: 29
| | | |--- Weapon_Cricket Bat > 0.50
| | | | |--- Height <= 166.00
| | | | | |--- class: 4
| | | | |--- Height > 166.00
| | | | | |--- Height <= 171.00
| | | | | | |--- class: 22
| | | | | |--- Height > 171.00
| | | | | | |--- class: 9
| |--- Weapon_Rolling Pin > 0.50
| | |--- Gender_F <= 0.50
| | | |--- Height <= 157.50
| | | | |--- class: 20
| | | |--- Height > 157.50
| | | | |--- Height <= 162.00
| | | | | |--- class: 24
| | | | |--- Height > 162.00
| | | | | |--- Height <= 173.50
| | | | | | |--- class: 6
| | | | | |--- Height > 173.50
| | | | | | |--- class: 12
| | |--- Gender_F > 0.50
| | | |--- Height <= 158.50
| | | | |--- class: 10
| | | |--- Height > 158.50
| | | | |--- Height <= 160.00
| | | | | |--- class: 15
| | | | |--- Height > 160.00
| | | | | |--- Height <= 168.00
| | | | | | |--- class: 3
| | | | | |--- Height > 168.00
| | | | | | |--- class: 7
Model Predictions (Arresting Suspects)
In the beginning we said that the police believes that the murderer is likely to be female whose height is about 165m and the murder weapon was either a cricket bat or a rolling pin.
How do we provide this data to our decision tree? Let’s define each dimension as a few properties:
1base_suspect = { 'Height' : 165,
2 'Gender' : 'F',
3 'Weapon' : None}
Now we’ll define an utility function to translate the above properties into features, make a prediction, and name the suspect that should be arrested:
1def predict(m, props):
2 features = { 'Height' : props['Height'],
3 'Gender_F' : 1 if props['Gender'] == 'F' else 0,
4 'Gender_M' : 1 if props['Gender'] == 'M' else 0,
5 'Weapon_Bottle' : 1 if props['Weapon'] == 'Bottle' else 0,
6 'Weapon_Cricket Bat' : 1 if props['Weapon'] == 'Cricket Bat' else 0,
7 'Weapon_Rolling Pin' : 1 if props['Weapon'] == 'Rolling Pin' else 0 }
8 print("Features: {}".format(features))
9 p = m.predict([list(features.values())])
10 suspect = labels[p].values[0]
11 print("Suspect: {}".format(suspect))
Assuming the suspect’s height to be 165cm and her gender to be female, the model suggests that the police should arrest Lada:
1predict(model, base_suspect)
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 0, 'Weapon_Rolling Pin': 0}
Suspect: Lada
However, when we specify a murder weapon, the model selects different suspects: Heidi for the rolling pin, and Juliette for the cricket bat.
1suspect_1 = base_suspect
2suspect_1['Weapon']='Rolling Pin'
3predict(model, suspect_1)
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 0, 'Weapon_Rolling Pin': 1}
Suspect: Heidi
1suspect_2 = base_suspect
2suspect_2['Weapon']='Cricket Bat'
3predict(model, suspect_2)
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 1, 'Weapon_Rolling Pin': 0}
Suspect: Juliette
Analysis
The police is only certain that the murderer is a female. Her height is ‘about’ 165cm. The murder weapon is believed not to be a bottle. Let’s see what other suspects fit the description:
1closest_suspects = df_raw[ (df_raw['Gender'] == 'F') &
2 (df_raw['Height'] <= 175) &
3 (df_raw['Weapon'] != 'Bottle') ]
4closest_suspects
Suspect | Gender | Height | Weapon | |
---|---|---|---|---|
3 | Heidi | F | 161 | Rolling Pin |
4 | Juliette | F | 163 | Cricket Bat |
7 | Radka | F | 175 | Rolling Pin |
9 | Sunita | F | 173 | Cricket Bat |
10 | Claudia | F | 158 | Rolling Pin |
15 | Lori | F | 159 | Rolling Pin |
22 | Isidora | F | 169 | Cricket Bat |
The decision tree suggested the arrest of Heidi and Juliette, assuming that the weapon was either a rolling pin, or a cricket bat, respectively.
The reason as to why the model discarded the other suspects is because of their height difference. Here we order suspects by their difference in height—from 165cm.
1by_height = closest_suspects.copy(deep=True)
2by_height['Height Difference'] = (closest_suspects['Height'] - 165).abs()
3by_height.sort_values(by=['Height Difference'])
Suspect | Gender | Height | Weapon | Height Difference | |
---|---|---|---|---|---|
4 | Juliette | F | 163 | Cricket Bat | 2 |
3 | Heidi | F | 161 | Rolling Pin | 4 |
22 | Isidora | F | 169 | Cricket Bat | 4 |
15 | Lori | F | 159 | Rolling Pin | 6 |
10 | Claudia | F | 158 | Rolling Pin | 7 |
9 | Sunita | F | 173 | Cricket Bat | 8 |
7 | Radka | F | 175 | Rolling Pin | 10 |
In the above table we see that the model’s choice of Juliette and Heidi was the right call, given their proximity to the target height but we can see that Isidora might as well be arrested given than a 4cm difference can easily be the result of, say, wearing high heels.
Limitations
Our model is atypical in that there are no repeated labels. In most conventional models, like the Iris data set, there are several instances pointing to a few labels.
In addition, our decision tree is actually an “expert system” (an old term from the 80s) in the sense that it has simply memorised the data as a collection of rules, rather than making a statistical inference. For larger data sets, decision trees cannot be allowed to grow out of control; they require pruning. Pruning, of course, is a trade-off between scale and accuracy.
For example, let’s set the depth to three:
1model_pre_pruned = DecisionTreeClassifier(random_state=12,max_depth=3)
2model_pre_pruned.fit(X.values,y)
3print("Scores")
4print("Depth = None: {}".format(model.score(X.values,y)))
5print("Depth = 3: {}".format(model_pre_pruned.score(X.values,y)))
Scores
Depth = None: 0.9354838709677419
Depth = 5: 0.16129032258064516
Let’s visualise the resulting tree:
1fig = plt.figure(figsize=(25,20))
2_ = tree.plot_tree(model_pre_pruned,
3 feature_names=fnames,
4 class_names=labels,
5 filled=True,
6 rounded=True)
As expected, for our particular case, pre-pruning leads to wrong results, and the arrest of an innocent lady!
1predict(model_pre_pruned, suspect_1)
2predict(model_pre_pruned, suspect_2)
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 1, 'Weapon_Rolling Pin': 0}
Suspect: Lucy
Features: {'Height': 165, 'Gender_F': 1, 'Gender_M': 0, 'Weapon_Bottle': 0, 'Weapon_Cricket Bat': 1, 'Weapon_Rolling Pin': 0}
Suspect: Lucy
Conclusion
Decision trees are an effective model to data sets in which labels (like the suspects in our example) can be ultimately selected by splitting features into decision branches.
For example, if the suspect’s height were known to be, say, 151cm, the murdered could only be Ximeno (class = 28), and the gender and murder weapon wouldn’t even matter, as far as the model is concerned:
1|--- Height <= 155.00
2| |--- class: 28
3|--- Height > 155.00
4| |--- Weapon_Rolling Pin <= 0.50
5...
Decision trees that require categorical values, such as murder weapons, require said categories to be translated to numerical values, for which a ‘get dummies’ process is applied as shown in this tutorial.
While our example is one in which the decision tree memorises the data set, you must bear in mind that their typical application is to draw inferences upon large data sets, which require pre-pruning the decision tree in advance. In this case, the decision tree would feel more like a logistic regression model, rather than our ‘binary choice’ one.