0%

【机器学习自学笔记3】决策树剪枝

决策树如果任其自由生长,很容易产生过拟合。因此,我们有必要对决策树进行剪枝。

CART 剪枝算法从"完全生长"的决策树的底端剪去一些子树,使决策树变小(模型变简单),从而能够对未知数据有更准确的预测。

CART 剪枝

CART 剪枝的思想:

  • 从完全生长的整体树 的最底端开始不断剪枝
  • 直至剪到 的根结点为止,形成子树序列 {}
  • 通过交叉验证法在独立的验证集上对子树序列进行测试,选出最优子树

观察上面的思想,我们很容易想到一些问题:

  • 如何挑选剪枝的结点?
  • 每个结点是否需要剪枝?

剪枝的损失函数

为了控制剪枝的位置和顺序,引入一个损失函数的概念:

  • T 代表某一子树

  • C(T) 代表训练数据的预测误差 (如基尼系数)

  • |T|代表子树的结点数

  • 是一个参数,随着 的改变,得到的最优子树的复杂度和拟合度也不同

可以想象:

  • 较大时,|T| 的权重更大,此时为使 较小,最优子树结点数越少(简单,拟合度低)
  • 较小时,C(T) 的权重更大,此时为使 较小,训练误差也越小,拟合度越高(复杂,拟合度高)
  • 时,最优子树为树本身
  • 时,最优子树为单独的根节点组成的树

剪枝的判断

通过损失函数,我们可以判断一个结点是否应该被剪枝,步骤如下:

对于某一个结点 t,计算以单独的结点 t 组成的树的损失函数: 对于该结点 t,计算以 t 为根结点的子树 的损失函数: 或充分小时,有不等式 随着 的增大,必有 使得 此时 (单结点树 t 的结点数 |t| = 1)

此时就可以对结点 t 进行剪枝,去掉 t 的子结点。

即为判断剪枝的临界值。

举例 (以基尼系数作为损失函数的情况):

graph TD
A[A,样本数=10,Gini=0.50]
B[B,样本数=7,Gini=0.25]
C[C,样本数=3,Gini=0.10]

A-->B
A-->C

先计算 A 组成的单结点树的损失函数 再计算 A 为根节点的树 的损失函数,其中 t 表示树的所有叶结点,N(t)表示各个叶结点的样本数(权重) 因此,随着 增大到 1.475 时,结点 A 应被剪枝。

剪枝的过程

  • 对于完整树 ,计算每一个非叶结点的 g(t) 值
  • 对得到的所有 g(t) 进行从小到大排序
  • g(t) 从小到大,分别对应第 1, 2, ..., n 个被剪枝的结点
  • 对于 根据第 1 个结点进行剪枝得到
  • 对于 根据第 2 个结点进行剪枝得到
  • 如此往复循环,对于 ,可以生成子树序列

最优子树的挑选

挑选最优子树需要使用独立的验证集,而不是之前的训练集。通过验证集计算出序列中每一个子树的损失函数 (如MSE、Gini) 等指标,选择损失最小的子树作为最优子树,这便可以得到最优决策树。

举例

graph TD
A[A,样本数=10,Gini=0.50]
B[B,样本数=7,Gini=0.25]
C[C,样本数=3,Gini=0.10]
D[D,样本数=4,Gini=0.20]
E[E,样本数=3,Gini=0.10]
F[F,样本数=3,Gini=0.05]
G[G,样本数=1,Gini=0.08]

A-->B
A-->C
B-->D
B-->E
D-->F
D-->G

计算所有非叶结点的 g(t)

所有非叶结点:ABD

计算每个结点对应单结点树的损失:

计算每个结点作为根结点对应子树的损失:

计算每个结点的 g(t):

对 g(t) 进行排序

得到的 g(t) 序列:

按照 g(t) 顺序逐个剪枝生成子树

T1:

graph TD
A[A,样本数=10,Gini=0.50]
B[B,样本数=7,Gini=0.25]
C[C,样本数=3,Gini=0.10]
D[D,样本数=4,Gini=0.20]
E[E,样本数=3,Gini=0.10]

A-->B
A-->C
B-->D
B-->E

T2:

graph TD
A[A,样本数=10,Gini=0.50]
B[B,样本数=7,Gini=0.25]
C[C,样本数=3,Gini=0.10]

A-->B
A-->C

T3:

graph TD
A[A,样本数=10,Gini=0.50]

最后,只需要用测试集对每一个子树进行交叉验证,就可以挑选出最优子树作为决策树。