Decision Tree Regression in Python in 10 lines

Decision tree algorithm creates a tree like conditional control statements to create its model hence it is named as decision tree.

Decision tree machine learning algorithm can be used to solve both regression and classification problem.

In this post we will be implementing a simple decision tree regression model using python and sklearn.

You may like to watch a video on Decision Tree from Scratch in Python

First thing first , let us import the required libraries.

import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import r2_score,mean_squared_error

After that we need to load data in jupyter notebook. You can find the data here.

df = pd.read_csv('Linear-Regression-Data.csv')
df.head()
Decision Tree Regression Data Load
Decision Tree Regression Data Load

Note that the above data has a feature called x and a label called y. We have to use values of x to predict y. The next step would be to split data into train and test as below.

x = df.x.values.reshape(-1, 1)
y = df.y.values.reshape(-1, 1)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=42)

After this let us train the model

DecisionTreeRegModel = DecisionTreeRegressor()
DecisionTreeRegModel.fit(x_train,y_train)
Decision Tree Regression Model Architecture
Decision Tree Regression Model Architecture

This is the time to do some prediction

y_pred = DecisionTreeRegModel.predict(x_test)

After the prediction is done we can evaluate the model using R squared and RMSE as below.

Decision Tree Model Evaluation using R2 Square and RMSE
Decision Tree Model Evaluation using R2 Square and RMSE

Note: You may also like to implement Linear Regression for this problem. And you may like to check the assumptions for Linear Regression as well.

There are various Advantages and Disadvantages of Decision Tree algorithm

And last but not the least you can visualize the decision tree regression model as below.

Decision Tree Regression Model Visualization
Decision Tree Regression Model Visualization

I hope you enjoyed this article and can start using some of the techniques described here in your own projects soon. Cheers !!

One thought on “Decision Tree Regression in Python in 10 lines

  1. Thanks for your post!
    One question….the instruction: y_pred = DecisionTreeRegModel.predict(x_test)
    If I want to predict an exact value I’m going to get an error.

    For example : y_pred = DecisionTreeRegModel.predict(50)..
    Error: ValueError: Expected 2D array, got scalar array instead:
    array=50.0.
    Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

Leave a Reply

%d