Open In Colab

Classification on matbench_mp_is_metal using xtal2png representation of crystal structures

Description

In this notebook, a convolutional neural network is applied to the matbench_mp_is_metal classification task using `xtal2png <https://xtal2png.readthedocs.io/en/latest/>`__ representations of crystal structures. Crystal structures are encoded as grayscale PNG images, but because the conversion operations are restricted to structures with fewer than 52 sites, the network is only trained on structures with num_sites <= 52. For structures in the test set with more than 52 sites, we simply predict the mode of the training outputs (i.e. the most common class in y_train, where X_train and y_train correspond to training inputs and labels respectively, with num_sites <= 52).

Benchmark Name

Matbench v0.1

Package Versions

Algorithm Description

A fairly simple CNN is created in vanilla PyTorch, very loosely based on the PyTorch implementation of AlexNet. Model surgery is then performed on the max-pooling and certain convolutional layers using MosaicML’s Composer library.

[1]:
%pip install matbench skorch xtal2png pytorch-lightning mosaicml
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting matbench
  Downloading matbench-0.5-py3-none-any.whl (9.9 MB)
     |████████████████████████████████| 9.9 MB 5.2 MB/s
Collecting skorch
  Downloading skorch-0.11.0-py3-none-any.whl (155 kB)
     |████████████████████████████████| 155 kB 58.5 MB/s
Collecting xtal2png
  Downloading xtal2png-0.8.0-py3-none-any.whl (30 kB)
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.6.4-py3-none-any.whl (585 kB)
     |████████████████████████████████| 585 kB 60.5 MB/s
Collecting mosaicml
  Downloading mosaicml-0.8.0-py3-none-any.whl (548 kB)
     |████████████████████████████████| 548 kB 54.0 MB/s
Collecting monty==2021.8.17
  Downloading monty-2021.8.17-py3-none-any.whl (65 kB)
     |████████████████████████████████| 65 kB 4.4 MB/s
Collecting matminer==0.7.4
  Downloading matminer-0.7.4-py3-none-any.whl (1.4 MB)
     |████████████████████████████████| 1.4 MB 46.3 MB/s
Collecting scikit-learn==1.0
  Downloading scikit_learn-1.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (23.1 MB)
     |████████████████████████████████| 23.1 MB 2.9 MB/s
Requirement already satisfied: pandas>=1.3.1 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (1.3.5)
Collecting pymatgen>=2022.0.11
  Downloading pymatgen-2022.0.17.tar.gz (40.6 MB)
     |████████████████████████████████| 40.6 MB 1.5 MB/s
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Installing backend dependencies ... done
    Preparing wheel metadata ... done
Requirement already satisfied: tqdm>=4.62.0 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (4.64.0)
Requirement already satisfied: numpy>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (1.21.6)
Collecting requests>=2.26.0
  Downloading requests-2.28.1-py3-none-any.whl (62 kB)
     |████████████████████████████████| 62 kB 1.6 MB/s
Collecting six>=1.16.0
  Downloading six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting future>=0.18.2
  Downloading future-0.18.2.tar.gz (829 kB)
     |████████████████████████████████| 829 kB 13.8 MB/s
Collecting pint>=0.17
  Downloading Pint-0.18-py2.py3-none-any.whl (209 kB)
     |████████████████████████████████| 209 kB 50.6 MB/s
Requirement already satisfied: jsonschema>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (4.3.3)
Collecting sympy>=1.8
  Downloading sympy-1.10.1-py3-none-any.whl (6.4 MB)
     |████████████████████████████████| 6.4 MB 45.2 MB/s
Requirement already satisfied: pymongo>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (4.1.1)
Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn==1.0->matbench) (1.4.1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn==1.0->matbench) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn==1.0->matbench) (3.1.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (4.1.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (4.11.4)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (21.4.0)
Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (5.7.1)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (0.18.1)
Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema>=3.2.0->matminer==0.7.4->matbench) (3.8.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.3.1->matminer==0.7.4->matbench) (2022.1)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.3.1->matminer==0.7.4->matbench) (2.8.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pint>=0.17->matminer==0.7.4->matbench) (21.3)
Collecting uncertainties>=3.1.4
  Downloading uncertainties-3.1.7-py2.py3-none-any.whl (98 kB)
     |████████████████████████████████| 98 kB 10.1 MB/s
