Decision Trees in R using rpart

The rpart package in R provides a powerful framework for growing classification and regression trees. To see how it works, let’s get started with a minimal example.

First let’s define a problem. There’s a common scam amongst motorists where a person will slam on his breaks in heavy traffic with the intention of being rear-ended. The person will then file an insurance claim for personal injury and damage to his vehicle, alleging that the other driver was at fault. Suppose we want to predict which of an insurance company’s claims are fraudulent using a decision tree.

To start, we need to build a training set of known fraudulent claims.

train <- data.frame(ClaimID = c(1,2,3),
                    RearEnd = c(TRUE, FALSE, TRUE),
                    Fraud = c(TRUE, FALSE, TRUE))
train
  ClaimID RearEnd Fraud
1       1    TRUE  TRUE
2       2   FALSE FALSE
3       3    TRUE  TRUE

In order to grow our decision tree, we have to first load the rpart package. Then we can use the function rpart(), specifying the model formula, data, and method parameters. In this case, we want to classify the feature Fraud using the predictor RearEnd, so our call to rpart should look like

library(rpart) #load the rpart package

mytree <- rpart(Fraud ~ RearEnd, data = train, method = "class")
mytree
n= 3 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 3 1 TRUE (0.3333333 0.6666667) *

Notice the output shows only a root node. This is because rpart has some default parameters that prevented our tree from growing. Namely: minsplit and minbucket. minsplit is “the minimum number of observations that must exist in a node in order for a split to be attempted” and minbucket is “the minimum number of observations in any terminal node”. See what happens when we override these parameters.

mytree <- rpart(Fraud ~ RearEnd, data = train, method = "class", minsplit = 2, minbucket = 1)
mytree
n= 3 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 3 1 TRUE (0.3333333 0.6666667)  
  2) RearEnd< 0.5 1 0 FALSE (1.0000000 0.0000000) *
  3) RearEnd>=0.5 2 0 TRUE (0.0000000 1.0000000) *

Now our tree has a root node, one split and two leaves (terminal nodes). Observe that rpart encoded our boolean variable as an integer (False = 0, True = 1). We can plot mytree by loading the rattle package (and some helper packages) and using the fancyRpartPlot() function.

library(rattle)
library(rpart.plot)
library(RColorBrewer)

fancyRpartPlot(mytree)

rpart_plot1

The decision tree correctly identified that if a claim involved a rear-end collision, the claim was most likely fraudulent.

By default, rpart uses gini impurty to select splits when performing classification. (If you’re unfamiliar with this metric, read this article.) You can use information gain instead by specifying it in the parms parameter.

mytree <- rpart(Fraud ~ RearEnd, data = train, method = "class", 
                parms = list(split = 'information'), minsplit = 2, minbucket = 1)

Now suppose our training set looked like

train <- data.frame(ClaimID = c(1,2,3),
                    RearEnd = c(TRUE, FALSE, TRUE),
                    Fraud = c(TRUE, FALSE, FALSE))
train
  ClaimID RearEnd Fraud
1       1    TRUE  TRUE
2       2   FALSE FALSE
3       3    TRUE FALSE

If we try to build a decision tree on this data…

mytree <- rpart(Fraud ~ RearEnd, data = train, method = "class", minsplit = 2, minbucket = 1)
mytree
n= 3 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 3 1 FALSE (0.6666667 0.3333333) *

Once again we’re left with just a root node. Internally, rpart keeps track of something called the complexity of a tree. The complexity measure is a combination of the size of a tree and the ability of the tree to separate the classes of the target variable. If the next best split in growing a tree does not reduce the tree’s overall complexity by a certain amount, rpart will terminate the growing process. This amount is specified by the complexity parameter, cp, in the call to rpart. Setting cp to a negative amount ensures that the tree will be fully grown.

mytree <- rpart(Fraud ~ RearEnd, data = train, method = "class", minsplit = 2, minbucket = 1, cp=-1)
fancyRpartPlot(mytree)

