欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 八卦 > 论文阅读Diffusion Autoencoders: Toward a Meaningful and Decodable Representation

论文阅读Diffusion Autoencoders: Toward a Meaningful and Decodable Representation

2025/4/5 20:13:56 来源:https://blog.csdn.net/vivi_cin/article/details/147003221  浏览:    关键词:论文阅读Diffusion Autoencoders: Toward a Meaningful and Decodable Representation

原文框架图:

官方代码: https://github.com/phizaz/diffae/blob/master/interpolate.ipynb

主要想记录一下模型的推理过程 :

%load_ext autoreload
%autoreload 2
from templates import *
device = 'cuda:1'
conf = ffhq256_autoenc()
# print(conf.name)
model = LitModel(conf)
state = torch.load(f'checkpoints/{conf.name}/last.ckpt', map_location='cpu')
model.load_state_dict(state['state_dict'], strict=False)
model.ema_model.eval()
model.ema_model.to(device);
Global seed set to 0
Model params: 160.69 M
data = ImageDataset('imgs_interpolate', image_size=conf.img_size, exts=['jpg', 'JPG', 'png'], do_augment=False)
batch = torch.stack([data[0]['img'],data[1]['img'],
])
import matplotlib.pyplot as plt
plt.imshow(batch[0].permute([1, 2, 0]) / 2 + 0.5)

cond = model.encode(batch.to(device))
xT = model.encode_stochastic(batch.to(device), cond, T=250)import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ori = (batch + 1) / 2
ax[0].imshow(ori[0].permute(1, 2, 0).cpu())
ax[1].imshow(xT[0].permute(1, 2, 0).cpu())

 

 Interpolate

Semantic codes are interpolated using convex combination, while stochastic codes are interpolated using spherical linear interpolation.

import numpy as np
alpha = torch.tensor(np.linspace(0, 1, 10, dtype=np.float32)).to(cond.device)
intp = cond[0][None] * (1 - alpha[:, None]) + cond[1][None] * alpha[:, None]def cos(a, b):a = a.view(-1)b = b.view(-1)a = F.normalize(a, dim=0)b = F.normalize(b, dim=0)return (a * b).sum()theta = torch.arccos(cos(xT[0], xT[1]))
x_shape = xT[0].shape
intp_x = (torch.sin((1 - alpha[:, None]) * theta) * xT[0].flatten(0, 2)[None] + torch.sin(alpha[:, None] * theta) * xT[1].flatten(0, 2)[None]) / torch.sin(theta)
intp_x = intp_x.view(-1, *x_shape)pred = model.render(intp_x, intp, T=20)import matplotlib.pyplot as plt
# torch.manual_seed(1)
fig, ax = plt.subplots(1, 10, figsize=(5*10, 5))
for i in range(len(alpha)):ax[i].imshow(pred[i].permute(1, 2, 0).cpu())
# plt.savefig('imgs_manipulated/compare.png')

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词