Open In Colab

Using xtal2png with denoising_diffusion_pytorch

[1]:
%pip install xtal2png mp-time-split denoising_diffusion_pytorch
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting xtal2png
  Downloading xtal2png-0.7.0-py3-none-any.whl (26 kB)
Collecting mp-time-split
  Downloading mp_time_split-0.1.4-py3-none-any.whl (38 kB)
Collecting denoising_diffusion_pytorch
  Downloading denoising_diffusion_pytorch-0.24.1-py3-none-any.whl (18 kB)
Requirement already satisfied: plotly in /usr/local/lib/python3.7/dist-packages (from xtal2png) (5.5.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from xtal2png) (4.11.4)
Collecting pymatgen
  Downloading pymatgen-2022.0.17.tar.gz (40.6 MB)
     |████████████████████████████████| 40.6 MB 1.1 MB/s
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Installing backend dependencies ... done
    Preparing wheel metadata ... done
Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from xtal2png) (7.1.2)
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
     |████████████████████████████████| 79.9 MB 139 kB/s
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from xtal2png) (1.21.6)
Collecting matminer
  Downloading matminer-0.7.8-py3-none-any.whl (1.4 MB)
     |████████████████████████████████| 1.4 MB 55.2 MB/s
Collecting pybtex
  Downloading pybtex-0.24.0-py2.py3-none-any.whl (561 kB)
     |████████████████████████████████| 561 kB 73.2 MB/s
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from mp-time-split) (4.1.1)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from mp-time-split) (1.0.2)
Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from denoising_diffusion_pytorch) (1.11.0+cu113)
Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from denoising_diffusion_pytorch) (0.12.0+cu113)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from denoising_diffusion_pytorch) (4.64.0)
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting ema-pytorch
  Downloading ema_pytorch-0.0.8-py3-none-any.whl (4.0 kB)
Collecting accelerate
  Downloading accelerate-0.10.0-py3-none-any.whl (117 kB)
     |████████████████████████████████| 117 kB 66.3 MB/s
Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from accelerate->denoising_diffusion_pytorch) (5.4.8)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from accelerate->denoising_diffusion_pytorch) (3.13)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from accelerate->denoising_diffusion_pytorch) (21.3)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->accelerate->denoising_diffusion_pytorch) (3.0.9)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->xtal2png) (3.8.0)
Collecting future>=0.18.2
  Downloading future-0.18.2.tar.gz (829 kB)
     |████████████████████████████████| 829 kB 67.7 MB/s
Collecting jsonschema>=4.5.1
  Downloading jsonschema-4.6.2-py3-none-any.whl (80 kB)
     |████████████████████████████████| 80 kB 10.2 MB/s
Collecting matminer
  Downloading matminer-0.7.7.tar.gz (5.2 MB)
     |████████████████████████████████| 5.2 MB 67.3 MB/s
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Installing backend dependencies ... done
    Preparing wheel metadata ... done
  Downloading matminer-0.7.6-py3-none-any.whl (1.4 MB)
     |████████████████████████████████| 1.4 MB 60.5 MB/s
Requirement already satisfied: pymongo>=4.0.1 in /usr/local/lib/python3.7/dist-packages (from matminer->mp-time-split) (4.1.1)
Collecting monty>=2022.1.12.1
  Downloading monty-2022.4.26-py3-none-any.whl (65 kB)
     |████████████████████████████████| 65 kB 5.0 MB/s
Collecting sympy>=1.9
  Downloading sympy-1.10.1-py3-none-any.whl (6.4 MB)
     |████████████████████████████████| 6.4 MB 53.9 MB/s
Collecting matminer
  Downloading matminer-0.7.5-py3-none-any.whl (1.4 MB)
     |████████████████████████████████| 1.4 MB 56.6 MB/s
  Downloading matminer-0.7.4-py3-none-any.whl (1.4 MB)
     |████████████████████████████████| 1.4 MB 70.4 MB/s
Collecting pint>=0.17
  Downloading Pint-0.18-py2.py3-none-any.whl (209 kB)
     |████████████████████████████████| 209 kB 78.1 MB/s
Collecting six>=1.16.0
  Downloading six-1.16.0-py2.py3-none-any.whl (11 kB)
Requirement already satisfied: pandas>=1.3.1 in /usr/local/lib/python3.7/dist-packages (from matminer->mp-time-split) (1.3.5)
Collecting requests>=2.26.0
  Downloading requests-2.28.1-py3-none-any.whl (62 kB)
     |████████████████████████████████| 62 kB 1.6 MB/s
