[DecisionTree] Building a decision tree from scratch - a beginner tutorial

Building a decision tree from scratch - a beginner tutorial

by Patrick L. Lê

Suppose you have a population. You want to divide this population into relevant subgroups based on specific features characterizing each subgroup, so that you can accurately predict outcomes associated with each subgroup. For instance, you could :

  • Use the list of the people on the Titanic, and by dividing them into subgroups depending on specific criteria (e.g. female vs male, passengers in 1st class vs 2nd and 3rd class, age class....) determines if they were (probably) going to survive or not.
  • Look at the people who bought product on your e-commerce website, divide this population into segments depending on specific features (e.g. returning visitors vs new visitors, localization, ...) and determines for future visitors if they are (probably) going to buy your product or not.

In sum, you want to create a model that predicts the value of a target variable (e.g. survive/die; buy/not buy) based on simple decision rules inferred from the data features (e.g. female vs male, age, etc.).

The result is a decision tree that offers the great advantage to be easily vizualized and simple to understand. For instance, the picture below, fromwikipedia, shows the probability of passengers of the Titanic to survive depending on their sex, age and number of spouses or siblings aboard. Note how each branching is based on answering a question (the decision rule) and how the graph looks like an inverted tree.Picture of tree

In this tutorial we will build a decision tree from scratch in Python and use it to classify future observations.

0. The battle plan

This tutorial is based on chapter 7 of Programming Collective Intelligence. The original code (for Python 2) is available on GitHub here. The code you found here is adapted to Python 3 and include comments I've added. This tutorial is self-contained and you don't need the book to understand it (though I would highly recommend you to buy it). I also only assume a basic understanding of Python and you have here all the necessary code to reproduce the results.

For more explanations on decision trees I recommend looking at chapter 3 of Data Science for Business or watch this excellent video by Prof. De Freitas.

Suppose we have a list of visitors. Our target attribute is if they bought a subscription to our service (possible values are None, Basic or Premium). To predict their behavior in a transparent way, we will use decision trees.

We collected data on 16 visitors. The data is represented like this :

Referer Country Read FAQ # of webpages visited Subscription (TARGET Attribute)
Slashdot USA yes 18 None
Google France yes 23 Premium
... ... ... ... ...

Note that Python consider the 1st column, Referer, as the element 0 of the list, Country as the element 1 of the list, etc.

In [1]:

my_data=[['slashdot','USA','yes',18,'None'],
        ['google','France','yes',23,'Premium'],
        ['digg','USA','yes',24,'Basic'],
        ['kiwitobes','France','yes',23,'Basic'],
        ['google','UK','no',21,'Premium'],
        ['(direct)','New Zealand','no',12,'None'],
        ['(direct)','UK','no',21,'Basic'],
        ['google','USA','no',24,'Premium'],
        ['slashdot','France','yes',19,'None'],
        ['digg','USA','no',18,'None'],
        ['google','UK','no',18,'None'],
        ['kiwitobes','UK','no',19,'None'],
        ['digg','New Zealand','yes',12,'Basic'],
        ['slashdot','UK','no',21,'None'],
        ['google','UK','yes',18,'Basic'],
        ['kiwitobes','France','yes',19,'Basic']]

In a nutshell : using our current data, we want to build a predictive model that will take the form of a tree, as shown below. This decision tree will help us classify future observations. For instance, according to this tree, if a new observation/visitor has Google as Referrer (1st decision node called 0:google), has read more than 21 pages (2nd decision node, on the right called 3:21), it will probably buy a Premium subscription - as already 3 previous observations have done (leaf Premium:3).

NB : The answer to each question is "False" on the left branch, and "True" on the right branch. The first number refers to the number of the column (starting with column 0 = Referer as Python start to count with zero) that is concerned with the question.

Note that not all features were used to classify observations (e.g. country is not used) and some features are used multiples times (e.g. referer). Indeed, we will see that our algorithm will pick up the best decision rules to split groups.Picture of tree

