1import os
2import re
3from os import path
4from pathlib import Path
5
6import numpy as np
7from denoising_diffusion_pytorch import GaussianDiffusion, Trainer, Unet
8from mp_time_split.core import MPTimeSplit
9from PIL import Image
10from pymatgen.core.composition import Composition
11from pymatviz.elements import ptable_heatmap_plotly
12
13from xtal2png.core import XtalConverter
14from xtal2png.utils.data import rgb_scaler
15
16fold = 0
17
18model = Unet(dim=64, dim_mults=(1, 2, 4, 8), channels=1).cuda()
19
20diffusion = GaussianDiffusion(
21 model, channels=1, image_size=64, timesteps=1000, loss_type="l1"
22).cuda()
23
24train_batch_size = 32
25print("train_batch_size: ", train_batch_size)
26
27uid = "5eab"
28results_folder = path.join(
29 "data", "interim", "denoising_diffusion_pytorch", f"fold={fold}", uid
30)
31Path(results_folder).mkdir(exist_ok=True, parents=True)
32
33data_path = path.join("data", "preprocessed", "mp-time-split", f"fold={fold}")
34
35fnames = os.listdir(results_folder)
36
37# i.e. "model-1.pt" --> "1.pt" --> "1" --> 1
38checkpoints = [int(name.split("-")[1].split(".")[0]) for name in fnames]
39checkpoint = np.max(checkpoints)
40
41print(f"checkpoint: {checkpoint}")
42
43trainer = Trainer(
44 diffusion,
45 data_path,
46 image_size=64,
47 train_batch_size=train_batch_size,
48 train_lr=2e-5,
49 train_num_steps=700000, # total training steps
50 gradient_accumulate_every=2, # gradient accumulation steps
51 ema_decay=0.995, # exponential moving average decay
52 amp=True, # turn on mixed precision
53 augment_horizontal_flip=False,
54 results_folder=results_folder,
55)
56
57trainer.load(checkpoint)
58
59diffusion = trainer.model
60
61img_arrays_torch = diffusion.sample(batch_size=16)
62unscaled_arrays = np.squeeze(img_arrays_torch.cpu().numpy())
63rgb_arrays = rgb_scaler(unscaled_arrays, data_range=(0, 1))
64
65mode = "L"
66if mode == "RGB":
67 rgb_arrays = [arr.transpose(1, 2, 0) for arr in rgb_arrays]
68sampled_images = [Image.fromarray(np.uint8(arr), mode) for arr in rgb_arrays]
69
70gen_path = path.join(
71 "data",
72 "preprocessed",
73 "mp-time-split",
74 "denoising_diffusion_pytorch",
75 f"fold={fold}",
76 uid,
77)
78xc = XtalConverter(
79 save_dir=gen_path, encode_as_primitive=True, decode_as_primitive=True
80)
81structures = xc.png2xtal(sampled_images, save=True)
82
83space_group = []
84W = []
85for s in structures:
86 try:
87 space_group.append(s.get_space_group_info(symprec=0.1)[1])
88 except Exception as e:
89 W.append(e)
90 space_group.append(None)
91print(space_group)
92
93mpt = MPTimeSplit()
94mpt.load()
95train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(fold)
96
97equimolar_compositions = train_inputs.apply(
98 lambda s: Composition(re.sub(r"\d", "", s.formula))
99)
100fig = ptable_heatmap_plotly(equimolar_compositions)
101fig.show()
102
1031 + 1
104
105# %% Code Graveyard
106# compositions = train_inputs.apply(lambda s: s.composition)
107# atomic_numbers = train_inputs.apply(lambda s: np.unique(s.atomic_numbers))