Requirement already satisfied: jsonschema>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from matminer->mp-time-split) (4.3.3)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer->mp-time-split) (0.18.1)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer->mp-time-split) (21.4.0)
Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer->mp-time-split) (5.7.1)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.3.1->matminer->mp-time-split) (2.8.2)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.3.1->matminer->mp-time-split) (2022.1)
Collecting uncertainties>=3.1.4
  Downloading uncertainties-3.1.7-py2.py3-none-any.whl (98 kB)
     |████████████████████████████████| 98 kB 10.4 MB/s
Requirement already satisfied: matplotlib>=1.5 in /usr/local/lib/python3.7/dist-packages (from pymatgen->xtal2png) (3.2.2)
Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from pymatgen->xtal2png) (0.8.9)
Collecting spglib>=1.9.9.44
  Downloading spglib-1.16.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)
     |████████████████████████████████| 325 kB 74.8 MB/s
Requirement already satisfied: networkx>=2.2 in /usr/local/lib/python3.7/dist-packages (from pymatgen->xtal2png) (2.6.3)
Collecting scipy>=1.5.0
  Downloading scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (38.1 MB)
     |████████████████████████████████| 38.1 MB 1.2 MB/s
Requirement already satisfied: palettable>=3.1.1 in /usr/local/lib/python3.7/dist-packages (from pymatgen->xtal2png) (3.3.0)
Collecting ruamel.yaml>=0.15.6
  Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
     |████████████████████████████████| 109 kB 64.8 MB/s
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen->xtal2png) (1.4.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen->xtal2png) (0.11.0)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from plotly->xtal2png) (8.0.1)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer->mp-time-split) (2.10)
Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer->mp-time-split) (2.1.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer->mp-time-split) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer->mp-time-split) (2022.6.15)
Collecting ruamel.yaml.clib>=0.2.6
  Downloading ruamel.yaml.clib-0.2.6-cp37-cp37m-manylinux1_x86_64.whl (546 kB)
     |████████████████████████████████| 546 kB 51.1 MB/s
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mp-time-split) (3.1.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mp-time-split) (1.1.0)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.7/dist-packages (from sympy>=1.9->matminer->mp-time-split) (1.2.1)
Collecting latexcodec>=1.0.4
  Downloading latexcodec-2.0.1-py2.py3-none-any.whl (18 kB)
Building wheels for collected packages: future, pymatgen
  Building wheel for future (setup.py) ... done
  Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491070 sha256=49ebf2a775b4a9be03eab6479131f188e02a215106e66b509f57347927359e33
  Stored in directory: /root/.cache/pip/wheels/56/b0/fe/4410d17b32f1f0c3cf54cdfb2bc04d7b4b8f4ae377e2229ba0
  Building wheel for pymatgen (PEP 517) ... done
  Created wheel for pymatgen: filename=pymatgen-2022.0.17-cp37-cp37m-linux_x86_64.whl size=41841028 sha256=90fbffdb7b5b650c4f460b36c03bf51ec1133b67ae98948129106940cf0e28c1
  Stored in directory: /root/.cache/pip/wheels/cf/f6/22/58a9be23c5f1b452770e02ff42047175eaf0f9c2f15219fc76
Successfully built future pymatgen
Installing collected packages: six, ruamel.yaml.clib, future, uncertainties, sympy, spglib, scipy, ruamel.yaml, requests, monty, pymatgen, pint, latexcodec, pybtex, matminer, kaleido, ema-pytorch, einops, accelerate, xtal2png, mp-time-split, denoising-diffusion-pytorch
  Attempting uninstall: six
    Found existing installation: six 1.15.0
    Uninstalling six-1.15.0:
      Successfully uninstalled six-1.15.0
  Attempting uninstall: future
    Found existing installation: future 0.16.0
    Uninstalling future-0.16.0:
      Successfully uninstalled future-0.16.0
  Attempting uninstall: sympy
    Found existing installation: sympy 1.7.1
    Uninstalling sympy-1.7.1:
      Successfully uninstalled sympy-1.7.1
  Attempting uninstall: scipy
    Found existing installation: scipy 1.4.1
    Uninstalling scipy-1.4.1:
      Successfully uninstalled scipy-1.4.1
  Attempting uninstall: requests
    Found existing installation: requests 2.23.0
    Uninstalling requests-2.23.0:
      Successfully uninstalled requests-2.23.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires requests~=2.23.0, but you have requests 2.28.1 which is incompatible.
