Selecting Default Parameter Ranges via Materials Project (Conventional Unit Cell)
In this notebook, we’ll go over how we selected parameter ranges for some hyperparameters of xtal2png, namely the lower and upper bounds of lattice parameter lengths (\(a\), \(b\), and \(c\)), cell volume, and site pairwise distances.
After we’ve downloaded the data from Materials Project (or loaded it if running the notebook again), we’ll extract the parameters from each of the compounds and do some exploratory data analysis. Based on the analysis, we choose to use a quantile as an upper bound on the parameter ranges in order to get rid of outliers. By removing the highest 1% in each parameter category, we retain 96% of the data with fewer than 52 (primitive) sites. This gives us our final parameter ranges. Finally, we make publication-ready histogram figures and save these.
Setup
Let’s keep this notebook compatible both as a Google Colab notebook and running locally as a Jupyter notebook.
[1]:
from os import path
try:
import google.colab
IN_COLAB = True
base_dir = "/content/drive/MyDrive/"
except:
IN_COLAB = False
base_dir = path.join("data", "external")
[2]:
if IN_COLAB:
%pip install pymatgen kaleido
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pymatgen
Downloading pymatgen-2022.0.17.tar.gz (40.6 MB)
|████████████████████████████████| 40.6 MB 113 kB/s
Installing build dependencies ... done
Getting requirements to build wheel ... done
Installing backend dependencies ... done
Preparing wheel metadata ... done
Collecting kaleido
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
|████████████████████████████████| 79.9 MB 1.2 MB/s
Requirement already satisfied: sympy in /usr/local/lib/python3.7/dist-packages (from pymatgen) (1.7.1)
Collecting uncertainties>=3.1.4
Downloading uncertainties-3.1.7-py2.py3-none-any.whl (98 kB)
|████████████████████████████████| 98 kB 7.4 MB/s
Collecting spglib>=1.9.9.44
Downloading spglib-1.16.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)
|████████████████████████████████| 325 kB 53.4 MB/s
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from pymatgen) (4.1.1)
Collecting monty>=3.0.2
Downloading monty-2022.4.26-py3-none-any.whl (65 kB)
|████████████████████████████████| 65 kB 3.6 MB/s
Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from pymatgen) (0.8.9)
Requirement already satisfied: matplotlib>=1.5 in /usr/local/lib/python3.7/dist-packages (from pymatgen) (3.2.2)
Requirement already satisfied: numpy>=1.20.1 in /usr/local/lib/python3.7/dist-packages (from pymatgen) (1.21.6)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from pymatgen) (1.3.5)
Collecting ruamel.yaml>=0.15.6
Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
|████████████████████████████████| 109 kB 27.6 MB/s
Requirement already satisfied: palettable>=3.1.1 in /usr/local/lib/python3.7/dist-packages (from pymatgen) (3.3.0)
Collecting scipy>=1.5.0
Downloading scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (38.1 MB)
|████████████████████████████████| 38.1 MB 1.1 MB/s
Requirement already satisfied: plotly>=4.5.0 in /usr/local/lib/python3.7/dist-packages (from pymatgen) (5.5.0)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from pymatgen) (2.23.0)
Requirement already satisfied: networkx>=2.2 in /usr/local/lib/python3.7/dist-packages (from pymatgen) (2.6.3)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen) (1.4.3)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5->pymatgen) (0.11.0)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from plotly>=4.5.0->pymatgen) (1.15.0)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from plotly>=4.5.0->pymatgen) (8.0.1)
Collecting ruamel.yaml.clib>=0.2.6
Downloading ruamel.yaml.clib-0.2.6-cp37-cp37m-manylinux1_x86_64.whl (546 kB)
|████████████████████████████████| 546 kB 53.1 MB/s
Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from uncertainties>=3.1.4->pymatgen) (0.16.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->pymatgen) (2022.1)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->pymatgen) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->pymatgen) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->pymatgen) (2022.6.15)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->pymatgen) (3.0.4)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.7/dist-packages (from sympy->pymatgen) (1.2.1)
Building wheels for collected packages: pymatgen
Building wheel for pymatgen (PEP 517) ... done
Created wheel for pymatgen: filename=pymatgen-2022.0.17-cp37-cp37m-linux_x86_64.whl size=41840992 sha256=3c56511709664f4daaf3d914b0599a6a13d8725e6abe3e08343458224c6be7fd
Stored in directory: /root/.cache/pip/wheels/cf/f6/22/58a9be23c5f1b452770e02ff42047175eaf0f9c2f15219fc76
Successfully built pymatgen
Installing collected packages: ruamel.yaml.clib, uncertainties, spglib, scipy, ruamel.yaml, monty, pymatgen, kaleido
Attempting uninstall: scipy
Found existing installation: scipy 1.4.1
Uninstalling scipy-1.4.1:
Successfully uninstalled scipy-1.4.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.
Successfully installed kaleido-0.2.1 monty-2022.4.26 pymatgen-2022.0.17 ruamel.yaml-0.17.21 ruamel.yaml.clib-0.2.6 scipy-1.7.3 spglib-1.16.5 uncertainties-3.1.7
Data
Materials Project API Key
Get your Materials Project API key from a file that you store in your Google Drive (see below) or current directory (.), or specify it manually by setting the api_key variable in the form field or by running in a local miniconda command prompt with an environment activated that has pymatgen installed: pmg config --add PMG_MAPI_KEY <USER_API_KEY>, e.g. pmg config --add PMG_MAPI_KEY 123abc456def. For the latter option, see the
`pymatgen docs <https://pymatgen.org/usage.html#setting-the-pmg-mapi-key-in-the-config-file>`__.
The file named mp-api-key.json placed directly in your MyDrive folder or in your current directory would look like the following:
{
"API_KEY": "YOUR_API_KEY"
}
Note that this file is not necessary locally if you use the pmg config option above.
[3]:
import json
if IN_COLAB:
from google.colab import drive
drive.mount('/content/drive')
apikey_fpath = "/content/drive/MyDrive/mp-api-key.json"
try:
# https://stackoverflow.com/a/68442279/13697228
with open(apikey_fpath, 'r') as f:
json_data = json.load(f)
api_key = json_data["API_KEY"]
except Exception as e:
print(e)
api_key = "" #@param {type:"string"}
if api_key == "":
print(f"Couldn't load API key from {apikey_fpath}, and user-input API key is also empty.")
print(f"defaulting to user-input API key {api_key}")
pass
else:
api_key = None
print("make sure that you have run `pmg config --add PMG_MAPI_KEY <USER_API_KEY>`")
Mounted at /content/drive
Download
Let’s either download the data directly from Materials Project using the MPRester API or load the data that’s been saved previously to your device as structures.pkl in your base_dir.
[4]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from pymatgen.ext.matproj import MPRester
[5]:
pkl_path = path.join(base_dir, "structures.pkl")
try:
with open(pkl_path, "rb") as f:
results = pickle.load(f)
except Exception as e:
print(e)
with MPRester(api_key) as m:
results = m.query(
{"nelements": {"$gte": 2},
"nsites": {"$lte": 52}},
properties=["structure"],
)
with open(pkl_path, "wb") as f:
pickle.dump(results, f)
pass
Extract Lattice and Distances
From here, we’ll loop through each of the structures and grab the lattice parameter lengths (a, b, and c) as well as the cell volume (volume) and pairwise distance matrices between each of the sites for a given structure (distance).
[6]:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
a = []
b = []
c = []
volume = []
distance = []
num_sites = []
for s in tqdm(results):
spa = SpacegroupAnalyzer(s["structure"], symprec=0.1, angle_tolerance=5.0)
s = spa.get_conventional_standard_structure()
lattice = s.lattice
a.append(lattice.a)
b.append(lattice.b)
c.append(lattice.c)
volume.append(lattice.volume)
distance.append(s.distance_matrix)
num_sites.append(len(s.sites))
print('range of a is: ', min(a), '-', max(a))
print('range of b is: ', min(b), '-', max(b))
print('range of c is: ', min(c), '-', max(c))
print('range of volume is: ', min(volume), '-', max(volume))
print('range of num_sites is: ', min(num_sites), '-', max(num_sites))
dis_min_tmp = []
dis_max_tmp = []
for d in tqdm(range(len(distance))):
dis_min_tmp.append(min(distance[d][np.nonzero(distance[d])]))
dis_max_tmp.append(max(distance[d][np.nonzero(distance[d])]))
print('range of pair-wise distance is: ', min(dis_min_tmp), '-', max(dis_max_tmp))
52%|█████▏ | 55011/106127 [23:34<17:19, 49.17it/s]/usr/local/lib/python3.7/dist-packages/pymatgen/core/periodic_table.py:216: UserWarning: No electronegativity for He. Setting to NaN. This has no physical meaning, and is mainly done to avoid errors caused by the code expecting a float.
"avoid errors caused by the code expecting a float." % self.symbol
100%|██████████| 106127/106127 [46:26<00:00, 38.09it/s]
range of a is: 1.648906 - 72.590284
range of b is: 2.263836 - 83.00690244
range of c is: 2.131537 - 194.82034183895985
range of volume is: 15.216262193734085 - 56652.78402596729
range of num_sites is: 2 - 208
100%|██████████| 106127/106127 [00:27<00:00, 3900.18it/s]
range of pair-wise distance is: 0.7249349602879995 - 97.40907797736223
Exploratory Data Analysis
Setup
First, we store the data as a DataFrame to make it easier to visualize and apply operations to it.
[7]:
import plotly.express as px
df = pd.DataFrame(dict(a=a, b=b, c=c, volume=volume, min_distance=dis_min_tmp, max_distance=dis_max_tmp, num_sites=num_sites))
Min/Max
Next, we take a look at the minimum and maximum for each of the parameters.
[8]:
low_df = df.apply(np.min).drop("max_distance")
low_df
[8]:
a 1.648906
b 2.263836
c 2.131537
volume 15.216262
min_distance 0.724935
num_sites 2.000000
dtype: float64
[9]:
df.apply(np.max)
[9]:
a 72.590284
b 83.006902
c 194.820342
volume 56652.784026
min_distance 8.650650
max_distance 97.409078
num_sites 208.000000
dtype: float64
The maxima here can be pretty large, for example ~20000 cubic angstroms for the unit cell volume.
Histogram
Let’s take a quick look at one of the parameters involved, in this case the a lattice parameter length.
[10]:
import plotly.express as px
px.histogram(df, x="a", marginal="rug")
Clearly, there are outliers.
Quantile Maximum
Since these are some pretty large ranges that will inflate the round-off error of xtal2png, let’s see if we can filter some of these further by considering only up to a certain percentile (q quantile) for the relevant parameters.
[11]:
q = 0.99
df.apply(lambda a: np.quantile(a, 1 - q)).drop("max_distance")
[11]:
a 2.918263
b 2.903764
c 3.407548
volume 49.107078
min_distance 0.981721
num_sites 3.000000
dtype: float64
[12]:
upp_df = df.apply(lambda a: np.quantile(a, q))
upp_df = upp_df.drop("min_distance")
upp_df
[12]:
a 18.875420
b 18.596167
c 39.051768
volume 2800.497775
max_distance 19.332560
num_sites 114.000000
dtype: float64
Data Retention
The ranges are a lot more reasonable now. Let’s see how many compounds are retained after applying an upper bound filtering step using this upper quantile.
[13]:
qstr = " and ".join([f"{lbl} < @upp_df.{lbl}" for lbl in upp_df.index]) # .drop(["volume", "max_distance"])
qstr
[13]:
'a < @upp_df.a and b < @upp_df.b and c < @upp_df.c and volume < @upp_df.volume and max_distance < @upp_df.max_distance and num_sites < @upp_df.num_sites'
[14]:
filt_df = df.query(qstr)
filt_df
[14]:
| a | b | c | volume | min_distance | max_distance | num_sites | |
|---|---|---|---|---|---|---|---|
| 0 | 3.582386 | 3.582386 | 9.058916 | 116.257503 | 2.906562 | 5.189676 | 6 |
| 1 | 3.728104 | 3.728104 | 9.398536 | 130.627991 | 3.018739 | 5.388181 | 6 |
| 2 | 4.667758 | 4.667758 | 4.667758 | 101.700947 | 2.333879 | 4.042397 | 8 |
| 3 | 4.947200 | 4.947200 | 4.947200 | 121.081670 | 2.142200 | 3.498199 | 8 |
| 4 | 3.510234 | 3.510234 | 3.510234 | 43.252200 | 3.039952 | 3.039952 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 106121 | 5.177231 | 7.413476 | 10.297735 | 364.791151 | 0.990069 | 5.707257 | 36 |
| 106122 | 8.466314 | 8.603384 | 8.606069 | 469.512423 | 1.502355 | 6.098565 | 40 |
| 106124 | 5.188079 | 5.477565 | 7.317851 | 181.850271 | 0.997791 | 4.101902 | 18 |
| 106125 | 5.211291 | 7.406056 | 10.707537 | 370.569429 | 0.983919 | 5.761587 | 36 |
| 106126 | 5.397578 | 7.469631 | 10.283887 | 371.461746 | 0.978758 | 5.621589 | 36 |
101456 rows × 7 columns
[15]:
frac_retained = filt_df.shape[0] / df.shape[0]
print(f"{100*frac_retained:.1f}% retained")
95.6% retained
The ranges are much more reasonable now. Also, we have retained ~96% of the original compounds. The other 4% will be much less likely to be represented during generation (i.e. it’s been masked from the distribution), although as outliers to begin with it’s unclear if most generative models would generate these kinds of compounds anyway. This may be interesting as a topic of future study.
Selected Parameter Ranges
We’ll leave the lower bound as the minimum of all Materials Project entries (with fewer than 52 sites, that is). Alternatively, the lower bound could be set to 0 for each of these.
[16]:
low_df # i.e. minima
[16]:
a 1.648906
b 2.263836
c 2.131537
volume 15.216262
min_distance 0.724935
num_sites 2.000000
dtype: float64
[17]:
upp_df # based on `q` quantile
[17]:
a 18.875420
b 18.596167
c 39.051768
volume 2800.497775
max_distance 19.332560
num_sites 114.000000
dtype: float64
Plotting Histogram Distributions
Let’s plot and save the distributions for the parameters in upp_df. First, we define some helper functions to make the figures more compatible with academic publishing and to save them.
[18]:
from typing import Union
import plotly.graph_objs as go
from plotly import offline
def matplotlibify(
fig: go.Figure,
size: int = 24,
width_inches: Union[float, int] = 3.5,
height_inches: Union[float, int] = 3.5,
dpi: int = 142,
return_scale: bool = False,
) -> go.Figure:
"""Make plotly figures look more like matplotlib for academic publishing.
modified from: https://medium.com/swlh/formatting-a-plotly-figure-with-matplotlib-style-fa56ddd97539
Parameters
----------
fig : go.Figure
Plotly figure to be matplotlibified
size : int, optional
Font size for layout and axes, by default 24
width_inches : Union[float, int], optional
Width of matplotlib figure in inches, by default 3.5
height_inches : Union[float, int], optional
Height of matplotlib figure in Inches, by default 3.5
dpi : int, optional
Dots per inch (resolution) of matplotlib figure, by default 142. Leave as
default unless you're willing to verify nothing strange happens with the output.
return_scale : bool, optional
If true, then return `scale` which is a quantity that helps with creating a
high-resolution image at the specified absolute width and height in inches.
More specifically:
>>> width_default_px = fig.layout.width
>>> targ_dpi = 300
>>> scale = width_inches / (width_default_px / dpi) * (targ_dpi / dpi)
Feel free to ignore this parameter.
Returns
-------
fig : go.Figure
The matplotlibified plotly figure.
Examples
--------
>>> import plotly.express as px
>>> df = px.data.tips()
>>> fig = px.histogram(df, x="day")
>>> fig.show()
>>> fig = matplotlibify(fig, size=24, width_inches=3.5, height_inches=3.5, dpi=142)
>>> fig.show()
Note the difference between URL and URL.
"""
font_dict = dict(family="Arial", size=size, color="black")
# app = QApplication(sys.argv)
# screen = app.screens()[0]
# dpi = screen.physicalDotsPerInch()
# app.quit()
fig.update_layout(
font=font_dict,
plot_bgcolor="white",
width=width_inches * dpi,
height=height_inches * dpi,
margin=dict(r=40, t=20, b=10),
)
fig.update_yaxes(
showline=True, # add line at x=0
linecolor="black", # line color
linewidth=2.4, # line size
ticks="inside", # ticks outside axis
tickfont=font_dict, # tick label font
mirror="allticks", # add ticks to top/right axes
tickwidth=2.4, # tick width
tickcolor="black", # tick color
)
fig.update_xaxes(
showline=True,
showticklabels=True,
linecolor="black",
linewidth=2.4,
ticks="inside",
tickfont=font_dict,
mirror="allticks",
tickwidth=2.4,
tickcolor="black",
)
fig.update(layout_coloraxis_showscale=False)
width_default_px = fig.layout.width
targ_dpi = 300
scale = width_inches / (width_default_px / dpi) * (targ_dpi / dpi)
if return_scale:
return fig, scale
else:
return fig
def plot_and_save(fig_path, fig, mpl_kwargs={}, show=False, update_legend=False):
if show:
try:
fig.show()
except Exception as e:
print(e)
offline.plot(fig)
pass
fig.write_html(fig_path + ".html")
fig.to_json(fig_path + ".json")
if update_legend:
fig.update_layout(
legend=dict(
font=dict(size=16),
yanchor="bottom",
y=0.99,
xanchor="right",
x=0.99,
bgcolor="rgba(0,0,0,0)",
# orientation="h",
)
)
fig = matplotlibify(fig, **mpl_kwargs)
fig.write_image(fig_path + ".png")
From here, we just loop through the various parameters, plotting and saving histograms as we go. If running on Google Colab, these will be saved to the current directory which is temporary storage that will be purged after the session is closed.
[19]:
figs = []
for lbl in df.columns.drop("min_distance"):
fig = px.histogram(df, x=lbl, marginal="rug")
fig = matplotlibify(fig)
figs.append(fig)
plot_and_save(lbl+"_hist", fig, show=False)
Here’s an example of what the first figure looks like (compare with the histogram from an earlier section in terms of formatting).
[20]:
figs[0]