Overfitting trap: Regularisation and beyond
[ Last Updated: 2024-06-30 ]
Sigh, I just felt that a title like "Talking About Regularization" was too ordinary...
Actually, I didn't want to just talk about regularization; I also wanted to record some of my small thoughts regarding "parameter tuning."
The Overfitting Trap
When first starting with regression statistics, it’s easy to fall into a common misunderstanding: staring obsessively at the value. In reality, is a very misleading metric that can trick you into desperately adding more variables to your model. Sometimes, we might even start disregarding logic entirely because there are too many variables, falling into an overfitting trap—creating a model that can only explain the training set itself. A extreme example is using 5 variables (or high-order terms) to fit a training set of only five data points; the equation will have a unique solution (add one more and you'll have multiple solutions...).
So, today let's talk about the general ideas for avoiding overfitting when facing a large amount of data:
Subset Selection
Since overfitting is caused by too many independent variables, an intuitive idea is to select a subset of them.
1) First, without much thought, we consider naive permutations and combinations: choosing 1, 2, 3...n variables from total variables. The complexity is . Forget manual screening; if the variables exceed 30, it would take a computer a long time to filter them. Therefore, this idea is likely to be passed over when time and hardware costs are limited.
2) If brute force fails, go for Greedy Algorithms. Simply put, start from a null model and sequentially add variables that improve model performance (or decrease cost most significantly) until no more variables provide a significant boost. This approach is known as Forward Selection. Conversely, we can start with all variables and sequentially delete insignificant ones until removing another would significantly degrade performance (Backward Selection). Or, combine the two, allowing both addition and deletion at each step. However, being a greedy approach, we cannot guarantee these steps will find the globally optimal combination of variables, as model performance doesn't necessarily follow a step-by-step optimal progression.
While it might not seem "scientific," simply deleting variables remains one of the most convenient and practical ways to reduce overfitting.
Validation Set / Cross-Validation
In addition to variable deletion, cross-validation is a very common method to avoid overfitting. The logic is simple: if a model overfits the training set, its predictive performance on a validation set will usually be poor. Therefore, we typically split the data into an 80%-20% ratio for training and validation—the training set for training the model, and the validation set to verify its reliability. The split must ensure that the validation and test sets have similar distribution characteristics. To ensure the validation set is representative, beyond random sampling, we often use K-fold cross-validation:
Divide the dataset into equal-sized subsets. Each time, use subsets for training and the remaining one for validation. Repeat this process times to get a more stable model evaluation. For example, if we have 1,000 samples and choose 5-fold cross-validation, each subset has 200 samples. The model is trained 5 times, using 800 samples for training and 200 for validation each time.
Of course, the specifics of cross-validation aren't the focus today; I might write a detailed post on it later. Now, let's finally introduce regularization.
Regularization
I used to be confused by this concept and its translation. Most online resources don't explain the "why" behind implementing it. It wasn't until I read a proper textbook recently that I realized: isn't this just automatic parameter tuning? (If variable deletion is "manual," this is the "automatic" version).
L1 Regularization (Lasso) and L2 Regularization (Ridge):
Both add a so-called Penalty Term to the original cost function. The former uses absolute values, while the latter uses squared terms. The selection of is crucial; one must balance model interpretability and overfitting (the optimal can be selected using K-fold cross-validation).
Next, using the three variables from my previous post on predicting admission probability (opens in a new tab), let's introduce the implementation and differences between L1 and L2.
First, let's create several high-order variables to give the model a tendency to overfit. Based on , we generate squared terms and cubic terms (Numpy has built-in functions for this):
Python
squared_columns = np.square(X_input)
cubic_columns = np.power(X_input, 3)
X_train_poly = np.concatenate((X_input, squared_columns,cubic_columns), axis=1)Now we have a new dataset with nine variables:
Python
X_train_poly=z_score(X_train_poly)We also need to rewrite the cost function and the gradient descent function (using L1 as an example; L2 just replaces the absolute value with a squared term):
Python
def compute_cost_reg_L1(X,y,w,b,lambda_=1):
m, n = X.shape
cost_without_reg = compute_cost(X, y, w, b)
reg_cost = np.sum(np.abs(w)) # L1 uses absolute value; L2 would use w[j] squared
reg_cost = (lambda_ / m) * reg_cost
total_cost = cost_without_reg + reg_cost
return total_costPython
def compute_gradient_reg_L1(X, y, w, b, lambda_ = 1):
m, n = X.shape
dj_dw,dj_db = compute_gradient(X, y, w, b)
for j in range(0,n):
dj_dw_reg=lambda_/m*np.sign(w[j]) # L1 gradient uses sign function; L2 would use w[j]
dj_dw[j]=dj_dw_reg+dj_dw[j]
return dj_dw, dj_dbNext, we select a range of (from 0.001, 0.01 to 1000) and perform gradient descent with L1 and L2 to see the effect on the nine variables:
As we can see, L1 regularization tends to drive the coefficients of unimportant variables to zero (making the matrix sparse). In the chart, at , several light blue/green lines have already dropped to 0 and vanished from the model. Meanwhile, the cost curve hasn't risen much, suggesting their existence was likely overfitting. This is similar to manual variable deletion, but L1 automates the process. However, when is large (reaching 100), many variables are forced to 0, and the model loses its predictive power. Thus, selection must be very cautious.
Looking at L2 regularization, we notice that even at , no parameter completely disappears. Instead, all parameters are scaled down almost proportionally. This also reduces the risk of overfitting, even though we still retain nine explanatory variables.
Why are the results so different? It comes down to the shape of the penalty term:
If we use a two-variable example and visualize the two parts of the cost function, the first part is like a contour ellipse from gradient descent. The second part (the penalty) is a diamond in L1 and a circle/ellipse in L2.
We want to find the minimum of .
This can be understood as minimizing the original cost function subject to the constraint . Here, is a "budget" (influenced by ). The solution is where the diamond (L1) or circle (L2) meets the innermost ellipse of the cost function. For L1, this intersection often occurs at a corner point (where one axis is zero). For L2, it occurs at a tangency point, which scales the coefficients without necessarily zeroing them out.
In practice, the choice between L1 and L2 depends on the situation: if many variables have similar explanatory power, L2 is useful for proportional scaling. Conversely, if there is a clear hierarchy among variables (e.g., deciding between linear, squared, and cubic terms), L1 is more appropriate to prune marginal variables.
Conclusion
This post briefly introduced some methods I've recently used to reduce overfitting. I also encountered related concepts like multicollinearity and high-dimensional regression; while related, they aren't included here yet (as I haven't finished studying them...).
Ultimately, overfitting is a product of blindly chasing prediction accuracy (especially on the training set). In data analysis, besides accuracy, we must consider interpretability and generalizability. This is why simple linear regression remains popular despite complex machine learning algorithms—it’s logical and easy to explain. If a relationship can be explained with a straight line, why use a curve just because GPUs are cheap now? (jk)
This week I've filled in most concepts regarding linear regression. I haven't finished the neural network content yet, so I'll likely start writing about that next week.
Learning is so slow.
— E Zai
2024/06/30 18:30
Would you like me to proofread the English version or help you translate the technical terms into a different dialect of Chinese?