In this tutorial we will :

  1. Learn how to split the dataset into children sets i.e. from a list of all customers, return 2 subgroups based on a criterion we give (e.g. criterion "did they read the FAQ ?")
  2. Learn what entropy is as it gives us a criterion to decide where to split. Basically, the idea is to split into subgroups (called child sets) that are homogenous with regard to te target attribute (e.g. a set regrouping only visitors having chosen Basic scubscription)
  3. Build a tree recursively. We cut, then cut the subgroups, and cut the sub-subgroups, etc. This is what the graph shows !
  4. Represent the decision trees graphically. Look at our beautiful picture :)
  5. Use the built tree to classify new observations

1. Dividing the set 

We now write a function to divide a set in 2 children sets. We will then try several divisions while keeping in mind that our goal is to have homogeneous groups with regard to the target attribute (e.g. a group of customer buying None, another buying Basic, etc.).

In [2]:

# Divides a set on a specific column. Can handle numeric or nominal values
def divideset(rows,column,value):
   # Make a function that tells us if a row is in the first group (true) or the second group (false)
   split_function=None
   if isinstance(value,int) or isinstance(value,float): # check if the value is a number i.e int or float
      split_function=lambda row:row[column]>=value
   else:
      split_function=lambda row:row[column]==value
   
   # Divide the rows into two sets and return them
   set1=[row for row in rows if split_function(row)]
   set2=[row for row in rows if not split_function(row)]
   return (set1,set2)
  • divideset() divides the set rows into 2 children sets, based on the criteria of the column number and the value that it takes. E.g. divideset(my_data,2,'yes') divides the set into 2 subsets based on the content of their column number 2 (i.e. the 3rd column as Python starts with 0) Read FAQ, depending if their value is yes or no.
  • If it the data is numeric, the True criterion is that the value in this column is greater than the given value. If the data is not numeric, split_function simply determines whether the column’s value is the same as value.
  • The data is divided into two sets, one where split_function returns True (set1) and one where it returns False (set2).
In [3]:
divideset(my_data,2,'yes')
Out[3]:
([['slashdot', 'USA', 'yes', 18, 'None'],
  ['google', 'France', 'yes', 23, 'Premium'],
  ['digg', 'USA', 'yes', 24, 'Basic'],
  ['kiwitobes', 'France', 'yes', 23, 'Basic'],
  ['slashdot', 'France', 'yes', 19, 'None'],
  ['digg', 'New Zealand', 'yes', 12, 'Basic'],
  ['google', 'UK', 'yes', 18, 'Basic'],
  ['kiwitobes', 'France', 'yes', 19, 'Basic']],
 [['google', 'UK', 'no', 21, 'Premium'],
  ['(direct)', 'New Zealand', 'no', 12, 'None'],
  ['(direct)', 'UK', 'no', 21, 'Basic'],
  ['google', 'USA', 'no', 24, 'Premium'],
  ['digg', 'USA', 'no', 18, 'None'],
  ['google', 'UK', 'no', 18, 'None'],
  ['kiwitobes', 'UK', 'no', 19, 'None'],
  ['slashdot', 'UK', 'no', 21, 'None']])
  • The function returns the inital dataset divided into 2 groups, depending of their attribute Read the FAQ (yes or no). Note that there are 2 lists, one for each child set.
  • We observe that the group of people having read the FAQ does not distinguish itself much from the group who has not with regard to the target attribute. In both subgroups we have a mix of NoneBasic and Premium.
