文章主题:网上, stable diffusion, 代码解读, txt2img.py
666AI工具大全,助力做AI时代先行者!
前言
在网络上未能找到关于stable diffusion代码的详细解释,因此作者决定记录下自己的阅读过程和进度,以便日后参考和回顾。以下是作者在阅读过程中所做笔记的摘要。
官方代码
related contents:
stable Diffusion 代码(二)
Stable Diffusion 代码 (三)
Stable Diffusion 代码(四)
模型的初始化和载入
从prompt生成图片时的命令为
python scripts/txt2img.py –prompt “a photograph of an astronaut riding a horse” –plms
就从入口txt2img.py开始阅读。跳过传入参数的parser部分
# 设定随机seed
seed_everything(opt.seed)
config = OmegaConf.load(f“{opt.config}”)
model = load_model_from_config(config, f“{opt.ckpt}”)
其中 opt.config= configs/stable-diffusion/v1-inference.yaml,指向一个预定义好的配置文件ckpt是预先下载好的模型
在进一步探讨`txt2img.py`文件中的`load_model_from_config`函数之前,我们需要先了解该函数所依赖的`ldm.util`模块中的两种方法。尽管这两个方法在同一文件中定義,但它们在功能上各不相同,分别负责模型的加载与处理。首先,`ldm.util.load_model`方法用于从文件或网络中加载预训练模型,并将其保存在一个指定的目录下。而`ldm.util.preprocess_image`方法则专注于对输入图像进行预处理,例如缩放、裁剪等操作,以便于模型能够更好地适应特定任务的需求。在这个函数中,`load_model_from_config`调用了`ldm.util.load_model`和`ldm.util.preprocess_image`这两种方法,从而实现了对模型的加载和图像的预处理。这种设计使得函数能够在处理文本到图像的任务时,充分利用已有的预训练模型和相应的预处理方法,提高整体的效率和效果。
def instantiate_from_config(config):
return get_obj_from_str(config[“target”])(**config.get(“params”, dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(“.”, 1)
return getattr(importlib.import_module(module, package=None), cls)
def load_model_from_config(config, ckpt):
pl_sd = torch.load(ckpt, map_location=“cpu”)
sd = pl_sd[“state_dict”]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
实际上等效于
from ldm.models.diffusion.ddpm import LatentDiffusion
model = LatentDiffusion(**config.model.get(“params”, dict()))
model.load_state_dict(torch.load(ckpt, map_location=“cpu”)[“state_dict”], strict=False)
原code使用importlib.import_module,来读取字典中的模块名称进行灵活的import。从方便理解代码运行和算法原理的视角来看,在实际使用LatentDiffusion时,上下两种写法是完全等效的。
这里多说一句,Config字典类似于
Config = { target: path1.path2.module_1_name,
params: { para_1 : value_a,
para_2 : value_b,
module_2:{ target: path1.path2.module_2_name,
params: { para_3 : value_c,
module_3:{ target: path1.path2.module_3_name,
params : {para_4: value_d }
}}}}}
get_obj_from_str接收config字典中target对应的值来导入对应的模块,
在 instantiate_from_config 返回对应的类的实例,返回的实例是以params对应的值初始化的params对应的值是同等格式的字典。
换句话说,我们可以在config文件中,按照上述示例设定好各个嵌套模块,并在模块实例化时,通过传入的config参数进行初始化。在模块的__init__方法内,可以调用instantiate_from_config函数,从而实现各模块的嵌套式实例化。具体操作可参考第三篇文章。
# 初始化模型的全部逻辑:
from ldm.models.diffusion.ddpm import LatentDiffusion
import torch
from omegaconf import OmegaConf
# 读取config
config = OmegaConf.load(f“{opt.config}“)
# 初始化模型并传入config中的参数
model = LatentDiffusion(**config.model.get(“params”, dict()))
model.load_state_dict(torch.load(ckpt, map_location=“cpu”)[“state_dict”], strict=False)
device = torch.device(“cuda”)
model = model.to(device)
图像生成的准备和图像的生成
在成功引入模型(Model)之后,接下来便是样本生成器(Sampler)的初始化过程。这个过程主要依赖于命令行参数(–plms),并根据其执行相应的判断语句第一条。
sampler = PLMSSampler(model)
原代码立即提供了两种输入prompt的方式,一是通过命令行接口,二是从文件中读取,这两种方式并非关键。最终,prompt被成功导入到data变量中,实现了其功能。
data = [batch_size * [prompt]]
到这里,我们有了
model-[LatentDiffusion]sampler-[PLMSSampler] prompt
这样就可以开始生成图片了。
在本篇中,我们将重点讨论两个关键组件的定义,分别为PLMSSampler以及LatentDiffusion。为了更好地理解这些组件的功能,我们暂时将其视为神秘的内核,仅在必要时深入研究其源代码。
这里先简单回忆一下classifier-free guidance的方法:
ϵ(x,t)=ϵ(x,t|ϕ)+α⋅(ϵ(x,t|c)−ϵ(x,t|ϕ))\epsilon(x, t)= \epsilon(x,t ~| ~\phi) + \alpha\cdot (\epsilon(x,t~|~ c) -\epsilon(x,t~ |~ \phi))
因此除了prompt,也就是上式中c所对应的条件,还需要unconditional的ϕ\phi 。
c = model.get_learned_conditioning(prompts)
uc = model.get_learned_conditioning(batch_size * [“”])
这里可以看到model中的一个方法 get_learned_conditioning() : 输入text, 输出text的embedding 。
之后就是图像的生成了。图像的生成调用sampler实例的sample方法。这里为了直观的理解省略了几个参数,完整的参数和具体的各个参数的作用在后面sampler的代码解读部分再说。
samples_ddim, _ = sampler.sample(S=50,
conditioning=c,
batch_size=1,
shape=[4,64,64],
unconditional_guidance_scale=7.5,
unconditional_conditioning=uc,
eta=opt.ddim_eta)
x_samples_ddim = model.decode_first_stage(samples_ddim)
到这里为止,diffusion的任务已经结束了,x_samples_ddim 再经过基本的图像处理就是最终的结果。
以上就是txt2img.py文件的全部内容。这一部分绝大多数代码都是数据的读写和准备工作,核心逻辑部分比较少,还是比较好理解的。
接下来进入plms文件去看sampler的代码实现。
AI时代,拥有个人微信机器人AI助手!AI时代不落人后!
免费ChatGPT问答,办公、写作、生活好得力助手!
搜索微信号aigc666aigc999或上边扫码,即可拥有个人AI助手!