DQN(深度Q网络)通过使用神经网络逼近Q函数,解决了传统Q表在处理大规模状态空间(如Atari游戏画面)时的存储问题。其核心思想是用一个神经网络Q(s, a; θ)代替Q表,输入状态s,输出所有动作的Q值。DQN面临的主要挑战包括样本相关性问题、目标不稳定问题以及奖励稀疏问题。为解决这些问题,DQN引入了两个关键技巧:经验回放和目标网络。经验回放通过将智能体的经历存储在回放缓冲区中,并在训练时随机采样小批量数据,从而打破了样本相关性并提高了数据利用率。目标网络则通过维护两个网络(Q网络和目标网络),使得目标在短期内保持稳定,帮助Q网络更有效地进行拟合。DQN的三大改进包括Double DQN、Prioritized Experience Replay和Dueling DQN,分别解决了高估问题、样本采样效率和状态-动作值分解问题。读者学完本章后,能够理解DQN的基本原理、实现方法及其改进方向,并能够应用这些知识解决大规模状态空间下的强化学习问题。
深度 Q 网络 DQN
当状态空间太大(比如 Atari 游戏画面 = 256^160000 种可能),Q 表根本存不下。DQN(Deep Q-Network,2015)用神经网络逼近 Q 函数,把这个限制打破。
DQN 的核心思想
用一个神经网络 Q(s, a; θ) 代替 Q 表,输入是状态 s,输出是所有动作的 Q 值。
状态 s (比如游戏画面) → 神经网络 → Q(s, "左") Q(s, "右") Q(s, "上") Q(s, "下")
↓ ↓ ↓ ↓
每个动作一个分数
直觉:网络学一个"评估函数",给定状态,告诉你每个动作有多好。
为什么朴素做法会失败?
直接把 Q-Learning 套到神经网络上,训练会爆炸:
- 样本相关:连续帧的状态高度相关,违反 SGD 假设(独立同分布)
- 目标不稳定:target = R + γ*max Q(s', a') 里的 Q 也在变,目标在"跑",网络追不上
- 奖励稀疏:1 局游戏可能几千步才有 1 次得分,大部分步奖励 0
DQN 用了两个关键技巧解决这些:
技巧 1:经验回放(Experience Replay)
把智能体的经历 (s, a, r, s', done) 存到回放缓冲区(replay buffer),训练时随机采样小批量。
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=100000):
self.buffer = deque(maxlen=capacity)
def push(self, s, a, r, s_, done):
self.buffer.append((s, a, r, s_, done))
def sample(self, batch_size=32):
return random.sample(self.buffer, batch_size)
好处:
- 打破样本相关性(随机采样)
- 重复利用数据(样本效率 ↑ 10x)
- 简单稳定
技巧 2:目标网络(Target Network)
维护两个网络:
- Q 网络(
θ):实时更新,用来选动作 - 目标网络(
θ_target):定期从 Q 网络复制,用来算 target
# 更新公式
target = r + γ * max Q_target(s', a')
# 每隔 C 步同步
if step % C == 0:
target_net.load_state_dict(q_net.state_dict())
为什么有效?target 短时间内不变,Q 网络可以稳定地朝它拟合;然后再"换"一个新 target。
DQN 完整实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
class DQN(nn.Module):
def __init__(self, state_dim, n_actions, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, n_actions)
)
def forward(self, x):
return self.net(x)
class DQNAgent:
def __init__(self, state_dim, n_actions, lr=1e-3, gamma=0.99,
epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995,
buffer_size=10000, batch_size=64, target_update=10):
self.n_actions = n_actions
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update = target_update
# 两个网络
self.q_net = DQN(state_dim, n_actions)
self.target_net = DQN(state_dim, n_actions)
self.target_net.load_state_dict(self.q_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
self.buffer = deque(maxlen=buffer_size)
self.steps = 0
def select_action(self, state):
if random.random() < self.epsilon:
return random.randrange(self.n_actions)
with torch.no_grad():
return self.q_net(state).argmax().item()
def push(self, s, a, r, s_, done):
self.buffer.append((s, a, r, s_, done))
def update(self):
if len(self.buffer) < self.batch_size:
return
batch = random.sample(self.buffer, self.batch_size)
s, a, r, s_, done = zip(*batch)
s = torch.stack(s)
a = torch.tensor(a)
r = torch.tensor(r, dtype=torch.float32)
s_ = torch.stack(s_)
done = torch.tensor(done, dtype=torch.float32)
# 当前 Q 值
q_pred = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
# 目标 Q 值(用目标网络, 且不传梯度)
with torch.no_grad():
q_next = self.target_net(s_).max(1)[0]
q_target = r + self.gamma * q_next * (1 - done)
# 损失 + 反向传播
loss = nn.functional.mse_loss(q_pred, q_target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 衰减 epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
# 定期同步目标网络
self.steps += 1
if self.steps % self.target_update == 0:
self.target_net.load_state_dict(self.q_net.state_dict())
return loss.item()
DQN 三大改进(论文必读)
DQN 2015 年发表,后来又有几个里程碑式的改进:
Double DQN(2016)
问题:DQN 的 target max Q(s', a') 容易高估——网络会"虚报"某些动作的 Q 值,贪心选错的。
解决:用 Q 网络选动作,用 target 网络估值:
target = r + γ * Q_target(s', argmax_a Q(s', a)) # 用 Q 网络选, target 网络估
Prioritized Experience Replay(2016)
不是均匀采样,TD 误差大的样本(更"惊讶"的)被采样概率更高——学得更快。
Dueling DQN(2016)
把 Q 拆成状态价值 V(s) + 动作优势 A(s, a):
Q(s, a) = V(s) + A(s, a) - mean_a' A(s, a')
直觉:有些状态不管做什么都好(比如快赢了),有些状态不管做什么都差(比如快输了)。把"状态本身的好坏"和"动作的相对好坏"分开学,通常更稳定。
DQN 的局限
虽然很经典,DQN 也有明显问题:
- 只支持离散动作:CartPole 的 左 / 右 行,但机器人有连续关节角度
- 样本效率不够高:玩一局 Atari 要看 200 帧,但只训练 1 次
- 不稳定:同样的代码,两次结果可能差很多
- 不能处理部分可观测:只看一帧画面,没有记忆
改进方向:
- 连续动作 → DDPG、TD3、SAC
- 样本效率 → Model-based RL(学习环境模型)
- 记忆 → LSTM + RL
训练稳定性技巧
实际训练 DQN 时,这些技巧非常有用:
# 1. 奖励裁剪(对 Atari 特别有效)
reward = np.clip(reward, -1, 1)
# 2. 梯度裁剪
torch.nn.utils.clip_grad_norm_(q_net.parameters(), 10.0)
# 3. Huber Loss 替代 MSE(对异常值更稳健)
loss = nn.functional.smooth_l1_loss(q_pred, q_target)
# 4. 软更新 target 网络(Polyak averaging)
target_param.data.copy_(tau * q_param.data + (1 - tau) * target_param.data)
# tau 一般 0.001 ~ 0.01
小结
- DQN = 神经网络 + 经验回放 + 目标网络——三大支柱
- 解决了 Q 表在大状态空间下的"装不下"问题
- Double DQN / Prioritized / Dueling 是常用增强
- DQN 只能处理离散动作
- 改进方向:DDPG / PPO / SAC
练习思考
- 为什么经验回放能"打破样本相关性"?举例说明。
- 目标网络不更新会怎样?更新太快又会怎样?
- 把 Double DQN 改进加到上面的 DQN 实现里,跑 CartPole 看看有没有差异。
章末小测验
检验你对《深度 Q 网络 DQN》的掌握程度。
dqn 的核心概念是?