1from os import path
2from uuid import uuid4
3
4import torch
5from imagen_pytorch import BaseUnet64, Imagen, ImagenTrainer, Unet
6from mp_time_split.core import MPTimeSplit
7
8from xtal2png.core import XtalConverter
9
10low_mem = False
11max_batch_size = 16
12
13mpt = MPTimeSplit()
14mpt.load()
15
16fold = 0
17train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(fold)
18
19xc = XtalConverter(save_dir="tmp", encode_as_primitive=True, decode_as_primitive=True)
20arrays, _, _ = xc.structures_to_arrays(train_inputs.tolist(), rgb_scaling=False)
21training_images = torch.from_numpy(arrays).float().cuda()
22
23# unets for unconditional imagen
24
25unet1 = Unet(
26 dim=32,
27 dim_mults=(1, 2, 4),
28 num_resnet_blocks=3,
29 layer_attns=(False, True, True),
30 layer_cross_attns=(False, True, True),
31 use_linear_attn=True,
32)
33
34if low_mem:
35 unet2 = Unet(
36 dim=32,
37 dim_mults=(1, 2, 4),
38 num_resnet_blocks=3,
39 layer_attns=(False, True, True),
40 layer_cross_attns=(False, True, True),
41 use_linear_attn=True,
42 )
43else:
44 # unet2 = SRUnet256()
45 unet2 = BaseUnet64()
46
47# imagen, which contains the unets above (base unet and super resoluting ones)
48
49imagen = Imagen(
50 condition_on_text=False, # this must be set to False for unconditional Imagen
51 unets=(unet1, unet2),
52 channels=1,
53 image_sizes=(32, 64),
54 timesteps=1000,
55)
56
57trainer = ImagenTrainer(imagen).cuda()
58
59# train each unet in concert, or separately (recommended) to completion
60
61for u in (1, 2):
62 loss = trainer(training_images, unet_number=u, max_batch_size=max_batch_size)
63 trainer.update(unet_number=u)
64
65# do the above for many many many many steps
66# now you can sample images unconditionally from the cascading unet(s)
67
68images = trainer.sample(batch_size=16, return_pil_images=True) # (16, 3, 128, 128)
69
70results_folder = path.join(
71 "data", "interim", "imagen-pytorch", f"fold={fold}", str(uuid4())[0:4] + ".pt"
72)
73trainer.save(results_folder)
741 + 1