Requirement already satisfied: matplotlib>=1.5 in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (3.2.2)
Requirement already satisfied: plotly>=4.5.0 in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (5.5.0)
Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (0.8.9)
Requirement already satisfied: networkx>=2.2 in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (2.6.3)
Requirement already satisfied: palettable>=3.1.1 in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (3.3.0)
Collecting ruamel.yaml>=0.15.6
  Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
     |████████████████████████████████| 109 kB 67.1 MB/s
Collecting scipy>=1.1.0
  Downloading scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (38.1 MB)
     |████████████████████████████████| 38.1 MB 1.1 MB/s
Collecting spglib>=1.9.9.44
  Downloading spglib-1.16.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)
     |████████████████████████████████| 325 kB 68.2 MB/s
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen>=2022.0.11->matminer==0.7.4->matbench) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen>=2022.0.11->matminer==0.7.4->matbench) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen>=2022.0.11->matminer==0.7.4->matbench) (1.4.3)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from plotly>=4.5.0->pymatgen>=2022.0.11->matminer==0.7.4->matbench) (8.0.1)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer==0.7.4->matbench) (2022.6.15)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer==0.7.4->matbench) (2.10)
Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer==0.7.4->matbench) (2.1.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer==0.7.4->matbench) (1.24.3)
Collecting ruamel.yaml.clib>=0.2.6
  Downloading ruamel.yaml.clib-0.2.6-cp37-cp37m-manylinux1_x86_64.whl (546 kB)
     |████████████████████████████████| 546 kB 52.1 MB/s
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.7/dist-packages (from sympy>=1.8->matminer==0.7.4->matbench) (1.2.1)
Collecting xtal2png
  Downloading xtal2png-0.7.1-py3-none-any.whl (30 kB)
  Downloading xtal2png-0.7.0-py3-none-any.whl (26 kB)
Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from xtal2png) (7.1.2)
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
     |████████████████████████████████| 79.9 MB 107 kB/s
Requirement already satisfied: torch>=1.8.* in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.11.0+cu113)
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
     |████████████████████████████████| 140 kB 48.5 MB/s
Collecting PyYAML>=5.4
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
     |████████████████████████████████| 596 kB 63.7 MB/s
Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (2.8.0)
Collecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Requirement already satisfied: protobuf<=3.20.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (3.17.3)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.9.2-py3-none-any.whl (419 kB)
     |████████████████████████████████| 419 kB 71.5 MB/s
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
     |████████████████████████████████| 1.1 MB 21.8 MB/s
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.6)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.37.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.8.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.6.1)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.1)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.1.0)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (57.4.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.46.3)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.7)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.35.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.4)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.1)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.2.0)
Collecting yahp<0.2,>=0.1.1
  Downloading yahp-0.1.1-py3-none-any.whl (42 kB)
     |████████████████████████████████| 42 kB 1.4 MB/s
Collecting coolname<2,>=1.1.0
  Downloading coolname-1.1.0-py2.py3-none-any.whl (35 kB)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.7.3-py3-none-any.whl (398 kB)
     |████████████████████████████████| 398 kB 58.3 MB/s
Collecting psutil<6,>=5.8.0
  Downloading psutil-5.9.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)
     |████████████████████████████████| 281 kB 68.2 MB/s
Collecting py-cpuinfo>=8.0.0
  Downloading py-cpuinfo-8.0.0.tar.gz (99 kB)
     |████████████████████████████████| 99 kB 12.3 MB/s
Requirement already satisfied: torchvision>=0.10.0 in /usr/local/lib/python3.7/dist-packages (from mosaicml) (0.12.0+cu113)
Collecting torch-optimizer<0.2,>=0.1.0
  Downloading torch_optimizer-0.1.0-py3-none-any.whl (72 kB)
     |████████████████████████████████| 72 kB 1.2 MB/s
Collecting pytorch-ranger>=0.1.1
  Downloading pytorch_ranger-0.1.1-py3-none-any.whl (14 kB)
