GridSearchCV——网格搜索与交叉验证详解
GridSearchCV——网格搜索与交叉验证详解
1. 什么是 GridSearchCV
GridSearchCV 是 sklearn 中 model_selection 模块提供的超参数调优工具,其名称揭示了两个核心步骤:
- Grid Search(网格搜索):穷举参数组合——将各参数的候选值做笛卡尔积,形成一张”参数网格”,逐一尝试;
- CV(Cross-Validation,交叉验证):对每一组参数组合,使用 K 折交叉验证评估模型性能,取平均得分作为该组合的评价指标。
最终选出交叉验证得分最高的那组参数,并在全部训练数据上重新训练出最终模型。
2. 核心原理
3.1 网格穷举
假设要调优 SVM 的两个参数:C 和 gamma:
1 | param_grid = { |
笛卡尔积产生 $3 \times 3 = 9$ 组参数组合:
| 编号 | C | gamma |
|---|---|---|
| 1 | 0.1 | 0.01 |
| 2 | 0.1 | 0.1 |
| 3 | 0.1 | 1 |
| 4 | 1 | 0.01 |
| … | … | … |
| 9 | 10 | 1 |
GridSearchCV 会逐一训练并评估这 9 组参数。
3.2 交叉验证评估
对每一组参数,使用 K 折交叉验证计算得分:
- 将训练集等分为 $K$ 份;
- 依次以其中 $K-1$ 份训练、1 份验证,重复 $K$ 次;
- 取 $K$ 次得分的平均值作为该组参数的最终得分。
1 | 折数 K=5 示意: |
3.3 总训练次数
若参数组合数为 $N$,交叉验证折数为 $K$,则总训练次数为:
$$
\text{Total fits} = N \times K
$$
上例中 $9 \times 5 = 45$ 次训练。若参数维度增加,训练次数会指数级增长——这是网格搜索的主要瓶颈。
3. 核心参数详解
1 | from sklearn.model_selection import GridSearchCV |
4.1 estimator —— 基础模型
传入一个 sklearn 兼容的 estimator 实例。GridSearchCV 内部会克隆该对象用于各参数组合的训练。
1 | from sklearn.svm import SVC |
4.2 param_grid —— 参数网格
支持两种形式:
形式一:单个字典(笛卡尔积搜索):
1 | param_grid = { |
形式二:字典列表(分段搜索,不取笛卡尔积):
1 | param_grid = [ |
第二种形式在搜索不同模型或不同参数子集时非常有用,避免了无意义的组合。
4.3 scoring —— 评分指标
指定用于评估和选择最优参数的指标。可以是字符串或自定义 scorer。
常用字符串:
| 任务 | 指标 | scoring 参数值 |
|---|---|---|
| 分类 | 准确率 | 'accuracy' |
| 分类 | F1 分数 | 'f1'、'f1_macro'、'f1_weighted' |
| 分类 | AUC | 'roc_auc' |
| 回归 | 均方误差 | 'neg_mean_squared_error' |
| 回归 | R² | 'r2' |
注意:sklearn 约定 scorer 值越大越好,因此 MSE 等误差指标会加 neg_ 前缀取负数。
支持多个评分指标:
1 | scoring = ['accuracy', 'f1_weighted', 'roc_auc_ovr'] |
4.4 cv —— 交叉验证策略
cv 参数不仅可以是整数,还可以传入各种 splitter:
1 | from sklearn.model_selection import StratifiedKFold, TimeSeriesSplit, GroupKFold |
4.5 fit —— 执行搜索与训练
fit(X, y, **fit_params) 是 GridSearchCV 的核心方法,调用它才会真正启动网格搜索。其内部流程为:
- 遍历
param_grid中的每一组参数组合; - 对每组参数,执行 $K$ 折交叉验证,记录每折的得分与训练耗时;
- 汇总所有结果到
cv_results_; - 若
refit=True,用best_params_在全部X上再训练一次,得到best_estimator_。
1 | grid = GridSearchCV(SVC(), param_grid, cv=5) |
**fit_params 会透传给底层 estimator 的 fit() 方法。典型场景是某些模型需要额外参数:
1 | # XGBoost 的 eval_set、early_stopping_rounds 等 |
fit() 返回 self(即 GridSearchCV 实例本身),因此支持链式调用,也意味着调用后可以直接访问 grid.best_params_ 等属性。
4.6 refit —— 重训练
默认 refit=True,在 fit() 的网格搜索阶段找到最优参数后,会用全部训练数据重新训练一次模型,得到 best_estimator_,可直接用于预测。
若使用了多指标评分,可指定 refit='f1_weighted' 告知以哪个指标选出最优参数。
4. 搜索结果分析——fit 后的属性与 cv_results_ 详解
GridSearchCV.fit() 完成后,搜索对象上会挂载一系列以 _ 结尾的属性,供后续分析和使用。以下逐一说明每个属性的含义、类型及访问方式。
4.1 顶层属性 —— 直接挂载在 grid 对象上
以下属性在 fit() 后直接挂载在 GridSearchCV 实例上,通过 grid.属性名 访问:
| 属性 | 类型 | 访问方式 | 说明 |
|---|---|---|---|
best_params_ |
dict |
grid.best_params_ |
交叉验证得分最高的那组参数 |
best_score_ |
float |
grid.best_score_ |
最优参数对应的交叉验证平均得分 |
best_estimator_ |
estimator 对象 | grid.best_estimator_ |
用最优参数在全量训练数据上重新训练好的模型,可直接 .predict() |
best_index_ |
int |
grid.best_index_ |
最优参数在 cv_results_ 各数组中的索引位置 |
cv_results_ |
dict of ndarray |
grid.cv_results_ |
所有参数组合的详细结果,其内部结构见 4.2 节 |
scorer_ |
scorer 对象 | grid.scorer_ |
实际使用的评分器 |
n_splits_ |
int |
grid.n_splits_ |
实际执行的交叉验证折数 |
refit_time_ |
float |
grid.refit_time_ |
全量 refit 所消耗的时间(秒) |
multimetric_ |
bool |
grid.multimetric_ |
是否使用了多指标评分 |
访问示例:
1 | grid = GridSearchCV(SVC(), param_grid, cv=5).fit(X_train, y_train) |
4.2 cv_results_ 内部结构 —— 它本身是一个 dict
下面按类别逐一说明 dict 中的 key。
4.2.1 参数列 —— key 以 param_ 为前缀
每个被搜索的参数在 dict 中对应一个 key,命名规则为 param_ + 参数名。这些 key 的值是长度为 $N$(参数组合总数)的 numpy 数组,记录了该组参数组合的候选值。
1 | # 例:param_grid = {'C': [0.1, 1], 'gamma': [0.01, 0.1]} |
当参数以双下划线嵌套时(如 Pipeline),key 名原样保留双下划线:
1 | # param_grid = {'svm__C': [0.1, 1]} |
4.2.2 得分列 —— mean_test_score、std_test_score、rank_test_score 与每折明细
这些 key 的数量取决于 scoring 的配置和 return_train_score 的值。
始终存在的 key:
| dict 中的 key | 含义 |
|---|---|
mean_test_score |
每组参数在 $K$ 折验证集上的平均得分 |
std_test_score |
每组参数在 $K$ 折验证集上得分的标准差 |
rank_test_score |
按 mean_test_score 降序排名(1 = 最优) |
每折明细 key(cv=K 时有 $K$ 个):
| dict 中的 key | 含义 |
|---|---|
split0_test_score |
第 1 折验证集得分 |
split1_test_score |
第 2 折验证集得分 |
| … | … |
split{K-1}_test_score |
第 K 折验证集得分 |
仅当 return_train_score=True 时才出现的 key:
| dict 中的 key | 含义 |
|---|---|
mean_train_score |
每组参数在 $K$ 折训练集上的平均得分 |
std_train_score |
训练集得分的标准差 |
split0_train_score … |
每折训练集得分明细 |
多指标评分时(scoring=['accuracy', 'f1']),以上 key 会按指标名展开:
| dict 中的 key | 含义 |
|---|---|
mean_test_accuracy |
准确率的交叉验证均值 |
mean_test_f1 |
F1 的交叉验证均值 |
rank_test_accuracy |
按准确率排名 |
rank_test_f1 |
按 F1 排名 |
1 | # 示例:访问 dict 中的具体 key |
4.2.3 时间列
| dict 中的 key | 含义 |
|---|---|
mean_fit_time |
每组参数平均训练耗时(秒) |
std_fit_time |
训练耗时标准差 |
mean_score_time |
每组参数平均评分耗时(秒) |
std_score_time |
评分耗时标准差 |
4.3 cv_results_ 的典型分析操作
4.3.1 转为 DataFrame 并按排名查看
1 | import pandas as pd |
4.3.2 按自定义条件筛选
1 | # 筛出 C=10 且平均得分 > 0.95 的组合 |
4.3.3 从中提取最优参数
1 | res = grid.cv_results_ |
4.3.4 诊断过拟合:对比训练集与测试集得分
1 | grid = GridSearchCV( |
若 gap 过大(如训练得分 0.99,验证得分仅 0.85),说明该组参数存在过拟合风险。
4.3.5 可视化参数空间
1 | import matplotlib.pyplot as plt |
4.4 best_estimator_ 深入使用
best_estimator_ 是一个已经训练好的模型实例,可以直接调用其所有方法:
1 | best_model = grid.best_estimator_ |
4.5 多指标评分下的属性行为
当 scoring 为列表或字典时,best_score_ 和 best_index_ 的行为取决于 refit:
1 | grid = GridSearchCV( |
5. 常见问题与避坑指南
5.1 数据泄露
错误做法:在拆分训练集/测试集之前对全量数据做标准化,再用 GridSearchCV。
1 | # 错误!已导致数据泄露 |
正确做法:使用 Pipeline 将预处理嵌入搜索流程。
1 | from sklearn.pipeline import Pipeline |
使用 Pipeline 后,每个交叉验证折内的标准化只会用该折的训练数据来 fit,彻底杜绝数据泄露。
5.2 参数组合爆炸
假设 5 个参数,每个 5 个候选值,cv=5:
$$
\text{总训练次数} = 5^5 \times 5 = 15{,}625
$$
应对策略:
- 先用粗粒度大范围搜索,锁定范围后精细搜索;
- 用
RandomizedSearchCV替代穷举; - 用
HalvingGridSearchCV加速; - 基于经验缩小候选值范围。
5.3 评分指标选择不当
分类任务中,当类别不均衡时,accuracy 会高估模型表现——例如正样本仅占 1%,全预测负类也有 99% 准确率。此时应使用 f1 或 roc_auc。
1 | # 类别不均衡时 |
5.4 best_score_ 与测试集得分的落差
best_score_ 是交叉验证的平均得分,由于在参数搜索中选择了”最好”的那组,该值可能略偏乐观。最终应以独立测试集的结果为准。
5.5 n_jobs 与内存
n_jobs=-1 会并行运行,但每个任务会克隆一份 estimator 和数据。如果模型或数据很大,可能导致内存溢出。此时应适当减少并行数。
6. 嵌套交叉验证——无偏性能估计
当需要同时调参和估计模型泛化性能时,简单的 train/test split 会导致乐观偏差。嵌套交叉验证提供无偏的性能估计:
1 | from sklearn.model_selection import cross_val_score, StratifiedKFold |
外层 K 折的每一折中,GridSearchCV 都会在内层做一次完整的网格搜索——计算开销很大,但结果更可靠。
7. 小结
GridSearchCV穷举参数组合并配合交叉验证评估,是寻找最优超参数的系统化方法;- 参数组合数随维度指数增长,建议控制候选值数量或使用
RandomizedSearchCV替代; - 务必使用
Pipeline将数据预处理嵌入搜索流程,避免交叉验证中的数据泄露; - 类别不均衡时使用分层交叉验证和合适的评分指标(
f1、roc_auc); best_score_可作为参考,但最终性能应以独立测试集为准;- 对性能估计要求严格的场景,可使用嵌套交叉验证。
