autowzry-agent

2025-11-24 路径管理与目录结构重构

概述

本次重构统一了项目的路径管理方式和目录结构,解决了以下核心问题:

  1. 参数不一致:不同脚本的 --output 参数含义不统一(有的是文件名,有的是完整路径)
  2. 目录分散:运行时数据文件(视频、模型、数据集)散落在项目根目录
  3. API冗余:DataManager 中存在已废弃的 save() 方法和不完整的 API
  4. 缺少调试信息:TrainingBuffer 样本中缺少帧名称追踪
  5. 文件覆盖风险:StreamingRecorder 无自动备份机制

变更详情

1. DataManager API 重构

1.1 删除已废弃方法

删除了 save() 方法(原 data/data_manager.py:178-211):

1.2 统一 load/save API

load() 方法重命名为 load_dataset(),并增强返回值:

# 旧版本
def load(self, filepath: str, keys: Optional[List[str]] = None,
         frame_range: Optional[tuple] = None) -> List[Dict]:
    # 仅返回 frames_data
    return frames_data

# 新版本
def load_dataset(self, filepath: str, keys: Optional[List[str]] = None,
                 frame_range: Optional[tuple] = None) -> Tuple[List[Dict], List[str], Dict]:
    """
    加载完整数据集(包括帧名称和文件属性)

    Returns:
        Tuple[List[Dict], List[str], Dict]: (frames_data, frame_names, file_attrs)
    """
    frames_data = []
    frame_names = []

    with h5py.File(filepath, 'r') as f:
        frame_list = self._get_frame_list(f)
        for frame_name in frame_list[start_idx:end_idx]:
            frame_group = f[frame_name]
            frame_data = {}
            # ... 加载数据 ...
            frames_data.append(frame_data)
            frame_names.append(frame_name)

        file_attrs = dict(f.attrs)

    return frames_data, frame_names, file_attrs

理由

1.3 移除 data_dir 参数

删除内容

改为完整路径

# 旧版本
def __init__(self, data_dir: str = "./data/episodes", ...):
    self.data_dir = data_dir
    os.makedirs(self.data_dir, exist_ok=True)

def collect(self, filename: str, ...) -> str:
    filepath = os.path.join(self.data_dir, filename)

# 新版本
def __init__(self, compatibility_layer=None, action_space=None, game_state=None):
    # 不创建任何目录

def collect(self, filepath: str, ...) -> str:
    # 调用者提供完整路径,脚本负责创建目录
    os.makedirs(os.path.dirname(filepath) or '.', exist_ok=True)

影响的方法

2. StreamingRecorder 自动备份

__init__ 中添加自动备份逻辑(data/data_manager.py:438-445):

def __init__(self, filepath: str):
    self.filepath = filepath
    os.makedirs(os.path.dirname(filepath) or '.', exist_ok=True)

    # 如果文件已存在,备份到 .backup(覆盖旧备份)
    if os.path.exists(filepath):
        backup_path = filepath + '.backup'
        if os.path.exists(backup_path):
            os.remove(backup_path)
        os.rename(filepath, backup_path)
        print(f"[StreamingRecorder] Backed up existing file to {backup_path}")

    self.file = h5py.File(filepath, 'w')
    # ... 其余初始化代码 ...

特性

3. TrainingBuffer 帧名称追踪

3.1 使用新的 load_dataset API

更新 load() 方法(data/training_buffer.py:82):

# 旧版本
frames = self.data_manager.load(filepath, keys=None, frame_range=frame_range)

# 新版本
frames, frame_names, _ = self.data_manager.load_dataset(
    filepath, keys=None, frame_range=frame_range
)

3.2 添加 frame_name 字段

在构建样本时添加帧名称(data/training_buffer.py:146-150):

sample = {
    'state': state,
    'next_state': next_state,
    'frame_name': frame_names[i]  # 新增字段
}

# 添加 actions 和 rewards
for action_name in self.enabled_actions:
    # ... 现有逻辑 ...

self.samples.append(sample)

3.3 修复类型处理

__getitem__ 中添加字符串类型处理(data/training_buffer.py:340-347):

for key, value in sample.items():
    if value is not None and isinstance(value, np.ndarray):
        tensor_sample[key] = torch.from_numpy(value).to(self.device)
    elif value is not None and isinstance(value, (int, float)):
        tensor_sample[key] = torch.tensor(value, dtype=torch.float32).to(self.device)
    else:
        # 保持字符串或 None 不变(用于 frame_name)
        tensor_sample[key] = value

4. 统一目录结构:workspace/

创建 workspace/ 顶层目录,集中管理所有运行时数据:

workspace/
├── episodes/          # 数据集(HDF5 文件)
├── checkpoints/       # 模型检查点(.pth 文件)
├── videos/            # 视频文件(.mp4 等)
├── buffer_check/      # buffer 检查输出
└── test_pipeline/     # 端到端测试输出