Collecting docstring-parser<=0.15,>=0.14.1
  Downloading docstring_parser-0.14.1-py3-none-any.whl (33 kB)
Collecting aiosignal>=1.1.2
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting multidict<7.0,>=4.5
  Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB)
     |████████████████████████████████| 94 kB 3.8 MB/s
Collecting asynctest==0.13.0
  Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
     |████████████████████████████████| 271 kB 70.9 MB/s
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
     |████████████████████████████████| 144 kB 75.5 MB/s
Collecting async-timeout<5.0,>=4.0.0a3
  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Building wheels for collected packages: future, pymatgen, py-cpuinfo
  Building wheel for future (setup.py) ... done
  Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491070 sha256=a86608dcbf5d57e0bb04b7397a6ff0717fa5f2d6787aaec041f0994bc62417bf
  Stored in directory: /root/.cache/pip/wheels/56/b0/fe/4410d17b32f1f0c3cf54cdfb2bc04d7b4b8f4ae377e2229ba0
  Building wheel for pymatgen (PEP 517) ... done
  Created wheel for pymatgen: filename=pymatgen-2022.0.17-cp37-cp37m-linux_x86_64.whl size=41841047 sha256=6c3a03400645661f0cc1ca1d5e49dc664a3ea401be740ccbe6f1b31506f36f26
  Stored in directory: /root/.cache/pip/wheels/cf/f6/22/58a9be23c5f1b452770e02ff42047175eaf0f9c2f15219fc76
  Building wheel for py-cpuinfo (setup.py) ... done
  Created wheel for py-cpuinfo: filename=py_cpuinfo-8.0.0-py3-none-any.whl size=22257 sha256=3bff97a046261d768dd31b6a2153ad19fb4cb198325d57ec80db1a041ccda148
  Stored in directory: /root/.cache/pip/wheels/d2/f1/1f/041add21dc9c4220157f1bd2bd6afe1f1a49524c3396b94401
Successfully built future pymatgen py-cpuinfo
Installing collected packages: six, ruamel.yaml.clib, requests, multidict, future, frozenlist, yarl, uncertainties, sympy, spglib, scipy, ruamel.yaml, monty, asynctest, async-timeout, aiosignal, scikit-learn, PyYAML, pytorch-ranger, pymatgen, pyDeprecate, pint, fsspec, docstring-parser, aiohttp, yahp, torchmetrics, torch-optimizer, py-cpuinfo, psutil, matminer, kaleido, coolname, xtal2png, skorch, pytorch-lightning, mosaicml, matbench
  Attempting uninstall: six
    Found existing installation: six 1.15.0
    Uninstalling six-1.15.0:
      Successfully uninstalled six-1.15.0
  Attempting uninstall: requests
    Found existing installation: requests 2.23.0
    Uninstalling requests-2.23.0:
      Successfully uninstalled requests-2.23.0
  Attempting uninstall: future
    Found existing installation: future 0.16.0
    Uninstalling future-0.16.0:
      Successfully uninstalled future-0.16.0
  Attempting uninstall: sympy
    Found existing installation: sympy 1.7.1
    Uninstalling sympy-1.7.1:
      Successfully uninstalled sympy-1.7.1
  Attempting uninstall: scipy
    Found existing installation: scipy 1.4.1
    Uninstalling scipy-1.4.1:
      Successfully uninstalled scipy-1.4.1
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 1.0.2
    Uninstalling scikit-learn-1.0.2:
      Successfully uninstalled scikit-learn-1.0.2
  Attempting uninstall: PyYAML
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
  Attempting uninstall: psutil
    Found existing installation: psutil 5.4.8
    Uninstalling psutil-5.4.8:
      Successfully uninstalled psutil-5.4.8
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires requests~=2.23.0, but you have requests 2.28.1 which is incompatible.
google-colab 1.0.0 requires six~=1.15.0, but you have six 1.16.0 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.
Successfully installed PyYAML-6.0 aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 coolname-1.1.0 docstring-parser-0.14.1 frozenlist-1.3.0 fsspec-2022.5.0 future-0.18.2 kaleido-0.2.1 matbench-0.5 matminer-0.7.4 monty-2021.8.17 mosaicml-0.8.0 multidict-6.0.2 pint-0.18 psutil-5.9.1 py-cpuinfo-8.0.0 pyDeprecate-0.3.2 pymatgen-2022.0.17 pytorch-lightning-1.6.4 pytorch-ranger-0.1.1 requests-2.28.1 ruamel.yaml-0.17.21 ruamel.yaml.clib-0.2.6 scikit-learn-1.0 scipy-1.7.3 six-1.16.0 skorch-0.11.0 spglib-1.16.5 sympy-1.10.1 torch-optimizer-0.1.0 torchmetrics-0.7.3 uncertainties-3.1.7 xtal2png-0.7.0 yahp-0.1.1 yarl-1.7.2