In [4]:
divideset(my_data,3,20)
Out[4]:
([['google', 'France', 'yes', 23, 'Premium'],
  ['digg', 'USA', 'yes', 24, 'Basic'],
  ['kiwitobes', 'France', 'yes', 23, 'Basic'],
  ['google', 'UK', 'no', 21, 'Premium'],
  ['(direct)', 'UK', 'no', 21, 'Basic'],
  ['google', 'USA', 'no', 24, 'Premium'],
  ['slashdot', 'UK', 'no', 21, 'None']],
 [['slashdot', 'USA', 'yes', 18, 'None'],
  ['(direct)', 'New Zealand', 'no', 12, 'None'],
  ['slashdot', 'France', 'yes', 19, 'None'],
  ['digg', 'USA', 'no', 18, 'None'],
  ['google', 'UK', 'no', 18, 'None'],
  ['kiwitobes', 'UK', 'no', 19, 'None'],
  ['digg', 'New Zealand', 'yes', 12, 'Basic'],
  ['google', 'UK', 'yes', 18, 'Basic'],
  ['kiwitobes', 'France', 'yes', 19, 'Basic']])
  • Here the division is # of page visited >20 and <20.
  • The division seems slightly better, but it's still difficult to draw conclusions. Which attribute should we chose? And based on what values (why 20 and not 21?)?
  • It would be nice to have an objective criterion to split a group. For that we introduce the concept of entropy.

2. Introducing Entropy 

If you look at our decision tree, you will notice that in each of the leaf nodes, there is only one category of subscription. This is quite logical, as otherwise we would not be able to make good predictions !
Thus, our goal is to maximize the homogeneity/purity of each childset for each split with regard to the target attribute. That would enable us to classify well future observations.

In [5]:
# Create counts of possible results (the last column of each row is the result)
def uniquecounts(rows):
   results={}
   for row in rows:
      # The result is the last column
      r=row[len(row)-1]
      if r not in results: results[r]=0
      results[r]+=1
   return results

For a given set, uniquecounts() gives the count of units depending on their values in target attribute.

In [6]:
print(uniquecounts(my_data))
{'Premium': 3, 'Basic': 6, 'None': 7}

For instance, let's look at how 2 children sets split based on their number of page visited differ with regard to their subscription:

In [7]:
print(divideset(my_data,3,20)[0])
print(uniquecounts(divideset(my_data,3,20)[0]))
print("")
print(divideset(my_data,3,20)[1])
print(uniquecounts(divideset(my_data,3,20)[1]))
[['google', 'France', 'yes', 23, 'Premium'], ['digg', 'USA', 'yes', 24, 'Basic'], ['kiwitobes', 'France', 'yes', 23, 'Basic'], ['google', 'UK', 'no', 21, 'Premium'], ['(direct)', 'UK', 'no', 21, 'Basic'], ['google', 'USA', 'no', 24, 'Premium'], ['slashdot', 'UK', 'no', 21, 'None']]
{'Premium': 3, 'Basic': 3, 'None': 1}

[['slashdot', 'USA', 'yes', 18, 'None'], ['(direct)', 'New Zealand', 'no', 12, 'None'], ['slashdot', 'France', 'yes', 19, 'None'], ['digg', 'USA', 'no', 18, 'None'], ['google', 'UK', 'no', 18, 'None'], ['kiwitobes', 'UK', 'no', 19, 'None'], ['digg', 'New Zealand', 'yes', 12, 'Basic'], ['google', 'UK', 'yes', 18, 'Basic'], ['kiwitobes', 'France', 'yes', 19, 'Basic']]
{'Basic': 3, 'None': 6}

We observe that the 1st set is less homogeneous than the 2nd set. Still, we need a metrics to measure it. For that we use entropy.
Entropy is basically the contrary of purity/homogenity within a set, with regard to the target attribute. As the graph below shows, the more mixed up the sets are, the higher their entropy. If there are 2 classes, entropy is maximum when a set contains 50% of each class and is null if the set is pure. Our goal is to reduce the entropy of the children sets when we split a set in comparison to the entropy in the parent set.Picture of treeIllustration from Data Science for Business

We won't delve into the mathematics of entropy (the formula is below in the code if necessary), just keep in mind the graph above.