rpart_plot2

This is not always a good idea since it will typically produce over-fitted trees, but trees can be pruned back as discussed later in this article.

You can also weight each observation for the tree’s construction by specifying the weights argument to rpart.

mytree <- rpart(Fraud ~ RearEnd, data = train, method = "class", minsplit = 2, minbucket = 1,
                weights = c(.4, .4, .2))
fancyRpartPlot(mytree)

rpart_plot3

One of the best ways to identify a fraudulent claim is to hire a private investigator to monitor the activities of a claimant. Since private investigators don’t work for free, the insurance company will have to strategically decide which claims to investigate. To do this, they can use a decision tree model based off some initial features of the claim. If the insurance company wants to aggressively investigate claims (i.e. investigate a lot of claims), they can train their decision tree in a manner that will penalize incorrectly labeled fraudulent claims more than it penalizes incorrectly labeled non-fraudulent claims.

To alter the default, equal penalization of mislabeled target classes set the loss component of the parms parameter to a matrix where the (i,j) element is the penalty for misclassifying an i as a j. (The loss matrix must have 0s in the diagonal). For example, consider the following training data.

train <- data.frame(ClaimID = c(1,2,3,4,5,6,7),
                    RearEnd = c(TRUE, TRUE, FALSE, FALSE, FALSE, FALSE, FALSE),
                    Whiplash = c(TRUE, TRUE, TRUE, TRUE, TRUE, FALSE, FALSE),
                    Fraud = c(TRUE, TRUE, TRUE, FALSE, FALSE, FALSE, FALSE))
train
  ClaimID RearEnd Whiplash Fraud
1       1    TRUE     TRUE  TRUE
2       2    TRUE     TRUE  TRUE
3       3   FALSE     TRUE  TRUE
4       4   FALSE     TRUE FALSE
5       5   FALSE     TRUE FALSE
6       6   FALSE    FALSE FALSE
7       7   FALSE    FALSE FALSE

Now let’s grow our decision tree, restricting it to one split by setting the maxdepth argument to 1.

mytree <- rpart(Fraud ~ RearEnd + Whiplash, data = train, method = "class", 
                maxdepth = 1, minsplit = 2, minbucket = 1)
fancyRpartPlot(mytree)

rpart_plot4

rpart has determined that RearEnd was the best variable for identifying a fraudulent claim. BUT there was one fraudulent claim in the training set that was not the case of a rear-end collision. If the insurance company wants to identify a high percentage of fraudulent claims without worrying too much about investigating non-fraudulent claims they can set the loss matrix to penalize claims incorrectly labeled as fraudulent three times less than claims incorrectly labeled as non-fraudulent.

lossmatrix <- matrix(c(0,1,3,0), byrow=TRUE, nrow=2)
lossmatrix
     [,1] [,2]
[1,]    0    1
[2,]    3    0

mytree <- rpart(Fraud ~ RearEnd + Whiplash, data = train, method = "class", 
                maxdepth = 1, minsplit = 2, minbucket = 1,
                parms = list(loss = lossmatrix))
fancyRpartPlot(mytree)

rpart_plot5

Now our model suggests that Whiplash is the best vairable to identify fraudulent claims. What I just described is known as an error measure and its up to the discretion of the insurance company to decide on it. Yaser Abu-Mostafa of Caltech has a great explanation of this topic here.

Now let’s see how rpart interacts with factor variables. Suppose the insurance company hires an investigator to assess the activity level of claimants. Activity levels can be Very Active, Active, Inactive, or Very Inactive.

Dataset 1…

train <- data.frame(ClaimID = c(1,2,3,4,5),
                    Activity = factor(c("active", "very active", "very active", "inactive", "very inactive"),
                                      levels=c("very inactive", "inactive", "active", "very active")),
                    Fraud = c(FALSE, TRUE, TRUE, FALSE, TRUE))
train
  ClaimID      Activity Fraud