Data type cannot be displayed: application/vnd.colab-display-data+json

Imports

[2]:
# %pip install skorch xtal2png matbench pytorch-lightning mosaicml

import composer.functional as cf
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.optim as optim
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from matbench.bench import MatbenchBenchmark
from skorch.callbacks import EarlyStopping
from skorch.classifier import NeuralNetBinaryClassifier
from torch import nn
from xtal2png.core import XtalConverter

# Set all random seeds as specified by Matbench
pl.seed_everything(18012019)
Global seed set to 18012019
[2]:
18012019

CNN Architecture

For the vanilla PyTorch model, the architecture of the convolutional layers is as follows:

self.convolutions = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=3, padding=1),  # (64, 64, 1) --> (64, 64, 8)
    nn.BatchNorm2d(8),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (31, 31, 8)

    nn.Conv2d(8, 16, kernel_size=3, padding=1),     # --> (31, 31, 16)
    nn.BatchNorm2d(16),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (15, 15, 16)

    nn.Conv2d(16, 32, kernel_size=3, padding=1),    # --> (15, 15, 32)
    nn.BatchNorm2d(32),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (7, 7, 32)
)

The full CNNClassifier class is defined below:

[ ]:
class CNNClassifier(nn.Module):
    def __init__(self, dropout: float = 0.5) -> None:
        super().__init__()
        self.convolutions = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.fullyconnected = nn.Sequential(
            nn.Linear(7 * 7 * 32, 512),
            nn.Mish(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, 256),
            nn.Mish(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(256, 256),
            nn.Mish(inplace=True),
            nn.Linear(256, 1),
            # No need for sigmoid here if using BCEWithLogitsLoss
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.convolutions(x)
        x = torch.flatten(x, 1)  # flatten all but batch dim
        x = self.fullyconnected(x)
        return x

To get slightly better generalization performance, BlurPool and squeeze-and-excite operations were applied to the model using Composer. BlurPool layers replace all max pooling layers, and squeeze-excite layers replace certain convolutional layers with channels above a threshold. Below is the full architecture of the model:

>>> model = CNNClassifier()
>>> composer.functional.apply_squeeze_excite(model, min_channels=16)
>>> composer.functional.apply_blurpool(model)
CNNClassifier(
  (convolutions): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish(inplace=True)
    (3): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Mish(inplace=True)
    (7): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): SqueezeExciteConv2d(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (se): SqueezeExcite2d(
        (pool_and_mlp): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): Flatten(start_dim=1, end_dim=-1)
          (2): Linear(in_features=32, out_features=64, bias=False)
          (3): ReLU()
          (4): Linear(in_features=64, out_features=32, bias=False)
          (5): Sigmoid()
        )
      )
    )
    (9): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Mish(inplace=True)
    (11): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fullyconnected): Sequential(
    (0): Linear(in_features=1568, out_features=512, bias=True)
    (1): Mish(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=512, out_features=256, bias=True)
    (4): Mish(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): Mish(inplace=True)
    (8): Linear(in_features=256, out_features=1, bias=True)
  )
)

Benchmark on Matbench Folds

Training is done using skorch to abstract away the typical training loop and need for DataLoaders. Images are preprocessed in the following manner: - Convert from PIL.Image to torch.Tensor and scale all pixel values to [0.0, 1.0]. - Compute the mean and standard deviation of scaled pixel values, then normalize to zero-mean, unit variance.