In [8]:
# Entropy is the sum of p(x)log(p(x)) across all 
# the different possible results
def entropy(rows):
   from math import log
   log2=lambda x:log(x)/log(2)  
   results=uniquecounts(rows)
   # Now calculate the entropy
   ent=0.0
   for r in results.keys():
      p=float(results[r])/len(rows)
      ent=ent-p*log2(p)
   return ent

The computation of entropy confirmed what we have seen above: the 1st set is less homogeneous than the 2nd set:

In [9]:
set1,set2=divideset(my_data,3,20)
entropy(set1), entropy(set2)
Out[9]:
(1.4488156357251847, 0.9182958340544896)

Still, we made some progress as the whole group was even less homogeneous :

In [10]:
entropy(my_data)
Out[10]:
1.5052408149441479

3. Building the tree recursively 

  • To see how good an attribute is, the algorithm first calculates the entropy of the whole group.
  • Then it tries dividing up the group by the possible values of each attribute and calculates the entropy of the two new groups. To determine which attribute is the best to divide on, the information gain is calculated. Information gain is the difference between the current entropy and the weighted-average entropy of the two new groups. Intuitively, IG basically represents the extent to which you reduced entropy / obtained more homogeneous groups with your split in comparison with the group you had before the split. Comparing the IG for various splits based on different decision rules enable us to chose the "best" split.
  • The algorithm calculates the information gain for every attribute and chooses the one with the highest information gain.
  • We do it again and again
In [11]:
class decisionnode:
  def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
    self.col=col
    self.value=value
    self.results=results
    self.tb=tb
    self.fb=fb
  • col is the column index of the criteria to be tested (e.g. Country is col 1)
  • value is the value that the column must match to get a true result. (e.g. if Country=value='France' then True)
  • tb and fb are decisionnodes, which are the next nodes in the tree if the result is true or false, respectively (e.g. go to node tb or fb).
  • results stores a dictionary of results for this branch. This is None for decision nodes and only contains the target attribute and the number of units for endpoints. (e.g. Basic:3) Look at the recursive part of buildtree() and printtree() below to better understand this class.
In [12]:
def buildtree(rows,scoref=entropy): #rows is the set, either whole dataset or part of it in the recursive call, 
                                    #scoref is the method to measure heterogeneity. By default it's entropy.
  if len(rows)==0: return decisionnode() #len(rows) is the number of units in a set
  current_score=scoref(rows)

  # Set up some variables to track the best criteria
  best_gain=0.0
  best_criteria=None
  best_sets=None
  
  column_count=len(rows[0])-1   #count the # of attributes/columns. 
                                #It's -1 because the last one is the target attribute and it does not count.
  for col in range(0,column_count):
    # Generate the list of all possible different values in the considered column
    global column_values        #Added for debugging
    column_values={}            
    for row in rows:
       column_values[row[col]]=1   
    # Now try dividing the rows up for each value in this column
    for value in column_values.keys(): #the 'values' here are the keys of the dictionnary
      (set1,set2)=divideset(rows,col,value) #define set1 and set2 as the 2 children set of a division
      
      # Information gain
      p=float(len(set1))/len(rows) #p is the size of a child set relative to its parent
      gain=current_score-p*scoref(set1)-(1-p)*scoref(set2) #cf. formula information gain
      if gain>best_gain and len(set1)>0 and len(set2)>0: #set must not be empty
        best_gain=gain
        best_criteria=(col,value)
        best_sets=(set1,set2)
        
  # Create the sub branches   
  if best_gain>0:
    trueBranch=buildtree(best_sets[0])
    falseBranch=buildtree(best_sets[1])
    return decisionnode(col=best_criteria[0],value=best_criteria[1],
                        tb=trueBranch,fb=falseBranch)
  else:
    return decisionnode(results=uniquecounts(rows))

