ML 学习站
跳到正文

实战:CartPole 训练

用 PyTorch + Gym 训一个 DQN agent。

45 分钟5 / 51,225
加载中...

本章通过训练 CartPole-v1 环境中的 DQN agent,展示了强化学习的基本流程和核心概念。核心内容包括:环境配置、状态和动作空间理解、DQN 算法实现、训练过程及结果可视化。读者将学习如何设置强化学习任务,包括处理 4 维连续状态空间和 2 个离散动作,以及如何应用经验回放、目标网络和神经网络等 DQN 的关键组件。学完后,读者能够独立训练一个 CartPole agent,使其达到平均奖励 ≥ 475 的目标,并掌握调整超参数(如学习率、gamma 值、epsilon 值等)以优化训练效果的方法。此外,本章还介绍了让训练更稳定的 7 个技巧,如奖励缩放、梯度裁剪和 Huber Loss 等。最后,读者将了解如何保存和加载模型,并尝试在更复杂的任务(如 LunarLander)中应用更高级的算法(如 DDPG、PPO 和 SAC)。

实战:CartPole 训练

CartPole 是强化学习的"Hello World"——简单到几分钟能跑通,但又能体现 RL 的所有核心思想。这一章带你端到端训练一个 DQN agent

项目目标

  • 环境:CartPole-v1(OpenAI Gym)
  • 状态:4 维连续向量(小车位置、速度、杆子角度、角速度)
  • 动作:2 个离散(0=左推, 1=右推)
  • 奖励:每平衡 1 步 +1
  • 目标:平均奖励 ≥ 475(满分 500)

预计训练时间:5-15 分钟(GPU)/ 10-20 分钟(CPU)

第一步:环境准备

pip install torch gymnasium numpy matplotlib
# 注: Gymnasium 是 Gym 的现代维护版,API 一样

第二步:理解环境

import gymnasium as gym

env = gym.make("CartPole-v1")
print(f"状态空间: {env.observation_space}")   # Box(4,)
print(f"动作空间: {env.action_space}")        # Discrete(2)
print(f"奖励范围: {env.reward_range}")        # (-inf, inf)

# 跑一局
state, _ = env.reset()
print(f"初始状态: {state}")
total_reward = 0
done = False
while not done:
    action = env.action_space.sample()  # 随机动作
    state, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    done = terminated or truncated
print(f"随机策略总奖励: {total_reward}")

第三步:实现 DQN

把上一章的 DQN 实现搬过来,加一点针对 CartPole 的优化:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import gymnasium as gym

class QNetwork(nn.Module):
    def __init__(self, state_dim=4, n_actions=2, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions)
        )

    def forward(self, x):
        return self.net(x)


class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, s, a, r, s_, done):
        self.buffer.append((s, a, r, s_, done))

    def sample(self, batch_size=64):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s_, done = zip(*batch)
        return (
            torch.tensor(np.array(s), dtype=torch.float32),
            torch.tensor(a, dtype=torch.long),
            torch.tensor(r, dtype=torch.float32),
            torch.tensor(np.array(s_), dtype=torch.float32),
            torch.tensor(done, dtype=torch.float32)
        )

    def __len__(self):
        return len(self.buffer)


class DQNAgent:
    def __init__(self, state_dim=4, n_actions=2):
        self.n_actions = n_actions
        self.gamma = 0.99
        self.batch_size = 64
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.target_update = 10

        self.q_net = QNetwork(state_dim, n_actions)
        self.target_net = QNetwork(state_dim, n_actions)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=1e-3)
        self.buffer = ReplayBuffer(capacity=10000)
        self.steps = 0

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.n_actions)
        with torch.no_grad():
            q = self.q_net(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
            return q.argmax().item()

    def train_step(self):
        if len(self.buffer) < self.batch_size:
            return None

        s, a, r, s_, done = self.buffer.sample(self.batch_size)
        q_pred = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
        with torch.no_grad():
            q_next = self.target_net(s_).max(1)[0]
            q_target = r + self.gamma * q_next * (1 - done)
        loss = nn.functional.mse_loss(q_pred, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 10.0)
        self.optimizer.step()

        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        self.steps += 1
        if self.steps % self.target_update == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())
        return loss.item()

第四步:训练

env = gym.make("CartPole-v1")
agent = DQNAgent(state_dim=4, n_actions=2)

episodes = 500
rewards_history = []

for ep in range(episodes):
    state, _ = env.reset()
    total_reward = 0
    done = False
    while not done:
        action = agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        agent.buffer.push(state, action, reward, next_state, done)
        agent.train_step()
        state = next_state
        total_reward += reward

    rewards_history.append(total_reward)

    # 打印进度
    if (ep + 1) % 20 == 0:
        recent = rewards_history[-20:]
        avg = np.mean(recent)
        print(f"Episode {ep+1:4d} | "
              f"Reward: {total_reward:3.0f} | "
              f"Avg(20): {avg:6.1f} | "
              f"ε: {agent.epsilon:.3f}")