1       1        active FALSE
2       2   very active  TRUE
3       3   very active  TRUE
4       4      inactive FALSE
5       5 very inactive  TRUE

mytree <- rpart(Fraud ~ Activity, data = train, method = "class", minsplit = 2, minbucket = 1)
fancyRpartPlot(mytree)

rpart_plot6

Dataset 2…

train <- data.frame(ClaimID = c(1,2,3,4,5),
                    Activity = factor(c("active", "very active", "very active", "inactive", "very inactive"),
                                      levels=c("very inactive", "inactive", "active", "very active"),
                                      ordered=TRUE),
                    Fraud = c(FALSE, TRUE, TRUE, FALSE, TRUE))
train
  ClaimID      Activity Fraud
1       1        active FALSE
2       2   very active  TRUE
3       3   very active  TRUE
4       4      inactive FALSE
5       5 very inactive  TRUE

mytree <- rpart(Fraud ~ Activity, data = train, method = "class", minsplit = 2, minbucket = 1)
fancyRpartPlot(mytree)

rpart_plot7

In the first dataset, we did not specify that the Activity vector was an ordered factor, so rpart tested every possible way to split the levels of the Activity vector. In the second dataset, Activity was specified as an ordered factor so rpart only tested splits that separated the ordered set of Activity levels. (For more explanation of this, see this post and/or this post)

It’s usually a good idea to prune a decision tree. Fully grown trees don’t perform well against data not in the training set because they tend to be overfitted so pruning is used to reduce their complexity by keeping only the most important splits.

#Build a training set
train <- data.frame(ClaimID = c(1,2,3,4,5,6,7,8,9,10),
                    RearEnd = c(TRUE, TRUE, TRUE, FALSE, FALSE, FALSE, FALSE, TRUE, TRUE, FALSE),
                    Whiplash = c(TRUE, TRUE, TRUE, TRUE, TRUE, FALSE, FALSE, FALSE, FALSE, TRUE),
                    Activity = factor(c("active", "very active", "very active", "inactive", "very inactive", "inactive", "very inactive", "active", "active", "very active"),
                                      levels=c("very inactive", "inactive", "active", "very active"),
                                      ordered=TRUE),
                    Fraud = c(FALSE, TRUE, TRUE, FALSE, FALSE, TRUE, TRUE, FALSE, FALSE, TRUE))
train
   ClaimID RearEnd Whiplash      Activity Fraud
1        1    TRUE     TRUE        active FALSE
2        2    TRUE     TRUE   very active  TRUE
3        3    TRUE     TRUE   very active  TRUE
4        4   FALSE     TRUE      inactive FALSE
5        5   FALSE     TRUE very inactive FALSE
6        6   FALSE    FALSE      inactive  TRUE
7        7   FALSE    FALSE very inactive  TRUE
8        8    TRUE    FALSE        active FALSE
9        9    TRUE    FALSE        active FALSE
10      10   FALSE     TRUE   very active  TRUE

#Grow a full tree
mytree <- rpart(Fraud ~ RearEnd + Whiplash + Activity, data = train, method = "class", minsplit = 2, minbucket = 1, cp=-1)
fancyRpartPlot(mytree)

rpart_plot8

You can view the importance of each variable in the model by calling the variable.importance attribute of the resulting rpart object. From the rpart documentation, “An overall measure of variable importance is the sum of the goodness of split measures for each split for which it was the primary variable…”

mytree$variable.importance
 Activity  Whiplash   RearEnd 
3.0000000 2.0000000 0.8571429 

When rpart grows a tree it performs 10-fold cross validation on the data. To see the cross validation results use the printcp() function.

printcp(mytree)

Classification tree:
rpart(formula = Fraud ~ RearEnd + Whiplash + Activity, data = train, 
    method = "class", minsplit = 2, minbucket = 1, cp = -1)

Variables actually used in tree construction:
[1] Activity RearEnd  Whiplash

Root node error: 5/10 = 0.5

