本章探讨了如何利用控制图(Control Chart)对机器学习模型的训练过程进行实时监控,以解决传统监控方法中无法及时发现训练异常的问题。核心概念包括:1)控制图的基本原理,即通过3σ法则判断数据是否在控制范围内;2)四种主要控制图类型(均值控制图、个值控制图、移动极差图、比例控制图),分别适用于不同类型的监控指标;3)Western Electric规则,提供比单一3σ法则更敏感的异常检测方法。学习本章后,读者能够设计并应用控制图来实时监控训练过程中的多个关键指标,如loss、accuracy、梯度范数等,从而及时发现并应对训练中的异常情况,减少训练时间并提高模型稳定性。
训练过程的统计监控
本章问题: 训练跑了一晚, 第二天发现 loss 早就 plateau 了 5 小时。怎么"实时"知道训练是否正常? 答案: 控制图 (Control Chart), 工业质量管理用了一个世纪的方法。
1. 训练监控的现状
大多数 ML 工程师的监控:
- ❌ 只看 loss/accuracy 曲线, 等训练完才发现问题
- ❌ 没量化"正常波动"和"异常波动"
- ❌ 多卡训练时, 不知道哪张卡异常
工业质量管理 (SPC) 用 控制图 解决了完全相同的问题: 监控"过程"是否在控。
2. 控制图基础 (Control Chart)
2.1 3σ 法则
如果过程在控, 数据点应该在 μ ± 3σ 范围内。超出 = 异常。
UCL (Upper Control Limit) = μ + 3σ
│
┌─────────┼─────────┐
│ 正常区域 │ 异常 │ ← 1 个点出 UCL → 异常
───────┼─────────┼─────────┼─────
│ │ │
│ (中线) │ │
│ │ │
└─────────┼─────────┘
│
LCL = μ - 3σ
2.2 控制图的 2 步构建
- 收集基线: 在"已知正常"的训练运行中, 记录 20+ 个数据点 (如每个 epoch 的 loss)
- 计算控制限: μ, σ, UCL = μ+3σ, LCL = μ-3σ
- 持续监控: 实时点 vs 控制限, 出界 = 告警
import numpy as np
import matplotlib.pyplot as plt
# 1. 基线 (假设前 20 个 epoch 是"正常训练")
np.random.seed(42)
baseline_loss = np.random.normal(0.5, 0.05, 20) # 20 个正常值
mu = baseline_loss.mean()
sigma = baseline_loss.std(ddof=1)
ucl = mu + 3 * sigma
lcl = mu - 3 * sigma
# 2. 实时数据 (训练中)
new_loss = np.concatenate([
np.random.normal(0.5, 0.05, 30), # 继续正常
np.random.normal(1.0, 0.1, 5), # 突然发散!
])
# 3. 画控制图
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(baseline_loss, "o-", color="blue", label="基线 (正常)")
ax.axhline(mu, color="green", linestyle="--", label=f"μ = {mu:.3f}")
ax.axhline(ucl, color="red", linestyle="--", label=f"UCL = {ucl:.3f}")
ax.axhline(lcl, color="red", linestyle="--", label=f"LCL = {lcl:.3f}")
ax.plot(range(20, 20+len(new_loss)), new_loss, "s-", color="orange", label="实时")
# 标异常
ax.axvline(20 + 30, color="red", alpha=0.3, label="异常开始")
ax.set_xlabel("Epoch"); ax.set_ylabel("Loss")
ax.set_title("控制图: 监控训练 Loss")
ax.legend(); ax.grid(True, alpha=0.3)
plt.show()
3. 4 种控制图
3.1 均值控制图 (X-bar Chart)
监控连续变量的均值
3.2 个值控制图 (I Chart / X Chart)
监控单个值 (适合小批量训练, 每步一个数据点)
3.3 移动极差图 (MR Chart)
监控相邻差值 (波动大小), 跟 X Chart 配对
3.4 比例控制图 (P Chart)
监控比例 (如每批的"准确率")
class ControlChart:
"""完整的控制图工具"""
def __init__(self, baseline):
self.mu = np.mean(baseline)
self.sigma = np.std(baseline, ddof=1)
self.ucl = self.mu + 3 * self.sigma
self.lcl = self.mu - 3 * self.sigma
# 移动极差
mr = np.abs(np.diff(baseline))
self.mr_bar = mr.mean()
# I Chart σ 估计
self.sigma_i = self.mr_bar / 1.128 # d2 (n=2)
def check_point(self, value):
"""检查单点是否异常"""
if value > self.ucl or value < self.lcl:
return "OUT_OF_CONTROL", abs(value - self.mu) / self.sigma_i
return "OK", abs(value - self.mu) / self.sigma_i
def check_run(self, values):
"""检查整个序列, 返回异常点索引"""
return [i for i, v in enumerate(values) if self.check_point(v)[0] == "OUT_OF_CONTROL"]
def western_electric_rules(self, values):
"""Western Electric 4 规则: 更敏感的异常检测"""
alerts = []
for i in range(len(values)):
v = values[i]
# 规则 1: 1 个点出 3σ
if v > self.ucl or v < self.lcl:
alerts.append(i)
# 规则 2: 连续 9 个点在中心同侧
if i >= 8:
last9 = values[i-8:i+1]
if all(x > self.mu for x in last9) or all(x < self.mu for x in last9):
alerts.append(i)
# 规则 3: 连续 6 个点递增/递减
if i >= 5:
last6 = values[i-5:i+1]
if all(last6[j] < last6[j+1] for j in range(5)):
alerts.append(i) # 上升趋势
if all(last6[j] > last6[j+1] for j in range(5)):
alerts.append(i) # 下降趋势
return list(set(alerts))
4. 训练监控实战
4.1 监控指标设计
| 指标 | 控制图 | 异常含义 |
|---|---|---|
| 训练 loss (每个 epoch) | I Chart | 发散/NaN |
| 验证 accuracy | I Chart | 过拟合 |
| 验证 loss - 训练 loss | I Chart | 泛化能力 |
| 梯度范数 | I Chart | 爆炸/消失 |
| 权重范数 | I Chart | 初始化异常 |
| 学习率 (实际) | I Chart | scheduler 异常 |
| 训练步耗时 | I Chart | I/O 瓶颈 |
| 显存使用 | I Chart | OOM 风险 |
| NaN/Inf 比例 | P Chart | 数值不稳定 |
| 梯度爆炸比例 | P Chart | 梯度截断失效 |
class TrainingMonitor:
"""训练过程实时监控"""
def __init__(self, baseline_file=None):
self.charts = {}
if baseline_file:
self.load_baseline(baseline_file)
def register(self, name, baseline):
self.charts[name] = ControlChart(baseline)
def step(self, name, value, epoch):
if name not in self.charts:
# 第一次, 用前 10 个点当基线
if not hasattr(self, f"_buffer_{name}"):
setattr(self, f"_buffer_{name}", [])
buffer = getattr(self, f"_buffer_{name}")
buffer.append(value)
if len(buffer) >= 10:
self.register(name, buffer)
return "INITIALIZING"
status, z = self.charts[name].check_point(value)
if status == "OUT_OF_CONTROL":
print(f"⚠️ [Epoch {epoch}] {name} = {value:.4f} 异常 (z={z:.2f})")
return status
def check(self, name, values):
"""训练完一次性检查"""
if name not in self.charts:
return []
return self.charts[name].check_run(values)
4.2 实时监控示例
# 模拟一次训练, 边训边监控
monitor = TrainingMonitor()
# 假设这是前 10 个 epoch 的 loss (当基线)
losses = []
for epoch in range(100):
# 模拟 loss
if epoch < 10:
loss = 0.5 + np.random.normal(0, 0.02) - epoch * 0.005 # 正常下降
elif epoch < 30:
loss = 0.45 + np.random.normal(0, 0.02) # 稳定
elif epoch == 30:
loss = 1.5 # 突然发散 (模拟)
else:
loss = 0.45 + np.random.normal(0, 0.02)
losses.append(loss)
status = monitor.step("loss", loss, epoch)
if status == "OUT_OF_CONTROL":
print(f" 建议: 停止训练, 检查数据/超参")
break
5. Western Electric 规则: 更敏感的异常检测
1 个点出 3σ 算太宽松, 实际工业用 Western Electric 4 大规则:
| 规则 | 检测 |
|---|---|
| 1 | 1 个点出 3σ |
| 2 | 连续 9 个点在中心同侧 (趋势) |
| 3 | 连续 6 个点递增/递减 (单调趋势) |
| 4 | 连续 14 个点交替上下 (振荡) |
# 集成监控
def smart_monitor(losses, baseline_chart, verbose=True):
"""智能训练监控"""
alerts = baseline_chart.western_electric_rules(losses)
if alerts:
if verbose:
print(f"⚠️ 检测到异常 epoch: {alerts}")
return alerts
return []
# 训练: 每个 epoch 调一次
# 异常时: 发 Slack 告警 / 自动回滚 checkpoint / 调小学习率
6. 高级: 多元控制图 (Hotelling T²)
单变量控制图只能看一个指标。多变量 用 Hotelling T²:
from statsmodels.stats.multivariate import test_mvmean
# 例: 同时监控 loss, accuracy, grad_norm
def multivariate_monitor(metrics_history, baseline, alpha=0.01):
"""Hotelling T² 多变量控制"""
n = len(baseline)
k = baseline.shape[1]
diff = np.array(metrics_history[-n:]) - baseline.mean(axis=0)
cov = np.cov(baseline.T)
T2 = n * diff @ np.linalg.inv(cov) @ diff
# F 分布临界值
from scipy.stats import f
f_crit = f.ppf(1-alpha, k, n-k)
ucl = ((n-1)**2 / n) * f_crit / n
return T2 > ucl, T2
7. 控制图在生产环境中的"完整告警链"
import requests # Slack/钉钉 webhook
class TrainingAlertSystem:
"""完整训练告警"""
def __init__(self, webhook_url=None):
self.webhook_url = webhook_url
self.charts = {}
self.epoch = 0
self.checkpoints = []
def on_epoch_end(self, metrics, model):
"""每个 epoch 结束调用"""
self.epoch += 1
alerts = []
for name, value in metrics.items():
if name in self.charts:
status, z = self.charts[name].check_point(value)
if status == "OUT_OF_CONTROL":
alerts.append(f"{name}={value:.4f} (z={z:.2f})")
if alerts:
self._handle_alert(alerts, model)
def _handle_alert(self, alerts, model):
"""告警处理: 1) 通知 2) 回滚 3) 暂停"""
# 1. Slack 通知
if self.webhook_url:
requests.post(self.webhook_url, json={
"text": f"⚠️ 训练异常 (epoch {self.epoch}): {', '.join(alerts)}"
})
# 2. 加载上一个稳定 checkpoint
if self.checkpoints:
model.load_state_dict(self.checkpoints[-1])
print(f" 已回滚到 epoch {self.epoch - 1} 的权重")
# 3. 可选: 暂停训练, 让人工介入
8. 现代替代: TensorBoard + 自定义插件
# TensorBoard + 自定义 HParams + 控制图插件
from torch.utils.tensorboard import SummaryWriter
import numpy as np
writer = SummaryWriter("runs/exp1")
# HParams (记录超参)
writer.add_hparams(
{"lr": 0.001, "batch_size": 32, "model": "resnet50"},
{"hparam/accuracy": 0.92}
)
# 控制图
for epoch in range(100):
loss = train_one_epoch()
writer.add_scalar("Loss/train", loss, epoch)
# 自定义告警
if loss > 1.0 and epoch > 10:
writer.add_scalar("Alerts/anomaly", 1, epoch)
W&B, MLflow, ClearML 等平台有内置的"训练监控告警"。
9. 实战: 完整训练监控脚本
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
class RealTimeMonitor:
"""实时训练监控器"""
def __init__(self, window=20):
self.window = window
self.history = deque(maxlen=1000)
self.baseline = None
self.alerts = []
def add(self, value):
self.history.append(value)
if self.baseline is None and len(self.history) >= self.window:
self.baseline = np.array(list(self.history))[:self.window]
print(f"[监控] 基线建立: μ={self.baseline.mean():.4f}, σ={self.baseline.std():.4f}")
def check(self):
if self.baseline is None or len(self.history) < self.window:
return None
mu, sigma = self.baseline.mean(), self.baseline.std(ddof=1)
current = list(self.history)[-1]
if abs(current - mu) > 3 * sigma:
alert = f"异常: 当前={current:.4f}, 基线 μ={mu:.4f}, 3σ 范围=[{mu-3*sigma:.4f}, {mu+3*sigma:.4f}]"
self.alerts.append(alert)
return alert
return None
def plot(self):
fig, ax = plt.subplots(figsize=(12, 5))
if self.baseline is not None:
mu = self.baseline.mean()
sigma = self.baseline.std(ddof=1)
ax.axhline(mu, color="green", linestyle="--")
ax.axhline(mu + 3*sigma, color="red", linestyle="--", label="UCL")
ax.axhline(mu - 3*sigma, color="red", linestyle="--", label="LCL")
ax.plot(list(self.history), "b-", alpha=0.7)
# 标异常
for i, _ in enumerate(self.history):
if self.baseline is not None and i >= self.window:
v = list(self.history)[i]
mu = self.baseline.mean()
sigma = self.baseline.std(ddof=1)
if abs(v - mu) > 3 * sigma:
ax.plot(i, v, "rx", markersize=12)
ax.set_xlabel("Step"); ax.set_ylabel("Loss")
ax.set_title("训练 Loss 实时控制图")
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout(); plt.show()
# 用法
monitor = RealTimeMonitor(window=20)
for step in range(100):
if step < 20:
loss = 0.5 + np.random.normal(0, 0.02) # 正常
elif step == 50:
loss = 2.0 # 异常
else:
loss = 0.45 + np.random.normal(0, 0.02)
monitor.add(loss)
alert = monitor.check()
if alert:
print(f" Step {step}: {alert}")
10. 小结
| 你学到了 | 关键点 |
|---|---|
| 控制图 3σ | 1 个点出 3σ 算异常 |
| 4 种控制图 | I / X-bar / MR / P, 选对监控类型 |
| Western Electric | 4 大规则比单 3σ 更敏感 |
| 训练指标 | loss, acc, grad, 权重, 时间, 显存, NaN 比例 |
| 自动告警 | Slack/钉钉 + 自动回滚 checkpoint |
| Hotelling T² | 多变量联合监控 |
| ML 平台 | W&B, MLflow, ClearML 都内置告警 |
| 业务价值 | 减少 50% 训练时间, 提前发现 NaN |
11. 习题
-
模拟 100 个 epoch 训练, 其中第 50-55 epoch 出现"loss spike":
- 建立控制图 (前 20 epoch 当基线)
- 用 Western Electric 4 规则检测异常
- 报告: 哪些 epoch 被标异常? 用什么规则?
-
写一个
TrainingMonitor类:- 同时监控 loss, gradient_norm, learning_rate
- 任何一个指标出 3σ, 打印告警
- 模拟训练, 验证类能工作
👉 查看参考答案
-
提示: 用前 20 epoch 当基线, 算 mu, sigma, 然后用 I Chart + Western Electric 规则。 第 50-55 epoch 会被规则 1 (出 3σ) 和规则 3 (连续递增) 标记。 Western Electric 比单 3σ 更早发现异常。
-
提示: 维护 3 个 ControlChart 实例, 每个 epoch 调一次 check_point。 模拟 loss spike / gradient explosion / scheduler 异常, 验证都能被捕获。
12. 下一章
- 机器学习入门 → EDA: 训练前的数据探索
- 统计学基础 → 用图表探索数据: 图表的更多玩法
- 统计学基础 → 描述统计: 集中趋势与离散度
📚 本章综合: 改编自 Triola《基础统计学》第 14 章统计过程控制 (SPC), 加入 ML 训练监控实战。
章末小测验
检验你对《训练过程的统计监控》的掌握程度。
关于控制图在机器学习训练监控中的作用,以下哪些说法是正确的?
以下哪些控制图类型适用于监控机器学习训练中的单个值?
关于Western Electric规则,以下哪些说法是正确的?
以下哪些指标适用于使用I Chart进行监控?
在机器学习训练监控中,以下哪些方法可以用于减少训练时间或提前发现问题?