DreamZero (WAM) 模型架构
1. 整体架构概览
DreamZero 是一个世界动作模型 (World Action Model, WAM),通过联合预测未来视频帧与机器人动作序列,实现对未见任务的零样本泛化。其核心创新在于:基于 Wan2.1 视频扩散模型构建 Causal WAN DiT,以 Flow Matching 框架在视频 latent 空间和动作空间上同步去噪,配合多机器人体 Category-specific MLP 支持 DROID、AgiBot、YAM 等多种机器人平台,在 MolmoSpaces 和 RoboArena 双榜均位列第一(截至 2026 年 2 月)。
[B, 1, C, H, W]
320×176"] TXT["语言指令
(text prompt)"] EID["机器人体 ID
(embodiment_id)"] end subgraph Encoders["感知编码器"] CLIP["CLIP ViT 图像编码器
open-clip-xlm-roberta-large-vit-huge-14
→ 1536-dim 特征"] T5["umt5-xxl 文本编码器
→ 1536-dim 特征"] end subgraph VAE["视频 VAE (Wan2.1)"] ENC["VAE 编码器
视频 → 16-dim latent"] DEC["VAE 解码器
latent → 视频帧"] end subgraph Core["Causal WAN DiT (核心)"] DIT["40层因果扩散 Transformer
dim=5120, 40 heads
Flow Matching + RoPE + Flash Attn 3"] ACT_MLP["多机器人体 MLP
Category-specific Linear"] end subgraph Output["输出"] ACT_OUT["动作序列
[B, 24, action_dim]"] VID_OUT["预测视频帧
[B, 33, H, W, 3]"] end IMG --> CLIP TXT --> T5 IMG --> ENC CLIP -->|"视觉特征"| DIT T5 -->|"语言特征"| DIT ENC -->|"视频 latent"| DIT EID --> ACT_MLP --> DIT DIT -->|"动作 latent"| ACT_OUT DIT -->|"视频 latent"| DEC --> VID_OUT style Input fill:#e8f4fd,stroke:#2196F3 style Encoders fill:#fff3e0,stroke:#FF9800 style VAE fill:#f3e5f5,stroke:#9C27B0 style Core fill:#e8f5e9,stroke:#4CAF50 style Output fill:#fce4ec,stroke:#E91E63
2. 核心组件详解
2.1 感知编码器
DreamZero 使用两个独立的预训练编码器将视觉和语言信息映射到统一的 1536-dim 特征空间:
- 图像编码器:
open-clip-xlm-roberta-large-vit-huge-14(CLIP ViT-H/14),对输入视频帧逐帧提取视觉特征 - 文本编码器:
umt5-xxl(Google 多语言 T5),将自然语言指令编码为序列特征
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50
2.2 视频 VAE
沿用 Wan2.1 的视频 VAE,将高维视频压缩至低维 latent 空间再进行扩散,大幅降低计算量。
[B, 33, H, W, 3]"] --> VAE_E["VAE 编码器
wan_video_vae.py"] VAE_E --> LAT["视频 latent
[B, 33, h, w, 16]"] end subgraph Decode["推理:解码"] LAT2["扩散后 latent
[B, 33, h, w, 16]"] --> VAE_D["VAE 解码器"] VAE_D --> V_OUT["重建视频帧
[B, 33, H, W, 3]"] end subgraph Params["关键参数"] P1["latent 通道数: 16"] P2["时域压缩: 4×"] P3["空域压缩: 8×"] end style Encode fill:#e3f2fd,stroke:#2196F3 style Decode fill:#e8f5e9,stroke:#4CAF50 style Params fill:#fff9c4,stroke:#FFC107
2.3 Causal WAN DiT(核心扩散模型)
Causal WAN DiT 是 DreamZero 的核心,基于 Wan2.1 的 DiT 架构改造而来:40 层因果 Transformer,通过 Causal Masking 保证时序因果性,同时在 latent 序列中嵌入动作 token,实现视频与动作的联合生成。
[B, T_vid, 16]"] AC["含噪动作
[B, 24, action_dim]"] VIS["视觉条件
[B, T_img, 1536]"] LNG["语言条件
[B, T_txt, 1536]"] TS["时间步 t"] VL --> ROPE["RoPE 位置编码"] AC --> ACT_PROJ["动作投影 MLP"] TS --> TIME_EMB["时间步嵌入
(正弦编码)"] end subgraph Block["单层 Causal DiT Block (×40)"] direction TB IN_B["输入 [B, T, 5120]"] --> LN1_B["RMSNorm"] LN1_B --> MOD1["Modulation
(shift, scale by timestep emb)"] MOD1 --> CATTN["因果自注意力
40 heads, Flash Attn 3
Causal Mask"] CATTN --> ADD1["+ 残差"] IN_B --> ADD1 ADD1 --> LN2_B["RMSNorm"] LN2_B --> XATTN["交叉注意力
(视觉 + 语言条件)"] XATTN --> ADD2["+ 残差"] ADD1 --> ADD2 ADD2 --> LN3_B["RMSNorm"] LN3_B --> FFN_B["FFN (SwiGLU)
dim 5120 → 13824 → 5120"] FFN_B --> ADD3["+ 残差"] ADD2 --> ADD3 ADD3 --> OUT_B["输出 [B, T, 5120]"] end subgraph Output["输出解码"] OUT_B --> SPLIT["分离 video / action token"] SPLIT --> VP["视频 latent 预测"] SPLIT --> AP["动作预测
[B, 24, action_dim]"] end ROPE --> Block ACT_PROJ --> Block TIME_EMB --> Block VIS --> Block LNG --> Block style InputSeq fill:#e3f2fd,stroke:#2196F3 style Block fill:#e8f5e9,stroke:#4CAF50 style Output fill:#fce4ec,stroke:#E91E63
因果注意力掩码设计: 视频帧 token 只能 attend 到过去帧和当前帧,动作 token attend 到所有视频 token(全局条件),保证生成的时序一致性。
2.4 多机器人体支持
DreamZero 通过每个机器人体独立的线性层(Category-specific MLP)将不同维度的动作空间映射到统一的 DiT 隐空间,支持在单一模型中处理多种机器人平台。
[B, 24, dim_droid]"] A_ACT["AgiBot 动作
[B, 24, dim_agibot]"] Y_ACT["YAM 动作
[B, 24, dim_yam]"] O_ACT["其他机器人体..."] D_ACT --> D_MLP["DROID Linear
dim_droid → 5120"] A_ACT --> A_MLP["AgiBot Linear
dim_agibot → 5120"] Y_ACT --> Y_MLP["YAM Linear
dim_yam → 5120"] O_ACT --> O_MLP["... Linear"] end subgraph Shared["共享 DiT 处理"] D_MLP --> UNIFIED["统一动作 token
[B, 24, 5120]"] A_MLP --> UNIFIED Y_MLP --> UNIFIED O_MLP --> UNIFIED UNIFIED --> DIT_PROC["Causal WAN DiT
(共享权重)"] end subgraph Decode["多机器人体动作解码"] DIT_PROC --> SPLIT_E["按 embodiment_id 路由"] SPLIT_E --> D_DEC["DROID Decoder Linear
5120 → dim_droid"] SPLIT_E --> A_DEC["AgiBot Decoder Linear
5120 → dim_agibot"] SPLIT_E --> Y_DEC["YAM Decoder Linear
5120 → dim_yam"] end style Embodiments fill:#e3f2fd,stroke:#2196F3 style Shared fill:#e8f5e9,stroke:#4CAF50 style Decode fill:#fce4ec,stroke:#E91E63
相对动作计算: 各机器人体在数据加载时将绝对动作转为相对值(action - reference_state),使模型学到的动作表示更加泛化。
3. 训练流水线
DreamZero 基于 Flow Matching 框架进行训练,同时对视频 latent 和动作 token 施加扩散损失。训练使用 DeepSpeed ZeRO-2 分布式优化,并通过 LoRA 高效微调 14B 参数基础模型。
· 视频 resize → 320×176
· 状态/动作归一化
· 相对动作计算
· 语言编码"] MOD --> BATCH_OUT["训练 Batch"] end subgraph Encode["编码阶段"] BATCH_OUT --> IMG_E["CLIP 图像编码
→ 视觉特征"] BATCH_OUT --> TXT_E["T5 文本编码
→ 语言特征"] BATCH_OUT --> VAE_E2["VAE 编码
视频 → latent"] end subgraph FlowMatch["Flow Matching 加噪"] T_SAMP["时间步采样
t ~ Uniform[0, 1000]"] NOISE2["高斯噪声 ε ~ N(0, I)"] ACT_GT["真实动作 / 视频 latent"] T_SAMP --> INTERP2["线性插值
x_t = ε·t/T + x_0·(1 - t/T)"] NOISE2 --> INTERP2 ACT_GT --> INTERP2 end subgraph Forward["前向传播"] IMG_E --> DIT_FWD["Causal WAN DiT"] TXT_E --> DIT_FWD VAE_E2 --> DIT_FWD INTERP2 --> DIT_FWD DIT_FWD --> PRED_ALL["联合预测
视频 latent + 动作"] end subgraph Loss["损失计算"] PRED_ALL --> L_VID["视频 Flow Matching Loss
(MSE on video latent)"] PRED_ALL --> L_ACT["动作 Flow Matching Loss
(MSE on action)"] L_VID --> L_TOTAL["总损失
L = λ_vid · L_vid + λ_act · L_act"] L_ACT --> L_TOTAL end subgraph Optim["优化"] L_TOTAL --> BACK["反向传播"] BACK --> LORA["LoRA 梯度更新
rank=4, alpha=4"] LORA --> OPT["AdamW
β₁=0.95, β₂=0.999
Cosine LR + Warmup"] OPT --> DS_ZERO["DeepSpeed ZeRO-2
梯度累积"] end DataLoad --> Encode Encode --> FlowMatch FlowMatch --> Forward Forward --> Loss Loss --> Optim style DataLoad fill:#e3f2fd,stroke:#2196F3 style Encode fill:#fff3e0,stroke:#FF9800 style FlowMatch fill:#f3e5f5,stroke:#9C27B0 style Forward fill:#e8f5e9,stroke:#4CAF50 style Loss fill:#fce4ec,stroke:#E91E63 style Optim fill:#e0f7fa,stroke:#00BCD4
LoRA 配置: 基础模型(Wan2.1,约 14B 参数)完全冻结,仅在 DiT 的注意力层插入 LoRA 适配器(rank=4, alpha=4),大幅降低显存占用。
4. 推理流水线
推理时模型从纯随机噪声出发,通过 Flow Matching 迭代去噪恢复动作序列和视频帧。DreamZero 支持通过 WebSocket 的分布式多 GPU 推理服务。
[B, 1, C, H, W]"] LANG_IN["语言指令"] X_NOISE["纯随机噪声
x_T ~ N(0, I)
(视频 latent + 动作)"] end subgraph EncoderOnce["编码(仅执行一次)"] OBS --> CLIP_INF["CLIP 编码"] LANG_IN --> T5_INF["T5 编码"] CLIP_INF --> COND["条件特征缓存"] T5_INF --> COND end subgraph DenoiseLoop["迭代去噪循环 (N 步)"] direction TB S_T["步骤 T: x_T → DiT → x̂₀"] S_T1["步骤 T-1: x_{T-1} = scheduler(x̂₀) → DiT → x̂₀"] SDOTS["..."] S_1["步骤 1: x₁ → DiT → x̂₀"] S_T --> S_T1 --> SDOTS --> S_1 end subgraph PostProcess["后处理"] S_1 --> SPLIT_OUT["分离 video latent / action"] SPLIT_OUT --> VAE_DEC2["VAE 解码 → 视频帧"] SPLIT_OUT --> ACT_POST["动作后处理
(反归一化, 相对→绝对)"] ACT_POST --> ACT_EXEC["执行前 N 步动作
(Action Chunking)"] end Init --> EncoderOnce COND --> DenoiseLoop X_NOISE --> DenoiseLoop DenoiseLoop --> PostProcess style Init fill:#e3f2fd,stroke:#2196F3 style EncoderOnce fill:#fff3e0,stroke:#FF9800 style DenoiseLoop fill:#e8f5e9,stroke:#4CAF50 style PostProcess fill:#fce4ec,stroke:#E91E63
分布式推理服务
发送观测 + 指令"] end subgraph Server["推理服务器 (socket_test_optimized_AR.py)"] WS_SRV["WebSocket 服务器
(Flask-SocketIO)"] REDIS["Redis
会话状态管理"] RAY_W["Ray Worker Pool
多 GPU 并行推理"] MODEL["DreamZero 模型
(GB200 / H100)"] WS_SRV --> REDIS WS_SRV --> RAY_W RAY_W --> MODEL end WS_CLI -->|"观测数据"| WS_SRV MODEL -->|"动作预测"| WS_SRV WS_SRV -->|"动作序列"| WS_CLI WS_CLI -->|"执行动作"| CLI style Client fill:#e3f2fd,stroke:#2196F3 style Server fill:#e8f5e9,stroke:#4CAF50
5. 关键超参数表
| 参数 | 值 | 说明 |
|---|---|---|
| 模型总参数 | ~14B | 基于 Wan2.1,LoRA 微调 |
| DiT 层数 | 40 | Causal WAN DiT |
| DiT hidden dim | 5120 | 每层隐藏层维度 |
| 注意力头数 | 40 | head_dim = 128 |
| Embedding dim | 1536 | 图像/文本编码器输出维度 |
| VAE latent 通道 | 16 | 视频压缩维度 |
| Action horizon | 24 | 单次预测动作步数 |
| 视频帧数 | 33 | 输入/预测帧数 |
| 图像分辨率 | 320×176 | 训练分辨率 |
| LoRA rank / alpha | 4 / 4 | 参数高效微调配置 |
| 批归一化策略 | DeepSpeed ZeRO-2 | 分布式训练 |
| 优化器 | AdamW | β₁=0.95, β₂=0.999 |
| 学习率调度 | Cosine + Warmup | 分布式多卡训练 |
| 精度 | bfloat16 | 混合精度训练 |
6. 关键源文件表
| 组件 | 文件路径(相对 /home/zhuyilong/dreamzero/) |
|---|---|
| 核心 VLA 模型 | groot/vla/model/dreamzero/base_vla.py |
| Action Head (Flow Matching) | groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py |
| Causal WAN DiT | groot/vla/model/dreamzero/modules/wan_video_dit_action_casual_chunk.py |
| 视频 VAE | groot/vla/model/dreamzero/modules/wan_video_vae.py |
| 文本编码器 | groot/vla/model/dreamzero/modules/wan_video_text_encoder.py |
| 图像编码器 | groot/vla/model/dreamzero/modules/wan_video_image_encoder.py |
| 数据变换 | groot/vla/model/dreamzero/transform/ |
| 数据集加载 | groot/vla/data/dataset/lerobot.py |
| 分片数据集 | groot/vla/data/dataset/lerobot_sharded.py |
| 训练器 | groot/vla/experiment/experiment.py |
| 训练基类 | groot/vla/experiment/base.py |
| 分布式推理服务 | socket_test_optimized_AR.py |
| 推理客户端 | test_client_AR.py |
| Hydra 模型配置 | groot/vla/configs/model/ |
| 训练脚本 | scripts/train/ |