Classification on matbench_mp_is_metal using xtal2png representation of crystal structures
Description
In this notebook, a convolutional neural network is applied to the matbench_mp_is_metal classification task using `xtal2png <https://xtal2png.readthedocs.io/en/latest/>`__ representations of crystal structures. Crystal structures are encoded as grayscale PNG images, but because the conversion operations are restricted to structures with fewer than 52 sites, the network is only trained on structures with num_sites <= 52. For structures in the test set with more than 52 sites, we
simply predict the mode of the training outputs (i.e. the most common class in y_train, where X_train and y_train correspond to training inputs and labels respectively, with num_sites <= 52).
Benchmark Name
Matbench v0.1
Package Versions
Algorithm Description
A fairly simple CNN is created in vanilla PyTorch, very loosely based on the PyTorch implementation of AlexNet. Model surgery is then performed on the max-pooling and certain convolutional layers using MosaicML’s Composer library.
[1]:
%pip install matbench skorch xtal2png pytorch-lightning mosaicml
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting matbench
Downloading matbench-0.5-py3-none-any.whl (9.9 MB)
|████████████████████████████████| 9.9 MB 5.2 MB/s
Collecting skorch
Downloading skorch-0.11.0-py3-none-any.whl (155 kB)
|████████████████████████████████| 155 kB 58.5 MB/s
Collecting xtal2png
Downloading xtal2png-0.8.0-py3-none-any.whl (30 kB)
Collecting pytorch-lightning
Downloading pytorch_lightning-1.6.4-py3-none-any.whl (585 kB)
|████████████████████████████████| 585 kB 60.5 MB/s
Collecting mosaicml
Downloading mosaicml-0.8.0-py3-none-any.whl (548 kB)
|████████████████████████████████| 548 kB 54.0 MB/s
Collecting monty==2021.8.17
Downloading monty-2021.8.17-py3-none-any.whl (65 kB)
|████████████████████████████████| 65 kB 4.4 MB/s
Collecting matminer==0.7.4
Downloading matminer-0.7.4-py3-none-any.whl (1.4 MB)
|████████████████████████████████| 1.4 MB 46.3 MB/s
Collecting scikit-learn==1.0
Downloading scikit_learn-1.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (23.1 MB)
|████████████████████████████████| 23.1 MB 2.9 MB/s
Requirement already satisfied: pandas>=1.3.1 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (1.3.5)
Collecting pymatgen>=2022.0.11
Downloading pymatgen-2022.0.17.tar.gz (40.6 MB)
|████████████████████████████████| 40.6 MB 1.5 MB/s
Installing build dependencies ... done
Getting requirements to build wheel ... done
Installing backend dependencies ... done
Preparing wheel metadata ... done
Requirement already satisfied: tqdm>=4.62.0 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (4.64.0)
Requirement already satisfied: numpy>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (1.21.6)
Collecting requests>=2.26.0
Downloading requests-2.28.1-py3-none-any.whl (62 kB)
|████████████████████████████████| 62 kB 1.6 MB/s
Collecting six>=1.16.0
Downloading six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting future>=0.18.2
Downloading future-0.18.2.tar.gz (829 kB)
|████████████████████████████████| 829 kB 13.8 MB/s
Collecting pint>=0.17
Downloading Pint-0.18-py2.py3-none-any.whl (209 kB)
|████████████████████████████████| 209 kB 50.6 MB/s
Requirement already satisfied: jsonschema>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (4.3.3)
Collecting sympy>=1.8
Downloading sympy-1.10.1-py3-none-any.whl (6.4 MB)
|████████████████████████████████| 6.4 MB 45.2 MB/s
Requirement already satisfied: pymongo>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from matminer==0.7.4->matbench) (4.1.1)
Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn==1.0->matbench) (1.4.1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn==1.0->matbench) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn==1.0->matbench) (3.1.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (4.1.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (4.11.4)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (21.4.0)
Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (5.7.1)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.2.0->matminer==0.7.4->matbench) (0.18.1)
Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema>=3.2.0->matminer==0.7.4->matbench) (3.8.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.3.1->matminer==0.7.4->matbench) (2022.1)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.3.1->matminer==0.7.4->matbench) (2.8.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pint>=0.17->matminer==0.7.4->matbench) (21.3)
Collecting uncertainties>=3.1.4
Downloading uncertainties-3.1.7-py2.py3-none-any.whl (98 kB)
|████████████████████████████████| 98 kB 10.1 MB/s
Requirement already satisfied: matplotlib>=1.5 in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (3.2.2)
Requirement already satisfied: plotly>=4.5.0 in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (5.5.0)
Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (0.8.9)
Requirement already satisfied: networkx>=2.2 in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (2.6.3)
Requirement already satisfied: palettable>=3.1.1 in /usr/local/lib/python3.7/dist-packages (from pymatgen>=2022.0.11->matminer==0.7.4->matbench) (3.3.0)
Collecting ruamel.yaml>=0.15.6
Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
|████████████████████████████████| 109 kB 67.1 MB/s
Collecting scipy>=1.1.0
Downloading scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (38.1 MB)
|████████████████████████████████| 38.1 MB 1.1 MB/s
Collecting spglib>=1.9.9.44
Downloading spglib-1.16.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)
|████████████████████████████████| 325 kB 68.2 MB/s
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen>=2022.0.11->matminer==0.7.4->matbench) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen>=2022.0.11->matminer==0.7.4->matbench) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen>=2022.0.11->matminer==0.7.4->matbench) (1.4.3)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from plotly>=4.5.0->pymatgen>=2022.0.11->matminer==0.7.4->matbench) (8.0.1)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer==0.7.4->matbench) (2022.6.15)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer==0.7.4->matbench) (2.10)
Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer==0.7.4->matbench) (2.1.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.26.0->matminer==0.7.4->matbench) (1.24.3)
Collecting ruamel.yaml.clib>=0.2.6
Downloading ruamel.yaml.clib-0.2.6-cp37-cp37m-manylinux1_x86_64.whl (546 kB)
|████████████████████████████████| 546 kB 52.1 MB/s
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.7/dist-packages (from sympy>=1.8->matminer==0.7.4->matbench) (1.2.1)
Collecting xtal2png
Downloading xtal2png-0.7.1-py3-none-any.whl (30 kB)
Downloading xtal2png-0.7.0-py3-none-any.whl (26 kB)
Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from xtal2png) (7.1.2)
Collecting kaleido
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
|████████████████████████████████| 79.9 MB 107 kB/s
Requirement already satisfied: torch>=1.8.* in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.11.0+cu113)
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
|████████████████████████████████| 140 kB 48.5 MB/s
Collecting PyYAML>=5.4
Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
|████████████████████████████████| 596 kB 63.7 MB/s
Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (2.8.0)
Collecting pyDeprecate>=0.3.1
Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Requirement already satisfied: protobuf<=3.20.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (3.17.3)
Collecting torchmetrics>=0.4.1
Downloading torchmetrics-0.9.2-py3-none-any.whl (419 kB)
|████████████████████████████████| 419 kB 71.5 MB/s
Collecting aiohttp
Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
|████████████████████████████████| 1.1 MB 21.8 MB/s
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.6)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.37.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.8.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.6.1)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.1)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.1.0)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (57.4.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.46.3)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.7)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.35.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.4)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.1)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.2.0)
Collecting yahp<0.2,>=0.1.1
Downloading yahp-0.1.1-py3-none-any.whl (42 kB)
|████████████████████████████████| 42 kB 1.4 MB/s
Collecting coolname<2,>=1.1.0
Downloading coolname-1.1.0-py2.py3-none-any.whl (35 kB)
Collecting torchmetrics>=0.4.1
Downloading torchmetrics-0.7.3-py3-none-any.whl (398 kB)
|████████████████████████████████| 398 kB 58.3 MB/s
Collecting psutil<6,>=5.8.0
Downloading psutil-5.9.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)
|████████████████████████████████| 281 kB 68.2 MB/s
Collecting py-cpuinfo>=8.0.0
Downloading py-cpuinfo-8.0.0.tar.gz (99 kB)
|████████████████████████████████| 99 kB 12.3 MB/s
Requirement already satisfied: torchvision>=0.10.0 in /usr/local/lib/python3.7/dist-packages (from mosaicml) (0.12.0+cu113)
Collecting torch-optimizer<0.2,>=0.1.0
Downloading torch_optimizer-0.1.0-py3-none-any.whl (72 kB)
|████████████████████████████████| 72 kB 1.2 MB/s
Collecting pytorch-ranger>=0.1.1
Downloading pytorch_ranger-0.1.1-py3-none-any.whl (14 kB)
Collecting docstring-parser<=0.15,>=0.14.1
Downloading docstring_parser-0.14.1-py3-none-any.whl (33 kB)
Collecting aiosignal>=1.1.2
Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting multidict<7.0,>=4.5
Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB)
|████████████████████████████████| 94 kB 3.8 MB/s
Collecting asynctest==0.13.0
Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)
Collecting yarl<2.0,>=1.0
Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
|████████████████████████████████| 271 kB 70.9 MB/s
Collecting frozenlist>=1.1.1
Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
|████████████████████████████████| 144 kB 75.5 MB/s
Collecting async-timeout<5.0,>=4.0.0a3
Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Building wheels for collected packages: future, pymatgen, py-cpuinfo
Building wheel for future (setup.py) ... done
Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491070 sha256=a86608dcbf5d57e0bb04b7397a6ff0717fa5f2d6787aaec041f0994bc62417bf
Stored in directory: /root/.cache/pip/wheels/56/b0/fe/4410d17b32f1f0c3cf54cdfb2bc04d7b4b8f4ae377e2229ba0
Building wheel for pymatgen (PEP 517) ... done
Created wheel for pymatgen: filename=pymatgen-2022.0.17-cp37-cp37m-linux_x86_64.whl size=41841047 sha256=6c3a03400645661f0cc1ca1d5e49dc664a3ea401be740ccbe6f1b31506f36f26
Stored in directory: /root/.cache/pip/wheels/cf/f6/22/58a9be23c5f1b452770e02ff42047175eaf0f9c2f15219fc76
Building wheel for py-cpuinfo (setup.py) ... done
Created wheel for py-cpuinfo: filename=py_cpuinfo-8.0.0-py3-none-any.whl size=22257 sha256=3bff97a046261d768dd31b6a2153ad19fb4cb198325d57ec80db1a041ccda148
Stored in directory: /root/.cache/pip/wheels/d2/f1/1f/041add21dc9c4220157f1bd2bd6afe1f1a49524c3396b94401
Successfully built future pymatgen py-cpuinfo
Installing collected packages: six, ruamel.yaml.clib, requests, multidict, future, frozenlist, yarl, uncertainties, sympy, spglib, scipy, ruamel.yaml, monty, asynctest, async-timeout, aiosignal, scikit-learn, PyYAML, pytorch-ranger, pymatgen, pyDeprecate, pint, fsspec, docstring-parser, aiohttp, yahp, torchmetrics, torch-optimizer, py-cpuinfo, psutil, matminer, kaleido, coolname, xtal2png, skorch, pytorch-lightning, mosaicml, matbench
Attempting uninstall: six
Found existing installation: six 1.15.0
Uninstalling six-1.15.0:
Successfully uninstalled six-1.15.0
Attempting uninstall: requests
Found existing installation: requests 2.23.0
Uninstalling requests-2.23.0:
Successfully uninstalled requests-2.23.0
Attempting uninstall: future
Found existing installation: future 0.16.0
Uninstalling future-0.16.0:
Successfully uninstalled future-0.16.0
Attempting uninstall: sympy
Found existing installation: sympy 1.7.1
Uninstalling sympy-1.7.1:
Successfully uninstalled sympy-1.7.1
Attempting uninstall: scipy
Found existing installation: scipy 1.4.1
Uninstalling scipy-1.4.1:
Successfully uninstalled scipy-1.4.1
Attempting uninstall: scikit-learn
Found existing installation: scikit-learn 1.0.2
Uninstalling scikit-learn-1.0.2:
Successfully uninstalled scikit-learn-1.0.2
Attempting uninstall: PyYAML
Found existing installation: PyYAML 3.13
Uninstalling PyYAML-3.13:
Successfully uninstalled PyYAML-3.13
Attempting uninstall: psutil
Found existing installation: psutil 5.4.8
Uninstalling psutil-5.4.8:
Successfully uninstalled psutil-5.4.8
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires requests~=2.23.0, but you have requests 2.28.1 which is incompatible.
google-colab 1.0.0 requires six~=1.15.0, but you have six 1.16.0 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.
Successfully installed PyYAML-6.0 aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 coolname-1.1.0 docstring-parser-0.14.1 frozenlist-1.3.0 fsspec-2022.5.0 future-0.18.2 kaleido-0.2.1 matbench-0.5 matminer-0.7.4 monty-2021.8.17 mosaicml-0.8.0 multidict-6.0.2 pint-0.18 psutil-5.9.1 py-cpuinfo-8.0.0 pyDeprecate-0.3.2 pymatgen-2022.0.17 pytorch-lightning-1.6.4 pytorch-ranger-0.1.1 requests-2.28.1 ruamel.yaml-0.17.21 ruamel.yaml.clib-0.2.6 scikit-learn-1.0 scipy-1.7.3 six-1.16.0 skorch-0.11.0 spglib-1.16.5 sympy-1.10.1 torch-optimizer-0.1.0 torchmetrics-0.7.3 uncertainties-3.1.7 xtal2png-0.7.0 yahp-0.1.1 yarl-1.7.2
Data type cannot be displayed: application/vnd.colab-display-data+json
Imports
[2]:
# %pip install skorch xtal2png matbench pytorch-lightning mosaicml
import composer.functional as cf
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.optim as optim
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from matbench.bench import MatbenchBenchmark
from skorch.callbacks import EarlyStopping
from skorch.classifier import NeuralNetBinaryClassifier
from torch import nn
from xtal2png.core import XtalConverter
# Set all random seeds as specified by Matbench
pl.seed_everything(18012019)
Global seed set to 18012019
[2]:
18012019
CNN Architecture
For the vanilla PyTorch model, the architecture of the convolutional layers is as follows:
self.convolutions = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1), # (64, 64, 1) --> (64, 64, 8)
nn.BatchNorm2d(8),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # --> (31, 31, 8)
nn.Conv2d(8, 16, kernel_size=3, padding=1), # --> (31, 31, 16)
nn.BatchNorm2d(16),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # --> (15, 15, 16)
nn.Conv2d(16, 32, kernel_size=3, padding=1), # --> (15, 15, 32)
nn.BatchNorm2d(32),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # --> (7, 7, 32)
)
The full CNNClassifier class is defined below:
[ ]:
class CNNClassifier(nn.Module):
def __init__(self, dropout: float = 0.5) -> None:
super().__init__()
self.convolutions = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1),
nn.BatchNorm2d(8),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(8, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.fullyconnected = nn.Sequential(
nn.Linear(7 * 7 * 32, 512),
nn.Mish(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(512, 256),
nn.Mish(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(256, 256),
nn.Mish(inplace=True),
nn.Linear(256, 1),
# No need for sigmoid here if using BCEWithLogitsLoss
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.convolutions(x)
x = torch.flatten(x, 1) # flatten all but batch dim
x = self.fullyconnected(x)
return x
To get slightly better generalization performance, BlurPool and squeeze-and-excite operations were applied to the model using Composer. BlurPool layers replace all max pooling layers, and squeeze-excite layers replace certain convolutional layers with channels above a threshold. Below is the full architecture of the model:
>>> model = CNNClassifier()
>>> composer.functional.apply_squeeze_excite(model, min_channels=16)
>>> composer.functional.apply_blurpool(model)
CNNClassifier(
(convolutions): Sequential(
(0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Mish(inplace=True)
(3): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): Mish(inplace=True)
(7): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(8): SqueezeExciteConv2d(
(conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(se): SqueezeExcite2d(
(pool_and_mlp): Sequential(
(0): AdaptiveAvgPool2d(output_size=1)
(1): Flatten(start_dim=1, end_dim=-1)
(2): Linear(in_features=32, out_features=64, bias=False)
(3): ReLU()
(4): Linear(in_features=64, out_features=32, bias=False)
(5): Sigmoid()
)
)
)
(9): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): Mish(inplace=True)
(11): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fullyconnected): Sequential(
(0): Linear(in_features=1568, out_features=512, bias=True)
(1): Mish(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=512, out_features=256, bias=True)
(4): Mish(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=256, out_features=256, bias=True)
(7): Mish(inplace=True)
(8): Linear(in_features=256, out_features=1, bias=True)
)
)
Benchmark on Matbench Folds
Training is done using skorch to abstract away the typical training loop and need for DataLoaders. Images are preprocessed in the following manner: - Convert from PIL.Image to torch.Tensor and scale all pixel values to [0.0, 1.0]. - Compute the mean and standard deviation of scaled pixel values, then normalize to zero-mean, unit variance.
For normalization, note that the mean and standard deviations of pixel values are calculated separately per training fold. In each fold, the statistics of the training set are also used to normalize the respective test set.
Matbench will keep track of the final test results for the full test set, but because we’re only able to train and predict on structures with 52 sites or fewer, it would be good to keep track of how well we do on just the subset of the data with num_sites <= 52. To that end, a simple helper function is defined below.
[ ]:
from sklearn.metrics import (
accuracy_score,
balanced_accuracy_score,
f1_score,
roc_auc_score,
)
def scoring(y_true, y_pred, fold):
acc = accuracy_score(y_true, y_pred)
bal_acc = balanced_accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_pred)
scores = {
"Accuracy": acc,
"Balanced Accuracy": bal_acc,
"F1 Score": f1,
"ROC AUC": roc_auc,
}
return pd.Series(scores, name=fold)
# track test_scores for subset of test_data with num_sites <= 52
subset_test_scores = []
The matbench benchmark is run below. To save all image representations created by the xtal2png function during crystal to image conversion, set the save flag to True.
[ ]:
xc = XtalConverter()
mb = MatbenchBenchmark(autoload=False, subset=["matbench_mp_is_metal"])
save = False # xtal2png(data, save=save)
for task in mb.tasks:
task.load()
for fold in task.folds:
# Get training data
train_inputs, train_outputs = task.get_train_and_val_data(fold)
# Train on structures with num_sites <= 52
site_counter = lambda x: x.num_sites
idx = train_inputs.apply(site_counter) <= 52
X_train = train_inputs[idx]
y_train = train_outputs[idx]
# Convert crystal structures to images
X_train = xc.xtal2png(X_train, save=save)
# Convert PIL Images to torch.Tensor
# Note that this scales from [0, 255] to [0.0, 1.0]
X_train = [TF.to_tensor(img) for img in X_train]
# Normalize images (subtract mean, divide by std)
mean = torch.cat(X_train).mean()
std = torch.cat(X_train).std()
X_train = [TF.normalize(i, mean=mean, std=std) for i in X_train]
# Change X from a list of tensors to a single tensor, and y from bool to float
X_train = torch.stack(X_train)
y_train = y_train.astype(np.float32)
# Apply Composer methods to vanilla PyTorch classifier before training
model = CNNClassifier()
cf.apply_squeeze_excite(model, min_channels=16)
cf.apply_blurpool(model)
# Train and validate classifier with skorch
net = NeuralNetBinaryClassifier(
model,
criterion=nn.BCEWithLogitsLoss,
max_epochs=50,
optimizer=optim.AdamW,
optimizer__amsgrad=True,
optimizer__lr=0.0005,
callbacks=[EarlyStopping(patience=15)],
device="cuda" if torch.cuda.is_available() else "cpu",
batch_size=64,
)
net.fit(X_train, y_train)
# Get test data and keep structures with num_sites <= 52
test_inputs, test_outputs = task.get_test_data(fold, include_target=True)
idx = test_inputs.apply(site_counter) <= 52
X_test = test_inputs[idx]
# Convert to images, preprocess using mean and std from training data
X_test = xc.xtal2png(X_test, save=save)
preprocess = T.Compose([T.ToTensor(), T.Normalize(mean, std),])
X_test = torch.stack([preprocess(img) for img in X_test])
# Predict on X_test
# For structures with num_sites > 52, predict mode of y_train
y_pred = net.predict(X_test)
y_pred_full = np.empty(test_inputs.size)
y_pred_full[idx] = y_pred
y_pred_full[~idx] = y_train.mode().item()
# Record data
task.record(fold, y_pred_full)
# Also record test scores on subset of data with num_sites <= 52
subset_test_scores.append(scoring(test_outputs[idx], y_pred, f"fold-{fold}"))
# Save benchmark results
mb.to_file("new-results.json.gz")
2022-07-08 16:49:57 INFO Initialized benchmark 'matbench_v0.1' with 1 tasks:
['matbench_mp_is_metal']
2022-07-08 16:49:57 INFO Loading dataset 'matbench_mp_is_metal'...
2022-07-08 16:51:29 INFO Dataset 'matbench_mp_is_metal loaded.
100%|██████████| 71602/71602 [06:18<00:00, 189.39it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-65 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ -------
1 0.6073 0.6224 0.6698 14.7695
2 0.5356 0.6208 0.7271 10.9460
3 0.5118 0.6481 0.7283 10.8090
4 0.5024 0.6467 0.7483 10.8945
5 0.4809 0.6736 0.6756 10.7665
6 0.4701 0.6716 0.6849 10.7760
7 0.4650 0.6680 0.6896 10.8740
8 0.4538 0.6705 0.6834 11.3155
9 0.4480 0.6702 0.6814 10.9335
10 0.4457 0.6669 0.6971 10.7680
11 0.4374 0.6736 0.6725 10.9445
12 0.4343 0.6710 0.6902 10.9260
13 0.4292 0.6684 0.7142 10.8770
14 0.4228 0.6750 0.6595 10.9695
15 0.4233 0.6766 0.6654 12.5390
16 0.4158 0.6770 0.6682 11.0055
17 0.4113 0.6726 0.6832 11.0400
18 0.4078 0.6740 0.6664 11.0470
19 0.4066 0.6740 0.6757 11.0325
20 0.4020 0.6709 0.6870 11.0223
21 0.4000 0.6742 0.6787 10.8605
22 0.3971 0.6716 0.6871 11.6585
23 0.3959 0.6674 0.7022 11.3410
24 0.3919 0.6824 0.6650 11.3810
25 0.3893 0.6716 0.6919 11.0685
26 0.3876 0.6728 0.6850 10.9170
27 0.3837 0.6667 0.7144 11.5390
28 0.3821 0.6794 0.6567 11.0540
29 0.3792 0.6782 0.6833 10.9590
30 0.3780 0.6736 0.7035 10.8730
31 0.3755 0.6835 0.6717 10.9410
32 0.3717 0.6709 0.7351 10.8490
33 0.3728 0.6807 0.6963 10.8965
34 0.3677 0.6865 0.6806 10.9120
35 0.3664 0.6921 0.6754 10.9500
36 0.3647 0.6887 0.6957 10.9740
37 0.3606 0.6745 0.7035 10.8780
38 0.3598 0.6930 0.6834 10.8200
39 0.3583 0.7027 0.6818 10.9130
40 0.3554 0.6891 0.6914 10.9930
41 0.3526 0.7003 0.6865 10.9970
42 0.3527 0.6967 0.7010 11.0350
Stopping since valid_loss has not improved in the last 15 epochs.
100%|██████████| 17925/17925 [01:38<00:00, 181.45it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (995 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
2022-07-08 17:08:07 INFO Recorded fold matbench_mp_is_metal-0 successfully.
100%|██████████| 71644/71644 [06:18<00:00, 189.14it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ -------
1 0.5929 0.6283 0.6789 11.1935
2 0.5243 0.6296 0.6866 11.0007
3 0.5031 0.6284 0.7424 10.8800
4 0.4796 0.6359 0.7158 10.8855
5 0.4669 0.6471 0.7155 10.8195
6 0.4543 0.6410 0.7543 10.9535
7 0.4490 0.6537 0.7031 10.8030
8 0.4395 0.6616 0.7039 10.7705
9 0.4372 0.6529 0.7108 10.9330
10 0.4305 0.6606 0.7028 10.9995
11 0.4266 0.6587 0.7050 10.9490
12 0.4226 0.6655 0.7067 10.8910
13 0.4202 0.6531 0.7094 10.8475
14 0.4165 0.6544 0.7416 10.9615
15 0.4120 0.6597 0.7153 10.9285
Stopping since valid_loss has not improved in the last 15 epochs.
100%|██████████| 17883/17883 [01:48<00:00, 164.18it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-27 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (956 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
2022-07-08 17:19:58 INFO Recorded fold matbench_mp_is_metal-1 successfully.
100%|██████████| 71604/71604 [06:31<00:00, 182.87it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ -------
1 0.5957 0.6205 0.6803 11.3735
2 0.5308 0.6339 0.7092 11.3055
3 0.5120 0.6208 0.7561 11.3290
4 0.4911 0.6561 0.7014 11.4160
5 0.4756 0.6729 0.6585 11.1370
6 0.4651 0.6716 0.6809 11.1925
7 0.4564 0.6707 0.7054 11.2350
8 0.4511 0.6703 0.7149 11.3085
9 0.4455 0.6683 0.7382 11.2265
10 0.4437 0.6769 0.7111 11.2260
11 0.4384 0.6573 0.7956 11.1905
12 0.4339 0.6650 0.7453 11.2540
13 0.4292 0.6642 0.7540 11.3160
14 0.4252 0.6659 0.7473 11.3150
15 0.4237 0.6657 0.7517 11.1695
16 0.4182 0.6588 0.8085 11.3225
17 0.4166 0.6729 0.7352 11.0960
18 0.4108 0.6627 0.7942 11.6485
19 0.4060 0.6704 0.7474 11.7510
Stopping since valid_loss has not improved in the last 15 epochs.
100%|██████████| 17923/17923 [01:42<00:00, 174.96it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-19 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1248 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
2022-07-08 17:32:45 INFO Recorded fold matbench_mp_is_metal-2 successfully.
100%|██████████| 71599/71599 [06:27<00:00, 184.65it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1248 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ -------
1 0.6078 0.6161 0.6687 11.1820
2 0.5410 0.5843 0.7135 10.9360
3 0.5274 0.6334 0.7075 10.8940
4 0.5167 0.6174 0.7411 10.7925
5 0.4938 0.6011 0.8099 10.8460
6 0.4825 0.6122 0.7850 11.2060
7 0.4710 0.6226 0.7418 10.8695
8 0.4610 0.6394 0.7188 10.9045
9 0.4559 0.6555 0.7164 10.8600
10 0.4507 0.6457 0.7033 10.8250
11 0.4424 0.6524 0.6752 10.7745
12 0.4394 0.6547 0.6737 10.7125
13 0.4307 0.6628 0.6637 10.8445
14 0.4262 0.6777 0.6504 10.9635
15 0.4207 0.6691 0.6339 10.9045
16 0.4172 0.6732 0.6542 10.8685
17 0.4156 0.6772 0.6430 10.8270
18 0.4141 0.6867 0.6256 10.9465
19 0.4075 0.6753 0.6330 10.8752
20 0.4037 0.6712 0.6490 10.8860
21 0.4029 0.6823 0.6322 10.9735
22 0.4007 0.6848 0.6341 10.8885
23 0.3975 0.6858 0.6181 10.8625
24 0.3941 0.6945 0.6044 10.8530
25 0.3926 0.6827 0.6352 11.1605
26 0.3895 0.6971 0.5993 10.9810
27 0.3855 0.6943 0.6160 11.4595
28 0.3861 0.6887 0.6159 11.6460
29 0.3838 0.7043 0.6016 11.9855
30 0.3804 0.6940 0.6222 11.6535
31 0.3779 0.7064 0.5930 10.8920
32 0.3751 0.7138 0.5745 10.8160
33 0.3726 0.7103 0.6030 10.8610
34 0.3705 0.7078 0.5839 10.8755
35 0.3684 0.7139 0.5945 10.8450
36 0.3691 0.7154 0.5903 10.8120
37 0.3669 0.7081 0.6114 10.8455
38 0.3660 0.7164 0.5878 10.7310
39 0.3655 0.7106 0.6092 10.8840
40 0.3604 0.7124 0.6066 10.9505
41 0.3601 0.7151 0.6166 10.8055
42 0.3575 0.7146 0.5996 10.9835
43 0.3546 0.7146 0.6096 10.9930
44 0.3523 0.7220 0.5837 10.9275
45 0.3505 0.7296 0.5714 10.9500
46 0.3486 0.7165 0.6097 10.8555
47 0.3481 0.7159 0.5908 10.8650
48 0.3457 0.7199 0.6040 10.8525
49 0.3429 0.7325 0.5719 10.9535
50 0.3429 0.7290 0.5789 10.7510
100%|██████████| 17928/17928 [01:40<00:00, 178.62it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-65 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
2022-07-08 17:50:48 INFO Recorded fold matbench_mp_is_metal-3 successfully.
100%|██████████| 71659/71659 [06:20<00:00, 188.16it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-91 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1317 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ -------
1 0.5997 0.6197 0.6638 10.9605
2 0.5519 0.6425 0.6501 10.7185
3 0.5317 0.6442 0.6350 10.7315
4 0.5290 0.6300 0.6813 10.7890
5 0.5041 0.6557 0.6740 10.5690
6 0.5022 0.6289 0.7428 10.4915
7 0.4817 0.6442 0.7406 10.6470
8 0.4694 0.6246 0.7653 10.7805
9 0.4636 0.6217 0.7734 10.6635
10 0.4534 0.6572 0.7254 10.4560
11 0.4469 0.6469 0.7296 10.4670
12 0.4401 0.6500 0.7351 10.5830
13 0.4351 0.6544 0.7014 10.7160
14 0.4309 0.6629 0.6672 10.7205
15 0.4276 0.6634 0.6505 10.7445
16 0.4262 0.6624 0.6455 10.4870
17 0.4175 0.6636 0.6666 10.6385
Stopping since valid_loss has not improved in the last 15 epochs.
100%|██████████| 17868/17868 [01:41<00:00, 175.96it/s]
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:268: UserWarning: lower RGB value(s) OOB (-15 less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
c:\Users\khanm\anaconda3\envs\xtal2png\lib\site-packages\xtal2png\core.py:274: UserWarning: upper RGB value(s) OOB (1096 greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)
warn(
2022-07-08 18:02:48 INFO Recorded fold matbench_mp_is_metal-4 successfully.
2022-07-08 18:02:49 INFO Successfully wrote MatbenchBenchmark to file 'new-results.json.gz'.
[ ]:
# Make sure our benchmark is valid
valid = mb.is_valid
print(f"is valid: {valid}")
# Check out how our algorithm is doing using scores
import pprint
pprint.pprint(mb.scores)
# Get some more info about the benchmark
mb.get_info()
is valid: True
{'matbench_mp_is_metal': {'accuracy': {'max': 0.8046838186787296,
'mean': 0.78905512981546,
'min': 0.7796729962776233,
'std': 0.011047855221953238},
'balanced_accuracy': {'max': 0.7889602638667743,
'mean': 0.7670721073159733,
'min': 0.7535760749130695,
'std': 0.015144418534814375},
'f1': {'max': 0.7484677468292978,
'mean': 0.7104270714737543,
'min': 0.6843673801137903,
'std': 0.026518595116783353},
'rocauc': {'max': 0.7889602638667743,
'mean': 0.7670721073159734,
'min': 0.7535760749130695,
'std': 0.015144418534814357}}}
2022-07-08 18:02:49 INFO
Matbench package 0.5 running benchmark 'matbench_v0.1'
is complete: False
is recorded: True
is valid: True
Results:
- 'matbench_mp_is_metal' ROCAUC mean: 0.7670721073159734
Finally, let’s also display the score summary for just the structures in the test sets with 52 sites or fewer.
[ ]:
from IPython.display import display
df = pd.concat(subset_test_scores, axis=1).T
display(df)
df.describe()
| Accuracy | Balanced Accuracy | F1 Score | ROC AUC | |
|---|---|---|---|---|
| fold-0 | 0.793752 | 0.790358 | 0.761561 | 0.790358 |
| fold-1 | 0.771627 | 0.766685 | 0.724315 | 0.766685 |
| fold-2 | 0.773029 | 0.766690 | 0.715604 | 0.766690 |
| fold-3 | 0.801818 | 0.798735 | 0.776358 | 0.798735 |
| fold-4 | 0.769756 | 0.764638 | 0.710648 | 0.764638 |
| Accuracy | Balanced Accuracy | F1 Score | ROC AUC | |
|---|---|---|---|---|
| count | 5.000000 | 5.000000 | 5.000000 | 5.000000 |
| mean | 0.781996 | 0.777421 | 0.737697 | 0.777421 |
| std | 0.014738 | 0.015933 | 0.029423 | 0.015933 |
| min | 0.769756 | 0.764638 | 0.710648 | 0.764638 |
| 25% | 0.771627 | 0.766685 | 0.715604 | 0.766685 |
| 50% | 0.773029 | 0.766690 | 0.724315 | 0.766690 |
| 75% | 0.793752 | 0.790358 | 0.761561 | 0.790358 |
| max | 0.801818 | 0.798735 | 0.776358 | 0.798735 |