google-colab 1.0.0 requires six~=1.15.0, but you have six 1.16.0 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.
Successfully installed accelerate-0.10.0 denoising-diffusion-pytorch-0.24.1 einops-0.4.1 ema-pytorch-0.0.8 future-0.18.2 kaleido-0.2.1 latexcodec-2.0.1 matminer-0.7.4 monty-2022.4.26 mp-time-split-0.1.4 pint-0.18 pybtex-0.24.0 pymatgen-2022.0.17 requests-2.28.1 ruamel.yaml-0.17.21 ruamel.yaml.clib-0.2.6 scipy-1.7.3 six-1.16.0 spglib-1.16.5 sympy-1.10.1 uncertainties-3.1.7 xtal2png-0.7.0

Data type cannot be displayed: application/vnd.colab-display-data+json

[2]:
from os import path
from pathlib import Path
from uuid import uuid4

from denoising_diffusion_pytorch import GaussianDiffusion, Trainer, Unet
from mp_time_split.core import MPTimeSplit

from xtal2png.core import XtalConverter

MPTimeSplit

Get the first fold of the Materials Project time split benchmark.

[3]:
dummy = True #@param {type:"boolean"}

mpt = MPTimeSplit()
mpt.load(dummy=dummy)

fold = 4 # last fold
train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(fold)
Reading file /usr/local/lib/python3.7/dist-packages/mp_time_split/utils/mp_dummy_time_summary.json.gz: 0it [00:00, ?it/s]
Decoding objects from /usr/local/lib/python3.7/dist-packages/mp_time_split/utils/mp_dummy_time_summary.json.gz: 100%|##########| 11/11 [00:00<00:00, 830.17it/s]

Convert via xtal2png

[36]:
channels =  3#@param ["1", "3"] {type:"raw", allow-input: true}
[37]:
data_path = path.join("data", "preprocessed", "mp-time-split", f"fold={fold}")
xc = XtalConverter(
    save_dir=data_path,
    encode_cell_type="primitive_standard",
    decode_cell_type="primitive_standard",
    channels=channels,
)
xc.xtal2png(train_inputs.tolist())
100%|██████████| 7/7 [00:00<00:00, 16.84it/s]
[37]:
[<PIL.Image.Image image mode=RGB size=64x64 at 0x7F8F514D3C10>,
 <PIL.Image.Image image mode=RGB size=64x64 at 0x7F8F514D3A50>,
 <PIL.Image.Image image mode=RGB size=64x64 at 0x7F8F514D3CD0>,
 <PIL.Image.Image image mode=RGB size=64x64 at 0x7F8F514D3C50>,
 <PIL.Image.Image image mode=RGB size=64x64 at 0x7F8EB437C710>,
 <PIL.Image.Image image mode=RGB size=64x64 at 0x7F8F514CB5D0>,
 <PIL.Image.Image image mode=RGB size=64x64 at 0x7F8F514CBE90>]

Denoising Diffusion Probabilistic Model Training

[38]:
if dummy:
  train_num_steps = 100
  timesteps = 10
  train_batch_size = 2
else:
  train_num_steps = 700000
  timesteps = 1000
  train_batch_size = 32

model = Unet(dim=64, dim_mults=(1, 2, 4, 8), channels=channels).cuda()

diffusion = GaussianDiffusion(
    model, channels=channels, image_size=64, timesteps=timesteps, loss_type="l1"
).cuda()

print("train_batch_size: ", train_batch_size)

results_folder = path.join(
    "data", "interim", "denoising_diffusion_pytorch", f"fold={fold}", str(uuid4())[0:4]
)
Path(results_folder).mkdir(exist_ok=True, parents=True)

trainer = Trainer(
    diffusion,
    data_path,
    train_batch_size=train_batch_size,
    train_lr=2e-5,
    train_num_steps=train_num_steps,  # total training steps
    gradient_accumulate_every=2,  # gradient accumulation steps
    ema_decay=0.995,  # exponential moving average decay
    amp=True,  # turn on mixed precision
    augment_horizontal_flip=False,
    results_folder=results_folder,
)

trainer.train()
train_batch_size:  2
training complete

Unconditionally Sample New Images

Note that the results will be non-converged/without much meaning if using the dummy dataset and training parameters.

[40]:
from PIL import Image
sampled_images = diffusion.sample(batch_size=100)
arr = sampled_images[0].cpu().numpy()
mode = "RGB" if channels == 3 else "L"
if mode == "RGB":
  arr = arr.transpose(1, 2, 0)
Image.fromarray(arr, mode=mode)
[40]:
../_images/notebooks_3.0-denoising-diffusion_12_1.png