本次重构统一了项目的路径管理方式和目录结构,解决了以下核心问题:
--output 参数含义不统一(有的是文件名,有的是完整路径)save() 方法和不完整的 API删除了 save() 方法(原 data/data_manager.py:178-211):
save_dataset() 替代collect() 作为替代品将 load() 方法重命名为 load_dataset(),并增强返回值:
# 旧版本
def load(self, filepath: str, keys: Optional[List[str]] = None,
frame_range: Optional[tuple] = None) -> List[Dict]:
# 仅返回 frames_data
return frames_data
# 新版本
def load_dataset(self, filepath: str, keys: Optional[List[str]] = None,
frame_range: Optional[tuple] = None) -> Tuple[List[Dict], List[str], Dict]:
"""
加载完整数据集(包括帧名称和文件属性)
Returns:
Tuple[List[Dict], List[str], Dict]: (frames_data, frame_names, file_attrs)
"""
frames_data = []
frame_names = []
with h5py.File(filepath, 'r') as f:
frame_list = self._get_frame_list(f)
for frame_name in frame_list[start_idx:end_idx]:
frame_group = f[frame_name]
frame_data = {}
# ... 加载数据 ...
frames_data.append(frame_data)
frame_names.append(frame_name)
file_attrs = dict(f.attrs)
return frames_data, frame_names, file_attrs
理由:
save_dataset() 形成完整的 API 对(都支持任意字段)删除内容:
__init__ 方法的 data_dir 参数self.data_dir 属性改为完整路径:
# 旧版本
def __init__(self, data_dir: str = "./data/episodes", ...):
self.data_dir = data_dir
os.makedirs(self.data_dir, exist_ok=True)
def collect(self, filename: str, ...) -> str:
filepath = os.path.join(self.data_dir, filename)
# 新版本
def __init__(self, compatibility_layer=None, action_space=None, game_state=None):
# 不创建任何目录
def collect(self, filepath: str, ...) -> str:
# 调用者提供完整路径,脚本负责创建目录
os.makedirs(os.path.dirname(filepath) or '.', exist_ok=True)
影响的方法:
collect(filename → filepath)create_streaming_recorder(filename → filepath)在 __init__ 中添加自动备份逻辑(data/data_manager.py:438-445):
def __init__(self, filepath: str):
self.filepath = filepath
os.makedirs(os.path.dirname(filepath) or '.', exist_ok=True)
# 如果文件已存在,备份到 .backup(覆盖旧备份)
if os.path.exists(filepath):
backup_path = filepath + '.backup'
if os.path.exists(backup_path):
os.remove(backup_path)
os.rename(filepath, backup_path)
print(f"[StreamingRecorder] Backed up existing file to {backup_path}")
self.file = h5py.File(filepath, 'w')
# ... 其余初始化代码 ...
特性:
更新 load() 方法(data/training_buffer.py:82):
# 旧版本
frames = self.data_manager.load(filepath, keys=None, frame_range=frame_range)
# 新版本
frames, frame_names, _ = self.data_manager.load_dataset(
filepath, keys=None, frame_range=frame_range
)
在构建样本时添加帧名称(data/training_buffer.py:146-150):
sample = {
'state': state,
'next_state': next_state,
'frame_name': frame_names[i] # 新增字段
}
# 添加 actions 和 rewards
for action_name in self.enabled_actions:
# ... 现有逻辑 ...
self.samples.append(sample)
在 __getitem__ 中添加字符串类型处理(data/training_buffer.py:340-347):
for key, value in sample.items():
if value is not None and isinstance(value, np.ndarray):
tensor_sample[key] = torch.from_numpy(value).to(self.device)
elif value is not None and isinstance(value, (int, float)):
tensor_sample[key] = torch.tensor(value, dtype=torch.float32).to(self.device)
else:
# 保持字符串或 None 不变(用于 frame_name)
tensor_sample[key] = value
创建 workspace/ 顶层目录,集中管理所有运行时数据:
workspace/
├── episodes/ # 数据集(HDF5 文件)
├── checkpoints/ # 模型检查点(.pth 文件)
├── videos/ # 视频文件(.mp4 等)
├── buffer_check/ # buffer 检查输出
└── test_pipeline/ # 端到端测试输出
旧结构问题:
moive/ 目录(拼写错误)在根目录checkpoints/ 在根目录data/episodes/ 混淆了数据模块和数据文件新结构优势:
.gitignore 只需一行 workspace/原则:所有 --output 参数接受完整路径(绝对或相对)
修改的脚本:
| 脚本 | 旧行为 | 新行为 |
|---|---|---|
| collect_from_video.py | filename + –output-dir | 完整路径,默认 workspace/episodes/episode_{timestamp}.hdf5 |
| collect_from_device.py | filename + –output-dir | 完整路径,默认 workspace/episodes/episode_{timestamp}.hdf5 |
| battle.py | 不明确(data_dir 拼接) | 完整路径,默认 workspace/episodes/battle_{timestamp}.hdf5 |
| extract_data.py | 自动生成(同目录) | 完整路径,默认 {原文件}_F{start}_E{end}.hdf5 |
| check_buffer.py | 自动生成 | 完整路径,默认 workspace/buffer_check/buffer_samples_{timestamp}.hdf5 |
| train.py | 自动生成 | 完整路径,默认 workspace/checkpoints/model_{model_mode}.pth |
从以下脚本移除:
collect_from_video.pycollect_from_device.py理由:--output 已经是完整路径,不需要额外的目录参数。
collect_from_video.py:
# 旧版本
parser.add_argument('--output', type=str, default='episode.hdf5', help='Output filename')
parser.add_argument('--output-dir', type=str, default=None, help='Output directory')
# 执行逻辑
output_dir = args.output_dir or config.data_dir
filename = args.output
filepath = os.path.join(output_dir, filename)
# 新版本
parser.add_argument('--output', type=str, default=None,
help='Output file path (default: workspace/episodes/episode_YYYYMMDD_HHMMSS.hdf5)')
# 执行逻辑
if not args.output:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.output = f"./workspace/episodes/episode_{timestamp}.hdf5"
os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True)
filepath = compat.data_manager.collect(filepath=args.output, ...)
battle.py:
# 新增自动生成默认路径
if args.record_interval and not args.output:
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.output = f"./workspace/episodes/battle_{timestamp}.hdf5"
更新默认路径(lines 65, 69, 107, 109):
@dataclass
class Config:
# 训练配置
data_dir: str = "./workspace/episodes" # 旧: "./data/episodes"
checkpoint_dir: str = "./workspace/checkpoints" # 旧: "./checkpoints"
@classmethod
def from_yaml(cls, yaml_path: str):
with open(yaml_path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
return cls(
data_dir=data.get('data', {}).get('dir', './workspace/episodes'),
checkpoint_dir=data.get('checkpoint', {}).get('dir', './workspace/checkpoints'),
# ...
)
更新所有路径引用(lines 28-34):
data:
dir: ./workspace/episodes # 旧: ./data/episodes
checkpoint:
dir: ./workspace/checkpoints # 旧: ./checkpoints
resume_model: ./workspace/checkpoints/model_mini.pth # 旧: ./checkpoints/model_mini.pth
更新所有命令示例:
# 步骤 1:收集数据(旧)
python scripts/collect_from_video.py --video moive/1.mp4 --interval 1.0 --max-frames 400 --output data1.hdf5
# 步骤 1:收集数据(新)
python scripts/collect_from_video.py --video workspace/videos/1.mp4 --interval 1.0 --max-frames 400 --output workspace/episodes/data1.hdf5
# 类似更新应用于所有步骤(提取、标记、训练、对战)
更新所有 9 条测试命令:
REM 旧版本
python scripts/collect_from_video.py --video moive/1.mp4 --interval 1.0 --max-frames 400 --output data1.hdf5
python scripts/battle.py --model checkpoints/model_mini.pth --video moive/1.mp4 --max-steps 50
REM 新版本
python scripts/collect_from_video.py --video workspace/videos/1.mp4 --interval 1.0 --max-frames 400 --output workspace/episodes/data1.hdf5
python scripts/battle.py --model workspace/checkpoints/model_mini.pth --video workspace/videos/1.mp4 --max-steps 50
简化忽略规则(line 170):
# 删除(分散的旧规则)
# moive/
# data/episodes/
# data/test_pipeline/
# data/buffer_check/
# checkpoints/
# 新增(统一规则)
workspace/
简化数据加载(lines 93-101):
# 旧版本:手动读取 frame_names
frames_data = compat.data_manager.load(args.file, keys=None, frame_range=(start_frame, end_frame))
with h5py.File(args.file, 'r') as f:
frame_list = sorted([k for k in f.keys() if k.startswith('frame_')])
frame_names = frame_list[start_idx:end_idx]
# 新版本:直接获取所有数据
frames_data, frame_names, file_attrs = compat.data_manager.load_dataset(
args.file, keys=None, frame_range=(start_frame, end_frame)
)
更新默认输出路径(lines 52-56):
if args.output is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = "workspace/buffer_check" # 旧: "data/buffer_check"
os.makedirs(output_dir, exist_ok=True)
args.output = os.path.join(output_dir, f"buffer_samples_{timestamp}.hdf5")
更新默认视频路径(line 58):
parser.add_argument('--video', type=str,
default='workspace/videos/1.mp4', # 旧: 'moive/1.mp4'
help='Video file path')
--output data.hdf5 --output-dir ./data → --output ./workspace/episodes/data.hdf5不保留向后兼容,理由:
用户需要执行以下操作:
mkdir -p workspace/videos workspace/episodes workspace/checkpoints
mv moive/* workspace/videos/ # 如果存在
mv checkpoints/* workspace/checkpoints/ # 如果存在
mv data/episodes/* workspace/episodes/ # 如果存在
# 如果使用自定义 config.yaml,更新路径
sed -i 's|./data/episodes|./workspace/episodes|g' config/custom.yaml
sed -i 's|./checkpoints|./workspace/checkpoints|g' config/custom.yaml
# 1. 收集数据
python scripts/collect_from_video.py --video workspace/videos/1.mp4 --interval 1.0 --max-frames 100 --output workspace/episodes/test.hdf5
# 2. 标记数据
python scripts/label_data.py --file workspace/episodes/test.hdf5
# 3. 训练模型
python scripts/train.py --files workspace/episodes/test.hdf5 --output workspace/checkpoints/test_model.pth
# 4. 离线测试
python scripts/battle.py --model workspace/checkpoints/test_model.pth --video workspace/videos/1.mp4 --max-steps 10
python scripts/test_pipeline.py --video workspace/videos/1.mp4 --max-frames 50 --epochs 5
# 运行两次相同命令,验证备份功能
python scripts/battle.py --model workspace/checkpoints/model_mini.pth --video workspace/videos/1.mp4 --max-steps 5 --record-interval 1 --output workspace/episodes/overwrite_test.hdf5
python scripts/battle.py --model workspace/checkpoints/model_mini.pth --video workspace/videos/1.mp4 --max-steps 5 --record-interval 1 --output workspace/episodes/overwrite_test.hdf5
# 应存在 overwrite_test.hdf5 和 overwrite_test.hdf5.backup
使用 run.bat 验证所有关键功能:
# 执行所有 9 个命令,确保无错误
./run.bat
本次重构实现了以下目标:
核心原则:无歧义、可预测、易维护。