autowzry-agent

双帧输入模式实现

日期: 2025-11-23 类型: 功能增强 影响范围: 配置管理、训练数据加载、模型训练、对战脚本


概述

实现了双帧输入模式(use_last_state),允许模型接收包含时序信息的6通道输入(laststate+state合并),相比单帧3通道输入能够感知运动信息,提升模型对动态场景的理解能力。


需求背景

原有的训练方案中,模型只接收单帧RGB图像作为输入(3通道),无法感知物体运动方向和速度等时序信息。为了让模型能够理解动态场景,需要提供前一时刻的图像作为额外输入。


设计方案

核心思路

在训练数据加载阶段(TrainingBuffer.load),将前一帧(laststate)与当前帧(state)在通道维度合并:

同样,next_state也进行合并:

优点

  1. 实现简单 - 只需在buffer阶段合并,其他模块自动适配
  2. 向后兼容 - 模型保存input_channels,可加载3通道或6通道模型
  3. 配置驱动 - 通过use_last_state开关灵活切换单帧/双帧模式

权衡


代码修改

1. config/config.py

添加use_last_state配置参数:

字段定义(第38行):

use_last_state: bool = True  # 是否使用laststate+state合并模式

from_yaml加载(第78行):

use_last_state=data.get('use_last_state', True),

save_yaml保存(第110行):

'use_last_state': self.use_last_state,

2. config/agent.config.yaml

添加配置项(第14行):

use_last_state: true  # 是否使用laststate+state合并模式

3. data/training_buffer.py

修改input_channels计算(第50-53行):

# Input channels (determined by config.use_last_state and train_resolution)
# use_last_state=True: channels * 2, False: channels
base_channels = config.train_resolution[2] if (config and hasattr(config, 'train_resolution')) else 3
self.input_channels = base_channels * 2 if (config and getattr(config, 'use_last_state', True)) else base_channels

修改load方法(第120-148行):

增加laststate读取逻辑:

# Get last_state if use_last_state is enabled
use_last_state = self.config and getattr(self.config, 'use_last_state', True)
if use_last_state:
    if i == 0:
        last_state = frames[0]['image']
    else:
        # Check if i-1 is in battle
        if frames[i - 1].get('state_in_battle', False):
            last_state = frames[i - 1]['image']
        else:
            last_state = frame['image']

增加合并逻辑:

# Preprocess images if config is available
if self.config and hasattr(self.config, 'train_resolution'):
    state = preprocess_frame(state, self.config.train_resolution, normalize_img=True)
    next_state = preprocess_frame(next_state, self.config.train_resolution, normalize_img=True)
    if use_last_state:
        last_state = preprocess_frame(last_state, self.config.train_resolution, normalize_img=True)
        # Concatenate: first merge next_state, then merge state
        next_state = np.concatenate([state, next_state], axis=0)
        state = np.concatenate([last_state, state], axis=0)

关键设计点


4. scripts/battle.py

在对战循环中添加历史帧维护逻辑:

初始化(第113-114行):

last_state_processed = None  # 初始化历史帧
use_last_state = config.use_last_state

状态处理(第143-155行):

# Handle last_state if enabled
if use_last_state:
    if last_state_processed is None:
        # First step: use same frame
        last_state_processed = state_processed

    # Concatenate last_state and state
    state_input = np.concatenate([last_state_processed, state_processed], axis=0)

    # Update for next iteration
    last_state_processed = state_processed
else:
    state_input = state_processed

关键设计点


5. scripts/test_pipeline.py

在测试配置生成中添加use_last_state参数(第100行):

'use_last_state': True,  # 使用laststate+state合并模式

无需修改的文件

以下文件自动适配,无需修改:


测试验证

运行端到端测试:

python scripts/test_pipeline.py --video moive/1.mp4 --max-frames 100 --epochs 10

测试结果

关键日志

[TrainingBuffer] Init
  Input channels: 6

[Trainer] Creating trainer...
  Input channels: 6 (from buffer)

[Train] Starting training...
  Epoch 1/10, Avg Loss: 535.2894
  Epoch 10/10, Avg Loss: 220.9174

Training Complete!

数据格式变化

TrainingBuffer样本格式

单帧模式 (use_last_state=false):

{
    'state': (3, 540, 960),      # RGB图像
    'next_state': (3, 540, 960), # RGB图像
    'move': [4],
    'reward': float
}

双帧模式 (use_last_state=true):

{
    'state': (6, 540, 960),      # laststate(3) + state(3)
    'next_state': (6, 540, 960), # state(3) + nextstate(3)
    'move': [4],
    'reward': float
}

兼容性说明

向后兼容

不兼容情况


使用建议

  1. 新训练项目 - 建议使用use_last_state=true,充分利用时序信息
  2. 旧项目迁移 - 需要删除旧checkpoint,设置use_last_state=true重新训练
  3. 显存受限 - 可以设置use_last_state=false,使用单帧模式节省显存
  4. 灰度模式 - train_resolution设为[540, 960, 1]时,双帧模式通道数为2

文档更新

已更新以下文档:

  1. docs/design/ARCHITECTURE.md - 更新Config属性、input_channels说明、buffer.load说明、样本格式
  2. docs/guides/quickstart.md - 简化配置文件说明,引导用户查看配置文件注释
  3. docs/logs/development_log.md - 添加本次开发日志索引

总结

本次开发成功实现了双帧输入模式,通过简单的通道合并方案为模型提供时序信息。实现简洁、兼容性好,测试验证通过。显存占用翻倍的代价换来了逻辑的清晰和实现的简单,符合项目当前阶段的需求。