预期输出:

Episode   20 | Reward:  18 | Avg(20):   14.5 | ε: 0.905
Episode   40 | Reward:  32 | Avg(20):   21.3 | ε: 0.819
Episode   60 | Reward:  45 | Avg(20):   38.7 | ε: 0.741
Episode   80 | Reward:  78 | Avg(20):   65.2 | ε: 0.671
Episode  100 | Reward:  120 | Avg(20):  103.5 | ε: 0.607
Episode  200 | Reward:  500 | Avg(20):  485.0 | ε: 0.299
Episode  300 | Reward:  500 | Avg(20):  500.0 | ε: 0.149

大概 150-250 轮就能稳到 500 分(满分)。

第五步:可视化训练曲线

import matplotlib.pyplot as plt

def smooth(values, window=20):
    return [np.mean(values[max(0, i-window):i+1]) for i in range(len(values))]

plt.figure(figsize=(10, 5))
plt.plot(rewards_history, alpha=0.3, label='每轮奖励')
plt.plot(smooth(rewards_history, 20), label='滑动平均(20 轮)', linewidth=2)
plt.axhline(475, color='r', linestyle='--', label='目标线 (475)')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN on CartPole')
plt.legend()
plt.grid(alpha=0.3)
plt.savefig('cartpole_training.png', dpi=120)
plt.show()

第六步:看 agent 玩耍

import time
env = gym.make("CartPole-v1", render_mode="human")
state, _ = env.reset()
done = False
while not done:
    action = agent.select_action(state)
    state, _, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    time.sleep(0.02)
env.close()

第七步:保存和加载模型

# 保存
torch.save({
    'q_net': agent.q_net.state_dict(),
    'target_net': agent.target_net.state_dict(),
    'epsilon': agent.epsilon
}, 'cartpole_dqn.pth')

# 加载
checkpoint = torch.load('cartpole_dqn.pth')
agent.q_net.load_state_dict(checkpoint['q_net'])
agent.target_net.load_state_dict(checkpoint['target_net'])
agent.epsilon = checkpoint['epsilon']
agent.epsilon = 0  # 测试时关掉探索

调参进阶

跑通之后,可以试试这些改进看能不能更快收敛:

# 1. Double DQN
def train_step_double_dqn(self):
    s, a, r, s_, done = self.buffer.sample(self.batch_size)
    q_pred = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
    with torch.no_grad():
        # 用 Q 网络选动作, target 网络估 Q
        best_actions = self.q_net(s_).argmax(1).unsqueeze(1)
        q_next = self.target_net(s_).gather(1, best_actions).squeeze()
        q_target = r + self.gamma * q_next * (1 - done)
    loss = nn.functional.mse_loss(q_pred, q_target)
    # ... 同前

# 2. Dueling DQN
class DuelingQNetwork(nn.Module):
    def __init__(self, state_dim, n_actions, hidden=128):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(state_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU()
        )
        self.value = nn.Linear(hidden, 1)        # V(s)
        self.advantage = nn.Linear(hidden, n_actions)  # A(s, a)

    def forward(self, x):
        f = self.feature(x)
        v = self.value(f)
        a = self.advantage(f)
        # Q = V + (A - mean(A))
        return v + a - a.mean(dim=1, keepdim=True)

实战技巧汇总

进阶:换更难的环境

CartPole 只是开胃菜,试着挑战更难的:

# 1. Acrobot - 双节摆, 目标甩到顶
env = gym.make("Acrobot-v1")

# 2. MountainCar - 小车要爬上山
env = gym.make("MountainCar-v0")

# 3. LunarLander - 登月(连续动作)
env = gym.make("LunarLander-v2")  # 需要 PPO/SAC

LunarLander 用了连续动作,DQN 搞不定,需要 DDPG/PPO/SAC。

小结

  • CartPole 4 维状态 + 2 离散动作,完美的 DQN 入门
  • 经验回放 + 目标网络 + 神经网络 = DQN 三件套
  • 100-200 轮基本能训到满分
  • 调参:lr、gamma、epsilon、buffer size、target_update
  • 想玩更难的环境,得换算法(DDPG/PPO/SAC)

练习思考

  1. 把 γ 改成 0.5 和 0.999,分别跑一下,看哪个更慢/更不稳。
  2. 不开经验回放(buffer size = 1),能训出来吗?为什么?
  3. 试试换成 Acrobot 环境,改改超参能不能训出来?

章末小测验

检验你对《实战:CartPole 训练》的掌握程度。

1

以下关于 CartPole 环境的描述,哪些是正确的?

2

关于 DQN 的核心组件,以下哪些说法是正确的?

3

以下哪些环境需要使用 DDPG/PPO/SAC 等算法?

4

在 CartPole 环境中,以下哪些调参选项是常见的?

5

关于 CartPole 训练,以下哪些说法是正确的?

讨论区(0)

加载评论中...