旧结构问题

新结构优势

5. 脚本参数标准化

5.1 统一 –output 语义

原则:所有 --output 参数接受完整路径(绝对或相对)

修改的脚本

脚本 旧行为 新行为
collect_from_video.py filename + –output-dir 完整路径,默认 workspace/episodes/episode_{timestamp}.hdf5
collect_from_device.py filename + –output-dir 完整路径,默认 workspace/episodes/episode_{timestamp}.hdf5
battle.py 不明确(data_dir 拼接) 完整路径,默认 workspace/episodes/battle_{timestamp}.hdf5
extract_data.py 自动生成(同目录) 完整路径,默认 {原文件}_F{start}_E{end}.hdf5
check_buffer.py 自动生成 完整路径,默认 workspace/buffer_check/buffer_samples_{timestamp}.hdf5
train.py 自动生成 完整路径,默认 workspace/checkpoints/model_{model_mode}.pth

5.2 删除 –output-dir 参数

从以下脚本移除:

理由--output 已经是完整路径,不需要额外的目录参数。

5.3 示例变更

collect_from_video.py

# 旧版本
parser.add_argument('--output', type=str, default='episode.hdf5', help='Output filename')
parser.add_argument('--output-dir', type=str, default=None, help='Output directory')

# 执行逻辑
output_dir = args.output_dir or config.data_dir
filename = args.output
filepath = os.path.join(output_dir, filename)

# 新版本
parser.add_argument('--output', type=str, default=None,
    help='Output file path (default: workspace/episodes/episode_YYYYMMDD_HHMMSS.hdf5)')

# 执行逻辑
if not args.output:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    args.output = f"./workspace/episodes/episode_{timestamp}.hdf5"

os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True)
filepath = compat.data_manager.collect(filepath=args.output, ...)

battle.py

# 新增自动生成默认路径
if args.record_interval and not args.output:
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    args.output = f"./workspace/episodes/battle_{timestamp}.hdf5"

6. 配置文件更新

6.1 config/config.py

更新默认路径(lines 65, 69, 107, 109):

@dataclass
class Config:
    # 训练配置
    data_dir: str = "./workspace/episodes"  # 旧: "./data/episodes"
    checkpoint_dir: str = "./workspace/checkpoints"  # 旧: "./checkpoints"

    @classmethod
    def from_yaml(cls, yaml_path: str):
        with open(yaml_path, 'r', encoding='utf-8') as f:
            data = yaml.safe_load(f)

        return cls(
            data_dir=data.get('data', {}).get('dir', './workspace/episodes'),
            checkpoint_dir=data.get('checkpoint', {}).get('dir', './workspace/checkpoints'),
            # ...
        )

6.2 config/agent.config.yaml

更新所有路径引用(lines 28-34):

data:
  dir: ./workspace/episodes  # 旧: ./data/episodes

checkpoint:
  dir: ./workspace/checkpoints  # 旧: ./checkpoints
  resume_model: ./workspace/checkpoints/model_mini.pth  # 旧: ./checkpoints/model_mini.pth

7. 文档和测试更新

7.1 docs/guides/quickstart.md

更新所有命令示例:

# 步骤 1:收集数据(旧)
python scripts/collect_from_video.py --video moive/1.mp4 --interval 1.0 --max-frames 400 --output data1.hdf5

# 步骤 1:收集数据(新)
python scripts/collect_from_video.py --video workspace/videos/1.mp4 --interval 1.0 --max-frames 400 --output workspace/episodes/data1.hdf5

# 类似更新应用于所有步骤(提取、标记、训练、对战)

7.2 run.bat

更新所有 9 条测试命令:

REM 旧版本
python scripts/collect_from_video.py --video moive/1.mp4 --interval 1.0 --max-frames 400 --output data1.hdf5
python scripts/battle.py --model checkpoints/model_mini.pth --video moive/1.mp4 --max-steps 50

REM 新版本
python scripts/collect_from_video.py --video workspace/videos/1.mp4 --interval 1.0 --max-frames 400 --output workspace/episodes/data1.hdf5
python scripts/battle.py --model workspace/checkpoints/model_mini.pth --video workspace/videos/1.mp4 --max-steps 50

7.3 .gitignore

简化忽略规则(line 170):

# 删除(分散的旧规则)
# moive/
# data/episodes/
# data/test_pipeline/
# data/buffer_check/
# checkpoints/

# 新增(统一规则)
workspace/

8. 其他脚本更新

8.1 scripts/extract_data.py

简化数据加载(lines 93-101):

# 旧版本:手动读取 frame_names
frames_data = compat.data_manager.load(args.file, keys=None, frame_range=(start_frame, end_frame))
with h5py.File(args.file, 'r') as f:
    frame_list = sorted([k for k in f.keys() if k.startswith('frame_')])
    frame_names = frame_list[start_idx:end_idx]

