Denoising Diffusion PyTorch Pretrained Sample (script)

  1import os
  2import re
  3from os import path
  4from pathlib import Path
  5
  6import numpy as np
  7from denoising_diffusion_pytorch import GaussianDiffusion, Trainer, Unet
  8from mp_time_split.core import MPTimeSplit
  9from PIL import Image
 10from pymatgen.core.composition import Composition
 11from pymatviz.elements import ptable_heatmap_plotly
 12
 13from xtal2png.core import XtalConverter
 14from xtal2png.utils.data import rgb_scaler
 15
 16fold = 0
 17
 18model = Unet(dim=64, dim_mults=(1, 2, 4, 8), channels=1).cuda()
 19
 20diffusion = GaussianDiffusion(
 21    model, channels=1, image_size=64, timesteps=1000, loss_type="l1"
 22).cuda()
 23
 24train_batch_size = 32
 25print("train_batch_size: ", train_batch_size)
 26
 27uid = "5eab"
 28results_folder = path.join(
 29    "data", "interim", "denoising_diffusion_pytorch", f"fold={fold}", uid
 30)
 31Path(results_folder).mkdir(exist_ok=True, parents=True)
 32
 33data_path = path.join("data", "preprocessed", "mp-time-split", f"fold={fold}")
 34
 35fnames = os.listdir(results_folder)
 36
 37# i.e. "model-1.pt" --> "1.pt" --> "1" --> 1
 38checkpoints = [int(name.split("-")[1].split(".")[0]) for name in fnames]
 39checkpoint = np.max(checkpoints)
 40
 41print(f"checkpoint: {checkpoint}")
 42
 43trainer = Trainer(
 44    diffusion,
 45    data_path,
 46    image_size=64,
 47    train_batch_size=train_batch_size,
 48    train_lr=2e-5,
 49    train_num_steps=700000,  # total training steps
 50    gradient_accumulate_every=2,  # gradient accumulation steps
 51    ema_decay=0.995,  # exponential moving average decay
 52    amp=True,  # turn on mixed precision
 53    augment_horizontal_flip=False,
 54    results_folder=results_folder,
 55)
 56
 57trainer.load(checkpoint)
 58
 59diffusion = trainer.model
 60
 61img_arrays_torch = diffusion.sample(batch_size=16)
 62unscaled_arrays = np.squeeze(img_arrays_torch.cpu().numpy())
 63rgb_arrays = rgb_scaler(unscaled_arrays, data_range=(0, 1))
 64
 65mode = "L"
 66if mode == "RGB":
 67    rgb_arrays = [arr.transpose(1, 2, 0) for arr in rgb_arrays]
 68sampled_images = [Image.fromarray(np.uint8(arr), mode) for arr in rgb_arrays]
 69
 70gen_path = path.join(
 71    "data",
 72    "preprocessed",
 73    "mp-time-split",
 74    "denoising_diffusion_pytorch",
 75    f"fold={fold}",
 76    uid,
 77)
 78xc = XtalConverter(
 79    save_dir=gen_path, encode_as_primitive=True, decode_as_primitive=True
 80)
 81structures = xc.png2xtal(sampled_images, save=True)
 82
 83space_group = []
 84W = []
 85for s in structures:
 86    try:
 87        space_group.append(s.get_space_group_info(symprec=0.1)[1])
 88    except Exception as e:
 89        W.append(e)
 90        space_group.append(None)
 91print(space_group)
 92
 93mpt = MPTimeSplit()
 94mpt.load()
 95train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(fold)
 96
 97equimolar_compositions = train_inputs.apply(
 98    lambda s: Composition(re.sub(r"\d", "", s.formula))
 99)
100fig = ptable_heatmap_plotly(equimolar_compositions)
101fig.show()
102
1031 + 1
104
105# %% Code Graveyard
106# compositions = train_inputs.apply(lambda s: s.composition)
107# atomic_numbers = train_inputs.apply(lambda s: np.unique(s.atomic_numbers))