Upgrade to Get Unlimited Access
($10 One Off Payment)

A Simple and Practical Guide to Linear Regression

Updated: Feb 1


Linear regression is a typical regression algorithm which is responsible for numerous prediction. It is distinct to classification models - such as decision tree, support vector machine, neural network. In a nutshell, a linear regression finds the optimal linear relationship between independent variables and dependent variables, thus makes prediction accordingly.


I guess most people have frequently encountered the function y = b0 + b1x in math class. It is basically the form of simple linear regression, where b0 defines the intercept and b1 defines the slope of the line.


Simple linear regression predicts relationship between one independent variable x and one dependent variable y, for instance, the classic height - weight correlation. As more features/independent variables are introduced, it evolves into multiple linear regression y = b0 + b1x1 + b2x2 + ... + bnxn, which cannot be easily plotted using a line in a two dimensional space. I will explain more theory behind the algorithm in the section "Model Implementation", but the aim of this article is to go practical! And if you are interested in a short video walkthrough.


Define Objectives

I use Kaggle public dataset "Insurance Premium Prediction" in this exercise. The data includes independent variables: age, sex, bmi, children, smoker, region, and target variable - expenses. Firstly, let's load the data and have a preliminary examination of the data using df.info()

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pandas.api.types import is_string_dtype, is_numeric_dtype

df = pd.read_csv('../input/insurance-premium-prediction/insurance.csv')
df.info()

Key take-away:

  1. categorical variables: sex, smoker, region

  2. numerical variables: age, bmi, children, expenses

  3. no missing data among 1338 records


Exploratory Data Analysis (EDA)

EDA is essential to both investigate the data quality and reveal hidden correlations among variables. To have a comprehensive view of EDA, check out my article on "Semi-Automated Exploratory Data Analysis (EDA) in Python".


1. Univariate Analysis

Visualize the data distribution using histogram for numeric variables and bar chart for categorical variables.

Why do we need univariate analysis?

  • identify if dataset contains outliers

  • identify if need data transformation or feature engineering

In this case, we found out that "expenses" follows a power law distribution, which means that log transformation is required as a step of feature engineering step, to convert it to normal distribution.


2. Multivariate Analysis


When thinking of linear regression, the first visualization technique that we can think of is scatterplot. By plotting the target variable against the independent variables using a single line of code sns.pairplot(df), the underlying linear relationship becomes more evident.

As shown, there seems to be some linear relationship between age and expenses - which does make sense.

Now, how about adding the categorical variable as the legend?