Let's just look at the non recursive part :

  • rows refer to the considered set. For instance the whole dataset my_data, or later subgroups.
  • for col in range(0,column_count) : we loop through each column of attribute, except the target attribute

    • for row in rows : row is a row of the table ['slashdot', 'USA', 'yes', 18, 'None']. So the for loop iterates through each row. In other words, row=my_data[0], then row=my_data[1].
    • column_values[row[col]]=1 : It's basically used to create an entry in the dictionnary. The values do not matters, only the keys. Basically , we are in one column and iterate through all rows i.e we loop through all the attributes cells of a column , and we give to each cell content the value 1. This value does not matter. E.g. slashdot=1, Google=1.... When the entry already exists, it simply replace it with same value. Thus we will have several times slashdot=1All this procedure is only to create a dictionnay with keys that are all possible values of attribute within a colum
    • At the end of an iteration of the for loop on col, we have a dictionnary of the form {1st possible value in that column:1, 2nd possible value in tha column, etc.}. E.g, {slashdot: 1, Google: 1, (direct): 1, etc} for the 1st column/attribute
    • for value in column_values.keys(): it iterates through each keys and use it to split (e.g. Google/Non-Google).
    • The formula to compute information gain is :

      IG(parent, children) = entropyparent - (entropychild1 x propchild1 + entropychild2 x propchild2))
      where propchild i is the proportion of instances falling into the child i.

To sum up: the algorithm looks at one column. It list all possible values in that column. It then attempt to split and compare the information gain to the best split until now. Then it go to the next column and do the same. At the end we keep the best split, i.e. the one who gave us the highest IG i.e. who reduced the most entropy i.e. who formed two homogeneous groups. The algorithm is then recursively applied to build the different branches.

To learn more about entropy and information gain, you can have a look at this Stackoverflow post.

In [13]:
tree=buildtree(my_data)

4. Displaying the tree 

Before we jump in, let's have a look at how the tree has been built, exploring some branchs. Note that results is None for decision nodes as they do not contain units, only the end leafs have values. The following output must be understood as such :

col : What is the column concerned. E.g. 0
value : What value is used for the split. E.g. google . In the graph, both col and value are displayed above each decision node.
results : None or {value : #} : type of target attribute and number of units classified in this leaf node. E.g. {'Premium': 3} This is what is displayed on each leaf node.

In [14]:
print(tree.col)
print(tree.value)
print(tree.results)
print("")
print(tree.tb.col)
print(tree.tb.value)
print(tree.tb.results)
print("")
print(tree.tb.tb.col)
print(tree.tb.tb.value)
print(tree.tb.tb.results)
print("")
print(tree.tb.fb.col)
print(tree.tb.fb.value)
print(tree.tb.fb.results)
0
google
None

3
21
None

-1
None
{'Premium': 3}

2
yes
None

We now introduce 2 visualizations, one with text, printtree(), then one with graphics, drawtree().

In [15]:
def printtree(tree,indent=''):
   # Is this a leaf node?
    if tree.results!=None:
        print(str(tree.results))
    else:
        print(str(tree.col)+':'+str(tree.value)+'? ')
        # Print the branches
        print(indent+'T->', end=" ")
        printtree(tree.tb,indent+'  ')
        print(indent+'F->', end=" ")
        printtree(tree.fb,indent+'  ')
In [16]:
printtree(tree)
0:google? 
T-> 3:21? 
  T-> {'Premium': 3}
  F-> 2:yes? 
    T-> {'Basic': 1}
    F-> {'None': 1}
F-> 0:slashdot? 
  T-> {'None': 3}
  F-> 2:yes? 
    T-> {'Basic': 4}
    F-> 3:21? 
      T-> {'Basic': 1}
      F-> {'None': 3}

PS : In Python 3, the trailing comma that suppresses newline in Python 2 has been replaced by end=" ". See explanations in the docs here.

While this first representation is OK for practical purposes, let's do a graph :)

In [17]:
def getwidth(tree):
  if tree.tb==None and tree.fb==None: return 1
  return getwidth(tree.tb)+getwidth(tree.fb)

def getdepth(tree):
  if tree.tb==None and tree.fb==None: return 0
  return max(getdepth(tree.tb),getdepth(tree.fb))+1


from PIL import Image,ImageDraw

def drawtree(tree,jpeg='tree.jpg'):
  w=getwidth(tree)*100
  h=getdepth(tree)*100+120

  img=Image.new('RGB',(w,h),(255,255,255))
  draw=ImageDraw.Draw(img)

  drawnode(draw,tree,w/2,20)
  img.save(jpeg,'JPEG')
  
def drawnode(draw,tree,x,y):
  if tree.results==None:
    # Get the width of each branch
    w1=getwidth(tree.fb)*100
    w2=getwidth(tree.tb)*100

    # Determine the total space required by this node
    left=x-(w1+w2)/2
    right=x+(w1+w2)/2

    # Draw the condition string
    draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))

    # Draw links to the branches
    draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
    draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
    
    # Draw the branch nodes
    drawnode(draw,tree.fb,left+w1/2,y+100)
    drawnode(draw,tree.tb,right-w2/2,y+100)
  else:
    txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
    draw.text((x-20,y),txt,(0,0,0))