n= 10 

    CP nsplit rel error xerror    xstd
1  0.6      0       1.0    2.0 0.00000
2  0.2      1       0.4    0.4 0.25298
3 -1.0      3       0.0    0.4 0.25298

The rel error of each iteration of the tree is the fraction of mislabeled elements in the iteration relative to the fraction of mislabeled elements in the root. In this example, 50% of training cases are fraudulent. The first splitting criteria is “Is the claimant very active?”, which separates the data into a set of three cases, all of which are fraudulent and a set of seven cases of which two are fraudulent. Labeling the cases at this point would produce an error rate of 20% which is 40% of the root node error rate (i.e. it’s 60% better). The cross validation error rates and standard deviations are displayed in the columns xerror and xstd respectively.

As a rule of thumb, it’s best to prune a decision tree using the cp of smallest tree that is within one standard deviation of the tree with the smallest xerror. In this example, the best xerror is 0.4 with standard deviation .25298. So, we want the smallest tree with xerror less than 0.65298. This is the tree with cp = 0.2, so we’ll want to prune our tree with a cp slightly greater than than 0.2.

mytree <- prune(mytree, cp=.21)
fancyRpartPlot(mytree)

rpart_plot9

From here we can use our decision tree to predict fraudulent claims on an unseen dataset using the predict() function.

test <- data.frame(ClaimID = c(1,2,3,4,5,6,7,8,9,10),
                    RearEnd = c(FALSE, TRUE, TRUE, FALSE, FALSE, FALSE, FALSE, TRUE, TRUE, FALSE),
                    Whiplash = c(FALSE, TRUE, TRUE, TRUE, TRUE, FALSE, FALSE, FALSE, FALSE, TRUE),
                    Activity = factor(c("inactive", "very active", "very active", "inactive", "very inactive", "inactive", "very inactive", "active", "active", "very active"),
                                      levels=c("very inactive", "inactive", "active", "very active"),
                                      ordered=TRUE))
test
   ClaimID RearEnd Whiplash      Activity
1        1   FALSE    FALSE      inactive
2        2    TRUE     TRUE   very active
3        3    TRUE     TRUE   very active
4        4   FALSE     TRUE      inactive
5        5   FALSE     TRUE very inactive
6        6   FALSE    FALSE      inactive
7        7   FALSE    FALSE very inactive
8        8    TRUE    FALSE        active
9        9    TRUE    FALSE        active
10      10   FALSE     TRUE   very active

test$FraudClass <- predict(mytree, newdata = test, type="class") #Returns the predicted class
test$FraudProb <- predict(mytree, newdata = test, type="prob") #Returns a matrix of predicted probabilities

test
   ClaimID RearEnd Whiplash      Activity FraudClass FraudProb.FALSE FraudProb.TRUE
1        1   FALSE    FALSE      inactive      FALSE       0.7142857      0.2857143
2        2    TRUE     TRUE   very active       TRUE       0.0000000      1.0000000
3        3    TRUE     TRUE   very active       TRUE       0.0000000      1.0000000
4        4   FALSE     TRUE      inactive      FALSE       0.7142857      0.2857143
5        5   FALSE     TRUE very inactive      FALSE       0.7142857      0.2857143
6        6   FALSE    FALSE      inactive      FALSE       0.7142857      0.2857143
7        7   FALSE    FALSE very inactive      FALSE       0.7142857      0.2857143
8        8    TRUE    FALSE        active      FALSE       0.7142857      0.2857143
9        9    TRUE    FALSE        active      FALSE       0.7142857      0.2857143
10      10   FALSE     TRUE   very active       TRUE       0.0000000      1.0000000

In summary, the rpart package is pretty sweet. I tried to cover the most important features of the package, but I suggest you read through the rpart vignette to understand the things I skipped. Also, I’d like to point out that a single decision tree usually won’t have much predictive power but an ensemble of varied decision trees such as random forests and boosted models can perform extremely well.