For normalization, note that the mean and standard deviations of pixel values are calculated separately per training fold. In each fold, the statistics of the training set are also used to normalize the respective test set.

Matbench will keep track of the final test results for the full test set, but because we’re only able to train and predict on structures with 52 sites or fewer, it would be good to keep track of how well we do on just the subset of the data with num_sites <= 52. To that end, a simple helper function is defined below.

[ ]:
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    f1_score,
    roc_auc_score,
)


def scoring(y_true, y_pred, fold):
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_pred)
    scores = {
        "Accuracy": acc,
        "Balanced Accuracy": bal_acc,
        "F1 Score": f1,
        "ROC AUC": roc_auc,
    }
    return pd.Series(scores, name=fold)


# track test_scores for subset of test_data with num_sites <= 52
subset_test_scores = []

The matbench benchmark is run below. To save all image representations created by the xtal2png function during crystal to image conversion, set the save flag to True.

[ ]:
xc = XtalConverter()
mb = MatbenchBenchmark(autoload=False, subset=["matbench_mp_is_metal"])
save = False  # xtal2png(data, save=save)

for task in mb.tasks:
    task.load()
    for fold in task.folds:
        # Get training data
        train_inputs, train_outputs = task.get_train_and_val_data(fold)

        # Train on structures with num_sites <= 52
        site_counter = lambda x: x.num_sites
        idx = train_inputs.apply(site_counter) <= 52
        X_train = train_inputs[idx]
        y_train = train_outputs[idx]

        # Convert crystal structures to images
        X_train = xc.xtal2png(X_train, save=save)

        # Convert PIL Images to torch.Tensor
        # Note that this scales from [0, 255] to [0.0, 1.0]
        X_train = [TF.to_tensor(img) for img in X_train]
        # Normalize images (subtract mean, divide by std)
        mean = torch.cat(X_train).mean()
        std = torch.cat(X_train).std()
        X_train = [TF.normalize(i, mean=mean, std=std) for i in X_train]

        # Change X from a list of tensors to a single tensor, and y from bool to float
        X_train = torch.stack(X_train)
        y_train = y_train.astype(np.float32)

        # Apply Composer methods to vanilla PyTorch classifier before training
        model = CNNClassifier()
        cf.apply_squeeze_excite(model, min_channels=16)
        cf.apply_blurpool(model)

        # Train and validate classifier with skorch
        net = NeuralNetBinaryClassifier(
            model,
            criterion=nn.BCEWithLogitsLoss,
            max_epochs=50,
            optimizer=optim.AdamW,
            optimizer__amsgrad=True,
            optimizer__lr=0.0005,
            callbacks=[EarlyStopping(patience=15)],
            device="cuda" if torch.cuda.is_available() else "cpu",
            batch_size=64,
        )
        net.fit(X_train, y_train)

        # Get test data and keep structures with num_sites <= 52
        test_inputs, test_outputs = task.get_test_data(fold, include_target=True)
        idx = test_inputs.apply(site_counter) <= 52
        X_test = test_inputs[idx]

        # Convert to images, preprocess using mean and std from training data
        X_test = xc.xtal2png(X_test, save=save)
        preprocess = T.Compose([T.ToTensor(), T.Normalize(mean, std),])
        X_test = torch.stack([preprocess(img) for img in X_test])

        # Predict on X_test
        # For structures with num_sites > 52, predict mode of y_train
        y_pred = net.predict(X_test)
        y_pred_full = np.empty(test_inputs.size)
        y_pred_full[idx] = y_pred
        y_pred_full[~idx] = y_train.mode().item()

        # Record data
        task.record(fold, y_pred_full)
        # Also record test scores on subset of data with num_sites <= 52
        subset_test_scores.append(scoring(test_outputs[idx], y_pred, f"fold-{fold}"))

