Imagen PyTorch Example (script)

 1from os import path
 2from uuid import uuid4
 3
 4import torch
 5from imagen_pytorch import BaseUnet64, Imagen, ImagenTrainer, Unet
 6from mp_time_split.core import MPTimeSplit
 7
 8from xtal2png.core import XtalConverter
 9
10low_mem = False
11max_batch_size = 16
12
13mpt = MPTimeSplit()
14mpt.load()
15
16fold = 0
17train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(fold)
18
19xc = XtalConverter(save_dir="tmp", encode_as_primitive=True, decode_as_primitive=True)
20arrays, _, _ = xc.structures_to_arrays(train_inputs.tolist(), rgb_scaling=False)
21training_images = torch.from_numpy(arrays).float().cuda()
22
23# unets for unconditional imagen
24
25unet1 = Unet(
26    dim=32,
27    dim_mults=(1, 2, 4),
28    num_resnet_blocks=3,
29    layer_attns=(False, True, True),
30    layer_cross_attns=(False, True, True),
31    use_linear_attn=True,
32)
33
34if low_mem:
35    unet2 = Unet(
36        dim=32,
37        dim_mults=(1, 2, 4),
38        num_resnet_blocks=3,
39        layer_attns=(False, True, True),
40        layer_cross_attns=(False, True, True),
41        use_linear_attn=True,
42    )
43else:
44    # unet2 = SRUnet256()
45    unet2 = BaseUnet64()
46
47# imagen, which contains the unets above (base unet and super resoluting ones)
48
49imagen = Imagen(
50    condition_on_text=False,  # this must be set to False for unconditional Imagen
51    unets=(unet1, unet2),
52    channels=1,
53    image_sizes=(32, 64),
54    timesteps=1000,
55)
56
57trainer = ImagenTrainer(imagen).cuda()
58
59# train each unet in concert, or separately (recommended) to completion
60
61for u in (1, 2):
62    loss = trainer(training_images, unet_number=u, max_batch_size=max_batch_size)
63    trainer.update(unet_number=u)
64
65# do the above for many many many many steps
66# now you can sample images unconditionally from the cascading unet(s)
67
68images = trainer.sample(batch_size=16, return_pil_images=True)  # (16, 3, 128, 128)
69
70results_folder = path.join(
71    "data", "interim", "imagen-pytorch", f"fold={fold}", str(uuid4())[0:4] + ".pt"
72)
73trainer.save(results_folder)
741 + 1