🎨FLUX 图像修复双雄源码解析
把 12B 参数的预训练扩散模型,改造成通用图像修复器的两种正交路线
起因:刷短视频刷到"FLUX 和 Wan 预训练扩散模型秒变图像修复神器",想验真伪。
结论:FLUX 这条路线实打实,Wan 是视频生成模型,跟图像修复无关——标题党。
本文:深啃两个 FLUX 修复代表作的源码,带文件路径 + 行号证据。
§0概述 & 标题党验证
两个仓库都 git clone 下来全量扫过,grep -r "[Ww]an2\.?[12]" 零命中。Wan 跟这两个工作毫无关系。真正能把 FLUX 拽成修复器的,是下面这两条完全不同的技术路线:
- LucidFlux(HKUST-GZ,W2GenAI-Lab)——架构创新路线:冻结 FLUX,加双分支 conditioner + SigLIP caption-free 语义对齐
- FLUX-IR(HKUST,Zhu Zhiyu)——训练算法创新路线:ControlNet 改造 + 强化 ODE 轨迹对齐 + 轨迹蒸馏
§1LucidFlux 深度解析
1. 整体架构
核心策略:冻结 FLUX.1 主干,只训练双分支 conditioner + SigLIP 语义对齐模块。这是最"温和"的改造——对原生 FLUX 能力破坏极小。
- src/flux/lucidflux.py:312-343 加载时对 FLUX 主干
requires_grad_(False),只有 condition_branch 进入训练 - src/flux/model.py:138-215 FLUX 前向接收
block_controlnet_hidden_states参数,166-194 行把 conditioner 输出直接加到每个 DoubleStreamBlock - 原 FLUX 的 double/single stream block 保持原样(src/flux/modules/layers.py),未魔改。hidden_size=3072, num_heads=24, depth=19 都保留
双分支 conditioner
src/flux/lucidflux.py:189-244
condition_lq: 处理低质量(LQ)原图condition_pre: 处理 SwinIR 预超分结果- 两支各含 2 个 DoubleStreamBlock +
zero_module初始化的 ControlNet block(78-80 行) - 通过 Modulation 模块自适应融合(232-243 行)
Modulation(本文深挖 1)
src/flux/lucidflux.py:161-186
- timestep + control_index 联合编码
- LayerNorm 后输出 shift/scale 系数做特征调制
- control_index 在 0-19 范围遍历,支持分层自适应
SigLIP 语义对齐
src/flux/flux_prior_redux_ir.py:44-104
- SigLIP 视觉编码器吃 SwinIR 前置输出 → Redux 编码器转 1024 token FLUX 兼容嵌入
- 训练时(train.py:476-487)和推理时(inference.py:160-195)与文本 embedding 拼接
2. 训练入口
- train.py:300-548 主流程,
train_configs/train_LucidFlux.yaml配置 - 超参:批大小 1, 最大 100k 步, lr=2e-5, Adafactor 优化器
- train.py:324-351 模型初始化 — FLUX 主干 + VAE 冻结,加载预训 conditioner 权重,只 conditioner 进
.train() - 损失:L2 MSE 预测噪声残差,target =
noise - packed_latents(train.py:95-100, 502) - 数据:image_datasets/lq_gt_dataset.py:34-68 LQ/GT 配对,中心裁剪 512×512
训练 pipeline 步骤:
- VAE 编码 LQ/GT 到 latent
- SwinIR 8× 前置超分
- SigLIP + Redux 提取语义特征
- DualConditionBranch 输出残差
- FLUX 接收残差 → 预测噪声
- MSE loss 回传,只更新 conditioner
3. 推理入口
标准推理 · inference.py
- inference.py:93
load_flow_model加载 FLUX - inference.py:94 VAE 来自 FLUX 官方
- inference.py:112 输入分辨率 16 倍对齐
2K 推理 · inference-2k.py
from src.ultraflux.autoencoder_kl import AutoencoderKL
ae = AutoencoderKL.from_pretrained("./weights/ultraflux", subfolder="vae")
- 分辨率对齐 32 倍(inference-2k.py:148)
- 输入预处理
width//2, height//2(inference-2k.py:180)
推理 Pipeline
inference.py:139-224
- LQ → 标准化到 [-1, 1]
- SwinIR 超分(151-158 行)
- SigLIP + Redux 特征(173-190 行)
- 生成噪声 + embedding(162-198 行)
denoise_lucidflux()逐步去噪(sampling.py:96-151):每步 DualConditionBranch 输出残差加到 FLUX- VAE 解码 + 小波重构对齐
4. 与原 FLUX 的差异
新增文件
- src/flux/condition.py:33-223 SingleConditionBranch
- src/flux/lucidflux.py DualConditionBranch + Modulation + ConditionBranchWithRedux
- src/flux/flux_prior_redux_ir.py SigLIP + Redux 集成
- src/flux/swinir.py SwinIR 前置
原 FLUX 魔改范围
block_controlnet_hidden_states 参数支持。DoubleStreamBlock、SingleStreamBlock 内部结构完全不动。这是 LucidFlux 最大的优势:未来 FLUX 官方升级,迁移成本极低。
5. 可复现性
| 项目 | 值 |
|---|---|
| 主模型 | black-forest-labs/FLUX.1-dev (HF) |
| SigLIP | siglip2-so400m-patch16-512 |
| Conditioner | weights/lucidflux/lucidflux.pth |
| 2K VAE | weights/ultraflux/vae/(额外) |
| 训练数据 | HF W2GenAI/LucidFlux → LucidFlux-Training-Data.tar.gz |
| 数据过滤 | tools/filtering_pipeline.py |
| 依赖 | diffusers==0.32.2, transformers==4.43.3, liger_kernel==0.6.1, deepspeed==0.18.8 |
| 显存 · 标准 | ~28GB(offload) |
| 显存 · 2K | ~38GB |
| 显存 · ComfyUI | 8-12GB(量化) |
§2FLUX-IR 深度解析
1. 核心方法论 · ODE 轨迹学习
论文宣称的两个核心创新:Reinforced ODE Alignment 和 Distillation Cost-Aware ODE Acceleration。这才是 FLUX-IR 的灵魂——不是架构创新,是训练目标创新。
Reinforced ODE Alignment
Unified_restoration/src/flux/sampling.py:241-356 的 denoise_controlnet_rein():
- 双轨迹 reward:
X_ode_t纯 ODE 求解轨迹X_sde_t在sde_step控制的步骤注入随机的 SDE 轨迹
- 用 SDE 作为 reward 信号对齐 ODE 轨迹
Distillation Cost-Aware ODE Acceleration
sampling.py:690-789 的 denoise_controlnet_distill():
- 把 10 步蒸馏到 5-8 步
- 损失设计:
loss_t2O(蒸馏 vs 原轨迹)+loss_t2G(vs GT)+ 基础像素损失 - 成本感知:
t_stride ∈ {10, 8, 6}加权不同阶段(model.py:99-193)
两阶段训练
Task_specific_restoration/model/model.py:203-227
optimize_parameters_reinforce()→ 第一阶段:训练 ODE/SDE 对齐optimize_parameters_distill()→ 第二阶段:蒸馏加速
2. Unified 模型结构 · ControlNet 改造 FLUX-dev
ControlNet 集成
src/flux/controlnet.py:33-223
ControlNetFlux继承 FLUX 架构,加 2 层 DoubleStreamBlock- controlnet.py:27-30
zero_module()零初始化保证初期无扰(ControlNet 标准做法) - controlnet.py:174-179 条件图像压缩到 latent + 位置编码融合
- model.py:199-200 残差加法:
img = img + block_controlnet_hidden_states[index_block % 2]
双 VAE encoder(关键创新)
xflux_pipeline.py:43-44
self.ae2 = load_ae(model_type, device=...)
self.ae2.encoder.load_state_dict(torch.load('checkpoints/encoder_lq.bin'))
ae吃 GTae2吃 LQ,独立训练的 encoder- 目的:学习品质差异。原生 FLUX VAE 对 LQ 图像编码效果差,专门训一个 LQ encoder 能把退化信息更好地映射到 latent 空间
RAM 标签(可选增强)
main.py:30-41 / main_.py:92-104
- Recognize Anything Model 自动打图像标签
prompts = f"{ram_prompt}, {args.prompt}"(main_.py版本)main.py里被注释,作为可选增强
3. 任务条件化
无显式任务 embedding —— 全靠 prompt 文本 + control_weight 隐式区分:
| 任务 | Prompt | control_weight |
|---|---|---|
| 超分 | high-resolution, ultra-sharp, detailed | 0.8 |
| 去噪 | noise-free, clean, smooth | 0.8 |
| 低光 | bright, clear, vivid | 0.9 |
| 去雨 | remove raindrops, clean | 0.9 |
控制权重缩放
sampling.py:312
block_controlnet_hidden_states = [i * controlnet_gs for i in ...]
controlnet_gs 直接映射命令行 --control_weight。任务越"重"(去雨、低光),weight 越大。
4. 训练与推理架构
两个 main 脚本
main.py中心裁剪 1024×1024 单图推理(main.py:220-221)main_.py分块推理 + 加权融合,支持任意尺寸(main_.py:23-77)
推理流程
xflux_pipeline.py:239-500
- 图像 → VAE 编码 → 4×64×64 latent
denoise_controlnet()ODE 求解(21 步)- latent → VAE 解码 → 输出图像
- 支持
offload=True降显存
损失函数
Task_specific_restoration/model/model.py:99-193
- 蒸馏阶段:
MSE(distilled_traj, GT_traj) + MSE(output, target) - 强化阶段:
MSE(X_ode, X_sde) + MSE(final_output, GT) - 像素损失占主导,轨迹损失辅助收敛
5. Task-specific vs Unified 对比
Task-specific
- 每任务独立模型(LOLv1/Raindrop/Underwater)
- 纯 DDPM U-Net,未改 FLUX
core/只有指标计算代码- 轻量但要训 N 个模型
Unified
- 单 FLUX-dev + ControlNet 适配
- 基于 X-FLUX 魔改(README.md:102 致谢)
- 参数共享,一模型搞 10+ 任务
- 新增:
xflux_pipeline.py管道 ·controlnet.py新类 ·model.py:165-200ControlNet 残差注入
6. 可复现性
| 项目 | 值 |
|---|---|
| 权重 | checkpoints/FluxIR.bin + checkpoints/encoder_lq.bin |
| 下载 | Google Drive |
| 显存 @1024² (bf16) | FLUX-dev ~21GB + VAE 解码 ~8GB |
| 建议 GPU | H100 / A100 / L40S(40GB+) |
| 依赖 | PyTorch 2.4+, CUDA 12+, diffusers / transformers / accelerate / deepspeed |
| 推理步数 | 21(蒸馏后) |
§3两篇正面对比
| 维度 | LucidFlux | FLUX-IR |
|---|---|---|
| 核心创新 | 双分支 conditioner + SigLIP caption-free | RL-ODE 对齐 + 轨迹蒸馏 |
| 训练目标 | 只训 conditioner | ControlNet + LQ VAE encoder |
| 任务区分 | 不区分(通用) | Prompt + control_weight |
| 预处理 | SwinIR 前置超分 | RAM 标签(可选) |
| 基础代码 | 黑森林 FLUX.1 官方 | X-FLUX 魔改 |
| 推理步数 | 标准 | 21 步(蒸馏后) |
| 场次 | ICLR 2026 | TPAMI 2025 |
| 对原 FLUX 侵入 | 几乎零 | 中等(双 VAE + ControlNet) |
§4深挖 1 · LucidFlux Modulation 怎么实现分层自适应
这段代码是 LucidFlux 能用"只有 2 个 DoubleStreamBlock 输出的 conditioner"喂"FLUX 19 层主干"的魔法所在。
完整实现
src/flux/lucidflux.py:161-186
class Modulation(nn.Module):
def __init__(self, dim, bias=True):
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, 2 * dim, bias=bias)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=dim)
self.control_index_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=dim)
# ↑ 关键:多加了一个 control_index 嵌入器
def forward(self, x, timestep, control_index):
timesteps_proj = self.time_proj(timestep * 1000)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=x.dtype))
# control_index 也走正弦位置编码 + MLP
control_index_proj = self.time_proj(control_index)
control_index_emb = self.control_index_embedder(control_index_proj.to(dtype=x.dtype))
# 两个 embedding 相加,一起产生调制参数
timesteps_emb = timesteps_emb + control_index_emb
emb = self.linear(self.silu(timesteps_emb))
shift_msa, scale_msa = emb.chunk(2, dim=1)
return self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
关键洞察
control_index 嵌入——这个 index 就是要注入的目标 block 编号(0-18)。两个 embedding 相加后过 linear 产生 shift/scale,实现:「在第
t 步、针对第 k 个 block,调 conditioner 信号的尺度和偏移」
调用处:拉伸 2 → 19
src/flux/lucidflux.py:232-243
out = []
num_blocks = 19
for i in range(num_blocks // 2 + 1): # i=0..9
for control_index, (lq, pre) in enumerate(zip(out_lq, out_pre)):
control_index = torch.tensor(control_index, device=timesteps.device, dtype=timesteps.dtype)
lq = self.modulation_lq(lq, timesteps, i * 2 + control_index)
if len(out) == num_blocks:
break
pre = self.modulation_pre(pre, timesteps, i * 2 + control_index)
out.append(lq + pre) # 两支相加,按层位置注入 FLUX
return out
解决方案:同一个 conditioner 输出 + 19 个不同的
(timestep, control_index) 组合 → 19 份不同调制结果,相当于用 control_index 把 2 个原始信号"拉伸"成 19 个位置感知的信号。收益:参数量极省——如果直接让 conditioner 输出 19 个残差,训练参数至少翻 9 倍。
§5深挖 2 · FLUX-IR Reinforce ODE 用 SDE 当 reward
数据流概览
Unified_restoration/src/flux/sampling.py:241-355
x_0 = img纯噪声x_1 = controlnet_cond[1:2]训练时是 GT 图像 latent,推理用 LQimg = (1-t)*x_1 + t*x_0线性插值初始化 —— rectified flow 的标准套路X_ode_t/X_sde_t两个列表记录两条轨迹
采样循环
sampling.py:291-353
for t_curr, t_prev in zip(timesteps[t_start:], timesteps[t_start+1:]):
block_res_samples = controlnet(...) # ControlNet 残差
pred = model(...,
block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples])
# CFG 分支被硬关(if ... and False),节省显存
if i == sde_step and sample_sde:
img = sde_sampling(t_curr, t_curr - t_prev, img, pred, seed) # 注入 SDE 噪声
else:
img = img + (t_prev - t_curr) * pred # 普通 ODE 步
X_sde_t.append(img)
i += 1
SDE 注入细节
sampling.py:471-496 的 sde_sampling():
def sde_sampling(t_curr, deltaT, x_curr, pred_noise_residual, seed):
eplson = get_noise(1, 1024, 1024, device=x_curr.device, dtype=torch.float32, seed=seed)
eplson = rearrange(eplson, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
alpha = torch.tensor([1 + np.random.random() * 1]) # α ∈ [1, 2]
# 方差补偿公式(保证 SDE 边际分布与 ODE 一致)
beta_ = ((t_curr - deltaT)**2 * (1 - (t_curr - alpha * deltaT))**2 /
(1 - (t_curr - deltaT))**2) - (t_curr - alpha * deltaT)**2
beta = torch.sqrt(beta_)
while beta_ < 0: # 数值稳定:重采样 alpha
alpha = torch.tensor([np.random.randint(2, 10)])
beta_ = ((t_curr - deltaT)**2 * (1 - (t_curr - alpha * deltaT))**2 /
(1 - (t_curr - deltaT))**2) - (t_curr - alpha * deltaT)**2
beta = torch.sqrt(beta_)
if beta_ > 0:
break
# ... 后续:alpha 控制漂移项,beta 控制噪声项
方法论精髓
- ODE 确定性采样:精度高但可能卡在局部最优
- SDE 随机注入:能跳出局部最优探索更好轨迹
- Loss =
MSE(X_sde_t, X_ode_t)+ 其他项 - 用 SDE 的随机探索当自监督 teacher,蒸馏回确定性 ODE
if i >= timestep_to_start_cfg and False: —— CFG 分支被硬关了,只剩条件路径。这是为了减显存(去雨/去噪不太需要负 prompt)。代码里留着原始 CFG 实现以防需要重启。
§6两个创新点的本质
| LucidFlux Modulation | FLUX-IR Reinforce | |
|---|---|---|
| 解决的问题 | 2 个 conditioner 输出怎么喂 19 个 FLUX block | ODE 确定性采样怎么学到 SDE 的鲁棒性 |
| 代价 | 增加 2 个 Modulation 模块(几 MB 参数) | 训练时 2 份前向,推理不变 |
| 可迁移性 | 可插入任何 DiT-style 模型 | 可用于任何 rectified flow 模型 |
| 能否组合 | ✓ 理论正交,可拼 | |
grep 一下源码。
grep 验证