# Save benchmark results
mb.to_file("new-results.json.gz")
2022-07-08 16:49:57 INFO     Initialized benchmark 'matbench_v0.1' with 1 tasks:
['matbench_mp_is_metal']
2022-07-08 16:49:57 INFO     Loading dataset 'matbench_mp_is_metal'...
2022-07-08 16:51:29 INFO     Dataset 'matbench_mp_is_metal loaded.
100%|██████████| 71602/71602 [06:18<00:00, 189.39it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-65 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        0.6073       0.6224        0.6698  14.7695
      2        0.5356       0.6208        0.7271  10.9460
      3        0.5118       0.6481        0.7283  10.8090
      4        0.5024       0.6467        0.7483  10.8945
      5        0.4809       0.6736        0.6756  10.7665
      6        0.4701       0.6716        0.6849  10.7760
      7        0.4650       0.6680        0.6896  10.8740
      8        0.4538       0.6705        0.6834  11.3155
      9        0.4480       0.6702        0.6814  10.9335
     10        0.4457       0.6669        0.6971  10.7680
     11        0.4374       0.6736        0.6725  10.9445
     12        0.4343       0.6710        0.6902  10.9260
     13        0.4292       0.6684        0.7142  10.8770
     14        0.4228       0.6750        0.6595  10.9695
     15        0.4233       0.6766        0.6654  12.5390
     16        0.4158       0.6770        0.6682  11.0055
     17        0.4113       0.6726        0.6832  11.0400
     18        0.4078       0.6740        0.6664  11.0470
     19        0.4066       0.6740        0.6757  11.0325
     20        0.4020       0.6709        0.6870  11.0223
     21        0.4000       0.6742        0.6787  10.8605
     22        0.3971       0.6716        0.6871  11.6585
     23        0.3959       0.6674        0.7022  11.3410
     24        0.3919       0.6824        0.6650  11.3810
     25        0.3893       0.6716        0.6919  11.0685
     26        0.3876       0.6728        0.6850  10.9170
     27        0.3837       0.6667        0.7144  11.5390
     28        0.3821       0.6794        0.6567  11.0540
     29        0.3792       0.6782        0.6833  10.9590
     30        0.3780       0.6736        0.7035  10.8730
     31        0.3755       0.6835        0.6717  10.9410
     32        0.3717       0.6709        0.7351  10.8490
     33        0.3728       0.6807        0.6963  10.8965
     34        0.3677       0.6865        0.6806  10.9120
     35        0.3664       0.6921        0.6754  10.9500
     36        0.3647       0.6887        0.6957  10.9740
     37        0.3606       0.6745        0.7035  10.8780
     38        0.3598       0.6930        0.6834  10.8200
     39        0.3583       0.7027        0.6818  10.9130
     40        0.3554       0.6891        0.6914  10.9930
     41        0.3526       0.7003        0.6865  10.9970
     42        0.3527       0.6967        0.7010  11.0350
Stopping since valid_loss has not improved in the last 15 epochs.
100%|██████████| 17925/17925 [01:38<00:00, 181.45it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (995 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
2022-07-08 17:08:07 INFO     Recorded fold matbench_mp_is_metal-0 successfully.
100%|██████████| 71644/71644 [06:18<00:00, 189.14it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        0.5929       0.6283        0.6789  11.1935
      2        0.5243       0.6296        0.6866  11.0007
      3        0.5031       0.6284        0.7424  10.8800
      4        0.4796       0.6359        0.7158  10.8855
      5        0.4669       0.6471        0.7155  10.8195
      6        0.4543       0.6410        0.7543  10.9535
      7        0.4490       0.6537        0.7031  10.8030
      8        0.4395       0.6616        0.7039  10.7705
      9        0.4372       0.6529        0.7108  10.9330
     10        0.4305       0.6606        0.7028  10.9995
     11        0.4266       0.6587        0.7050  10.9490
     12        0.4226       0.6655        0.7067  10.8910
     13        0.4202       0.6531        0.7094  10.8475
     14        0.4165       0.6544        0.7416  10.9615
     15        0.4120       0.6597        0.7153  10.9285
Stopping since valid_loss has not improved in the last 15 epochs.
100%|██████████| 17883/17883 [01:48<00:00, 164.18it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-27 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (956 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
2022-07-08 17:19:58 INFO     Recorded fold matbench_mp_is_metal-1 successfully.
100%|██████████| 71604/71604 [06:31<00:00, 182.87it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        0.5957       0.6205        0.6803  11.3735
      2        0.5308       0.6339        0.7092  11.3055
      3        0.5120       0.6208        0.7561  11.3290
      4        0.4911       0.6561        0.7014  11.4160
      5        0.4756       0.6729        0.6585  11.1370
      6        0.4651       0.6716        0.6809  11.1925
      7        0.4564       0.6707        0.7054  11.2350
      8        0.4511       0.6703        0.7149  11.3085
      9        0.4455       0.6683        0.7382  11.2265
     10        0.4437       0.6769        0.7111  11.2260
     11        0.4384       0.6573        0.7956  11.1905
     12        0.4339       0.6650        0.7453  11.2540
     13        0.4292       0.6642        0.7540  11.3160
     14        0.4252       0.6659        0.7473  11.3150
     15        0.4237       0.6657        0.7517  11.1695
     16        0.4182       0.6588        0.8085  11.3225
     17        0.4166       0.6729        0.7352  11.0960
     18        0.4108       0.6627        0.7942  11.6485
     19        0.4060       0.6704        0.7474  11.7510
Stopping since valid_loss has not improved in the last 15 epochs.
100%|██████████| 17923/17923 [01:42<00:00, 174.96it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-19 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1248 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
2022-07-08 17:32:45 INFO     Recorded fold matbench_mp_is_metal-2 successfully.
100%|██████████| 71599/71599 [06:27<00:00, 184.65it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1248 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        0.6078       0.6161        0.6687  11.1820
      2        0.5410       0.5843        0.7135  10.9360
      3        0.5274       0.6334        0.7075  10.8940
      4        0.5167       0.6174        0.7411  10.7925
      5        0.4938       0.6011        0.8099  10.8460
      6        0.4825       0.6122        0.7850  11.2060
      7        0.4710       0.6226        0.7418  10.8695
      8        0.4610       0.6394        0.7188  10.9045
      9        0.4559       0.6555        0.7164  10.8600
     10        0.4507       0.6457        0.7033  10.8250
     11        0.4424       0.6524        0.6752  10.7745
     12        0.4394       0.6547        0.6737  10.7125
     13        0.4307       0.6628        0.6637  10.8445
     14        0.4262       0.6777        0.6504  10.9635
     15        0.4207       0.6691        0.6339  10.9045
     16        0.4172       0.6732        0.6542  10.8685
     17        0.4156       0.6772        0.6430  10.8270
     18        0.4141       0.6867        0.6256  10.9465
     19        0.4075       0.6753        0.6330  10.8752
     20        0.4037       0.6712        0.6490  10.8860
     21        0.4029       0.6823        0.6322  10.9735
     22        0.4007       0.6848        0.6341  10.8885
     23        0.3975       0.6858        0.6181  10.8625
     24        0.3941       0.6945        0.6044  10.8530
     25        0.3926       0.6827        0.6352  11.1605
     26        0.3895       0.6971        0.5993  10.9810
     27        0.3855       0.6943        0.6160  11.4595
     28        0.3861       0.6887        0.6159  11.6460
     29        0.3838       0.7043        0.6016  11.9855
     30        0.3804       0.6940        0.6222  11.6535
     31        0.3779       0.7064        0.5930  10.8920
     32        0.3751       0.7138        0.5745  10.8160
     33        0.3726       0.7103        0.6030  10.8610
     34        0.3705       0.7078        0.5839  10.8755
     35        0.3684       0.7139        0.5945  10.8450
     36        0.3691       0.7154        0.5903  10.8120
     37        0.3669       0.7081        0.6114  10.8455
     38        0.3660       0.7164        0.5878  10.7310
     39        0.3655       0.7106        0.6092  10.8840
     40        0.3604       0.7124        0.6066  10.9505
     41        0.3601       0.7151        0.6166  10.8055
     42        0.3575       0.7146        0.5996  10.9835
     43        0.3546       0.7146        0.6096  10.9930
     44        0.3523       0.7220        0.5837  10.9275
     45        0.3505       0.7296        0.5714  10.9500
     46        0.3486       0.7165        0.6097  10.8555
     47        0.3481       0.7159        0.5908  10.8650
     48        0.3457       0.7199        0.6040  10.8525
     49        0.3429       0.7325        0.5719  10.9535
     50        0.3429       0.7290        0.5789  10.7510
100%|██████████| 17928/17928 [01:40<00:00, 178.62it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-65 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
2022-07-08 17:50:48 INFO     Recorded fold matbench_mp_is_metal-3 successfully.
100%|██████████| 71659/71659 [06:20<00:00, 188.16it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        0.5997       0.6197        0.6638  10.9605
      2        0.5519       0.6425        0.6501  10.7185
      3        0.5317       0.6442        0.6350  10.7315
      4        0.5290       0.6300        0.6813  10.7890
      5        0.5041       0.6557        0.6740  10.5690
      6        0.5022       0.6289        0.7428  10.4915
      7        0.4817       0.6442        0.7406  10.6470
      8        0.4694       0.6246        0.7653  10.7805
      9        0.4636       0.6217        0.7734  10.6635
     10        0.4534       0.6572        0.7254  10.4560
     11        0.4469       0.6469        0.7296  10.4670
     12        0.4401       0.6500        0.7351  10.5830
     13        0.4351       0.6544        0.7014  10.7160
     14        0.4309       0.6629        0.6672  10.7205
     15        0.4276       0.6634        0.6505  10.7445
     16        0.4262       0.6624        0.6455  10.4870
     17        0.4175       0.6636        0.6666  10.6385
Stopping since valid_loss has not improved in the last 15 epochs.
100%|██████████| 17868/17868 [01:41<00:00, 175.96it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-15 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1096 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
  warn(
2022-07-08 18:02:48 INFO     Recorded fold matbench_mp_is_metal-4 successfully.
2022-07-08 18:02:49 INFO     Successfully wrote MatbenchBenchmark to file 'new-results.json.gz'.
[ ]:
# Make sure our benchmark is valid
valid = mb.is_valid
print(f"is valid: {valid}")

# Check out how our algorithm is doing using scores
import pprint
pprint.pprint(mb.scores)

# Get some more info about the benchmark
mb.get_info()
is valid: True
{'matbench_mp_is_metal': {'accuracy': {'max': 0.8046838186787296,
                                       'mean': 0.78905512981546,
                                       'min': 0.7796729962776233,
                                       'std': 0.011047855221953238},
                          'balanced_accuracy': {'max': 0.7889602638667743,
                                                'mean': 0.7670721073159733,
                                                'min': 0.7535760749130695,
                                                'std': 0.015144418534814375},
                          'f1': {'max': 0.7484677468292978,
                                 'mean': 0.7104270714737543,
                                 'min': 0.6843673801137903,
                                 'std': 0.026518595116783353},
                          'rocauc': {'max': 0.7889602638667743,
                                     'mean': 0.7670721073159734,
                                     'min': 0.7535760749130695,
                                     'std': 0.015144418534814357}}}
2022-07-08 18:02:49 INFO
Matbench package 0.5 running benchmark 'matbench_v0.1'
        is complete: False
        is recorded: True
        is valid: True

Results:
        - 'matbench_mp_is_metal' ROCAUC mean: 0.7670721073159734

Finally, let’s also display the score summary for just the structures in the test sets with 52 sites or fewer.

[ ]:
from IPython.display import display

df = pd.concat(subset_test_scores, axis=1).T
display(df)
df.describe()
Accuracy Balanced Accuracy F1 Score ROC AUC
fold-0 0.793752 0.790358 0.761561 0.790358
fold-1 0.771627 0.766685 0.724315 0.766685
fold-2 0.773029 0.766690 0.715604 0.766690
fold-3 0.801818 0.798735 0.776358 0.798735
fold-4 0.769756 0.764638 0.710648 0.764638
Accuracy Balanced Accuracy F1 Score ROC AUC
count 5.000000 5.000000 5.000000 5.000000
mean 0.781996 0.777421 0.737697 0.777421
std 0.014738 0.015933 0.029423 0.015933
min 0.769756 0.764638 0.710648 0.764638
25% 0.771627 0.766685 0.715604 0.766685
50% 0.773029 0.766690 0.724315 0.766690
75% 0.793752 0.790358 0.761561 0.790358
max 0.801818 0.798735 0.776358 0.798735