日期: 2025-11-23 类型: 功能增强 影响范围: 配置管理、训练数据加载、模型训练、对战脚本
实现了双帧输入模式(use_last_state),允许模型接收包含时序信息的6通道输入(laststate+state合并),相比单帧3通道输入能够感知运动信息,提升模型对动态场景的理解能力。
原有的训练方案中,模型只接收单帧RGB图像作为输入(3通道),无法感知物体运动方向和速度等时序信息。为了让模型能够理解动态场景,需要提供前一时刻的图像作为额外输入。
在训练数据加载阶段(TrainingBuffer.load),将前一帧(laststate)与当前帧(state)在通道维度合并:
同样,next_state也进行合并:
添加use_last_state配置参数:
字段定义(第38行):
use_last_state: bool = True # 是否使用laststate+state合并模式
from_yaml加载(第78行):
use_last_state=data.get('use_last_state', True),
save_yaml保存(第110行):
'use_last_state': self.use_last_state,
添加配置项(第14行):
use_last_state: true # 是否使用laststate+state合并模式
修改input_channels计算(第50-53行):
# Input channels (determined by config.use_last_state and train_resolution)
# use_last_state=True: channels * 2, False: channels
base_channels = config.train_resolution[2] if (config and hasattr(config, 'train_resolution')) else 3
self.input_channels = base_channels * 2 if (config and getattr(config, 'use_last_state', True)) else base_channels
修改load方法(第120-148行):
增加laststate读取逻辑:
# Get last_state if use_last_state is enabled
use_last_state = self.config and getattr(self.config, 'use_last_state', True)
if use_last_state:
if i == 0:
last_state = frames[0]['image']
else:
# Check if i-1 is in battle
if frames[i - 1].get('state_in_battle', False):
last_state = frames[i - 1]['image']
else:
last_state = frame['image']
增加合并逻辑:
# Preprocess images if config is available
if self.config and hasattr(self.config, 'train_resolution'):
state = preprocess_frame(state, self.config.train_resolution, normalize_img=True)
next_state = preprocess_frame(next_state, self.config.train_resolution, normalize_img=True)
if use_last_state:
last_state = preprocess_frame(last_state, self.config.train_resolution, normalize_img=True)
# Concatenate: first merge next_state, then merge state
next_state = np.concatenate([state, next_state], axis=0)
state = np.concatenate([last_state, state], axis=0)
关键设计点:
在对战循环中添加历史帧维护逻辑:
初始化(第113-114行):
last_state_processed = None # 初始化历史帧
use_last_state = config.use_last_state
状态处理(第143-155行):
# Handle last_state if enabled
if use_last_state:
if last_state_processed is None:
# First step: use same frame
last_state_processed = state_processed
# Concatenate last_state and state
state_input = np.concatenate([last_state_processed, state_processed], axis=0)
# Update for next iteration
last_state_processed = state_processed
else:
state_input = state_processed
关键设计点:
在测试配置生成中添加use_last_state参数(第100行):
'use_last_state': True, # 使用laststate+state合并模式
以下文件自动适配,无需修改:
运行端到端测试:
python scripts/test_pipeline.py --video moive/1.mp4 --max-frames 100 --epochs 10
测试结果:
Input channels: 6Total params: 2,211,599(6通道输入)data/test_pipeline/test_model_mini.pth关键日志:
[TrainingBuffer] Init
Input channels: 6
[Trainer] Creating trainer...
Input channels: 6 (from buffer)
[Train] Starting training...
Epoch 1/10, Avg Loss: 535.2894
Epoch 10/10, Avg Loss: 220.9174
Training Complete!
单帧模式 (use_last_state=false):
{
'state': (3, 540, 960), # RGB图像
'next_state': (3, 540, 960), # RGB图像
'move': [4],
'reward': float
}
双帧模式 (use_last_state=true):
{
'state': (6, 540, 960), # laststate(3) + state(3)
'next_state': (6, 540, 960), # state(3) + nextstate(3)
'move': [4],
'reward': float
}
已更新以下文档:
本次开发成功实现了双帧输入模式,通过简单的通道合并方案为模型提供时序信息。实现简洁、兼容性好,测试验证通过。显存占用翻倍的代价换来了逻辑的清晰和实现的简单,符合项目当前阶段的需求。