Source code for xtal2png.core

"""Crystal to PNG conversion core functions and scripts."""
import logging
import sys
from functools import lru_cache
from itertools import chain

# from itertools import zip_longest
from os import PathLike, path
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from uuid import uuid4
from warnings import warn

import numpy as np
import numpy.typing as npt
import pandas as pd
from element_coder import decode_many, encode_many
from element_coder.utils import get_range
from numpy.typing import NDArray
from PIL import Image
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from tqdm import tqdm

from xtal2png.utils.data import (
    _get_space_group,
    dummy_structures,
    element_wise_scaler,
    element_wise_unscaler,
    get_image_mode,
    rgb_scaler,
    rgb_unscaler,
    unit_cell_converter,
)

# from sklearn.preprocessing import MinMaxScaler


__author__ = "sgbaird"
__copyright__ = "sgbaird"
__license__ = "MIT"

_logger = logging.getLogger(__name__)


# ---- Python API ----
# The functions defined in this section can be imported by users in their
# Python scripts/interactive interpreter, e.g. via `from xtal2png.core import
# XtalConverter`, when using this Python module as a library.


ATOM_ID = 1
FRAC_ID = 2
A_ID = 3
B_ID = 4
C_ID = 5
ANGLES_ID = 6
NUM_SITES_ID = 7
SPACE_GROUP_ID = 8
DISTANCE_ID = 9

ATOM_KEY = "atom"
FRAC_KEY = "frac"
A_KEY = "a"
B_KEY = "b"
C_KEY = "c"
ANGLES_KEY = "angles"
NUM_SITES_KEY = "num_sites"
SPACE_GROUP_KEY = "space_group"
DISTANCE_KEY = "distance"
LOWER_TRI_KEY = "lower_tri"

SUPPORTED_MASK_KEYS = [
    ATOM_KEY,
    FRAC_KEY,
    A_KEY,
    B_KEY,
    C_KEY,
    ANGLES_KEY,
    NUM_SITES_KEY,
    SPACE_GROUP_KEY,
    DISTANCE_KEY,
    LOWER_TRI_KEY,
]


