autowzry-agent

Target Network 优化

日期: 2025-11-22 类型: 训练算法优化


背景

在强化学习训练中,传统的DQN算法存在训练不稳定的问题。当使用同一个网络既计算当前Q值又计算目标Q值时,容易产生目标值不断变化(moving target)的问题,导致训练震荡甚至发散。

原始DQN论文(Mnih et al., 2015)提出了Target Network解决方案,通过使用一个延迟更新的目标网络来计算目标Q值,从而提高训练稳定性。autowzry-agent项目在之前的实现中未使用Target Network,本次开发中引入了该重要特性。


修改内容

1. 配置文件新增

文件: config/config.py

新增属性:

# Target Network
self.use_target_network: bool = True  # 是否使用目标网络
self.target_update_epochs: int = 5    # 每N个epoch更新一次目标网络

文件: config/agent.config.yaml

新增配置项:

training:
  use_target_network: true  # 是否使用目标网络
  target_update_epochs: 5   # 每N个epoch更新一次目标网络

2. Trainer类重构

文件: core/trainer.py

新增初始化逻辑:

# Target Network
if config.use_target_network:
    import copy
    self.target_model = copy.deepcopy(self.model)
    self.target_model.eval()
    print(f"  Using Target Network: True")
    print(f"  Target update every {config.target_update_epochs} epochs")
else:
    self.target_model = self.model
    print(f"  Using Target Network: False")

模型加载后同步:

# 加载resume_model后同步目标网络
if config.resume_model is not None and config.resume_model != '':
    if os.path.exists(config.resume_model):
        self.load_model(config.resume_model)
        # Update target network after loading
        if config.use_target_network:
            self.target_model.load_state_dict(self.model.state_dict())

修改loss计算逻辑:

# 目标Q值(使用目标网络)
with torch.no_grad():
    next_q_values = self.target_model(next_states)  # 使用目标网络计算
    max_next_q = next_q_values.max(dim=1)[0]
    target_q = rewards + self.config.gamma * max_next_q

更新训练循环:

3. 移除 best_loss 机制

修改原因:

删除内容:


影响范围

训练稳定性提升

使用Target Network前:

target_q = reward + gamma * max(self.model(next_state))
                              
                      同一模型在不断更新

使用Target Network后:

target_q = reward + gamma * max(self.target_model(next_state))
                              
                      稳定的目标网络

预期效果:

模型保存策略变更

新策略:

优势:


设计决策

为什么选择硬更新而非软更新?

硬更新(我们的实现):

if epoch % target_update_epochs == 0:
    self.target_model.load_state_dict(self.model.state_dict())

软更新(可选方案):

tau = 0.005  # 软更新系数
for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

选择硬更新的原因:

  1. 实现简单:直接复制参数,无需额外计算
  2. 性能稳定:目标网络在更新间隔内完全固定
  3. 调试友好:每个epoch的目标网络状态清晰
  4. 业界认可:DQN原论文使用硬更新

为什么选择5个epoch作为更新间隔?

参考标准:

计算逻辑:

假设:buffer_size=10000, batch_size=32
每个epoch ≈ 312 steps
10,000 / 312 ≈ 32 epochs

考虑离线训练的稳定性需求:
取更保守的值:5 epochs(平衡稳定性和时效性)

调参建议:


验证方法

训练过程观察

Loss曲线:

输出示例:

Epoch 5/32, Loss: 0.1234 [Target Network Updated]
Epoch 10/32, Loss: 0.0567 [Target Network Updated]

实际性能测试

离线测试:

python scripts/battle.py --model checkpoints/model_epoch_10.pth --video test.mp4 --max-steps 20

评估指标:


相关变更

docs/design/ARCHITECTURE.md

docs/guides/quickstart.md


总结

Target Network的引入是强化学习训练稳定性的重要改进。通过使用延迟更新的目标网络计算目标Q值,有效解决了传统DQN中的”移动目标”问题,显著提升了训练稳定性。

同时,移除best_loss机制符合强化学习的最佳实践,强调通过实际表现评估模型而非依赖loss指标。新的定期检查点保存策略为开发者提供了更大的灵活性和安全性。

这一优化将大幅提升autowzry-agent项目的训练效率和可靠性。