## Building a decision tree from scratch - a beginner tutorial## by Patrick L. LêSuppose you have a population. You want to divide this population into relevant subgroups based on specific features characterizing each subgroup, so that you can accurately predict outcomes associated with each subgroup. For instance, you could : - Use the list of the people on the Titanic, and by dividing them into subgroups depending on specific criteria (e.g. female vs male, passengers in 1st class vs 2nd and 3rd class, age class....) determines if they were (probably) going to survive or not.
- Look at the people who bought product on your e-commerce website, divide this population into segments depending on specific features (e.g. returning visitors vs new visitors, localization, ...) and determines for future visitors if they are (probably) going to buy your product or not.
In sum, you want to create a model that predicts the value of a target variable (e.g. survive/die; buy/not buy) based on simple decision rules inferred from the data features (e.g. female vs male, age, etc.). The result is a decision tree that offers the great advantage to be easily vizualized and simple to understand. For instance, the picture below, fromwikipedia, shows the probability of passengers of the Titanic to survive depending on their sex, age and number of spouses or siblings aboard. Note how each branching is based on answering a question (the decision rule) and how the graph looks like an inverted tree. In this tutorial we will build a decision tree from scratch in Python and use it to classify future observations. ## 0. The battle planThis tutorial is based on chapter 7 of Programming Collective Intelligence. The original code (for Python 2) is available on GitHub here. The code you found here is adapted to Python 3 and include comments I've added. This tutorial is self-contained and you don't need the book to understand it (though I would highly recommend you to buy it). I also only assume a basic understanding of Python and you have here all the necessary code to reproduce the results. For more explanations on decision trees I recommend looking at chapter 3 of Data Science for Business or watch this excellent video by Prof. De Freitas. Suppose we have a list of visitors. Our target attribute is if they bought a subscription to our service (possible values are We collected data on 16 visitors. The data is represented like this :
Note that Python consider the 1st column, Referer, as the element 0 of the list, Country as the element 1 of the list, etc. In [1]: In a nutshell : using our current data, we want to build a predictive model that will take the form of a tree, as shown below. This decision tree will help us classify future observations. For instance, according to this tree, if a new observation/visitor has Google as Referrer (1st decision node called 0:google), has read more than 21 pages (2nd decision node, on the right called 3:21), it will probably buy a Premium subscription - as already 3 previous observations have done (leaf Premium:3). NB : The answer to each question is "False" on the left branch, and "True" on the right branch. The first number refers to the number of the column (starting with column 0 = Referer as Python start to count with zero) that is concerned with the question. Note that not all features were used to classify observations (e.g. country is not used) and some features are used multiples times (e.g. referer). Indeed, we will see that our algorithm will pick up the best decision rules to split groups. In this tutorial we will : - Learn how to split the dataset into children sets i.e. from a list of all customers, return 2 subgroups based on a criterion we give (e.g. criterion "did they read the FAQ ?")
- Learn what entropy is as it gives us a criterion to decide where to split. Basically, the idea is to split into subgroups (called child sets) that are homogenous with regard to te target attribute (e.g. a set regrouping only visitors having chosen Basic scubscription)
- Build a tree recursively. We cut, then cut the subgroups, and cut the sub-subgroups, etc. This is what the graph shows !
- Represent the decision trees graphically. Look at our beautiful picture :)
- Use the built tree to classify new observations
`divideset()` divides the set`rows` into 2 children sets, based on the criteria of the`column` number and the`value` that it takes. E.g.`divideset(my_data,2,'yes')` divides the set into 2 subsets based on the content of their`column` number 2 (i.e. the 3rd column as Python starts with 0) Read FAQ, depending if their`value` is`yes` or`no` .- If it the data is numeric, the True criterion is that the value in this column is greater than the given
`value` . If the data is not numeric,`split_function` simply determines whether the column’s value is the same as value. - The data is divided into two sets, one where
`split_function` returns True (`set1` ) and one where it returns False (`set2` ).
In [3]:
divideset(my_data,2,'yes') Out[3]:
([['slashdot', 'USA', 'yes', 18, 'None'], ['google', 'France', 'yes', 23, 'Premium'], ['digg', 'USA', 'yes', 24, 'Basic'], ['kiwitobes', 'France', 'yes', 23, 'Basic'], ['slashdot', 'France', 'yes', 19, 'None'], ['digg', 'New Zealand', 'yes', 12, 'Basic'], ['google', 'UK', 'yes', 18, 'Basic'], ['kiwitobes', 'France', 'yes', 19, 'Basic']], [['google', 'UK', 'no', 21, 'Premium'], ['(direct)', 'New Zealand', 'no', 12, 'None'], ['(direct)', 'UK', 'no', 21, 'Basic'], ['google', 'USA', 'no', 24, 'Premium'], ['digg', 'USA', 'no', 18, 'None'], ['google', 'UK', 'no', 18, 'None'], ['kiwitobes', 'UK', 'no', 19, 'None'], ['slashdot', 'UK', 'no', 21, 'None']]) - The function returns the inital dataset divided into 2 groups, depending of their attribute Read the FAQ (
`yes` or`no` ). Note that there are 2 lists, one for each child set. - We observe that the group of people having read the FAQ does not distinguish itself much from the group who has not with regard to the target attribute. In both subgroups we have a mix of
`None` ,`Basic` and`Premium` .
In [4]:
divideset(my_data,3,20) Out[4]:
([['google', 'France', 'yes', 23, 'Premium'], ['digg', 'USA', 'yes', 24, 'Basic'], ['kiwitobes', 'France', 'yes', 23, 'Basic'], ['google', 'UK', 'no', 21, 'Premium'], ['(direct)', 'UK', 'no', 21, 'Basic'], ['google', 'USA', 'no', 24, 'Premium'], ['slashdot', 'UK', 'no', 21, 'None']], [['slashdot', 'USA', 'yes', 18, 'None'], ['(direct)', 'New Zealand', 'no', 12, 'None'], ['slashdot', 'France', 'yes', 19, 'None'], ['digg', 'USA', 'no', 18, 'None'], ['google', 'UK', 'no', 18, 'None'], ['kiwitobes', 'UK', 'no', 19, 'None'], ['digg', 'New Zealand', 'yes', 12, 'Basic'], ['google', 'UK', 'yes', 18, 'Basic'], ['kiwitobes', 'France', 'yes', 19, 'Basic']]) - Here the division is # of page visited >20 and <20.
- The division seems slightly better, but it's still difficult to draw conclusions. Which attribute should we chose? And based on what values (why 20 and not 21?)?
- It would be nice to have an objective criterion to split a group. For that we introduce the concept of entropy.
## 2. Introducing EntropyIf you look at our decision tree, you will notice that in each of the leaf nodes, there is only one category of subscription. This is quite logical, as otherwise we would not be able to make good predictions ! In [5]:
For a given set, In [6]:
print(uniquecounts(my_data)) {'Premium': 3, 'Basic': 6, 'None': 7} For instance, let's look at how 2 children sets split based on their number of page visited differ with regard to their subscription: In [7]:
We observe that the 1st set is less homogeneous than the 2nd set. Still, we need a metrics to measure it. For that we use entropy. We won't delve into the mathematics of entropy (the formula is below in the code if necessary), just keep in mind the graph above. In [8]:
The computation of entropy confirmed what we have seen above: the 1st set is less homogeneous than the 2nd set: In [9]:
set1,set2=divideset(my_data,3,20) entropy(set1), entropy(set2) Out[9]:
(1.4488156357251847, 0.9182958340544896) Still, we made some progress as the whole group was even less homogeneous : In [10]:
entropy(my_data) Out[10]:
1.5052408149441479 ## 3. Building the tree recursively- To see how good an attribute is, the algorithm first calculates the entropy of the whole group.
- Then it tries dividing up the group by the possible values of each attribute and calculates the entropy of the two new groups. To determine which attribute is the best to divide on, the information gain is calculated. Information gain is the difference between the current entropy and the weighted-average entropy of the two new groups. Intuitively, IG basically represents the extent to which you reduced entropy / obtained more homogeneous groups with your split in comparison with the group you had before the split. Comparing the IG for various splits based on different decision rules enable us to chose the "best" split.
- The algorithm calculates the information gain for every attribute and chooses the one with the highest information gain.
- We do it again and again
In [11]:
`col` is the column index of the criteria to be tested (e.g. Country is col 1)`value` is the value that the column must match to get a true result. (e.g. if Country=value='France' then True)`tb` and`fb` are decisionnodes, which are the next nodes in the tree if the result is true or false, respectively (e.g. go to node tb or fb).`results` stores a dictionary of results for this branch. This is None for decision nodes and only contains the target attribute and the number of units for endpoints. (e.g. Basic:3) Look at the recursive part of`buildtree()` and`printtree()` below to better understand this class.
In [12]:
Let's just look at the non recursive part : `rows` refer to the considered set. For instance the whole dataset`my_data` , or later subgroups.-
`for col in range(0,column_count)` : we loop through each column of attribute, except the target attribute`for row in rows` :`row` is a row of the table`['slashdot', 'USA', 'yes', 18, 'None']` . So the for loop iterates through each row. In other words,`row=my_data[0]` , then`row=my_data[1]` .`column_values[row[col]]=1` : It's basically used to create an entry in the dictionnary. The values do not matters, only the keys. Basically , we are in one column and iterate through all rows i.e we loop through all the attributes cells of a column , and we give to each cell content the value 1. This value does not matter. E.g.`slashdot=1, Google=1...` . When the entry already exists, it simply replace it with same value. Thus we will have several times`slashdot=1` . All this procedure is only to create a dictionnay with keys that are all possible values of attribute within a colum- At the end of an iteration of the for loop on
`col` , we have a dictionnary of the form`{1st possible value in that column:1, 2nd possible value in tha column, etc.}` . E.g,`{slashdot: 1, Google: 1, (direct): 1, etc}` for the 1st column/attribute `for value in column_values.keys()` : it iterates through each keys and use it to split (e.g. Google/Non-Google).-
The formula to compute information gain is : IG(parent, children) = entropyparent - (entropychild1 x propchild1 + entropychild2 x propchild2)) where propchild i is the proportion of instances falling into the child i.
To sum up: the algorithm looks at one column. It list all possible values in that column. It then attempt to split and compare the information gain to the best split until now. Then it go to the next column and do the same. At the end we keep the best split, i.e. the one who gave us the highest IG i.e. who reduced the most entropy i.e. who formed two homogeneous groups. The algorithm is then recursively applied to build the different branches. To learn more about entropy and information gain, you can have a look at this Stackoverflow post. In [13]:
tree=buildtree(my_data) Before we jump in, let's have a look at how the tree has been built, exploring some branchs. Note that results is
In [14]:
0 google None 3 21 None -1 None {'Premium': 3} 2 yes None We now introduce 2 visualizations, one with text, In [15]:
In [16]:
printtree(tree) PS : In Python 3, the trailing comma that suppresses newline in Python 2 has been replaced by While this first representation is OK for practical purposes, let's do a graph :) In [17]:
In [18]:
drawtree(tree,jpeg='treeview.jpg') How to read the tree : - Above each decision node, there is the concerned column as well as the criterion generating True. E.g 0:Google means "Is the referrer Google ?"; 3:21 means "Is the # of pages visited > 21 ?"
- In the generated tree diagrams, the "True branch" (i.e. the branch leading to the child set fulfilling the condition) is always the right hand branch, so you can follow the reasoning process through. Note that here, True can mean answering "no" to a question as in 2:no.
- The leaf nodes display the target attribute with the number of units in. E.g. Basic:3 means that among the original 16 observations, 3 from them landed in this leaf.
In [19]:
In [20]:
classify(['(direct)','USA','yes',5],tree) Out[20]:
{'Basic': 4} This means that the new unit will take a - Did it use Google as Referer ? False (going left branch)
- Did it use Slashdot as Referer ? False (going left branch)
- Was the anwer to reading the FAQ "No" ? False (going left => subscribing
`Basic` )
In [21]:
classify(['(direct)','USA','no',23],tree) Out[21]:
{'Basic': 1} This means that the new unit will take a - Did it use Google as Referer ? False (going left branch)
- Did it use Slashdot as Referer ? False (going left branch)
- Was the anwer to reading the FAQ "No" ? True (going right)
- Did it read more than 21 pages ? True (=> subscribing
`Basic` )
## 6. ConclusionWe only scratched the surface of how to build decision trees. Other topics to explore include how to prune the tree (to avoid overfitting) and dealing with missing datas. What you should keep in mind are the advantages of decision trees : - You get a transparent model where you actually understand the decision rules and that can be easily interpreted
- You can easily mix categorical and numerical data
- You don't need much data preparation (no strong assumptions on how the data should look like)
An important weakness you can derive from this tutorial is that decision rules could have been much different if the data set have been slightly different (experience yourself !). In other words, a decision tree is prone to overfit and may not generalize well - it can be quite unstable. This issue can be mitigated through the use of multiples decision trees combined together in a technique called random forest. I hope you enjoyed this tutorial ! Patrick |