[docs]def construct_save_name(s: Structure) -> str: """Construct savename based on formula, space group, and a uid.""" save_name = f"{s.formula.replace(' ', '')},space-group={_get_space_group(s)},uid={str(uuid4())[0:4]}" # noqa: E501 return save_name
@lru_cache(maxsize=None) def _element_encoding_range_cached(elements, encoding_type): return get_range(elements, encoding_type)
[docs]class XtalConverter: """Convert between pymatgen Structure object and PNG-encoded representation. Note that if you modify the ranges to be different than their defaults, you have effectively created a new representation. In the future, anytime you use :func:`XtalConverter` with a dataset that used modified range(s), you will need to specify the same ranges; otherwise, your data will be decoded (unscaled) incorrectly. In other words, make sure you're using the same :func:`XtalConverter` object for both encoding and decoding. We encourage you to use the default ranges, which were carefully selected based on a trade-off between keeping the range as low as possible and trying to incorporate as much of what's been observed on Materials Project with no more than 52 sites. For more details, see the corresponding notebook in the ``notebooks`` directory: https://github.com/sparks-baird/xtal2png/tree/main/notebooks Parameters ---------- atom_range : Tuple[int, int], optional Expected range for atomic number, by default (1, 118) frac_range : Tuple[float, float], optional Expected range for fractional coordinates, by default (0.0, 1.0) a_range : Tuple[float, float], optional Expected range for lattice parameter length a, by default (2.0, 15.3) b_range : Tuple[float, float], optional Expected range for lattice parameter length b, by default (2.0, 15.0) c_range : Tuple[float, float], optional Expected range for lattice parameter length c, by default (2.0, 36.0) angles_range : Tuple[float, float], optional Expected range for lattice parameter angles, by default (0.0, 180.0) num_sites_range : Tuple[float, float], optional Expected range for unit cell num_sites, by default (0, 52) space_group_range : Tuple[int, int], optional Expected range for space group numbers, by default (1, 230) distance_range : Tuple[float, float], optional Expected range for pairwise distances between sites, by default (0.0, 25.0) max_sites : int, optional Maximum number of sites to accomodate in encoding, by default 52 save_dir : Union[str, 'PathLike[str]'] Directory to save PNG files via :func:`xtal2png`, by default path.join("data", "interim") symprec : Union[float, Tuple[float, float]], optional The symmetry precision to use when decoding `pymatgen` structures via :func:`pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure`. If specified as a tuple, then ``symprec[0]`` applies to encoding and ``symprec[1]`` applies to decoding. By default 0.1. angle_tolerance : Union[float, int, Tuple[float, float], Tuple[int, int]], optional The angle tolerance (degrees) to use when decoding `pymatgen` structures via :func:`pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure`. If specified as a tuple, then ``angle_tolerance[0]`` applies to encoding and ``angle_tolerance[1]`` applies to decoding. By default 5.0. encode_cell_type : Optional[str], optional Encode structures as-is (None), or after applying a certain tranformation. Uses ``symprec`` if ``symprec`` is of type float, else uses ``symprec[0]`` if ``symprec`` is of type tuple. Same applies for ``angle_tolerance``. "primitive_standard", "conventional_standard", "refined", "reduced", and None. By default None decode_cell_type : Optional[str], optional Decode structures as-is (None), or after applying a certain tranformation. Uses ``symprec`` if ``symprec`` is of type float, else uses ``symprec[0]`` if ``symprec`` is of type tuple. Same applies for ``angle_tolerance``. "primitive_standard", "conventional_standard", "refined", "reduced", and None. By default None relax_on_decode: bool, optional Use m3gnet to relax the decoded crystal structures. channels : int, optional Number of channels, a positive integer. Typically choices would be 1 (grayscale) or 3 (RGB), and are the only compatible choices when using :func:`XtalConverter().xtal2png` and :func:`XtalConverter().png2xtal`. For positive integers other than 1 or 3, use :func:`XtalConverter().structures_to_arrays` and :func:`XtalConverter().arrays_to_structures` directly instead. verbose: bool, optional Whether to print verbose debugging information or not. element_encoding : str How to encode the element. Can be one of `element_coder.data.coding_data._PROPERTY_KEYS` (e.g., `mod_pettifor`, `atomic`, `pettifor`, `X`). Defaults to `atomic` (which encodes elements as atomic numbers). element_decoding_metric: Union[str, callable] Metric to measure distance between (noisy) input encoding and tabulated encodings. If a string, the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'jensenshannon', 'kulsinski', 'kulczynski1', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'. Defaults to "euclidean". mask_types : List[str], optional List of information types to mask out (assign as 0) from the array/image. values are "atom", "frac", "a", "b", "c", "angles", "num_sites", "space_group", "distance", "lower_tri", and None. If None, then no masking is applied. If "lower_tri" is present, then zeros out the lower triangle. By default, None. Examples -------- >>> xc = XtalConverter() >>> xc = XtalConverter(atom_range=(0, 83)) # assumes no radioactive elements in data """ def __init__( self, atom_range: Union[Tuple[int, int], npt.ArrayLike] = (1, 118), frac_range: Tuple[float, float] = (0.0, 1.0), a_range: Tuple[float, float] = (2.0, 15.3), b_range: Tuple[float, float] = (2.0, 15.0), c_range: Tuple[float, float] = (2.0, 36.0), angles_range: Tuple[float, float] = (0.0, 180.0), num_sites_range: Tuple[float, float] = (0, 52), space_group_range: Tuple[int, int] = (1, 230), distance_range: Tuple[float, float] = (0.0, 18.0), max_sites: int = 52, save_dir: Union[str, "PathLike[str]"] = path.join("data", "preprocessed"), symprec: Union[float, Tuple[float, float]] = 0.1, angle_tolerance: Union[float, int, Tuple[float, float], Tuple[int, int]] = 5.0, encode_cell_type: Optional[str] = None, decode_cell_type: Optional[str] = None, relax_on_decode: bool = False, channels: int = 1, verbose: bool = True, element_encoding: str = "atomic", element_decoding_metric: Union[str, Callable] = "euclidean", mask_types: List[str] = [], ): """Instantiate an XtalConverter object with desired ranges and ``max_sites``.""" self.atom_range = atom_range self.frac_range = frac_range self.a_range = a_range self.b_range = b_range self.c_range = c_range self.angles_range = angles_range self.num_sites_range = num_sites_range self.space_group_range = space_group_range self.distance_range = distance_range self.max_sites = max_sites self.save_dir = save_dir self.element_encoding = element_encoding self.element_decoding_metric = element_decoding_metric if isinstance(symprec, (float, int)): self.encode_symprec = symprec self.decode_symprec = symprec elif isinstance(symprec, tuple): self.encode_symprec = symprec[0] self.decode_symprec = symprec[1] if isinstance(angle_tolerance, (float, int)): self.encode_angle_tolerance = angle_tolerance self.decode_angle_tolerance = angle_tolerance elif isinstance(angle_tolerance, tuple): self.encode_angle_tolerance = angle_tolerance[0] self.decode_angle_tolerance = angle_tolerance[1] self.encode_cell_type = encode_cell_type self.decode_cell_type = decode_cell_type self.relax_on_decode = relax_on_decode self.channels = channels self.verbose = verbose if self.verbose: self.tqdm_if_verbose = tqdm else: self.tqdm_if_verbose = lambda x: x unsupported_mask_types = np.setdiff1d(mask_types, SUPPORTED_MASK_KEYS).tolist() if unsupported_mask_types != []: raise ValueError( f"{unsupported_mask_types} is/are not a valid mask type. Expected one of {SUPPORTED_MASK_KEYS}. Received {mask_types}" # noqa: E501 ) self.mask_types = mask_types Path(save_dir).mkdir(exist_ok=True, parents=True) @property def _element_encoding_range(self): # We do *not* use cached_property as the result might change as # the users calls fit. However, we still want to cache, as we reuse # the result of this method for both encoding and decoding. return _element_encoding_range_cached(self.atom_range, self.element_encoding)
[docs] def xtal2png( self, structures: List[Union[Structure, str, "PathLike[str]"]], show: bool = False, save: bool = True, ): """Encode crystal (via CIF filepath or Structure object) as PNG file. Parameters ---------- structures : List[Union[Structure, str, PathLike[str]]] pymatgen Structure objects or path to CIF files. show : bool, optional Whether to display the PNG-encoded file, by default False save : bool, optional Whether to save the PNG-encoded file, by default True Returns ------- imgs : List[Image.Image] PIL images that (approximately) encode the supplied crystal structures. Raises ------ ValueError structures should be of same datatype ValueError structures should be of same datatype ValueError structures should be of type `str`, `os.PathLike` or `pymatgen.core.structure.Structure` Examples -------- >>> coords = [[0, 0, 0], [0.75,0.5,0.75]] >>> lattice = Lattice.from_parameters( ... a=3.84, b=3.84, c=3.84, alpha=120, beta=90, gamma=60 ... ) >>> structures = [Structure(lattice, ["Si", "Si"], coords), ... Structure(lattice, ["Ni", "Ni"], coords)] >>> xc = XtalConverter() >>> xc.xtal2png(structures, show=False, save=True) """ self.savenames, S = self.process_filepaths_or_structures(structures) # convert structures to 3D NumPy Matrices self.data, self.id_data, self.id_mapper = self.structures_to_arrays(S) mn, mx = self.data.min(), self.data.max() if mn < 0: warn( f"lower RGB value(s) OOB ({mn} less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)" # noqa: E501 ) # noqa self.data[self.data < 0] = 0 if mx > 255: warn( f"upper RGB value(s) OOB ({mx} greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)" # noqa: E501 ) # noqa self.data[self.data > 255] = 255 self.data = self.data.astype(np.uint8) # convert to PNG images. Save and/or show, if applicable imgs: List[Image.Image] = [] for d, save_name in zip(self.data, self.savenames): mode = get_image_mode(d) d = np.squeeze(d) if mode == "RGB": d = d.transpose(1, 2, 0) img = Image.fromarray(d, mode=mode) imgs.append(img) if save: savepath = path.join(self.save_dir, save_name + ".png") img.save(savepath) if show: img.show() return imgs
[docs] def fit( self, structures: List[Union[Structure, str, "PathLike[str]"]], y=None, fit_quantiles: Tuple[float, float] = (0.00, 0.99), verbose: Optional[bool] = None, ): """Find optimal range parameters for encoding crystal structures. Parameters ---------- structures : List[Union[Structure, str, "PathLike[str]"]] List of pymatgen Structure objects. y : NoneType, optional No effect, for compatibility only, by default None fit_quantiles : Tuple[float,float], optional The lower and upper quantiles to use for fitting ranges to the data, by default (0.00, 0.99) verbose : Optional[bool], optional Whether to print information about the fitted ranges. If None, then defaults to ``self.verbose``. By default None Examples -------- >>> fit(structures, , y=None, fit_quantiles=(0.00, 0.99), verbose=None, ) OUTPUT """ verbose = self.verbose if verbose is None else verbose _, S = self.process_filepaths_or_structures(structures) # TODO: deal with arbitrary site_properties atomic_numbers = [s.atomic_numbers for s in S] a = [s.lattice.a for s in S] b = [s.lattice.b for s in S] c = [s.lattice.c for s in S] space_group = [_get_space_group(s) for s in S] num_sites = [s.num_sites for s in S] distance = [s.distance_matrix for s in S] if verbose: print("range of atomic_numbers is: ", min(a), "-", max(a)) print("range of a is: ", min(a), "-", max(a)) print("range of b is: ", min(b), "-", max(b)) print("range of c is: ", min(c), "-", max(c)) print("range of space_group is: ", min(space_group), "-", max(space_group)) print("range of num_sites is: ", min(num_sites), "-", max(num_sites)) dis_min_tmp = [] dis_max_tmp = [] for d in tqdm(range(len(distance))): dis_min_tmp.append(min(distance[d][np.nonzero(distance[d])])) dis_max_tmp.append(max(distance[d][np.nonzero(distance[d])])) atoms = np.array(atomic_numbers, dtype="object") uniq_atoms = np.unique(list(chain(*atomic_numbers))) self._atom_range = [np.min(uniq_atoms), np.max(uniq_atoms)] self.atom_range = atoms self.space_group_range = (np.min(space_group), np.max(space_group)) self.num_sites_range = (np.min(num_sites), np.max(num_sites)) self.num_sites = np.max(num_sites) df = pd.DataFrame( dict( a=a, b=b, c=c, min_distance=dis_min_tmp, max_distance=dis_max_tmp, ) ) low_quantile, upp_quantile = fit_quantiles low_df = ( df.apply(lambda a: np.quantile(a, low_quantile)) .drop(["max_distance"]) .rename(index={"min_distance": "distance"}) ) upp_df = ( df.apply(lambda a: np.quantile(a, upp_quantile)) .drop(["min_distance"]) .rename(index={"max_distance": "distance"}) ) low_df.name = "low" upp_df.name = "upp" range_df = pd.concat((low_df, upp_df), axis=1) for name, bounds in range_df.iterrows(): setattr(self, name + "_range", tuple(bounds))
[docs] def process_filepaths_or_structures( self, structures: List[Union[Structure, str, "PathLike[str]"]], ) -> Tuple[List[str], List[Structure]]: """Extract (or create) save names and convert/passthrough the structures. Parameters ---------- structures : Union[PathLike, Structure] List of filepaths or list of structures to be processed. Returns ------- savenames : List[str] Save names of the files if filepaths are passed, otherwise some relatively unique names (due to 4 random characters being appended at the end) for each structure. See ``construct_save_name``. S : List[Structure] Processed structures. Raises ------ ValueError "structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" ValueError "structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" ValueError "structures should be of type `str`, `os.PathLike` or `pymatgen.core.structure.Structure`, not {type(structures[i])} (entry {i})" Examples -------- >>> savenames, structures = process_filepaths_or_structures(structures) """ savenames: List[str] = [] first_is_structure = isinstance(structures[0], Structure) for i, s in enumerate(structures): if isinstance(s, str) or isinstance(s, PathLike): if first_is_structure: raise ValueError( f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa: E501 ) structures[i] = Structure.from_file(s) savenames.append(Path(str(s)).stem) elif isinstance(s, Structure): if not first_is_structure: raise ValueError( f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa ) structures[i] = s savenames.append(construct_save_name(s)) else: raise ValueError( f"structures should be of type `str`, `os.PathLike` or `pymatgen.core.structure.Structure`, not {type(structures[i])} (entry {i})" # noqa ) for i, s in enumerate(structures): assert isinstance( s, Structure ), f"structures[{i}]: {type(s)}, expected: Structure" assert not isinstance(s, str) and not isinstance(s, PathLike) if not s.is_ordered: raise ValueError( "xtal2png does not support disordered structures. " "Your input structure seems to contain partial occupancies. " "Please resolve those and try again." ) return savenames, structures # type: ignore
[docs] def png2xtal( self, images: List[Union[Image.Image, "PathLike"]], save: bool = False ) -> List[Structure]: """Decode PNG files as Structure objects. Parameters ---------- images : List[Union[Image.Image, 'PathLike']] PIL images that (approximately) encode crystal structures. Examples -------- >>> from xtal2png.utils.data import example_structures >>> xc = XtalConverter() >>> imgs = xc.xtal2png(example_structures) >>> xc.png2xtal(imgs) OUTPUT """ if not isinstance(images, list): raise ValueError( f"images (or filepaths) should be of type list, received {type(images)}" ) data_tmp = [] if self.channels == 1: mode = "L" elif self.channels == 3: mode = "RGB" else: raise ValueError( f"expected grayscale (1-channel) or RGB (3-channels) image, but got {self.channels}-channels. Either set channels to 1 or 3 or use xc.structures_to_arrays and xc.arrays_to_structures directly instead of xc.xtal2png and xc.png2xtal" # noqa: E501 ) for img in images: if isinstance(img, str): # load image from file with Image.open(img).convert(mode) as im: arr = np.asarray(im) elif isinstance(img, Image.Image): arr = np.asarray(img.convert(mode)) if mode == "RGB": arr = arr.transpose(2, 0, 1) data_tmp.append(arr) data = np.stack(data_tmp, axis=0) if mode == "L": data = np.expand_dims(data, 1) S = self.arrays_to_structures(data) if save: for s in self.tqdm_if_verbose(S): fpath = path.join(self.save_dir, construct_save_name(s) + ".cif") CifWriter( s, symprec=self.decode_symprec, angle_tolerance=self.decode_angle_tolerance, ).write_file(fpath) return S
# unscale values
[docs] def structures_to_arrays( self, structures: Sequence[Structure], rgb_scaling=True, ) -> Tuple[NDArray, NDArray, Dict[str, int]]: """Convert pymatgen Structure to scaled 3D array of crystallographic info. ``atomic_numbers`` and ``distance_matrix`` get padded or cropped as appropriate, as these depend on the number of sites in the structure. Parameters ---------- structures : Sequence[Structure] Sequence (e.g. list) of pymatgen Structure object(s) rgb_scaling : Whether to scale the arrays to RGB values (0-255), otherwise assume scaled between (0-1), by default True. Returns ------- data : ArrayLike RGB-scaled arrays with first dimension corresponding to each crystal structure. id_data : ArrayLike Same shape as ``data``, except one-hot encoded to distinguish between the various types of information contained in ``data``. See ``id_mapper`` for the "legend" for this data. id_mapper : ArrayLike Dictionary containing the legend/key between the names of the blocks and the corresponding numbers in ``id_data``. Raises ------ ValueError "`structures` should be a list of pymatgen Structure(s)" ValueError "crystal supplied with {n_sites} sites, which is more than {self.max_sites} sites. Remove crystal or increase `max_sites`." ValueError "len(atomic_numbers) {n_sites} and distance_matrix.shape[0] {s.distance_matrix.shape[0]} do not match" Examples -------- >>> xc = XtalConverter() >>> data, id_data, id_mapper = xc.structures_to_arrays(structures) OUTPUT """ for s in structures: if not isinstance(s, Structure): raise ValueError( "`structures` should be a list of pymatgen Structure(s)" ) if not s.is_ordered: raise ValueError( "xtal2png does not support disordered structures. " "Your input structure seems to contain partial occupancies. " "Please resolve those and try again." ) # extract crystallographic information element_encoding: List[List[int]] = [] frac_coords_tmp: List[NDArray] = [] latt_a: List[float] = [] latt_b: List[float] = [] latt_c: List[float] = [] angles: List[List[float]] = [] num_sites: List[float] = [] space_group: List[int] = [] distance_matrix_tmp: List[NDArray[np.float64]] = [] for s in self.tqdm_if_verbose(structures): s = unit_cell_converter( s, self.encode_cell_type, symprec=self.encode_symprec, angle_tolerance=self.encode_angle_tolerance, ) # noqa: E501 n_sites = len(s.atomic_numbers) if n_sites > self.max_sites: raise ValueError( f"crystal supplied with {n_sites} sites, which is more than {self.max_sites} sites. Remove the offending crystal(s), increase `max_sites`, or use a more compact cell_type (see encode_cell_type and decode_cell_type kwargs)." # noqa: E501 ) element_encoding.append( np.pad( encode_many(s.atomic_numbers, self.element_encoding), (0, self.max_sites - n_sites), ).tolist() ) frac_coords_tmp.append( np.pad(s.frac_coords, ((0, self.max_sites - n_sites), (0, 0))) ) latt_a.append(s._lattice.a) latt_b.append(s._lattice.b) latt_c.append(s._lattice.c) angles.append(list(s._lattice.angles)) num_sites.append(s.num_sites) space_group.append(_get_space_group(s)) dm = s.distance_matrix # avoid repeat calculation if n_sites != dm.shape[0]: raise ValueError( f"len(atomic_numbers) {n_sites} and distance_matrix.shape[0] {dm.shape[0]} do not match" # noqa ) # noqa # assume that distance matrix is square padwidth = (0, self.max_sites - n_sites) distance_matrix_tmp.append(np.pad(dm, padwidth)) # [0:max_sites, 0:max_sites] frac_coords = np.stack(frac_coords_tmp) distance_matrix = np.stack(distance_matrix_tmp) if rgb_scaling: # REVIEW: consider using modified pettifor scale instead of atomic numbers # REVIEW: consider using feature_range=atom_range or 2*atom_range # REVIEW: since it introduces a sort of non-linearity b.c. of rounding # ToDo: the range below is not optimal. For this the fit should return a # list of all the elements atom_scaled = rgb_scaler( element_encoding, data_range=self._element_encoding_range, ) # noqa frac_scaled = rgb_scaler(frac_coords, data_range=self.frac_range) a_scaled = rgb_scaler(latt_a, data_range=self.a_range) b_scaled = rgb_scaler(latt_b, data_range=self.b_range) c_scaled = rgb_scaler(latt_c, data_range=self.c_range) angles_scaled = rgb_scaler(angles, data_range=self.angles_range) num_sites_scaled = rgb_scaler(num_sites, data_range=self.num_sites_range) space_group_scaled = rgb_scaler( space_group, data_range=self.space_group_range ) # NOTE: max_distance could be added as another (repeated) value/row to infer # NOTE: kind of like frac_distance_matrix, not sure if would be effective # NOTE: Or could normalize distance_matix by cell volume # NOTE: and possibly include cell volume as a (repeated) value/row to infer # NOTE: It's possible extra info like this isn't so bad, instilling the # physics # NOTE: but it could also just be extraneous work to predict/infer distance_scaled = rgb_scaler( distance_matrix, data_range=self.distance_range ) else: feature_range = (0.0, 1.0) atom_scaled = element_wise_scaler( element_encoding, feature_range=feature_range, data_range=self._element_encoding_range, ) frac_scaled = element_wise_scaler( frac_coords, feature_range=feature_range, data_range=self.frac_range ) a_scaled = element_wise_scaler( latt_a, feature_range=feature_range, data_range=self.a_range ) b_scaled = element_wise_scaler( latt_b, feature_range=feature_range, data_range=self.b_range ) c_scaled = element_wise_scaler( latt_c, feature_range=feature_range, data_range=self.c_range ) angles_scaled = element_wise_scaler( angles, feature_range=feature_range, data_range=self.angles_range ) num_sites_scaled = element_wise_scaler( num_sites, feature_range=feature_range, data_range=self.num_sites_range ) space_group_scaled = element_wise_scaler( space_group, feature_range=feature_range, data_range=self.space_group_range, ) distance_scaled = element_wise_scaler( distance_matrix, feature_range=feature_range, data_range=self.distance_range, ) atom_arr = np.expand_dims(atom_scaled, 2) frac_arr = frac_scaled a_arr = np.repeat(np.expand_dims(a_scaled, (1, 2)), self.max_sites, axis=1) b_arr = np.repeat(np.expand_dims(b_scaled, (1, 2)), self.max_sites, axis=1) c_arr = np.repeat(np.expand_dims(c_scaled, (1, 2)), self.max_sites, axis=1) angles_arr = np.repeat(np.expand_dims(angles_scaled, 1), self.max_sites, axis=1) num_sites_arr = np.repeat( np.expand_dims(num_sites_scaled, (1, 2)), self.max_sites, axis=1 ) space_group_arr = np.repeat( np.expand_dims(space_group_scaled, (1, 2)), self.max_sites, axis=1 ) distance_arr = distance_scaled data = self.assemble_blocks( atom_arr, frac_arr, a_arr, b_arr, c_arr, angles_arr, num_sites_arr, space_group_arr, distance_arr, ) id_mapper = { ATOM_KEY: ATOM_ID, FRAC_KEY: FRAC_ID, A_KEY: A_ID, B_KEY: B_ID, C_KEY: C_ID, ANGLES_KEY: ANGLES_ID, NUM_SITES_KEY: NUM_SITES_ID, SPACE_GROUP_KEY: SPACE_GROUP_ID, DISTANCE_KEY: DISTANCE_ID, } id_blocks = [ np.ones_like(atom_arr) * ATOM_ID, np.ones_like(frac_arr) * FRAC_ID, np.ones_like(a_arr) * A_ID, np.ones_like(b_arr) * B_ID, np.ones_like(c_arr) * C_ID, np.ones_like(angles_arr) * ANGLES_ID, np.ones_like(num_sites_arr) * NUM_SITES_ID, np.ones_like(space_group_arr) * SPACE_GROUP_ID, np.ones_like(distance_arr) * DISTANCE_ID, ] id_data = self.assemble_blocks(*id_blocks) # apply num_sites mask (zero out bottom and RHS blocks past num_sites) data = self.apply_num_sites_mask(data, num_sites) id_data = self.apply_num_sites_mask(id_data, num_sites) data = np.expand_dims(data, 1) id_data = np.expand_dims(id_data, 1) for mask_type in self.mask_types: if mask_type == LOWER_TRI_KEY: for d in data: if d.shape[1] != d.shape[2]: raise ValueError( f"Expected square matrix in last two dimensions, received {d.shape}" # noqa: E501 ) for sq in d: sq[np.tril_indices(d.shape[1])] = 0.0 else: data[id_data == id_mapper[mask_type]] = 0.0 data = np.repeat(data, self.channels, 1) id_data = np.repeat(id_data, self.channels, 1) return data, id_data, id_mapper
[docs] def assemble_blocks( self, atom_arr, frac_arr, a_arr, b_arr, c_arr, angles_arr, num_sites, space_group_arr, distance_arr, ) -> NDArray: arrays = [ atom_arr, frac_arr, a_arr, b_arr, c_arr, angles_arr, num_sites, space_group_arr, ] zero_pad = sum([arr.shape[2] for arr in arrays]) n_structures = atom_arr.shape[0] zero = np.zeros((n_structures, zero_pad, zero_pad), dtype=int) vertical_arr = np.block( [ [zero], [ atom_arr, frac_arr, a_arr, b_arr, c_arr, angles_arr, num_sites, space_group_arr, ], ] ) horizontal_arr = np.block( [ atom_arr, frac_arr, a_arr, b_arr, c_arr, angles_arr, num_sites, space_group_arr, ] ) horizontal_arr = np.moveaxis(horizontal_arr, 1, 2) left_arr = vertical_arr right_arr = np.block([[horizontal_arr], [distance_arr]]) data = np.block([left_arr, right_arr]) return data
[docs] def disassemble_blocks( self, data, # id_data: Optional[NDArray] = None, # id_mapper: Optional[dict] = None, ): # TODO: implement a more robust solution using id_data and id_mapper # if (id_data is None) is not (id_mapper is None): # raise ValueError( # f"id_data (type: {type(id_data)}) and id_mapper (type: {type(id_mapper)}) should either both be assigned or both be None." # noqa # ) # elif id_data is None and id_mapper is None: # _, id_data, id_mapper = self.structures_to_arrays(dummy_structures) # assert ( # id_data is not None and id_mapper is not None # ), "id_data and id_mapper should not be None at this point" zero_pad = 12 left_arr, right_arr = np.array_split(data, [zero_pad], axis=1) _, bottom_left = np.array_split(left_arr, [zero_pad], axis=2) lengths = [1, 3, 1, 1, 1, 3, 1] verts = np.array_split(bottom_left, np.cumsum(lengths), axis=1) top_right, bottom_right = np.array_split(right_arr, [zero_pad], axis=2) distance_arr = bottom_right horzs = np.array_split(top_right, np.cumsum(lengths), axis=2) def average_vert_horz(vert, horz): vert = np.float64(vert) horz = np.float64(horz) avg = (vert.swapaxes(1, 2) + horz) / 2 return avg avgs = [average_vert_horz(v, h) for v, h in zip(verts, horzs)] ( atom_arr, frac_arr, a_arr, b_arr, c_arr, angles_arr, num_sites_arr, space_group_arr, ) = avgs return ( atom_arr, frac_arr, a_arr, b_arr, c_arr, angles_arr, num_sites_arr, space_group_arr, distance_arr, )
[docs] def arrays_to_structures( self, data: np.ndarray, id_data: Optional[np.ndarray] = None, id_mapper: Optional[dict] = None, rgb_scaling: bool = True, ) -> List[Structure]: """Convert scaled crystal (xtal) arrays to pymatgen Structures. Parameters ---------- data : np.ndarray 3D array containing crystallographic information. id_data : ArrayLike Same shape as ``data``, except one-hot encoded to distinguish between the various types of information contained in ``data``. See ``id_mapper`` for the "legend" for this data. id_mapper : ArrayLike Dictionary containing the legend/key between the names of the blocks and the corresponding numbers in ``id_data``. rgb_scaling : Whether the input arrays are scaled to RGB values (0-255), otherwise assume scaled between (0-1), by default True. """ if not isinstance(data, np.ndarray): raise ValueError( f"`data` should be of type `np.ndarray`. Received type {type(data)}. Maybe you passed a tuple of (data, id_data, id_mapper) returned from `structures_to_arrays()` by accident?" # noqa: E501 ) # convert to single channel and remove singleton dimension before disassembly data = np.mean(data, axis=1) # to extract num_sites for preprocess masking of data if id_data is None and id_mapper is None: _, id_data, id_mapper = self.structures_to_arrays(dummy_structures) if id_data is None or id_mapper is None: # for mypy raise ValueError("id_data and id_mapper should not be None at this point") id_data = np.mean(id_data, axis=1) assert id_data is not None, "id_data should not be None at this point" num_sites = [d[id_data[0] == id_mapper[NUM_SITES_KEY]] for d in data] num_sites = [ns[np.where(ns > 0)] for ns in num_sites] num_sites = [np.mean(ns) for ns in num_sites] if rgb_scaling: num_sites = rgb_unscaler(num_sites, data_range=self.num_sites_range) else: num_sites = element_wise_unscaler( num_sites, feature_range=(0.0, 1.0), data_range=self.num_sites_range ) assert isinstance(num_sites, np.ndarray) num_sites = np.round(num_sites).astype(int) data = self.apply_num_sites_mask(data, num_sites) # for decoding final crystal structure arrays = self.disassemble_blocks( data, # id_data=id_data, # id_mapper=id_mapper ) ( atom_scaled, frac_scaled, a_scaled_tmp, b_scaled_tmp, c_scaled_tmp, angles_scaled_tmp, _, space_group_scaled_tmp, distance_scaled, ) = [np.squeeze(arr, axis=2) if arr.shape[2] == 1 else arr for arr in arrays] a_scaled = np.mean(a_scaled_tmp, axis=1, where=a_scaled_tmp != 0) b_scaled = np.mean(b_scaled_tmp, axis=1, where=b_scaled_tmp != 0) c_scaled = np.mean(c_scaled_tmp, axis=1, where=c_scaled_tmp != 0) angles_scaled = np.mean(angles_scaled_tmp, axis=1, where=angles_scaled_tmp != 0) # num_sites_scaled = np.mean(num_sites_scaled_tmp, axis=1) space_group_scaled = np.round(np.mean(space_group_scaled_tmp, axis=1)).astype( int ) if rgb_scaling: # ToDo: expose the distance options for the decoding unscaled_atom_encodings = [ encoding for encoding in rgb_unscaler( atom_scaled, data_range=self._element_encoding_range ) ] atomic_symbols = [ decode_many( encoding, self.element_encoding, metric=self.element_decoding_metric ) for encoding in unscaled_atom_encodings ] frac_coords = rgb_unscaler(frac_scaled, data_range=self.frac_range) latt_a = rgb_unscaler(a_scaled, data_range=self.a_range) latt_b = rgb_unscaler(b_scaled, data_range=self.b_range) latt_c = rgb_unscaler(c_scaled, data_range=self.c_range) angles = rgb_unscaler(angles_scaled, data_range=self.angles_range) # # num_sites, space_group, distance_matrix unecessary for making Structure # num_sites = rgb_unscaler(num_sites_scaled, # data_range=self.num_sites_range) space_group = rgb_unscaler( space_group_scaled, data_range=self.space_group_range ) distance_matrix = rgb_unscaler( distance_scaled, data_range=self.distance_range ) else: feature_range = (0.0, 1.0) unscaled_atom_encodings = [ encoding for encoding in element_wise_unscaler( atom_scaled, feature_range=feature_range, data_range=self._element_encoding_range, ) ] atomic_symbols = [ decode_many( encoding, self.element_encoding, metric=self.element_decoding_metric ) for encoding in unscaled_atom_encodings ] frac_coords = element_wise_unscaler( frac_scaled, feature_range=feature_range, data_range=self.frac_range ) latt_a = element_wise_unscaler( a_scaled, feature_range=feature_range, data_range=self.a_range ) latt_b = element_wise_unscaler( b_scaled, feature_range=feature_range, data_range=self.b_range ) latt_c = element_wise_unscaler( c_scaled, feature_range=feature_range, data_range=self.c_range ) angles = element_wise_unscaler( angles_scaled, feature_range=feature_range, data_range=self.angles_range ) # num_sites = element_wise_unscaler( # num_sites_scaled, # feature_range=feature_range, # data_range=self.num_sites_range, # ) space_group = element_wise_unscaler( space_group_scaled, feature_range=feature_range, data_range=self.space_group_range, ) distance_matrix = element_wise_unscaler( distance_scaled, feature_range=feature_range, data_range=self.distance_range, ) # num_sites = np.round(num_sites).astype(int) for dm, ns in zip(distance_matrix, num_sites): np.fill_diagonal(dm, 0.0) # mask bottom and RHS via num_sites dm[ns:, :] = 0.0 dm[:, ns:] = 0.0 # technically unused, but to avoid issue with pre-commit for now: space_group, distance_matrix # TODO: tweak lattice parameters to match predicted space group rules if self.relax_on_decode: try: import tensorflow as tf from m3gnet.models import Relaxer except ImportError as e: print(e) print( "For Windows users on Anaconda, you need to `pip install m3gnet` or set relax_on_decode=False." # noqa: E501 ) if not self.verbose: tf.get_logger().setLevel(logging.ERROR) relaxer = Relaxer() # This loads the default pre-trained model # build Structure-s S: List[Structure] = [] num_structures = len(atomic_symbols) for i in self.tqdm_if_verbose(range(num_structures)): ns = num_sites[i] at = atomic_symbols[i][:ns] fr = frac_coords[i][:ns] a, b, c = latt_a[i], latt_b[i], latt_c[i] alpha, beta, gamma = angles[i] lattice = Lattice.from_parameters( a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma ) s = Structure(lattice, at, fr) # REVIEW: round fractional coordinates to nearest multiple? if self.relax_on_decode: relaxed_results = relaxer.relax(s, verbose=self.verbose) s = relaxed_results["final_structure"] s = unit_cell_converter( s, self.decode_cell_type, symprec=self.decode_symprec, angle_tolerance=self.decode_angle_tolerance, ) S.append(s) return S
[docs] def apply_num_sites_mask(self, data, num_sites): tot = data.shape[-1] for d, ns in zip(data, num_sites): filler_dim = tot - self.max_sites # i.e. 12 # apply mask to bottom and RHS blocks d[:, filler_dim + ns :] = 0.0 d[filler_dim + ns :, :] = 0.0 return data
[docs]def setup_logging(loglevel): """Setup basic logging Args: loglevel (int): minimum loglevel for emitting messages """ logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s" logging.basicConfig( level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S" )