In [18]:
drawtree(tree,jpeg='treeview.jpg')

Picture of treeHow to read the tree :

  • Above each decision node, there is the concerned column as well as the criterion generating True. E.g 0:Google means "Is the referrer Google ?"; 3:21 means "Is the # of pages visited > 21 ?"
  • In the generated tree diagrams, the "True branch" (i.e. the branch leading to the child set fulfilling the condition) is always the right hand branch, so you can follow the reasoning process through. Note that here, True can mean answering "no" to a question as in 2:no.
  • The leaf nodes display the target attribute with the number of units in. E.g. Basic:3 means that among the original 16 observations, 3 from them landed in this leaf.

5. Classifying new observations 

Now that we have built our tree, we can feed new observations and classify them. The following code basically do what we could do manually by using the tree and answering the questions.

In [19]:
def classify(observation,tree):
  if tree.results!=None:
    return tree.results
  else:
    v=observation[tree.col]
    branch=None
    if isinstance(v,int) or isinstance(v,float):
      if v>=tree.value: branch=tree.tb
      else: branch=tree.fb
    else:
      if v==tree.value: branch=tree.tb
      else: branch=tree.fb
    return classify(observation,branch)
In [20]:
classify(['(direct)','USA','yes',5],tree)
Out[20]:
{'Basic': 4}

This means that the new unit will take a Basic subscription and join the node called "Basic:4". The full path being :

  • Did it use Google as Referer ? False (going left branch)
  • Did it use Slashdot as Referer ? False (going left branch)
  • Was the anwer to reading the FAQ "No" ? False (going left => subscribing Basic)
In [21]:
classify(['(direct)','USA','no',23],tree)
Out[21]:
{'Basic': 1}

This means that the new unit will take a Basic subscription and join the node called "Basic:1". The full path being :

  • Did it use Google as Referer ? False (going left branch)
  • Did it use Slashdot as Referer ? False (going left branch)
  • Was the anwer to reading the FAQ "No" ? True (going right)
  • Did it read more than 21 pages ? True (=> subscribing Basic)

6. Conclusion

We only scratched the surface of how to build decision trees. Other topics to explore include how to prune the tree (to avoid overfitting) and dealing with missing datas.

What you should keep in mind are the advantages of decision trees :

  • You get a transparent model where you actually understand the decision rules and that can be easily interpreted
  • You can easily mix categorical and numerical data
  • You don't need much data preparation (no strong assumptions on how the data should look like)

An important weakness you can derive from this tutorial is that decision rules could have been much different if the data set have been slightly different (experience yourself !). In other words, a decision tree is prone to overfit and may not generalize well - it can be quite unstable. This issue can be mitigated through the use of multiples decision trees combined together in a technique called random forest.

I hope you enjoyed this tutorial !

Patrick

Comments