Skip to main content

Linear Regression Theory

Linear regression is a fundamental statistical and machine learning technique used to model the relationship between a dependent variable (target) and one or more independent variables (features). The goal is to find the best-fitting line (or hyperplane in higher dimensions) that minimizes the error between the predicted and actual values.

In this article, we will cover:

  • The basic form of linear regression.
  • The Ordinary Least Squares (OLS) method used to estimate parameters.
  • Key concepts like the cost function and gradient descent.
  • Evaluating the model’s performance.
  • Introduction to regularization techniques (Ridge, Lasso).

1. The Linear Regression Model

The linear regression model assumes that the relationship between the input features x1,x2,,xnx_1, x_2, \dots, x_n and the output yy can be modeled as a linear combination of the input features:

y=β0+β1x1+β2x2++βnxn+ϵy = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \dots + \beta_n x_n + \epsilon

Where:

  • yy is the dependent variable (the value we want to predict).
  • β0\beta_0 is the intercept (the predicted value when all independent variables are zero).
  • β1,,βn\beta_1, \dots, \beta_n are the coefficients (weights) assigned to the independent variables.
  • x1,x2,,xnx_1, x_2, \dots, x_n are the independent variables (features).
  • ϵ\epsilon is the error term (residuals), representing the noise or variability in the prediction that cannot be explained by the model.

The goal of linear regression is to find the best-fitting line (or hyperplane) that minimizes the difference between the actual and predicted values of yy.


2. Ordinary Least Squares (OLS) Method

2.1. What is OLS?

Ordinary Least Squares (OLS) is the most common method used to estimate the parameters β0,β1,,βn\beta_0, \beta_1, \dots, \beta_n in linear regression. The OLS method works by minimizing the sum of squared residuals — the squared differences between the actual values and the predicted values from the model.

The cost function (or loss function) that OLS minimizes is called the Mean Squared Error (MSE):

MSE=1ni=1n(yiy^i)2MSE = \frac{1}{n} \sum_{i=1}^{n} \left( y_i - \hat{y}_i \right)^2

Where:

  • yiy_i is the actual value for the ii-th data point.
  • y^i\hat{y}_i is the predicted value for the ii-th data point.
  • nn is the number of observations.

2.2. OLS Formula

The OLS estimator finds the values of the coefficients that minimize the sum of squared residuals. This can be computed analytically as:

β^=(XTX)1XTy\hat{\beta} = (X^TX)^{-1}X^Ty

Where:

  • XX is the matrix of input features (independent variables).
  • yy is the vector of observed values (dependent variable).
  • β^\hat{\beta} is the vector of estimated coefficients.

This formula assumes that the matrix XTXX^TX is invertible, which requires that there is no perfect multicollinearity between the independent variables.

2.3. Why Minimize Squared Errors?

The decision to minimize squared errors (instead of absolute errors or other measures) is motivated by the fact that squared errors give more weight to larger deviations. This makes OLS particularly sensitive to outliers, but it also ensures that the solution is differentiable, allowing for analytical solutions like the one shown above.


3. Gradient Descent (When OLS is Not Feasible)

For very large datasets or complex models, directly calculating the OLS solution can be computationally expensive. In such cases, gradient descent is often used as an alternative to minimize the cost function.

3.1. What is Gradient Descent?

Gradient Descent is an iterative optimization algorithm used to minimize a function by adjusting parameters (in this case, the coefficients β1,,βn\beta_1, \dots, \beta_n) in the direction of the negative gradient of the cost function.

The general update rule for gradient descent is:

βj=βjαβjMSE\beta_j = \beta_j - \alpha \frac{\partial}{\partial \beta_j} \text{MSE}

Where:

  • α\alpha is the learning rate, which controls the step size of each iteration.
  • βjMSE\frac{\partial}{\partial \beta_j} \text{MSE} is the gradient of the cost function with respect to the coefficient βj\beta_j.

3.2. Convergence

The algorithm iteratively adjusts the coefficients until it reaches a minimum in the cost function (or a specified tolerance). The choice of the learning rate α\alpha is crucial: too large and the algorithm may overshoot the minimum; too small and the convergence will be very slow.


4. Model Evaluation Metrics

Once we have trained a linear regression model, it is important to evaluate its performance. Some common metrics for regression tasks include:

4.1. R-Squared (R2R^2)

The R-squared value (also called the coefficient of determination) measures the proportion of the variance in the dependent variable that is predictable from the independent variables.

R2=1i=1n(yiy^i)2i=1n(yiyˉ)2R^2 = 1 - \frac{\sum_{i=1}^{n} (y_i - \hat{y}_i)^2}{\sum_{i=1}^{n} (y_i - \bar{y})^2}

Where yˉ\bar{y} is the mean of the actual values. An R2R^2 value of 1 means the model perfectly predicts the dependent variable, while an R2R^2 value of 0 means the model does no better than the mean.

4.2. Mean Absolute Error (MAE)

The Mean Absolute Error (MAE) is the average of the absolute differences between the actual and predicted values:

MAE=1ni=1nyiy^iMAE = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i|

Unlike the squared error metrics, MAE is less sensitive to outliers because it doesn’t square the residuals.

4.3. Mean Squared Error (MSE) and Root Mean Squared Error (RMSE)

As mentioned before, MSE is a common metric in linear regression, but another closely related metric is Root Mean Squared Error (RMSE):

RMSE=MSERMSE = \sqrt{MSE}

RMSE is often easier to interpret because it is in the same units as the dependent variable.


5. Limitations of Linear Regression

While linear regression is powerful, it has several limitations that make it unsuitable for certain types of data:

5.1. Linearity Assumption

  • Linear regression assumes that the relationship between the independent and dependent variables is linear, which may not always be true. If the data exhibits nonlinear patterns, linear regression may perform poorly.

5.2. Sensitive to Outliers

  • OLS minimizes the sum of squared residuals, making it highly sensitive to outliers. A single outlier can disproportionately affect the model’s coefficients.

5.3. Multicollinearity

  • When independent variables are highly correlated with each other, the model may suffer from multicollinearity. This can inflate the variance of the coefficient estimates and make the model unstable.

6. Regularization: Ridge and Lasso Regression

To address some of the limitations of basic linear regression, regularization techniques like Ridge and Lasso are often used.

6.1. Ridge Regression

Ridge Regression adds a penalty term to the cost function to shrink the size of the coefficients, helping to prevent overfitting:

Cost Function=(yiy^i)2+λβj2\text{Cost Function} = \sum (y_i - \hat{y}_i)^2 + \lambda \sum \beta_j^2

Where λ\lambda is the regularization parameter. The larger λ\lambda is, the more the coefficients are penalized.

6.2. Lasso Regression

Lasso Regression adds a penalty based on the absolute value of the coefficients, which can result in some coefficients being reduced to zero, effectively selecting a subset of features:

Cost Function=(yiy^i)2+λβj\text{Cost Function} = \sum (y_i - \hat{y}_i)^2 + \lambda \sum |\beta_j|

This makes Lasso particularly useful when dealing with high-dimensional data or when we want to perform feature selection.


Summary

Linear regression is a powerful and widely used algorithm, but it comes with several key assumptions and limitations. Understanding the theory behind Ordinary Least Squares, the cost function, and methods like gradient descent is essential for building accurate and interpretable models. Furthermore, incorporating regularization techniques like Ridge and Lasso can help address some of the pitfalls associated with standard linear regression, such as multicollinearity and overfitting.

In the next section, we will explore practical examples of applying linear regression to real-world datasets, and learn how to implement the model using popular machine learning libraries like scikit-learn, TensorFlow, and PyTorch.