"""First-shell coordination targets extracted from crystalline reference cells."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from ase.atoms import Atoms
from ase.data import chemical_symbols
from ase.neighborlist import neighbor_list
from .g3 import _EPS
def _cell_face_spacings(cell_matrix: np.ndarray) -> np.ndarray:
"""Return periodic face spacings for a cell spanned by row vectors."""
inverse = np.linalg.inv(np.asarray(cell_matrix, dtype=np.float64))
return 1.0 / np.maximum(np.linalg.norm(inverse, axis=0), _EPS)
def _gaussian_kernel(sigma_bins: float) -> np.ndarray:
"""Return a normalized 1D Gaussian kernel."""
sigma_bins = float(sigma_bins)
if sigma_bins <= _EPS:
return np.array([1.0], dtype=np.float64)
radius = max(1, int(np.ceil(3.0 * sigma_bins)))
x = np.arange(-radius, radius + 1, dtype=np.float64)
kernel = np.exp(-0.5 * (x / max(sigma_bins, _EPS)) ** 2)
kernel /= max(np.sum(kernel), _EPS)
return kernel
def _smooth_histogram(values: np.ndarray, sigma_bins: float) -> np.ndarray:
"""Smooth a histogram using a small Gaussian kernel."""
kernel = _gaussian_kernel(sigma_bins)
return np.convolve(np.asarray(values, dtype=np.float64), kernel, mode="same")
def _first_local_maximum(values: np.ndarray) -> int:
"""Return the first prominent local maximum, falling back to the global one."""
values = np.asarray(values, dtype=np.float64)
if values.size == 0:
return 0
if values.size == 1:
return 0
peak_threshold = 0.25 * float(np.max(values))
for index in range(1, values.size - 1):
if values[index] >= peak_threshold and values[index] >= values[index - 1] and values[index] >= values[index + 1]:
return int(index)
return int(np.argmax(values))
def _infer_shell_window(
distances: np.ndarray,
*,
hist_step: float,
smooth_sigma_bins: float,
) -> tuple[float, float, float, float, float]:
"""Infer a first-shell window from reference pair distances."""
distances = np.sort(np.asarray(distances, dtype=np.float64))
if distances.size == 0:
return 0.0, 0.0, 0.0, hist_step, 0.0
upper = float(np.max(distances) + hist_step)
edges = np.arange(0.0, upper + hist_step, hist_step, dtype=np.float64)
if edges.size < 3:
edges = np.array([0.0, hist_step, 2.0 * hist_step], dtype=np.float64)
hist, _ = np.histogram(distances, bins=edges)
occupied = np.flatnonzero(hist > 0)
if occupied.size == 0:
return 0.0, 0.0, 0.0, hist_step, 0.0
first_occ = int(occupied[0])
left_index = max(first_occ - 1, 0)
right_index = hist.size - 1
zero_run = 0
zero_right_index: int | None = None
for index in range(first_occ + 1, hist.size):
if hist[index] == 0:
zero_run += 1
if zero_run >= 2:
zero_right_index = index - zero_run + 1
break
else:
zero_run = 0
if zero_right_index is not None:
right_index = min(hist.size - 1, zero_right_index)
else:
smooth = _smooth_histogram(hist, sigma_bins=smooth_sigma_bins)
peak_index = _first_local_maximum(smooth[first_occ:]) + first_occ
peak_value = float(max(smooth[peak_index], _EPS))
cutoff_value = 0.12 * peak_value
right_index = peak_index
while right_index < smooth.size - 1 and smooth[right_index] > cutoff_value:
right_index += 1
r_inner = float(edges[max(left_index, 0)])
r_outer = float(edges[min(right_index + 1, edges.size - 1)])
in_shell = distances[(distances >= r_inner) & (distances <= r_outer)]
if in_shell.size == 0:
in_shell = distances
r_peak = float(np.mean(in_shell))
sigma_r = float(max(np.std(in_shell), hist_step))
hard_min = float(max(0.0, r_inner - 1.5 * sigma_r))
return hard_min, r_inner, r_peak, sigma_r, r_outer
[docs]
@dataclass(frozen=True)
class CoordinationShellTarget:
"""Species-aware first-shell coordination targets extracted from a crystal."""
atoms: Atoms
label: str
species: np.ndarray
species_labels: tuple[str, ...]
phi_num_bins: int
phi_edges: np.ndarray
phi: np.ndarray
phi_deg: np.ndarray
angle_index: np.ndarray
angle_lookup: np.ndarray
pair_hard_min: np.ndarray
pair_inner: np.ndarray
pair_peak: np.ndarray
pair_sigma: np.ndarray
pair_outer: np.ndarray
pair_mask: np.ndarray
coordination_target: np.ndarray
coordination_std: np.ndarray
angle_target: np.ndarray
angle_pair_mass_target: np.ndarray
angle_mode_deg: np.ndarray
# Per-triplet mask that controls whether ``shell_relax`` installs an
# angle spring for that triplet type. Bond-distance springs are
# unaffected. Useful for multi-modal shells (SrO\u2081\u2082
# cuboctahedron, Sr-O-Sr triplets in SrTiO\u2083) where forcing a
# single ``angle_mode_deg`` would strain pairs at the other modes.
angle_enabled_mask: np.ndarray
motif_center_species: np.ndarray
motif_neighbor_species: tuple[np.ndarray, ...]
motif_neighbor_vectors: tuple[np.ndarray, ...]
max_pair_outer: float
max_pair_outer_by_center: np.ndarray
summary: dict[str, object]
[docs]
@classmethod
def from_atoms(
cls,
atoms: Atoms,
*,
phi_num_bins: int = 72,
shell_hist_step: float = 0.05,
shell_smooth_sigma_bins: float = 1.2,
extract_cutoff: float | None = None,
auto_filter_lattice_artifacts: bool = True,
label: str | None = None,
) -> "CoordinationShellTarget":
"""Extract first-shell coordination, distance, and angle targets from a reference crystal.
For every species pair in ``atoms`` the method fits the first
peak of the radial pair distribution g(r) — its inner edge,
peak position, Gaussian width, and outer edge — and counts the
average number of neighbours each centre atom has within that
window. For every triplet (centre, neighbour-A, neighbour-B)
species combination it builds a histogram of bond angles
ranged 0–180° using ``phi_num_bins`` bins, identifying the
dominant angle mode. All quantities are stored as 2-D /
3-D numpy arrays indexed by species index in ``species``.
Parameters
----------
atoms : ase.Atoms
Reference crystal whose g(r) and bond-angle distributions
define the target geometry. Must be a periodic cell with
at least one neighbour pair within
``extract_cutoff`` (auto if ``None``).
phi_num_bins : int, optional
Number of bins used to discretise the [0, 180°] bond-angle
axis. Default ``72`` (2.5° per bin). Higher values
sharpen the angle target but slow down the angle
measurement loop in :meth:`Supercell.measure_g3`.
shell_hist_step : float, optional
Bin width (Å) of the per-pair radial histogram used to
locate the first peak. Default ``0.05`` Å.
shell_smooth_sigma_bins : float, optional
Gaussian smoothing width (in bins) applied to the radial
histogram before peak detection. Default ``1.2`` bins.
extract_cutoff : float, optional
Maximum centre-neighbour distance (Å) considered when
building the neighbour list. If ``None`` the routine picks
``min(default, max(3.8 × NN, NN + 2))`` based on the
inferred nearest-neighbour distance.
auto_filter_lattice_artifacts : bool, optional
If ``True`` (default), zero out ``coordination_target`` for
species pairs that represent lattice artefacts rather than
real chemical bonds. A pair ``(i, j)`` is kept only when
``pair_peak[i, j]`` is the smallest enabled peak in
either row ``i`` or column ``j``. This automatically
silences the second-shell ``Si-Si`` / ``O-O`` springs in
``SiO2``, the ``Sr-Sr`` / ``Ti-Ti`` / ``O-O`` / ``Sr-Ti``
springs in ``SrTiO3``, etc. Set to ``False`` to keep every
extracted pair (pre-2026 behaviour); callers can also
override the filter by chaining
:meth:`with_bonded_species_pairs` after extraction.
label : str, optional
Free-text identifier carried along on the returned target
(used in plot legends and HTML viewer titles). Default
uses the chemical formula of ``atoms``.
Returns
-------
CoordinationShellTarget
Frozen dataclass populated with every per-species and
per-triplet field listed in the class header
(``coordination_target``, ``pair_peak``, ``pair_inner``,
``angle_target``, ``angle_mode_deg``, etc.). All ndarray
fields are pre-symmetrised over species pairs and the
angle-enabled mask is initialised to ``True`` everywhere.
Notes
-----
Single-element references (Si, Cu, …) produce a 1×1
coordination matrix and one self-self angle channel.
Multi-element references (SiO₂, SrTiO₃, …) populate every
cross-pair entry — see
:meth:`with_cross_species_bonds_only` and
:meth:`with_bonded_species_pairs` for masking helpers when
only a subset of pairs represents real chemical bonds.
Examples
--------
>>> from ase.build import bulk
>>> import tricor as tc
>>> atoms = bulk("Si", "diamond", a=5.431)
>>> shell = tc.CoordinationShellTarget.from_atoms(atoms,
... phi_num_bins=90)
>>> shell.coordination_target # 4 NN per Si
array([[4.]])
>>> float(shell.pair_peak[0, 0])
2.352
"""
atoms = atoms.copy()
species = np.unique(np.asarray(atoms.numbers, dtype=np.int64))
num_species = int(species.size)
species_index = np.searchsorted(species, np.asarray(atoms.numbers, dtype=np.int64))
species_labels = tuple(chemical_symbols[int(spec)] for spec in species)
phi_num_bins = int(phi_num_bins)
if phi_num_bins <= 0:
raise ValueError("phi_num_bins must be positive.")
phi_edges = np.linspace(0.0, np.pi, phi_num_bins + 1, dtype=np.float64)
phi = 0.5 * (phi_edges[:-1] + phi_edges[1:])
phi_deg = np.rad2deg(phi)
angle_index = []
angle_lookup = -np.ones((num_species, num_species, num_species), dtype=np.intp)
for center_index in range(num_species):
for neigh1_index in range(num_species):
for neigh2_index in range(neigh1_index, num_species):
triplet_index = len(angle_index)
angle_index.append((center_index, neigh1_index, neigh2_index))
angle_lookup[center_index, neigh1_index, neigh2_index] = triplet_index
angle_lookup[center_index, neigh2_index, neigh1_index] = triplet_index
angle_index = np.asarray(angle_index, dtype=np.intp)
cell_matrix = np.asarray(atoms.cell.array, dtype=np.float64)
cell_face_spacings = _cell_face_spacings(cell_matrix)
cell_lengths = np.linalg.norm(cell_matrix, axis=1)
default_probe_cutoff = max(8.0, 1.5 * float(np.max(cell_lengths)), 1.2 * float(np.max(cell_face_spacings)))
if extract_cutoff is None:
probe_cutoff = default_probe_cutoff
else:
probe_cutoff = float(extract_cutoff)
if probe_cutoff <= 0:
raise ValueError("extract_cutoff must be positive when provided.")
i, j, d, D = neighbor_list(
"ijdD",
atoms,
probe_cutoff,
self_interaction=False,
)
if d.size == 0:
raise ValueError("Could not detect any periodic neighbors in the reference cell.")
nearest = np.full(len(atoms), np.inf, dtype=np.float64)
np.minimum.at(nearest, i.astype(np.intp), d.astype(np.float64))
finite_nearest = nearest[np.isfinite(nearest)]
if finite_nearest.size == 0:
raise ValueError("Failed to infer nearest-neighbor distances from the reference cell.")
nearest_reference = float(np.median(finite_nearest))
if extract_cutoff is None:
probe_cutoff = min(default_probe_cutoff, max(3.8 * nearest_reference, nearest_reference + 2.0))
i, j, d, D = neighbor_list(
"ijdD",
atoms,
probe_cutoff,
self_interaction=False,
)
center_species = species_index[np.asarray(i, dtype=np.intp)]
neighbor_species = species_index[np.asarray(j, dtype=np.intp)]
pair_hard_min = np.zeros((num_species, num_species), dtype=np.float64)
pair_inner = np.zeros((num_species, num_species), dtype=np.float64)
pair_peak = np.zeros((num_species, num_species), dtype=np.float64)
pair_sigma = np.full((num_species, num_species), float(shell_hist_step), dtype=np.float64)
pair_outer = np.zeros((num_species, num_species), dtype=np.float64)
pair_mask = np.zeros((num_species, num_species), dtype=bool)
for species_a in range(num_species):
for species_b in range(species_a, num_species):
mask = (
((center_species == species_a) & (neighbor_species == species_b))
| ((center_species == species_b) & (neighbor_species == species_a))
)
distances = np.asarray(d[mask], dtype=np.float64)
if distances.size == 0:
continue
hard_min, r_inner, r_peak, sigma_r, r_outer = _infer_shell_window(
distances,
hist_step=float(shell_hist_step),
smooth_sigma_bins=float(shell_smooth_sigma_bins),
)
pair_hard_min[species_a, species_b] = hard_min
pair_hard_min[species_b, species_a] = hard_min
pair_inner[species_a, species_b] = r_inner
pair_inner[species_b, species_a] = r_inner
pair_peak[species_a, species_b] = r_peak
pair_peak[species_b, species_a] = r_peak
pair_sigma[species_a, species_b] = sigma_r
pair_sigma[species_b, species_a] = sigma_r
pair_outer[species_a, species_b] = r_outer
pair_outer[species_b, species_a] = r_outer
pair_mask[species_a, species_b] = True
pair_mask[species_b, species_a] = True
coordination_target = np.zeros((num_species, num_species), dtype=np.float64)
coordination_std = np.zeros((num_species, num_species), dtype=np.float64)
for center_ind in range(num_species):
centers = np.flatnonzero(species_index == center_ind)
if centers.size == 0:
continue
for neigh_ind in range(num_species):
if not pair_mask[center_ind, neigh_ind]:
continue
r_inner = pair_inner[center_ind, neigh_ind]
r_outer = pair_outer[center_ind, neigh_ind]
mask = (
(center_species == center_ind)
& (neighbor_species == neigh_ind)
& (d >= r_inner)
& (d <= r_outer)
)
counts = np.bincount(np.asarray(i[mask], dtype=np.intp), minlength=len(atoms))
centered_counts = counts[centers]
coordination_target[center_ind, neigh_ind] = float(np.mean(centered_counts))
coordination_std[center_ind, neigh_ind] = float(np.std(centered_counts))
# ----------------------------------------------------------------
# Auto-filter lattice-artefact bond pairs.
#
# In multi-element crystals like SiO2 and SrTiO3 the
# neighbour-list extraction also picks up the *second* shell
# (Si-Si at 3.06 Å in alpha-quartz, O-O at 2.64 Å, Sr-Sr at
# 3.91 Å in SrTiO3, etc.) and emits a non-zero
# ``coordination_target`` for those pairs. Treating them as
# bonds during shell relaxation puts geometrically-incompatible
# springs on the same atom (each Si simultaneously gets pulled
# to 4 O at 1.61 Å AND 4 Si at 3.06 Å, etc.) and FIRE thrashes
# without converging.
#
# Heuristic: pair (i, j) is a real chemical bond iff
# ``pair_peak[i, j]`` is the smallest enabled peak for at least
# one of {row i, column j}. In words: the pair must represent
# a direct atom-atom contact for at least one of the two
# species. Lattice artefacts (which always go through a
# bridging atom or a longer lattice vector) are larger than
# both sides' minimum and get zeroed out.
#
# Concrete examples:
# SiO2: Si-O is min for Si and O → real
# Si-Si > min(Si row) → artefact (zeroed)
# O-O > min(O row) → artefact (zeroed)
# SrTiO3: Ti-O is min for Ti and O → real
# Sr-O is min for Sr only → real
# Sr-Sr > min(Sr row) → artefact (zeroed)
# Ti-Ti > min(Ti row) → artefact (zeroed)
# Sr-Ti > min(Sr) > min(Ti) → artefact (zeroed)
# Cu/Si: only one species, sole pair → real
#
# Set ``auto_filter_lattice_artifacts=False`` to disable and
# recover the pre-fix behaviour (every pair_mask entry kept).
# Callers can always override afterwards via
# :meth:`with_bonded_species_pairs` /
# :meth:`with_cross_species_bonds_only`.
if auto_filter_lattice_artifacts and num_species >= 2:
enabled_peaks = np.where(pair_mask, pair_peak, np.inf)
min_per_row = np.min(enabled_peaks, axis=1)
min_per_col = np.min(enabled_peaks, axis=0)
tol = 1e-6 # numerical slack for symmetric peaks
for i_sp in range(num_species):
for j_sp in range(num_species):
if not pair_mask[i_sp, j_sp]:
continue
peak_ij = pair_peak[i_sp, j_sp]
is_min_for_row = peak_ij <= min_per_row[i_sp] + tol
is_min_for_col = peak_ij <= min_per_col[j_sp] + tol
if not (is_min_for_row or is_min_for_col):
coordination_target[i_sp, j_sp] = 0.0
angle_target = np.zeros((angle_index.shape[0], phi_num_bins), dtype=np.float64)
angle_pair_mass_target = np.zeros(angle_index.shape[0], dtype=np.float64)
angle_mode_deg = np.zeros(angle_index.shape[0], dtype=np.float64)
# Default: angle springs enabled for every triplet type. Use
# ``with_angle_triplets`` / ``without_angle_triplets`` to mask
# specific triplets for multi-modal shells.
angle_enabled_mask = np.ones(angle_index.shape[0], dtype=bool)
motif_center_species: list[int] = []
motif_neighbor_species: list[np.ndarray] = []
motif_neighbor_vectors: list[np.ndarray] = []
neighbors_by_center: list[dict[str, np.ndarray]] = []
for atom_index in range(len(atoms)):
mask = np.asarray(i, dtype=np.intp) == int(atom_index)
neighbors_by_center.append(
{
"neighbor_index": np.asarray(j[mask], dtype=np.intp),
"neighbor_species": neighbor_species[mask].astype(np.intp, copy=False),
"vectors": np.asarray(D[mask], dtype=np.float64),
"distance": np.asarray(d[mask], dtype=np.float64),
}
)
for center_atom in range(len(atoms)):
center_species_index = int(species_index[center_atom])
local = neighbors_by_center[center_atom]
if local["neighbor_index"].size == 0:
motif_center_species.append(center_species_index)
motif_neighbor_species.append(np.empty(0, dtype=np.intp))
motif_neighbor_vectors.append(np.empty((0, 3), dtype=np.float64))
continue
keep = np.zeros(local["neighbor_index"].shape[0], dtype=bool)
for neighbor_ind, neighbor_species_index in enumerate(local["neighbor_species"]):
if not pair_mask[center_species_index, int(neighbor_species_index)]:
continue
radius = float(local["distance"][neighbor_ind])
keep[neighbor_ind] = (
radius >= float(pair_inner[center_species_index, int(neighbor_species_index)])
and radius <= float(pair_outer[center_species_index, int(neighbor_species_index)])
)
local_species = local["neighbor_species"][keep].astype(np.intp, copy=False)
local_vectors = local["vectors"][keep].astype(np.float64, copy=False)
if local_species.size:
radius = np.linalg.norm(local_vectors, axis=1)
order = np.lexsort(
(
local_vectors[:, 2],
local_vectors[:, 1],
local_vectors[:, 0],
radius,
local_species,
)
)
local_species = local_species[order]
local_vectors = local_vectors[order]
motif_center_species.append(center_species_index)
motif_neighbor_species.append(np.array(local_species, dtype=np.intp, copy=True))
motif_neighbor_vectors.append(np.array(local_vectors, dtype=np.float64, copy=True))
for center_atom in range(len(atoms)):
center_species_index = int(species_index[center_atom])
local = neighbors_by_center[center_atom]
if local["neighbor_index"].size == 0:
continue
for triplet_index, (_, species_1, species_2) in enumerate(angle_index):
if angle_index[triplet_index, 0] != center_species_index:
continue
inner_1 = pair_inner[center_species_index, species_1]
outer_1 = pair_outer[center_species_index, species_1]
inner_2 = pair_inner[center_species_index, species_2]
outer_2 = pair_outer[center_species_index, species_2]
mask_1 = (
(local["neighbor_species"] == species_1)
& (local["distance"] >= inner_1)
& (local["distance"] <= outer_1)
)
mask_2 = (
(local["neighbor_species"] == species_2)
& (local["distance"] >= inner_2)
& (local["distance"] <= outer_2)
)
if not np.any(mask_1) or not np.any(mask_2):
continue
v1 = local["vectors"][mask_1]
v2 = local["vectors"][mask_2]
r1_sq = np.einsum("ij,ij->i", v1, v1)
r2_sq = np.einsum("ij,ij->i", v2, v2)
if species_1 == species_2:
if v1.shape[0] < 2:
continue
dot = v1 @ v2.T
denom = np.sqrt(np.maximum(r1_sq[:, None] * r2_sq[None, :], _EPS))
cos_phi = np.clip(dot / denom, -1.0, 1.0)
phi_bin = np.floor(np.arccos(cos_phi) / (phi_edges[1] - phi_edges[0])).astype(np.intp)
np.clip(phi_bin, 0, phi_num_bins - 1, out=phi_bin)
upper = np.triu_indices(phi_bin.shape[0], k=1)
bins = phi_bin[upper]
else:
dot = v1 @ v2.T
denom = np.sqrt(np.maximum(r1_sq[:, None] * r2_sq[None, :], _EPS))
cos_phi = np.clip(dot / denom, -1.0, 1.0)
phi_bin = np.floor(np.arccos(cos_phi) / (phi_edges[1] - phi_edges[0])).astype(np.intp)
np.clip(phi_bin, 0, phi_num_bins - 1, out=phi_bin)
bins = phi_bin.ravel()
if bins.size == 0:
continue
angle_target[triplet_index] += np.bincount(bins, minlength=phi_num_bins)
angle_pair_mass_target[triplet_index] += float(bins.size)
centers_per_species = np.bincount(species_index, minlength=num_species).astype(np.float64)
for triplet_index, (center_ind, _, _) in enumerate(angle_index):
mass = float(angle_pair_mass_target[triplet_index])
if mass > 0.0:
angle_target[triplet_index] /= mass
angle_pair_mass_target[triplet_index] = mass / max(float(centers_per_species[center_ind]), 1.0)
angle_mode_deg[triplet_index] = float(phi_deg[int(np.argmax(angle_target[triplet_index]))])
max_pair_outer = float(np.max(pair_outer[pair_mask])) if np.any(pair_mask) else 0.0
max_pair_outer_by_center = np.zeros(num_species, dtype=np.float64)
for center_ind in range(num_species):
row = pair_outer[center_ind][pair_mask[center_ind]]
max_pair_outer_by_center[center_ind] = float(np.max(row)) if row.size else 0.0
summary = {
"num_atoms": len(atoms),
"num_species": num_species,
"phi_num_bins": phi_num_bins,
"extract_cutoff": float(probe_cutoff),
"nearest_reference": nearest_reference,
"max_pair_outer": max_pair_outer,
"num_triplets": int(angle_index.shape[0]),
"num_motifs": int(len(motif_center_species)),
}
return cls(
atoms=atoms,
label=label or "coordination-shell-target",
species=species.astype(np.int64, copy=False),
species_labels=species_labels,
phi_num_bins=phi_num_bins,
phi_edges=phi_edges,
phi=phi,
phi_deg=phi_deg,
angle_index=angle_index,
angle_lookup=angle_lookup,
pair_hard_min=pair_hard_min,
pair_inner=pair_inner,
pair_peak=pair_peak,
pair_sigma=pair_sigma,
pair_outer=pair_outer,
pair_mask=pair_mask,
coordination_target=coordination_target,
coordination_std=coordination_std,
angle_target=angle_target,
angle_pair_mass_target=angle_pair_mass_target,
angle_mode_deg=angle_mode_deg,
angle_enabled_mask=angle_enabled_mask,
motif_center_species=np.asarray(motif_center_species, dtype=np.intp),
motif_neighbor_species=tuple(motif_neighbor_species),
motif_neighbor_vectors=tuple(motif_neighbor_vectors),
max_pair_outer=max_pair_outer,
max_pair_outer_by_center=max_pair_outer_by_center,
summary=summary,
)
[docs]
@classmethod
def from_targets(
cls,
targets: "dict[str, CoordinationShellTarget]",
*,
cross_pair_peak: "dict[tuple[str, str], float] | None" = None,
cross_pair_outer_scale: float = 1.15,
label: str | None = None,
) -> "CoordinationShellTarget":
"""Stack multiple shell targets into one with a widened species axis.
Used for blended materials where atoms share an atomic number but
want different local coordination (e.g. graphite sp\u00b2 + diamond
sp\u00b3 carbon). Each input target contributes a *virtual species
slot* per element of its ``species`` array; the composite target's
``species`` is the concatenation of all inputs, with
``species_labels`` rewritten as ``f"{key}_{element}"``.
Cross-target pairs default to:
- ``coordination_target = 0`` (no bonds form across virtual
species boundaries; the repulsion term still keeps them apart),
- ``pair_peak = mean(peak_a, peak_b)`` unless overridden by
``cross_pair_peak``,
- ``pair_outer = max(outer_a, outer_b) * cross_pair_outer_scale``,
- ``pair_hard_min`` / ``pair_inner`` pro-rated from the two
source values.
Cross-target triplets (any of the three species drawn from a
different source than the other two) get
``coordination_target = 0`` and zero ``angle_mode_deg`` - the
relaxer will never enumerate these triplets because no such
bonds form.
Parameters
----------
targets
Mapping ``{key: CoordinationShellTarget}``. Insertion order
defines the virtual-species order.
cross_pair_peak
Optional overrides for ``pair_peak`` between elements drawn
from different source targets. Keys are ``(key_a, key_b)``
with symbol lookup done via ``atomic_numbers`` when pair
contains non-tuple element labels.
cross_pair_outer_scale
Multiplier applied to the larger of the two source
``pair_outer`` values when populating cross-target entries.
label
Optional label; defaults to ``"composite(" + keys + ")"``.
"""
from dataclasses import replace as _dc_replace # noqa: F401
if len(targets) == 0:
raise ValueError("from_targets requires at least one target.")
keys = list(targets.keys())
first = targets[keys[0]]
phi_num_bins = int(first.phi_num_bins)
for key in keys[1:]:
t = targets[key]
if int(t.phi_num_bins) != phi_num_bins:
raise ValueError(
f"All targets must share phi_num_bins; got "
f"{phi_num_bins} (from {keys[0]!r}) and "
f"{int(t.phi_num_bins)} (from {key!r})."
)
# Per-source species counts and global offsets.
src_species_counts = [int(np.asarray(targets[k].species).size) for k in keys]
offsets = np.zeros(len(keys) + 1, dtype=np.intp)
offsets[1:] = np.cumsum(src_species_counts)
num_species = int(offsets[-1])
# Track which source (key index) each global species belongs to.
species_source = np.zeros(num_species, dtype=np.intp)
for ki, count in enumerate(src_species_counts):
species_source[offsets[ki] : offsets[ki + 1]] = ki
# --- concat species + labels ---
species = np.zeros(num_species, dtype=np.int64)
species_labels: list[str] = []
for ki, key in enumerate(keys):
t = targets[key]
species[offsets[ki] : offsets[ki + 1]] = np.asarray(t.species, dtype=np.int64)
for sym in t.species_labels:
species_labels.append(f"{key}_{sym}")
# --- block-diagonal pair-array scaffold ---
pair_hard_min = np.zeros((num_species, num_species), dtype=np.float64)
pair_inner = np.zeros_like(pair_hard_min)
pair_peak = np.zeros_like(pair_hard_min)
pair_sigma = np.zeros_like(pair_hard_min)
pair_outer = np.zeros_like(pair_hard_min)
pair_mask = np.zeros_like(pair_hard_min, dtype=bool)
coordination_target = np.zeros_like(pair_hard_min)
coordination_std = np.zeros_like(pair_hard_min)
for ki, key in enumerate(keys):
t = targets[key]
a, b = int(offsets[ki]), int(offsets[ki + 1])
pair_hard_min[a:b, a:b] = np.asarray(t.pair_hard_min, dtype=np.float64)
pair_inner[a:b, a:b] = np.asarray(t.pair_inner, dtype=np.float64)
pair_peak[a:b, a:b] = np.asarray(t.pair_peak, dtype=np.float64)
pair_sigma[a:b, a:b] = np.asarray(t.pair_sigma, dtype=np.float64)
pair_outer[a:b, a:b] = np.asarray(t.pair_outer, dtype=np.float64)
pair_mask[a:b, a:b] = np.asarray(t.pair_mask, dtype=bool)
coordination_target[a:b, a:b] = np.asarray(t.coordination_target, dtype=np.float64)
coordination_std[a:b, a:b] = np.asarray(t.coordination_std, dtype=np.float64)
# --- cross-target pair entries (repulsion-only by default) ---
cross_peak_lookup: dict[tuple[str, str], float] = {}
if cross_pair_peak is not None:
for (ka, kb), v in cross_pair_peak.items():
cross_peak_lookup[(ka, kb)] = float(v)
cross_peak_lookup[(kb, ka)] = float(v)
for i in range(num_species):
for j in range(num_species):
if species_source[i] == species_source[j]:
continue
# Use the source's own self-pair as a proxy for the
# same-element repulsion wall on each side.
ki, kj = int(species_source[i]), int(species_source[j])
key_a, key_b = keys[ki], keys[kj]
peak_a = float(pair_peak[i, i])
peak_b = float(pair_peak[j, j])
inner_a = float(pair_inner[i, i])
inner_b = float(pair_inner[j, j])
outer_a = float(pair_outer[i, i])
outer_b = float(pair_outer[j, j])
hmin_a = float(pair_hard_min[i, i])
hmin_b = float(pair_hard_min[j, j])
sig_a = float(pair_sigma[i, i])
sig_b = float(pair_sigma[j, j])
if (key_a, key_b) in cross_peak_lookup:
peak_ij = cross_peak_lookup[(key_a, key_b)]
elif peak_a > 0 and peak_b > 0:
peak_ij = 0.5 * (peak_a + peak_b)
else:
peak_ij = max(peak_a, peak_b)
pair_peak[i, j] = peak_ij
pair_inner[i, j] = 0.5 * (inner_a + inner_b)
pair_outer[i, j] = max(outer_a, outer_b) * float(cross_pair_outer_scale)
pair_hard_min[i, j] = max(hmin_a, hmin_b)
pair_sigma[i, j] = max(sig_a, sig_b) if (sig_a + sig_b) > 0 else 0.05
pair_mask[i, j] = True
# coordination_target[i, j] stays 0 - no cross bonds.
# --- rebuild triplet index over the widened species set ---
angle_index_list: list[tuple[int, int, int]] = []
angle_lookup_new = -np.ones((num_species, num_species, num_species), dtype=np.intp)
for c in range(num_species):
for n1 in range(num_species):
for n2 in range(n1, num_species):
idx = len(angle_index_list)
angle_index_list.append((c, n1, n2))
angle_lookup_new[c, n1, n2] = idx
angle_lookup_new[c, n2, n1] = idx
angle_index_new = np.asarray(angle_index_list, dtype=np.intp)
angle_target = np.zeros((angle_index_new.shape[0], phi_num_bins), dtype=np.float64)
angle_pair_mass_target = np.zeros(angle_index_new.shape[0], dtype=np.float64)
angle_mode_deg = np.zeros(angle_index_new.shape[0], dtype=np.float64)
# Default: True everywhere; cross-source triplets will never fire
# anyway because their coordination_target is zero.
angle_enabled_mask = np.ones(angle_index_new.shape[0], dtype=bool)
# Copy triplets where all three species come from the same source.
for ki, key in enumerate(keys):
t = targets[key]
a = int(offsets[ki])
src_angle_index = np.asarray(t.angle_index, dtype=np.intp)
src_angle_lookup = np.asarray(t.angle_lookup, dtype=np.intp)
src_angle_target = np.asarray(t.angle_target, dtype=np.float64)
src_angle_mass = np.asarray(t.angle_pair_mass_target, dtype=np.float64)
src_angle_mode = np.asarray(t.angle_mode_deg, dtype=np.float64)
src_angle_mask = np.asarray(
getattr(t, "angle_enabled_mask",
np.ones(src_angle_index.shape[0], dtype=bool)),
dtype=bool,
)
for local_t, (lc, ln1, ln2) in enumerate(src_angle_index):
gc = int(a + lc)
gn1 = int(a + ln1)
gn2 = int(a + ln2)
new_t = int(angle_lookup_new[gc, gn1, gn2])
angle_target[new_t] = src_angle_target[local_t]
angle_pair_mass_target[new_t] = float(src_angle_mass[local_t])
angle_mode_deg[new_t] = float(src_angle_mode[local_t])
angle_enabled_mask[new_t] = bool(src_angle_mask[local_t])
del src_angle_lookup # silence unused
# --- concatenate motif arrays with species remapping ---
motif_center_species_list: list[int] = []
motif_neighbor_species_list: list[np.ndarray] = []
motif_neighbor_vectors_list: list[np.ndarray] = []
for ki, key in enumerate(keys):
t = targets[key]
a = int(offsets[ki])
src_center = np.asarray(t.motif_center_species, dtype=np.intp)
for local_i, cs in enumerate(src_center):
motif_center_species_list.append(int(a + int(cs)))
ns = np.asarray(t.motif_neighbor_species[local_i], dtype=np.intp) + a
vs = np.asarray(t.motif_neighbor_vectors[local_i], dtype=np.float64)
motif_neighbor_species_list.append(ns)
motif_neighbor_vectors_list.append(vs)
max_pair_outer = float(np.max(pair_outer[pair_mask])) if np.any(pair_mask) else 0.0
max_pair_outer_by_center = np.zeros(num_species, dtype=np.float64)
for center_ind in range(num_species):
row = pair_outer[center_ind][pair_mask[center_ind]]
max_pair_outer_by_center[center_ind] = float(np.max(row)) if row.size else 0.0
phi_edges = np.asarray(first.phi_edges, dtype=np.float64)
phi = np.asarray(first.phi, dtype=np.float64)
phi_deg = np.asarray(first.phi_deg, dtype=np.float64)
summary = {
"composite": True,
"keys": tuple(keys),
"source_species_counts": tuple(int(c) for c in src_species_counts),
"phi_num_bins": phi_num_bins,
"num_species": num_species,
"num_triplets": int(angle_index_new.shape[0]),
"max_pair_outer": float(max_pair_outer),
}
composite_label = label or f"composite({'+'.join(keys)})"
return cls(
atoms=first.atoms, # placeholder; not used by shell_relax
label=composite_label,
species=species,
species_labels=tuple(species_labels),
phi_num_bins=phi_num_bins,
phi_edges=phi_edges,
phi=phi,
phi_deg=phi_deg,
angle_index=angle_index_new,
angle_lookup=angle_lookup_new,
pair_hard_min=pair_hard_min,
pair_inner=pair_inner,
pair_peak=pair_peak,
pair_sigma=pair_sigma,
pair_outer=pair_outer,
pair_mask=pair_mask,
coordination_target=coordination_target,
coordination_std=coordination_std,
angle_target=angle_target,
angle_pair_mass_target=angle_pair_mass_target,
angle_mode_deg=angle_mode_deg,
angle_enabled_mask=angle_enabled_mask,
motif_center_species=np.asarray(motif_center_species_list, dtype=np.intp),
motif_neighbor_species=tuple(motif_neighbor_species_list),
motif_neighbor_vectors=tuple(motif_neighbor_vectors_list),
max_pair_outer=max_pair_outer,
max_pair_outer_by_center=max_pair_outer_by_center,
summary=summary,
)
[docs]
def with_bonded_species_pairs(
self,
pairs: "list[tuple[str, str]]",
) -> "CoordinationShellTarget":
"""Return a copy whose ``coordination_target`` is zero everywhere except for the listed species pairs.
Useful for materials with spectator ions: perovskites like
SrTiO\u2083 want only Ti-O bonds considered by
:meth:`Supercell.shell_relax`, since Sr-O, Sr-Ti, O-O, Ti-Ti, etc.
would either install spurious angle springs (``angle_mode_deg``
is a geometric artefact for non-bond triplets) or pin atoms via
bond springs to distances that are really second-shell
separations, not chemical bonds.
Parameters
----------
pairs : list of tuple of str
Each ``(symbol_a, symbol_b)`` pair is treated symmetrically
\u2014 both directions in the ``coordination_target`` matrix
are preserved. Pairs whose symbols don't appear in
``self.species_labels`` are silently skipped.
Returns
-------
CoordinationShellTarget
A new ``CoordinationShellTarget`` (the original is left
unmodified) whose ``coordination_target`` keeps only the
listed species-pair entries; every other slot is zeroed
so :meth:`Supercell.shell_relax` won't try to enforce
bonds there.
Examples
--------
.. code-block:: python
# SrTiO3: preserve only TiO6 octahedra
st.with_bonded_species_pairs([('Ti', 'O')])
# SiO2: equivalent to ``with_cross_species_bonds_only`` for a
# binary, but explicit about what a bond is
st.with_bonded_species_pairs([('Si', 'O')])
"""
from dataclasses import replace as _dc_replace
from ase.data import atomic_numbers as _an
ct = np.zeros_like(np.asarray(self.coordination_target, dtype=np.float64))
sp = np.asarray(self.species, dtype=np.int64)
orig = np.asarray(self.coordination_target, dtype=np.float64)
for sa, sb in pairs:
za = int(_an[sa])
zb = int(_an[sb])
ia_arr = np.where(sp == za)[0]
ib_arr = np.where(sp == zb)[0]
if ia_arr.size == 0 or ib_arr.size == 0:
continue
ia, ib = int(ia_arr[0]), int(ib_arr[0])
ct[ia, ib] = orig[ia, ib]
ct[ib, ia] = orig[ib, ia]
return _dc_replace(self, coordination_target=ct)
[docs]
def with_cross_species_bonds_only(self) -> "CoordinationShellTarget":
"""Return a copy where same-species ``coordination_target`` entries are zeroed.
Useful for network-former compounds such as SiO\u2082 where only
cross-species pairs (Si-O) are real chemical bonds; the same-species
"shell" peaks (Si-Si, O-O) come from the second coordination shell
through the bridging atom and should not be treated as bonds by
:meth:`Supercell.shell_relax` (which would otherwise install
spurious angle springs on triplets like Si-Si-Si or O-O-O whose
``angle_mode_deg`` is just a geometric artefact of the reference
sampling, not a physical target).
Returns
-------
CoordinationShellTarget
New target whose ``coordination_target`` diagonal is
zeroed (off-diagonal cross-species entries preserved).
The original target is unmodified.
"""
from dataclasses import replace as _dc_replace
ct = np.asarray(self.coordination_target, dtype=np.float64).copy()
for i in range(ct.shape[0]):
ct[i, i] = 0.0
return _dc_replace(self, coordination_target=ct)
[docs]
def with_angle_triplets(
self,
triplets: "list[tuple[str, str, str]]",
) -> "CoordinationShellTarget":
"""Return a copy whose angle-spring mask is enabled *only* for
the listed triplet types; all other angle springs are disabled.
Each triplet is ``(centre_symbol, neighbour_1_symbol,
neighbour_2_symbol)``; both (n1, n2) and (n2, n1) are enabled
automatically. Bond-distance springs are untouched - only the
angle springs installed during ``shell_relax`` are filtered.
Useful for multi-modal shells where the extracted
``angle_mode_deg`` picks one peak of a bimodal / quadrimodal
distribution; enforcing it would strain the other modes.
SrTiO\u2083's SrO\u2081\u2082 cuboctahedron (O-Sr-O angles at
60°/90°/120°/180°) is the canonical example.
Parameters
----------
triplets : list of tuple of str
Each ``(centre, n1, n2)`` triplet enables the angle-spring
term for that combination of species. Ordering of ``n1``
and ``n2`` is symmetrised internally.
Returns
-------
CoordinationShellTarget
New target whose ``angle_enabled_mask`` is ``True`` only
for the listed triplets; every other triplet's angle
spring is silenced.
Examples
--------
.. code-block:: python
# Keep Ti-centered 90° and linear O-Ti-Ti 180° angle
# springs; silence every Sr-centered or Sr-in-triplet
# angle spring.
st.with_angle_triplets([
('Ti', 'O', 'O'),
('O', 'Ti', 'Ti'),
])
"""
from dataclasses import replace as _dc_replace
from ase.data import atomic_numbers as _an
sp = np.asarray(self.species, dtype=np.int64)
ai = np.asarray(self.angle_index, dtype=np.intp)
mask = np.zeros(ai.shape[0], dtype=bool)
def _species_slots(sym: str) -> np.ndarray:
return np.where(sp == int(_an[sym]))[0]
for centre_sym, n1_sym, n2_sym in triplets:
c_idx = _species_slots(centre_sym)
n1_idx = _species_slots(n1_sym)
n2_idx = _species_slots(n2_sym)
if c_idx.size == 0 or n1_idx.size == 0 or n2_idx.size == 0:
continue
for ci in c_idx:
for a_i in n1_idx:
for b_i in n2_idx:
# Canonical order: neigh_1 <= neigh_2.
lo, hi = (int(a_i), int(b_i)) if a_i <= b_i else (int(b_i), int(a_i))
t = int(self.angle_lookup[int(ci), lo, hi])
if t >= 0:
mask[t] = True
return _dc_replace(self, angle_enabled_mask=mask)
[docs]
def without_angle_triplets(
self,
triplets: "list[tuple[str, str, str]]",
) -> "CoordinationShellTarget":
"""Return a copy with the angle mask disabled for the listed triplets.
Inverse of :meth:`with_angle_triplets`: starts from the
current ``angle_enabled_mask`` and turns OFF the listed
triplets, leaving every other triplet's angle spring intact.
Parameters
----------
triplets : list of tuple of str
Each ``(centre, n1, n2)`` triplet disables the angle
spring for that species combination.
Returns
-------
CoordinationShellTarget
New target whose ``angle_enabled_mask`` matches the
original except the listed triplets are now ``False``.
"""
from dataclasses import replace as _dc_replace
from ase.data import atomic_numbers as _an
sp = np.asarray(self.species, dtype=np.int64)
mask = np.asarray(self.angle_enabled_mask, dtype=bool).copy()
def _species_slots(sym: str) -> np.ndarray:
return np.where(sp == int(_an[sym]))[0]
for centre_sym, n1_sym, n2_sym in triplets:
c_idx = _species_slots(centre_sym)
n1_idx = _species_slots(n1_sym)
n2_idx = _species_slots(n2_sym)
for ci in c_idx:
for a_i in n1_idx:
for b_i in n2_idx:
lo, hi = (int(a_i), int(b_i)) if a_i <= b_i else (int(b_i), int(a_i))
t = int(self.angle_lookup[int(ci), lo, hi])
if t >= 0:
mask[t] = False
return _dc_replace(self, angle_enabled_mask=mask)
@property
def pair_labels(self) -> list[str]:
"""Human-readable species-pair labels for every present pair.
Returns
-------
list of str
One ``"<centre>-<neighbour>"`` string per ``(centre,
neighbour)`` slot in ``self.pair_mask`` that is ``True``.
Order matches the row-major flatten of the species table
(centre, neighbour) — useful for legend labels in
multi-pair g(r) plots.
"""
labels = []
for center_ind, center_label in enumerate(self.species_labels):
for neigh_ind, neigh_label in enumerate(self.species_labels):
if self.pair_mask[center_ind, neigh_ind]:
labels.append(f"{center_label}-{neigh_label}")
return labels
@property
def angle_labels(self) -> list[str]:
"""Human-readable rooted-angle labels for every triplet channel.
Returns
-------
list of str
One ``"<n1>-<centre>-<n2>"`` string per row of
``self.angle_index``. Useful for labelling angle-channel
histograms or filtering the per-triplet output of
:meth:`Supercell.measure_g3`.
"""
labels = []
for center_ind, neigh1_ind, neigh2_ind in self.angle_index:
labels.append(
f"{self.species_labels[neigh1_ind]}-{self.species_labels[center_ind]}-{self.species_labels[neigh2_ind]}"
)
return labels