ML 学习站
跳到正文

决策树

信息熵、ID3/C4.5 概念、树的剪枝与可视化。

35 分钟2 / 51,267
加载中...

决策树是一种直观且易于解释的分类/回归算法,其结构类似于倒置的树,包含内部节点(特征判断)和叶子节点(决策结果)。核心概念包括信息熵和信息增益,信息熵用于衡量数据纯度,而信息增益则表示切分后熵的下降量,ID3、C4.5和CART算法分别采用不同的指标进行特征选择。读者将学会使用sklearn库构建和可视化决策树,并理解如何通过预剪枝和后剪枝方法防止过拟合。决策树具有可解释性强、不需特征缩放、能处理类别型特征和缺失值等优点,但也存在易过拟合、不稳定和难以学习复杂关系的缺点。掌握决策树后,读者能够运用其进行数据分类和回归分析,并通过集成方法(如随机森林和XGBoost)提升模型性能。

决策树

决策树(Decision Tree)是最直观的分类/回归算法——它就是一系列"if-else"的判断流程,长得像一棵倒过来的树。

长什么样

一个判断"今天适不适合打球"的简单决策树:

              天气?
             /  |  \
           晴   阴   雨
           |    |    |
        湿度?   ✓    ✗
        /  \
      高   低
      |    |
      ✗    ✓

每个内部节点是一个特征判断,每个叶子节点是一个决策结果。

信息熵:为什么这么切?

决策树的关键问题是:按哪个特征切分?用哪个阈值?

衡量"切分后数据有多纯"的指标是信息熵(Information Entropy):

H(S) = -sum( p_c * log2(p_c) )
  • 全部一类(p=1):熵 = 0(完全纯)
  • 两类各半(p=0.5):熵 = 1(最混乱)

我们希望切分后熵尽量低,这个"切分后熵下降的量"叫做信息增益(Information Gain):

IG(S, A) = H(S) - sum_v (|S_v|/|S|) * H(S_v)

ID3 算法每次选信息增益最大的特征切分。C4.5 改用信息增益率,处理偏向多取值特征的问题。CART 用基尼系数(Gini),计算更快。

实战 sklearn

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# 1. 加载数据(经典鸢尾花)
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 2. 训练(限制深度防止过拟合)
model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X_train, y_train)

# 3. 评估
print(f"测试集准确率: {model.score(X_test, y_test):.3f}")

# 4. 可视化决策树
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 8))
plot_tree(model, feature_names=['花萼长', '花萼宽', '花瓣长', '花瓣宽'],
          class_names=['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾'],
          filled=True, rounded=True)
plt.show()

防止过拟合:剪枝

决策树特别容易过拟合——只要深度够深,它能把每个训练样本都分对(纯度 100%)。两种剪枝:

预剪枝(Pre-pruning)

提前限制树的生长:

DecisionTreeClassifier(
    max_depth=5,           # 最大深度
    min_samples_split=10,  # 节点至少 10 个样本才继续切
    min_samples_leaf=5,    # 叶子节点至少 5 个样本
    max_leaf_nodes=20,     # 最多 20 个叶子
)

后剪枝(Post-pruning)

先让树长满,再从下往上剪掉没用的分支。sklearn 用 cost-complexity pruning:

# 获取不同 alpha 下的剪枝路径
path = model.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas

# 挑最优 alpha
from sklearn.model_selection import cross_val_score
scores = [cross_val_score(DecisionTreeClassifier(ccp_alpha=a),
                          X_train, y_train, cv=5).mean()
          for a in ccp_alphas]
best_alpha = ccp_alphas[np.argmax(scores)]
model_pruned = DecisionTreeClassifier(ccp_alpha=best_alpha)

决策树的优缺点

优点:

  • 可解释性极强(白盒模型,银行/医疗常用)
  • 不需要特征缩放
  • 能处理类别型特征和缺失值
  • 训练快,推理更快(O(log n))

缺点:

  • 极易过拟合(需要剪枝)
  • 不稳定——训练数据小幅变动可能让树结构大变
  • 难学到复杂关系(单棵树能力有限)

小结

  • 决策树是一系列 if-else,基于信息增益 / 基尼系数选切分
  • sklearn: DecisionTreeClassifier,可视化用 plot_tree
  • 防止过拟合:限制深度 + 限制叶子后剪枝
  • 单棵树效果一般,集成后(随机森林/XGBoost)威力大增(下一章)

练习思考

  1. load_breast_cancer 数据集,跑一个不限制深度的决策树和一个 max_depth=3 的,对比两者的训练集和测试集准确率,理解过拟合。
  2. 为什么说"决策树不需要特征缩放",而逻辑回归/SVM 强烈建议缩放?
  3. 决策树的"不稳定"是什么意思?为什么随机森林能解决它?

章末小测验

检验你对《决策树》的掌握程度。

1

决策树用哪个指标选择在哪个特征上切分?

2

决策树最常用的防止过拟合方法是?

讨论区(0)

加载评论中...