# 新版本:直接获取所有数据
frames_data, frame_names, file_attrs = compat.data_manager.load_dataset(
    args.file, keys=None, frame_range=(start_frame, end_frame)
)

8.2 scripts/check_buffer.py

更新默认输出路径(lines 52-56):

if args.output is None:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = "workspace/buffer_check"  # 旧: "data/buffer_check"
    os.makedirs(output_dir, exist_ok=True)
    args.output = os.path.join(output_dir, f"buffer_samples_{timestamp}.hdf5")

8.3 scripts/test_pipeline.py

更新默认视频路径(line 58):

parser.add_argument('--video', type=str,
    default='workspace/videos/1.mp4',  # 旧: 'moive/1.mp4'
    help='Video file path')

影响分析

破坏性变更

  1. 不兼容旧命令
    • 使用 filename + –output-dir 的命令需要改为完整路径
    • 例:--output data.hdf5 --output-dir ./data--output ./workspace/episodes/data.hdf5
  2. 配置文件路径
    • 旧的 agent.config.yaml 需要更新 data.dir 和 checkpoint.dir
    • 使用相对路径的用户需要调整
  3. 目录位置变更
    • moive/ → workspace/videos/
    • checkpoints/ → workspace/checkpoints/
    • data/episodes/ → workspace/episodes/

向后兼容性

不保留向后兼容,理由:

迁移指南

用户需要执行以下操作:

  1. 移动文件
    mkdir -p workspace/videos workspace/episodes workspace/checkpoints
    mv moive/* workspace/videos/  # 如果存在
    mv checkpoints/* workspace/checkpoints/  # 如果存在
    mv data/episodes/* workspace/episodes/  # 如果存在
    
  2. 更新配置
    # 如果使用自定义 config.yaml,更新路径
    sed -i 's|./data/episodes|./workspace/episodes|g' config/custom.yaml
    sed -i 's|./checkpoints|./workspace/checkpoints|g' config/custom.yaml
    
  3. 更新命令
    • 所有 –output 改为完整路径
    • 删除 –output-dir 参数
    • 视频路径从 moive/ 改为 workspace/videos/
    • 模型路径从 checkpoints/ 改为 workspace/checkpoints/

测试建议

单元测试

  1. DataManager.load_dataset()
    • 验证返回三元组 (frames_data, frame_names, file_attrs)
    • 测试 frame_range 切片的 frame_names 正确性
  2. StreamingRecorder 备份
    • 验证首次写入不创建 .backup
    • 验证覆盖写入时创建 .backup
    • 验证多次覆盖时只保留最新 .backup
  3. TrainingBuffer.load()
    • 验证 samples 中包含 frame_name 字段
    • 验证 getitem 正确处理字符串类型

集成测试

  1. 完整训练流程
    # 1. 收集数据
    python scripts/collect_from_video.py --video workspace/videos/1.mp4 --interval 1.0 --max-frames 100 --output workspace/episodes/test.hdf5
    
    # 2. 标记数据
    python scripts/label_data.py --file workspace/episodes/test.hdf5
    
    # 3. 训练模型
    python scripts/train.py --files workspace/episodes/test.hdf5 --output workspace/checkpoints/test_model.pth
    
    # 4. 离线测试
    python scripts/battle.py --model workspace/checkpoints/test_model.pth --video workspace/videos/1.mp4 --max-steps 10
    
  2. 自动化测试脚本
    python scripts/test_pipeline.py --video workspace/videos/1.mp4 --max-frames 50 --epochs 5
    
  3. StreamingRecorder 覆盖测试
    # 运行两次相同命令,验证备份功能
    python scripts/battle.py --model workspace/checkpoints/model_mini.pth --video workspace/videos/1.mp4 --max-steps 5 --record-interval 1 --output workspace/episodes/overwrite_test.hdf5
    python scripts/battle.py --model workspace/checkpoints/model_mini.pth --video workspace/videos/1.mp4 --max-steps 5 --record-interval 1 --output workspace/episodes/overwrite_test.hdf5
    # 应存在 overwrite_test.hdf5 和 overwrite_test.hdf5.backup
    

回归测试

使用 run.bat 验证所有关键功能:

# 执行所有 9 个命令,确保无错误
./run.bat

总结

本次重构实现了以下目标:

  1. API 清晰性:统一 DataManager 的 load/save 接口,删除冗余方法
  2. 参数一致性:所有脚本的 –output 参数语义统一
  3. 目录规范化:集中管理运行时数据,分离源码和数据
  4. 可调试性:TrainingBuffer 样本包含帧名称,方便追踪
  5. 数据安全:StreamingRecorder 自动备份,防止意外覆盖

核心原则:无歧义、可预测、易维护