ML 学习站
跳到正文

实战:图像分类

用 PyTorch 训一个 ResNet,识别 CIFAR-10。

45 分钟5 / 61,046
加载中...

本章通过PyTorch端到端地训练一个图像分类器,涵盖从数据准备到模型部署的完整流程。核心概念包括使用CIFAR-10数据集、ResNet-18模型的微调(迁移学习)以及数据增强的重要性。读者将学习如何通过数据增强技术(如Mixup、CutMix、AutoAugment)提升模型性能,并掌握微调预训练模型的方法,该方法比从头训练快10倍且精度高20%。完成本章后,读者能够独立完成图像分类任务,达到90%以上的测试准确率,并了解如何进行错误分析和模型调优(如使用更深的模型、更长的训练时间、学习率调度和TTA等)。此外,读者还将掌握将训练好的模型导出为ONNX或TorchScript格式,以便在各种平台上进行部署。

实战:图像分类

这一章用 PyTorch 端到端训练一个图像分类器,从数据准备到模型部署。强烈建议你跟着敲一遍

项目目标

  • 数据集:CIFAR-10(10 类,每类 6000 张 32×32 彩色图)
  • 模型:ResNet-18 微调(迁移学习)
  • 目标:测试准确率 90%+
  • 时间:GPU 上 5-10 分钟

第一步:环境准备

pip install torch torchvision tensorboard matplotlib

第二步:数据准备

import torch
import torchvision
import torchvision.transforms as transforms

# 训练集增强(数据增强能大幅提高泛化)
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),       # 随机裁剪
    transforms.RandomHorizontalFlip(),         # 随机水平翻转
    transforms.ColorJitter(0.2, 0.2, 0.2),    # 随机调色
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 测试集不做增强,只做归一化
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=test_transform
)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

第三步:模型定义

import torch.nn as nn
import torchvision.models as models

def get_model(num_classes=10):
    # 加载预训练 ResNet-18
    model = models.resnet18(pretrained=True)

    # 替换最后的全连接层
    # ImageNet 是 1000 类, CIFAR-10 是 10 类
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

model = get_model()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
print(f"使用设备: {device}")

第四步:训练循环

import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
scaler = torch.amp.GradScaler('cuda')  # 混合精度训练, 提速 2-3 倍

best_acc = 0
for epoch in range(20):
    # === 训练阶段 ===
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    scheduler.step()

    # === 验证阶段 ===
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            test_total += targets.size(0)
            test_correct += predicted.eq(targets).sum().item()

    test_acc = 100. * test_correct / test_total
    print(f"Epoch {epoch+1}/20 | "
          f"Train Loss: {train_loss/len(trainloader):.3f}, "
          f"Train Acc: {100.*correct/total:.2f}% | "
          f"Test Acc: {test_acc:.2f}%")

    # 保存最好的模型
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_model.pth')

print(f"\n最佳测试准确率: {best_acc:.2f}%")

预期结果:5-10 分钟,90%+ 准确率(用 GPU)。

第五步:可视化与错误分析

看看哪些样本被分错了,往往能给出改进的方向:

import matplotlib.pyplot as plt
import numpy as np

# 加载最佳模型
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# 收集错误预测
misclassified = []
with torch.no_grad():
    for inputs, targets in testloader:
        outputs = model(inputs.to(device))
        _, predicted = outputs.max(1)
        mask = predicted.cpu() != targets
        for i in range(mask.sum()):
            misclassified.append((
                inputs[mask.nonzero()[0][i]],
                targets[mask.nonzero()[0][i]].item(),
                predicted[mask.nonzero()[0][i]].item()
            ))

# 可视化前 10 个错误
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, (img, true, pred) in enumerate(misclassified[:10]):
    ax = axes[i // 5, i % 5]
    img = img / 2 + 0.5  # 反归一化
    npimg = img.numpy()
    ax.imshow(np.transpose(npimg, (1, 2, 0)))
    ax.set_title(f"真: {classes[true]}\n预测: {classes[pred]}", fontsize=10)
    ax.axis('off')
plt.tight_layout()
plt.savefig('errors.png')
print("错误样本已保存到 errors.png")

第六步:加载模型做预测

def predict(model, image_tensor, classes):
    """对单张图片做预测"""
    model.eval()
    with torch.no_grad():
        output = model(image_tensor.unsqueeze(0).to(device))
        probs = torch.nn.functional.softmax(output, dim=1)[0]
        top5_prob, top5_idx = torch.topk(probs, 5)
        return [(classes[i], p.item()) for i, p in zip(top5_idx, top5_prob)]

# 测试
from PIL import Image
img = Image.open('my_cat.jpg')
img_tensor = test_transform(img)
predictions = predict(model, img_tensor, classes)
for label, prob in predictions:
    print(f"{label}: {prob:.2%}")

调优进阶:再榨几个点

如果 90% 不够,这些技巧能帮你到 93-95%:

  1. 更强的数据增强:Mixup、CutMix、AutoAugment
  2. 更深的模型:ResNet-50 / EfficientNet-B0
  3. 更长的训练:50-100 轮
  4. 学习率调度:Warmup + Cosine
  5. TTA(Test-Time Augmentation):测试时对多张增强图预测取平均
  6. 模型集成:训 5 个模型,投票
# Mixup 增强示例
def mixup(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    return lam * x + (1 - lam) * x[idx], y, y[idx], lam

部署:导出为 ONNX / TorchScript

训完怎么用?导出为标准格式,在任何平台跑:

# ONNX 格式
dummy = torch.randn(1, 3, 32, 32).to(device)
torch.onnx.export(model, dummy, "model.onnx",
                  input_names=["input"], output_names=["output"],
                  dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

# ONNX 推理(任何语言)
import onnxruntime as ort
session = ort.InferenceSession("model.onnx")
outputs = session.run(None, {"input": image_tensor.numpy()})

小结

  • 微调预训练模型是工业默认,比从头训快 10 倍,精度高 20%
  • 数据增强是免费的午餐
  • 混合精度训练(CUDA 专用)能提速 2-3 倍
  • 错误分析往往比改模型更能提精度
  • 部署:ONNX / TorchScript 是跨平台标准

练习思考

  1. 把 batch size 从 128 调到 32 和 512,观察训练速度和最终精度的变化。
  2. 试不同 optimizer(SGD、Adam、AdamW),结果有什么不同?
  3. 试试用 EfficientNet-B0 替换 ResNet-18,看精度能到多少。

章末小测验

检验你对《实战:图像分类》的掌握程度。

1

关于数据增强的作用,下列哪些说法是正确的?

2

关于提高图像分类模型性能的技巧,下列哪些说法是正确的?

3

关于模型训练和部署,下列哪些说法是正确的?

4

关于错误分析的作用,下列哪些说法是正确的?

5

关于模型训练中的batch size调整,下列哪些说法是正确的?

讨论区(0)

加载评论中...