Source code for tricor._plotting

from __future__ import annotations

from pathlib import Path as _Path
from typing import TYPE_CHECKING

import numpy as np
from ase.neighborlist import neighbor_list

from .g3 import _EPS, _TextProgressBar

if TYPE_CHECKING:
    from .shells import CoordinationShellTarget
    from .supercell import Supercell


_STATIC_DIR = _Path(__file__).parent / "static"
_TRAJECTORY_HTML_TEMPLATE = (_STATIC_DIR / "trajectory_viewer.html").read_text()
_G3_HTML_TEMPLATE = (_STATIC_DIR / "g3_viewer.html").read_text()
_G2_HTML_TEMPLATE = (_STATIC_DIR / "g2_viewer.html").read_text()
_OVERVIEW_HTML_TEMPLATE = (_STATIC_DIR / "overview_viewer.html").read_text()


def _detect_tetrahedra(
    atoms,
    *,
    center_symbol: str = "Si",
    vertex_symbol: str = "O",
    bond_length: float | None = None,
    bond_length_tol: float = 0.15,
    ideal_angle_deg: float = 109.47,
    angle_tol_deg: float = 25.0,
    center_species_filter: "np.ndarray | None" = None,
) -> list[dict]:
    """Find center atoms whose 4 nearest vertex atoms form a tetrahedron.

    Returns a list of dicts with keys ``center`` (int) and ``vertices``
    (list of 4 ints), in order of increasing center-to-vertex distance.
    Only tetrahedra whose 4 bond lengths are within
    ``bond_length_tol`` of ``bond_length`` AND whose 6 pairwise
    (vertex)-(center)-(vertex) angles are within ``angle_tol_deg`` of
    ``ideal_angle_deg`` are returned.

    When ``bond_length`` is ``None`` the median of all center-vertex
    distances inside 3.5 Å is used as an estimate.
    """
    from ase.data import atomic_numbers

    Zc = atomic_numbers[center_symbol]
    Zv = atomic_numbers[vertex_symbol]
    numbers = np.asarray(atoms.numbers)

    if bond_length is None:
        # Use the 10th percentile of center-vertex distances so the
        # "bond length" sits near the first-neighbour peak even in
        # disordered samples where many longer pairs exist.
        bi, bj, bd = neighbor_list("ijd", atoms, 3.5)
        mask = (numbers[bi] == Zc) & (numbers[bj] == Zv)
        if not np.any(mask):
            return []
        bond_length = float(np.percentile(bd[mask], 10))

    cutoff = float(bond_length) * (1.0 + float(bond_length_tol)) * 1.05
    bi_all, bj_all, bd_all, bD_all = neighbor_list("ijdD", atoms, float(cutoff))

    keep = (
        (numbers[bi_all] == Zc)
        & (numbers[bj_all] == Zv)
        & (bd_all >= bond_length * (1.0 - bond_length_tol))
        & (bd_all <= bond_length * (1.0 + bond_length_tol))
    )
    if center_species_filter is not None:
        filt = np.asarray(center_species_filter, dtype=bool)
        keep &= filt[bi_all]
    if not np.any(keep):
        return []
    bi = bi_all[keep]
    bj = bj_all[keep]
    bd = bd_all[keep]
    bD = bD_all[keep]

    order = np.lexsort((bd, bi))
    bi_s = bi[order]
    bj_s = bj[order]
    bD_s = bD[order]

    ideal_rad = float(np.deg2rad(ideal_angle_deg))
    tol_rad = float(np.deg2rad(angle_tol_deg))

    unique_i, start_idx = np.unique(bi_s, return_index=True)
    end_idx = np.concatenate([start_idx[1:], [bi_s.size]])

    tetrahedra: list[dict] = []
    for u, s, e in zip(unique_i, start_idx, end_idx):
        if int(e - s) < 4:
            continue
        js = bj_s[s : s + 4]
        vs = bD_s[s : s + 4]
        norms = np.linalg.norm(vs, axis=1)
        if np.any(norms < 1e-6):
            continue
        unit = vs / norms[:, None]
        cos_ab = np.clip(unit @ unit.T, -1.0, 1.0)
        angles = np.arccos(cos_ab)
        triu = np.triu_indices(4, k=1)
        if float(np.max(np.abs(angles[triu] - ideal_rad))) <= tol_rad:
            tetrahedra.append(
                {
                    "center": int(u),
                    "vertices": [int(j) for j in js],
                }
            )
    return tetrahedra


def _detect_triangles(
    atoms,
    *,
    center_symbol: str = "C",
    vertex_symbol: str = "C",
    bond_length: float | None = None,
    bond_length_tol: float = 0.15,
    ideal_angle_deg: float = 120.0,
    angle_tol_deg: float = 18.0,
    center_species_filter: "np.ndarray | None" = None,
) -> list[dict]:
    """Find centres whose 3 nearest neighbours form a trigonal planar motif.

    Mirrors :func:`_detect_tetrahedra` but enforces:

    - exactly 3 neighbours within the bond-length window,
    - all 3 pairwise (vertex)-(centre)-(vertex) angles within
      ``angle_tol_deg`` of ``ideal_angle_deg`` (120° for sp\u00b2
      graphene, 109.5° is NOT what you want here).

    The three-120°-angles constraint implies coplanarity - no
    separate planarity test is needed.

    ``center_species_filter`` optionally restricts detection to
    specific atom indices (e.g. the subset flagged as sp\u00b2 by
    ``Supercell._atom_shell_species_index``).
    """
    from ase.data import atomic_numbers

    Zc = atomic_numbers[center_symbol]
    Zv = atomic_numbers[vertex_symbol]
    numbers = np.asarray(atoms.numbers)

    if bond_length is None:
        bi, bj, bd = neighbor_list("ijd", atoms, 3.5)
        mask = (numbers[bi] == Zc) & (numbers[bj] == Zv)
        if not np.any(mask):
            return []
        bond_length = float(np.percentile(bd[mask], 10))

    cutoff = float(bond_length) * (1.0 + float(bond_length_tol)) * 1.05
    bi_all, bj_all, bd_all, bD_all = neighbor_list("ijdD", atoms, float(cutoff))

    keep = (
        (numbers[bi_all] == Zc)
        & (numbers[bj_all] == Zv)
        & (bd_all >= bond_length * (1.0 - bond_length_tol))
        & (bd_all <= bond_length * (1.0 + bond_length_tol))
    )
    if center_species_filter is not None:
        filt = np.asarray(center_species_filter, dtype=bool)
        keep &= filt[bi_all]
    if not np.any(keep):
        return []
    bi = bi_all[keep]
    bj = bj_all[keep]
    bd = bd_all[keep]
    bD = bD_all[keep]

    order = np.lexsort((bd, bi))
    bi_s = bi[order]
    bj_s = bj[order]
    bD_s = bD[order]

    ideal_rad = float(np.deg2rad(ideal_angle_deg))
    tol_rad = float(np.deg2rad(angle_tol_deg))

    unique_i, start_idx = np.unique(bi_s, return_index=True)
    end_idx = np.concatenate([start_idx[1:], [bi_s.size]])

    triangles: list[dict] = []
    for u, s, e in zip(unique_i, start_idx, end_idx):
        if int(e - s) < 3:
            continue
        js = bj_s[s : s + 3]
        vs = bD_s[s : s + 3]
        norms = np.linalg.norm(vs, axis=1)
        if np.any(norms < 1e-6):
            continue
        unit = vs / norms[:, None]
        cos_ab = np.clip(unit @ unit.T, -1.0, 1.0)
        angles = np.arccos(cos_ab)
        triu = np.triu_indices(3, k=1)
        if float(np.max(np.abs(angles[triu] - ideal_rad))) <= tol_rad:
            # Emit 4 vertices: [centre, j, k, l].  _TRI_FACES then
            # fans three sub-triangles from the centre (vertex 0) to
            # each pair of consecutive neighbours so the rendered mesh
            # always anchors to the parent atom.
            triangles.append(
                {
                    "center": int(u),
                    "vertices": [int(u)] + [int(j) for j in js],
                }
            )
    return triangles


def _tetrahedra_vertex_coords(
    tetrahedra: list[dict],
    positions: np.ndarray,
    cell_matrix: np.ndarray,
    scale: float = 1.0,
) -> list[float]:
    """Flat list of vertex positions (min-image wrt centre, box-centred).

    When ``scale < 1`` the polyhedron is shrunk toward its centre -
    ``scale=0.5`` puts the rendered vertices exactly at the midpoints
    of the centre-vertex bonds, which is what we want for single-
    element polyhedra (Si tets, Cu cuboctahedra) where the vertex
    atoms are the same species as the centre.
    """
    if not tetrahedra:
        return []
    cell = np.asarray(cell_matrix, dtype=np.float64)
    cell_inv = np.linalg.inv(cell)
    centre = 0.5 * cell.sum(axis=0)
    out = []
    for t in tetrahedra:
        c_pos = positions[t["center"]]
        for v_idx in t["vertices"]:
            v_pos = positions[v_idx]
            disp = v_pos - c_pos
            frac = disp @ cell_inv
            frac -= np.round(frac)
            disp_w = frac @ cell
            adj = c_pos + float(scale) * disp_w - centre
            out.extend(float(x) for x in adj)
    return out


# _polyhedra_vertex_coords is the generalised name; kept as an alias so
# the rest of the module (and external callers) can use either form.
_polyhedra_vertex_coords = _tetrahedra_vertex_coords


# Face / edge topology tables used by the viewers to build translucent
# polyhedra meshes.  Indices refer into the per-polyhedron vertex list
# produced by ``_detect_tetrahedra`` / ``_detect_octahedra``.
_TET_FACES = [[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]]
_TET_EDGES = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
# Triangle motif (sp², 4 vertices).  Vertex 0 is the centre atom i,
# vertices 1-3 are the three bonded neighbours j, k, l.  Three
# sub-triangles all share the centre: (i,j,k), (i,k,l), (i,l,j), so
# the rendered face always touches the parent atom even when the
# bonds are at unequal scales after relaxation.
_TRI_FACES = [[0, 1, 2], [0, 2, 3], [0, 3, 1]]
_TRI_EDGES = [[0, 1], [0, 2], [0, 3], [1, 2], [2, 3], [3, 1]]
# Octahedra: ``_detect_octahedra`` orders vertices so (0,1), (2,3),
# (4,5) are antipodal; the 8 faces are all triples that take one vertex
# from each antipodal pair, the 12 edges are every pair *except* the
# three antipodal ones.
_OCT_FACES = [
    [0, 2, 4], [0, 2, 5], [0, 3, 4], [0, 3, 5],
    [1, 2, 4], [1, 2, 5], [1, 3, 4], [1, 3, 5],
]
_OCT_EDGES = [
    [0, 2], [0, 3], [0, 4], [0, 5],
    [1, 2], [1, 3], [1, 4], [1, 5],
    [2, 4], [2, 5], [3, 4], [3, 5],
]


def _resolve_group_cfg(entry: dict) -> dict:
    """Normalise one polyhedra_groups entry into the cfg dict shape.

    Each entry is of the form
    ``{"kind": "triangles"|"tetrahedra"|"octahedra"|"cuboctahedra",
       "center_symbol", "vertex_symbol", ..., "virtual_species",
       "color", "opacity"}``.
    """
    kind = entry.get("kind", "tetrahedra")
    # Slot into the one-of-four branches of _resolve_polyhedra_cfg.
    args = dict(triangles=None, tetrahedra=None, octahedra=None, cuboctahedra=None)
    args[kind] = dict(entry)
    return _resolve_polyhedra_cfg(**args)


def _render_polyhedra_group(
    atoms,
    cfg: dict,
    cell_obj=None,
    *,
    default_color=(0.25, 0.65, 0.95),
    default_opacity: float = 0.35,
) -> dict:
    """Run the detector for ``cfg`` and produce a JSON-ready group dict.

    ``cell_obj`` is the owning :class:`Supercell`; when set the
    function will:

    - use ``cell_obj._shell_target.pair_peak`` to auto-resolve
      ``bond_length`` if not given;
    - apply ``cfg['virtual_species']`` to the detector's
      ``center_species_filter`` argument using
      ``cell_obj._atom_shell_species_index``.
    """
    eff_bond_length = cfg["bond_length"]
    if eff_bond_length is None and cell_obj is not None:
        from ase.data import atomic_numbers as _an
        st = getattr(cell_obj, "_shell_target", None)
        if st is not None:
            sp = np.asarray(st.species, dtype=np.int64)
            try:
                i = int(np.where(sp == _an[cfg["center_symbol"]])[0][0])
                j = int(np.where(sp == _an[cfg["vertex_symbol"]])[0][0])
                v = float(np.asarray(st.pair_peak, dtype=np.float64)[i, j])
                if v > 0:
                    eff_bond_length = v
            except (IndexError, KeyError):
                pass

    species_filter = None
    vsp = cfg.get("virtual_species")
    if vsp is not None and cell_obj is not None:
        asi = getattr(cell_obj, "_atom_shell_species_index", None)
        if asi is not None:
            species_filter = (np.asarray(asi, dtype=np.intp) == int(vsp))

    detector_kwargs = dict(
        center_symbol=cfg["center_symbol"],
        vertex_symbol=cfg["vertex_symbol"],
        bond_length=eff_bond_length,
        bond_length_tol=cfg["bond_length_tol"],
        ideal_angle_deg=cfg["ideal_angle_deg"],
        angle_tol_deg=cfg["angle_tol_deg"],
    )
    # Only the newer detectors accept the filter kwarg; test by name.
    try:
        polys = cfg["detector"](
            atoms,
            center_species_filter=species_filter,
            **detector_kwargs,
        )
    except TypeError:
        polys = cfg["detector"](atoms, **detector_kwargs)

    vertices_flat = _polyhedra_vertex_coords(
        polys, atoms.positions, atoms.cell.array, scale=cfg["scale"]
    )
    vertices_flat = [round(v, 3) for v in vertices_flat]

    if cfg["per_polyhedron_topology"]:
        per_poly_faces = [p["faces"] for p in polys]
        per_poly_edges = [p["edges"] for p in polys]
    else:
        per_poly_faces = []
        per_poly_edges = []

    color = cfg.get("color") or default_color
    opacity = cfg.get("opacity")
    if opacity is None:
        opacity = default_opacity

    return {
        "n_vertices": int(cfg["n_vertices"]),
        "vertices": vertices_flat,
        "num": int(len(polys)),
        "faces": cfg["faces"] if cfg["faces"] is not None else _TET_FACES,
        "edges": cfg["edges"] if cfg["edges"] is not None else _TET_EDGES,
        "color": list(color),
        "opacity": float(opacity),
        "per_polyhedron_topology": bool(cfg["per_polyhedron_topology"]),
        "polyhedra_faces_per_poly": per_poly_faces,
        "polyhedra_edges_per_poly": per_poly_edges,
    }


def _resolve_polyhedra_cfg(
    tetrahedra: "dict | None",
    octahedra: "dict | None",
    cuboctahedra: "dict | None" = None,
    triangles: "dict | None" = None,
) -> "dict | None":
    """Normalise user-provided polyhedra dicts into a single config.

    At most one of ``tetrahedra`` / ``octahedra`` / ``cuboctahedra`` may
    be provided.  Returns ``None`` when none are given.  The returned
    dict carries:

    - ``center_symbol``, ``vertex_symbol``, ``bond_length``,
      ``bond_length_tol``, ``ideal_angle_deg``, ``angle_tol_deg`` -
      detector-level config
    - ``scale`` - render-time polyhedron shrinkage (1.0 = vertices at
      actual atom positions, 0.5 = vertices at bond midpoints; handy for
      single-element polyhedra where corners otherwise sit directly on
      neighbouring atoms)
    - ``n_vertices`` (4 / 6 / 12)
    - ``faces``, ``edges`` - shared face / edge tables (``None`` when
      topology is per-polyhedron, e.g. cuboctahedra, in which case the
      detector embeds its own tables)
    - ``detector``, ``per_polyhedron_topology`` (bool)
    """
    provided = [x for x in (tetrahedra, octahedra, cuboctahedra, triangles) if x is not None]
    if len(provided) > 1:
        raise ValueError(
            "Pass at most one of 'tetrahedra', 'octahedra', 'cuboctahedra', 'triangles'.",
        )
    per_poly = False
    if triangles is not None:
        cfg = triangles
        # 4 vertices: [centre, j, k, l] - see _TRI_FACES comment.
        n_vertices = 4
        faces, edges = _TRI_FACES, _TRI_EDGES
        detector = _detect_triangles
        ideal_default = 120.0
        angle_tol_default = 18.0
        scale_default = 0.5  # outer vertices at bond midpoints
    elif tetrahedra is not None:
        cfg = tetrahedra
        n_vertices = 4
        faces, edges = _TET_FACES, _TET_EDGES
        detector = _detect_tetrahedra
        ideal_default = 109.47
        angle_tol_default = 25.0
        # Single-element tetrahedra (e.g. Si-Si-Si in diamond): vertices
        # sit directly on neighbour atoms when scale=1.0, which makes
        # the polyhedra visually overlap.  Default to bond-midpoint
        # vertices for cleaner rendering.  Multi-element tetrahedra
        # (e.g. SiO₄: Si centre, O vertices) keep scale=1.0 so the
        # vertices land on the actual O atoms.
        same_element = cfg.get("center_symbol") == cfg.get("vertex_symbol") \
            and cfg.get("center_symbol") is not None
        scale_default = 0.5 if same_element else 1.0
    elif octahedra is not None:
        cfg = octahedra
        n_vertices = 6
        faces, edges = _OCT_FACES, _OCT_EDGES
        detector = _detect_octahedra
        ideal_default = 90.0
        angle_tol_default = 18.0
        # Same single-element logic as tetrahedra.
        same_element = cfg.get("center_symbol") == cfg.get("vertex_symbol") \
            and cfg.get("center_symbol") is not None
        scale_default = 0.5 if same_element else 1.0
    elif cuboctahedra is not None:
        cfg = cuboctahedra
        n_vertices = 12
        faces, edges = None, None  # per-polyhedron (embedded in each dict)
        detector = _detect_cuboctahedra
        ideal_default = 60.0   # nominal nearest-neighbour angle
        angle_tol_default = 18.0
        scale_default = 0.5     # default to half-size for single-element
        per_poly = True
    else:
        return None
    return {
        "center_symbol": cfg.get("center_symbol", "Si"),
        "vertex_symbol": cfg.get("vertex_symbol", "O"),
        "bond_length": cfg.get("bond_length"),
        "bond_length_tol": float(cfg.get("bond_length_tol", 0.15)),
        "ideal_angle_deg": float(cfg.get("ideal_angle_deg", ideal_default)),
        "angle_tol_deg": float(cfg.get("angle_tol_deg", angle_tol_default)),
        "scale": float(cfg.get("scale", scale_default)),
        "n_vertices": n_vertices,
        "faces": faces,
        "edges": edges,
        "detector": detector,
        "per_polyhedron_topology": per_poly,
        # Optional: restrict detection to atoms whose
        # ``_atom_shell_species_index == virtual_species``.  Used for
        # carbon sp²/sp³ blends where the sp² triangles should only
        # decorate sp² atoms and sp³ tets only sp³ atoms.
        "virtual_species": cfg.get("virtual_species"),
        "color": cfg.get("color"),
        "opacity": cfg.get("opacity"),
    }


def _detect_octahedra(
    atoms,
    *,
    center_symbol: str = "Ti",
    vertex_symbol: str = "O",
    bond_length: float | None = None,
    bond_length_tol: float = 0.18,
    ideal_angle_deg: float = 90.0,
    angle_tol_deg: float = 18.0,
) -> list[dict]:
    """Find center atoms whose 6 nearest vertex atoms form an octahedron.

    Returns a list of dicts ``{center: int, vertices: list[int]}``.  The
    six vertex indices are written so pairs (0, 1), (2, 3), (4, 5) are
    antipodal - i.e. the octahedron is oriented as (±x, ±y, ±z) about
    the center, which the viewer uses to build the 8 triangular faces
    ``[[0,2,4],[0,2,5],[0,3,4],[0,3,5],[1,2,4],[1,2,5],[1,3,4],[1,3,5]]``
    and 12 edges (every pair except the 3 antipodal pairs).

    Acceptance criteria for a candidate centre:

    * Exactly 6 vertex neighbours sit within ``bond_length * (1 \u00b1 tol)``.
    * Among the 15 pairwise centre-vertex unit-vector angles, 3 are
      within ``angle_tol_deg`` of 180\u00b0 (the antipodal pairs) and the
      remaining 12 are within ``angle_tol_deg`` of 90\u00b0.

    When ``bond_length`` is ``None``, the 10th percentile of all
    centre-vertex distances inside 3.5 \u00c5 is used as an estimate.
    """
    from ase.data import atomic_numbers

    Zc = atomic_numbers[center_symbol]
    Zv = atomic_numbers[vertex_symbol]
    numbers = np.asarray(atoms.numbers)

    if bond_length is None:
        bi, bj, bd = neighbor_list("ijd", atoms, 3.5)
        mask = (numbers[bi] == Zc) & (numbers[bj] == Zv)
        if not np.any(mask):
            return []
        bond_length = float(np.percentile(bd[mask], 10))

    cutoff = float(bond_length) * (1.0 + float(bond_length_tol)) * 1.05
    bi_all, bj_all, bd_all, bD_all = neighbor_list("ijdD", atoms, float(cutoff))

    keep = (
        (numbers[bi_all] == Zc)
        & (numbers[bj_all] == Zv)
        & (bd_all >= bond_length * (1.0 - bond_length_tol))
        & (bd_all <= bond_length * (1.0 + bond_length_tol))
    )
    if not np.any(keep):
        return []
    bi = bi_all[keep]
    bj = bj_all[keep]
    bd = bd_all[keep]
    bD = bD_all[keep]

    order = np.lexsort((bd, bi))
    bi_s = bi[order]
    bj_s = bj[order]
    bD_s = bD[order]

    ideal_rad_90 = float(np.deg2rad(ideal_angle_deg))
    ideal_rad_180 = float(np.deg2rad(180.0))
    tol_rad = float(np.deg2rad(angle_tol_deg))

    unique_i, start_idx = np.unique(bi_s, return_index=True)
    end_idx = np.concatenate([start_idx[1:], [bi_s.size]])

    octahedra: list[dict] = []
    for u, s, e in zip(unique_i, start_idx, end_idx):
        if int(e - s) < 6:
            continue
        js = bj_s[s : s + 6]
        vs = bD_s[s : s + 6]
        norms = np.linalg.norm(vs, axis=1)
        if np.any(norms < 1e-6):
            continue
        unit = vs / norms[:, None]
        cos_ab = np.clip(unit @ unit.T, -1.0, 1.0)
        angles = np.arccos(cos_ab)
        triu_i, triu_j = np.triu_indices(6, k=1)
        pair_angles = angles[triu_i, triu_j]

        # Identify the 3 antipodal pairs (angle ~ 180).
        anti_mask = np.abs(pair_angles - ideal_rad_180) <= tol_rad
        if int(np.sum(anti_mask)) != 3:
            continue

        anti_pairs = list(zip(triu_i[anti_mask].tolist(), triu_j[anti_mask].tolist()))
        # Every vertex must appear in exactly one antipodal pair.
        anti_of = [-1] * 6
        ok = True
        for a, b in anti_pairs:
            if anti_of[a] != -1 or anti_of[b] != -1:
                ok = False
                break
            anti_of[a] = b
            anti_of[b] = a
        if not ok or -1 in anti_of:
            continue

        # Remaining 12 angles must all be ~ 90.
        rem_mask = ~anti_mask
        if float(np.max(np.abs(pair_angles[rem_mask] - ideal_rad_90))) > tol_rad:
            continue

        # Re-order so (0,1), (2,3), (4,5) are the antipodal pairs.
        order_local: list[int] = []
        visited = [False] * 6
        for v in range(6):
            if visited[v]:
                continue
            w = anti_of[v]
            order_local.extend([v, w])
            visited[v] = True
            visited[w] = True

        verts = [int(js[k]) for k in order_local]
        octahedra.append({"center": int(u), "vertices": verts})
    return octahedra


def _detect_cuboctahedra(
    atoms,
    *,
    center_symbol: str = "Cu",
    vertex_symbol: str = "Cu",
    bond_length: float | None = None,
    bond_length_tol: float = 0.12,
    ideal_angle_deg: float | None = None,   # unused (API parity)
    angle_tol_deg: float = 22.0,
    distance_tol: float = 0.10,
) -> list[dict]:
    """Find centers with 12 nearest vertex neighbours forming a
    close-packed (FCC) cuboctahedron.

    Acceptance:

    * Exactly 12 vertex neighbours within ``bond_length * (1 \u00b1 tol)``.
    * Max/min of the 12 centre-vertex distances must be within
      ``distance_tol`` of the mean (catches 13th-nearest-neighbour
      interlopers).
    * All 66 pairwise centre-vertex unit-vector angles must lie within
      ``angle_tol_deg`` of one of the four canonical FCC values
      {60\u00b0, 90\u00b0, 120\u00b0, 180\u00b0}.  This is what distinguishes a real
      cuboctahedron from a disordered close-packed shell whose atoms
      happen to fall within the radial tolerance - without this test
      the detector counts amorphous clusters as cuboctahedra because
      they too have ~12 nearest neighbours.

    Each returned dict carries its own ``faces`` (20 triangles, two per
    square face plus the eight corner triangles) and ``edges`` (24
    cuboctahedron edges, square-face diagonals filtered out by length).
    """
    from ase.data import atomic_numbers
    from scipy.spatial import ConvexHull

    Zc = atomic_numbers[center_symbol]
    Zv = atomic_numbers[vertex_symbol]
    numbers = np.asarray(atoms.numbers)

    if bond_length is None:
        bi0, bj0, bd0 = neighbor_list("ijd", atoms, 4.0)
        mask0 = (numbers[bi0] == Zc) & (numbers[bj0] == Zv)
        if not np.any(mask0):
            return []
        bond_length = float(np.percentile(bd0[mask0], 10))

    cutoff = float(bond_length) * (1.0 + float(bond_length_tol)) * 1.05
    bi_all, bj_all, bd_all, bD_all = neighbor_list("ijdD", atoms, float(cutoff))
    keep = (
        (numbers[bi_all] == Zc)
        & (numbers[bj_all] == Zv)
        & (bd_all >= bond_length * (1.0 - bond_length_tol))
        & (bd_all <= bond_length * (1.0 + bond_length_tol))
    )
    if not np.any(keep):
        return []
    bi = bi_all[keep]
    bj = bj_all[keep]
    bd = bd_all[keep]
    bD = bD_all[keep]

    order = np.lexsort((bd, bi))
    bi_s = bi[order]
    bj_s = bj[order]
    bd_s = bd[order]
    bD_s = bD[order]

    unique_i, start_idx = np.unique(bi_s, return_index=True)
    end_idx = np.concatenate([start_idx[1:], [bi_s.size]])

    cubocta: list[dict] = []
    for u, s, e in zip(unique_i, start_idx, end_idx):
        n = int(e - s)
        if n < 12:
            continue
        js = bj_s[s : s + 12]
        vs = bD_s[s : s + 12]
        ds = bd_s[s : s + 12]
        norms = np.linalg.norm(vs, axis=1)
        if np.any(norms < 1e-6):
            continue
        mean_d = float(np.mean(ds))
        if (ds.max() - ds.min()) / max(mean_d, 1e-6) > 2.0 * distance_tol:
            continue

        unit = vs / norms[:, None]
        # Angular spectrum check: every pairwise angle must be close
        # to one of {60, 90, 120, 180} deg.  Amorphous close-packed
        # clusters fail this test; real cuboctahedra pass.
        cos_pairs = np.clip(unit @ unit.T, -1.0, 1.0)
        triu_i, triu_j = np.triu_indices(12, k=1)
        pair_angles_deg = np.rad2deg(np.arccos(cos_pairs[triu_i, triu_j]))
        target_angles = np.array([60.0, 90.0, 120.0, 180.0])
        # distance (in degrees) to the nearest target angle, per pair
        dev = np.abs(pair_angles_deg[:, None] - target_angles[None, :]).min(axis=1)
        if float(dev.max()) > angle_tol_deg:
            continue
        # Convex-hull triangulation of the 12 unit directions gives 20
        # triangles (8 corner tris + 12 split-square tris) on an ideal
        # cuboctahedron.  Works for slightly-distorted shells too.
        try:
            hull = ConvexHull(unit)
        except Exception:
            continue
        simplices = hull.simplices.tolist()

        # Build edge set from simplices, filter out square-face
        # diagonals by length (diagonals ~ sqrt(2) vs real edges ~ 1).
        edge_set: dict[tuple[int, int], float] = {}
        for tri in simplices:
            for a, b in [
                (tri[0], tri[1]),
                (tri[1], tri[2]),
                (tri[0], tri[2]),
            ]:
                key = (min(a, b), max(a, b))
                if key in edge_set:
                    continue
                edge_set[key] = float(np.linalg.norm(unit[a] - unit[b]))
        if not edge_set:
            continue
        min_edge = min(edge_set.values())
        edges = [
            [int(a), int(b)]
            for (a, b), length in edge_set.items()
            if length <= min_edge * 1.25
        ]

        cubocta.append({
            "center": int(u),
            "vertices": [int(j) for j in js],
            "faces": [[int(x) for x in tri] for tri in simplices],
            "edges": edges,
        })
    return cubocta


[docs] def export_g2_compare_html( cells_and_labels, output_path: "str | None" = None, *, r_max: float = 10.0, r_step: float = 0.05, background_color: str = "#f7f8f5", title: str = "", show_progress: bool = False, sample_fraction: float = 1.0, sample_rng_seed: int | None = None, ) -> str: """Export a g(r) overlay viewer comparing multiple supercells. Every cell must share the same set of species (same reference material). One g(r) curve is drawn per cell for the currently selected species pair; the dropdown in the viewer switches which pair is shown and a legend identifies each cell by its label. Parameters ---------- cells_and_labels Accepts any of: * ``dict[str, Supercell]`` - keys become legend labels * ``list[tuple[Supercell, str]]`` * ``list[Supercell]`` - each ``cell.label`` is used output_path Path to write the HTML file. When ``None`` the HTML string is returned instead (for :func:`IPython.display.HTML` / inline display); see :func:`plot_g2_compare` for a ready-made Jupyter wrapper. r_max, r_step Radial grid for the measurements. Finer ``r_step`` gives smoother curves; 0.05 A is a good default. background_color, title Cosmetic. show_progress Forwarded to each :meth:`Supercell.measure_g3` call (used under the hood; ``phi_num_bins`` is set low for speed). Returns ------- str Resolved output path when ``output_path`` was provided, otherwise the HTML source string. """ import json from .g3 import G3Distribution from ase.data import chemical_symbols # Normalise input into [(cell, label), ...] pairs: list[tuple] = [] if isinstance(cells_and_labels, dict): for label, cell in cells_and_labels.items(): pairs.append((cell, str(label))) else: for item in cells_and_labels: if isinstance(item, tuple) and len(item) == 2: cell, label = item pairs.append((cell, str(label))) else: pairs.append((item, str(getattr(item, "label", "cell")))) if not pairs: raise ValueError("cells_and_labels is empty.") # Species from the first cell define the pair grid. Verify that # every subsequent cell uses the same species set so the curves # line up. def _species_of(cell) -> np.ndarray: return np.unique(np.asarray(cell.atoms.numbers, dtype=np.int64)) ref_species = _species_of(pairs[0][0]) for cell, lab in pairs[1:]: if not np.array_equal(_species_of(cell), ref_species): raise ValueError( f"Cell '{lab}' has species {_species_of(cell)} which " f"differs from the first cell's {ref_species}. " "export_g2_compare_html requires all supercells to share " "the same species." ) sp_labels = [chemical_symbols[int(z)] for z in ref_species] num_species = len(ref_species) # Measure every cell on the same grid. r_edges_global = None r_centres_global = None def _norm_profile(y: np.ndarray, r_arr: np.ndarray) -> np.ndarray: y = y / np.maximum(r_arr * r_arr, _EPS) tail_start = 0.7 * float(r_arr[-1]) tail_mask = r_arr >= tail_start if not np.any(tail_mask): tail_mask = np.zeros_like(r_arr, dtype=bool) tail_mask[-max(1, r_arr.size // 4):] = True tail = y[tail_mask] finite = tail[np.isfinite(tail)] scale = float(np.mean(finite)) if finite.size else 1.0 if scale <= _EPS: scale = 1.0 return (y / scale).astype(np.float32) series: list[dict] = [] pair_labels: list[str] = [] # Build pair labels from the first cell's species ordering. for ci in range(num_species): for vi in range(ci, num_species): pair_labels.append(f"{sp_labels[ci]}-{sp_labels[vi]}") for cell, lab in pairs: dist = G3Distribution(cell.atoms, label=f"{lab}-g2-compare") dist.measure_g3( r_max=r_max, r_step=r_step, phi_num_bins=12, show_progress=show_progress, sample_fraction=sample_fraction, sample_rng_seed=sample_rng_seed, ) if dist.g2 is None: raise ValueError(f"Distribution measurement for '{lab}' has no g2.") r = np.asarray(dist.r, dtype=np.float64) if r_centres_global is None: r_centres_global = r.tolist() r_edges_global = np.asarray(dist.bin_edges, dtype=np.float64).tolist() g2 = np.asarray(dist.g2, dtype=np.float64) profiles: list[list[float]] = [] for ci in range(num_species): for vi in range(ci, num_species): prof = _norm_profile(g2[ci, vi], r) profiles.append(prof.tolist()) series.append({"label": lab, "profiles": profiles}) # Per-pair peak markers pulled from the first cell's shell_target # (if present). pair_peaks: list[float] = [0.0] * len(pair_labels) shell_target = getattr(pairs[0][0], "_shell_target", None) if shell_target is not None: st_species = np.asarray(shell_target.species, dtype=np.int64) pair_peak_mat = np.asarray(shell_target.pair_peak, dtype=np.float64) # Map dist species -> shell_target species index idx_map: dict[int, int] = {} for ci_dist, z in enumerate(ref_species): m = np.where(st_species == int(z))[0] if m.size: idx_map[ci_dist] = int(m[0]) out_idx = 0 for ci in range(num_species): for vi in range(ci, num_species): ci_st = idx_map.get(ci) vi_st = idx_map.get(vi) if ci_st is not None and vi_st is not None: val = float(pair_peak_mat[ci_st, vi_st]) if np.isfinite(val) and val > 0: pair_peaks[out_idx] = val out_idx += 1 # Default pair: smallest positive peak (typically the short bond). default_pair = 0 positive = [(i, p) for i, p in enumerate(pair_peaks) if p > 0] if positive: default_pair = min(positive, key=lambda x: x[1])[0] data = { "num_r": len(r_centres_global), "num_pairs": len(pair_labels), "r_centers": r_centres_global, "r_edges": r_edges_global, "pair_labels": pair_labels, "pair_peaks": pair_peaks, "series": series, "default_pair": int(default_pair), "background_color": background_color, "title": title, } html = _G2_HTML_TEMPLATE.replace( "__TRICOR_DATA_PLACEHOLDER__", json.dumps(data), ) if output_path is None: return html output_path = str(output_path) with open(output_path, "w") as f: f.write(html) return output_path
[docs] def plot_g2_compare( cells_and_labels, *, r_max: float = 10.0, r_step: float = 0.05, title: str = "", height: int = 480, show_progress: bool = False, sample_fraction: float = 1.0, sample_rng_seed: int | None = None, ): """Inline Jupyter display of the g(r) overlay-compare viewer. Convenience wrapper around :func:`export_g2_compare_html` that packages the HTML as an :class:`IPython.display.HTML` object so you can do ``tc.plot_g2_compare(cells)`` in a notebook cell. Parameters ---------- cells_and_labels : list of (Supercell, str) or dict Either a list of ``(supercell, label)`` pairs or a ``{label: supercell}`` dict. Each cell's g(r) becomes one curve in the overlay. See :func:`export_g2_compare_html` for full input-shape details. r_max : float, optional Maximum radial distance (Å) plotted on the x-axis. Default ``10.0``. r_step : float, optional Bin width (Å) of the g(r) histogram. Default ``0.05``. title : str, optional Title text rendered above the viewer. Default ``""``. height : int, optional Iframe height (pixels) of the rendered viewer in Jupyter. Default ``480``. show_progress : bool, optional Display a tqdm progress bar while the per-cell g(r) histograms are built. Default ``False``. Returns ------- IPython.display.HTML The viewer wrapped in an iframe and ready for inline display in a Jupyter cell. """ from IPython.display import HTML html = export_g2_compare_html( cells_and_labels, None, r_max=r_max, r_step=r_step, title=title, show_progress=show_progress, sample_fraction=sample_fraction, sample_rng_seed=sample_rng_seed, ) import html as _html escaped = _html.escape(html, quote=True) return HTML( f'<div class="tricor-g2-compare-wrapper" style="width:100%">' f'<iframe srcdoc="{escaped}" ' f'style="width:100%; height:{int(height)}px; ' f'border:1px solid rgba(0,0,0,0.1); border-radius:6px;"></iframe>' f'</div>' )
[docs] def export_overview_html( output_path: str, cells_and_labels, *, grid_cols: int = 3, atom_scale: float = 0.17, bond_radius: float = 0.07, bond_color=(0.95, 0.1, 0.1), background_color: str = "#f7f8f5", title: str = "", subtitle: str = "", bond_cutoff_scale: float = 1.2, max_bonds_per_atom: int = 4, bond_length_tol: float = 0.10, ideal_angle_deg: float = 109.47, bond_angle_tol_deg: float = 18.0, tetrahedra: dict | None = None, tetrahedra_color=(0.35, 0.45, 0.95), tetrahedra_opacity: float = 0.45, octahedra: dict | None = None, octahedra_color=(0.95, 0.55, 0.25), octahedra_opacity: float = 0.4, cuboctahedra: dict | None = None, cuboctahedra_color=(0.55, 0.35, 0.85), cuboctahedra_opacity: float = 0.4, polyhedra_groups: "list[dict] | None" = None, ) -> str: """Export a 3D grid of supercells as a self-contained, auto-rotating HTML viewer. Each panel renders the final atom positions of one :class:`Supercell` using ASE element colours and black outlines. All panels share a single auto-rotating camera; dragging any panel pauses the rotation and lets the user orbit manually. Bonds are rendered by default; passing one of the polyhedra kwargs (``tetrahedra``, ``octahedra``, ``cuboctahedra``, ``polyhedra_groups``) replaces bonds with translucent polyhedra. Parameters ---------- output_path : str Filesystem path for the written HTML file. cells_and_labels : list of (Supercell, str) Pairs of supercells and their panel-title labels, rendered left-to-right then top-to-bottom into a ``grid_cols``-wide grid. grid_cols : int, optional Number of panels per row. Default ``3``. atom_scale : float, optional Multiplier applied to ASE covalent radii when sizing atom spheres. Default ``0.17``. bond_radius : float, optional Cylinder radius (Å) for rendered bonds. Default ``0.07``. bond_color : tuple of float, optional RGB triplet (each in [0, 1]) for bond colour. Default red ``(0.95, 0.1, 0.1)``. background_color : str, optional CSS colour string for the panel background. Default ``"#f7f8f5"`` (off-white). title, subtitle : str, optional Headline + sub-line text rendered above the grid. bond_cutoff_scale : float, optional Bond search radius is ``shell_target.pair_peak × bond_cutoff_scale``. Default ``1.2``. max_bonds_per_atom : int, optional Per-atom cap on bonds drawn (after the angle filter). Default ``4``. bond_length_tol : float, optional Acceptance window around ``pair_peak`` for the bond filter (fraction). Default ``0.10`` (±10 %). ideal_angle_deg, bond_angle_tol_deg : float, optional Reject bonds whose 4-NN angles deviate more than ``bond_angle_tol_deg`` from ``ideal_angle_deg``. Default ``109.47°`` ± ``18°`` (tetrahedral). Set ``bond_angle_tol_deg=180`` to disable the angle filter (e.g. FCC metals where 60°/90°/120° angles all matter). tetrahedra : dict, optional Switch panels to tetrahedron rendering. Dict keys: ``center_symbol`` (default ``"Si"``), ``vertex_symbol`` (default ``"O"``), ``bond_length`` (default auto-detected from atoms), ``bond_length_tol`` (default ``0.15``), ``ideal_angle_deg`` (default ``109.47``), ``angle_tol_deg`` (default ``25.0``), ``scale`` (vertex shrink factor; default ``0.5`` for same-element, ``1.0`` for cross-species). tetrahedra_color : tuple of float, optional RGB triplet for tetrahedron faces. Default navy ``(0.35, 0.45, 0.95)``. tetrahedra_opacity : float, optional Face opacity in [0, 1]. Default ``0.45``. octahedra, octahedra_color, octahedra_opacity : dict / tuple / float, optional Octahedron variant — same dict keys as ``tetrahedra``. cuboctahedra, cuboctahedra_color, cuboctahedra_opacity : dict / tuple / float, optional Cuboctahedron variant (12-vertex FCC close-packed shell). polyhedra_groups : list of dict, optional Multi-group polyhedra (e.g. sp²/sp³ carbon) — list of per-group dicts each with ``kind`` (one of ``"triangles"``, ``"tetrahedra"``, ``"octahedra"``, ``"cuboctahedra"``), ``center_symbol``, ``vertex_symbol``, ``color``, ``opacity``, plus optional ``virtual_species`` for filtering by ``Supercell._atom_shell_species_index``. Returns ------- str The HTML written to ``output_path``. Examples -------- >>> import tricor as tc >>> cells = [(cell_amorphous, "amorphous"), ... (cell_mro, "MRO"), ... (cell_nc, "NC")] >>> tc.export_overview_html("overview.html", cells, grid_cols=3, ... tetrahedra=dict(center_symbol="Si", ... vertex_symbol="Si")) """ import json from ase.data import covalent_radii from ase.data.colors import jmol_colors tetra_cfg = _resolve_polyhedra_cfg(tetrahedra, octahedra, cuboctahedra) # Pick the right colour/opacity based on which kwarg was supplied. if cuboctahedra is not None: poly_color = cuboctahedra_color poly_opacity = cuboctahedra_opacity elif octahedra is not None: poly_color = octahedra_color poly_opacity = octahedra_opacity else: poly_color = tetrahedra_color poly_opacity = tetrahedra_opacity # Multi-group polyhedra mode overrides the single-cfg path above. # Each entry is a dict {kind, center_symbol, vertex_symbol, ..., # color, opacity, virtual_species}. groups_cfg: list[dict] | None = None if polyhedra_groups: groups_cfg = [_resolve_group_cfg(g) for g in polyhedra_groups] def _bond_length_from_shell_target(cell, center_sym, vertex_sym): from ase.data import atomic_numbers st = getattr(cell, "_shell_target", None) if st is None: return None sp = np.asarray(st.species, dtype=np.int64) try: i = int(np.where(sp == atomic_numbers[center_sym])[0][0]) j = int(np.where(sp == atomic_numbers[vertex_sym])[0][0]) except (IndexError, KeyError): return None v = float(np.asarray(st.pair_peak, dtype=np.float64)[i, j]) return v if v > 0 else None structures = [] for cell, label in cells_and_labels: atoms = cell.atoms shell_target = getattr(cell, "_shell_target", None) if shell_target is not None: pair_peak = float(np.max( np.asarray(shell_target.pair_peak, dtype=np.float64), )) else: pair_peak = 2.35 cell_mat = np.asarray(atoms.cell.array, dtype=np.float32) centre = 0.5 * np.sum(cell_mat, axis=0) pos = (atoms.positions - centre).astype(np.float32) pos = np.round(pos, 3) numbers = atoms.numbers colors = np.array([jmol_colors[z] for z in numbers], dtype=np.float32) radii = np.array([covalent_radii[z] for z in numbers], dtype=np.float32) if tetra_cfg is None: # Existing bond-filter path. cutoff_lo = pair_peak * (1.0 - bond_length_tol) cutoff_hi = pair_peak * (1.0 + bond_length_tol) search_cutoff = pair_peak * bond_cutoff_scale bi_all, bj_all, bd_all, bD_all = neighbor_list( "ijdD", atoms, float(search_cutoff), ) length_ok = (bd_all >= cutoff_lo) & (bd_all <= cutoff_hi) bi_all = bi_all[length_ok] bj_all = bj_all[length_ok] bd_all = bd_all[length_ok] bD_all = bD_all[length_ok] order = np.lexsort((bd_all, bi_all)) bi_s = bi_all[order] bj_s = bj_all[order] bD_s = bD_all[order] keep_mask = np.zeros(bi_s.size, dtype=bool) if bi_s.size: unique_i, start_idx = np.unique(bi_s, return_index=True) end_idx = np.concatenate([start_idx[1:], [bi_s.size]]) for s, e in zip(start_idx, end_idx): keep_mask[s : min(s + max_bonds_per_atom, e)] = True bi_top = bi_s[keep_mask] bj_top = bj_s[keep_mask] bD_top = bD_s[keep_mask] ideal_angle = float(np.deg2rad(ideal_angle_deg)) angle_tol = float(np.deg2rad(bond_angle_tol_deg)) needed = int(max_bonds_per_atom) n_atoms = len(atoms) per_atom_js: list[list] = [[] for _ in range(n_atoms)] per_atom_vs: list[list] = [[] for _ in range(n_atoms)] for i_, j_, v_ in zip(bi_top.tolist(), bj_top.tolist(), bD_top): if len(per_atom_js[i_]) < needed: per_atom_js[i_].append(int(j_)) per_atom_vs[i_].append(v_) good_pairs: set[tuple[int, int]] = set() for i_ in range(n_atoms): js = per_atom_js[i_] if len(js) < needed: continue vecs = np.asarray(per_atom_vs[i_], dtype=np.float64) norms = np.linalg.norm(vecs, axis=1) if np.any(norms <= 1e-9): continue unit = vecs / norms[:, None] cos = np.clip(unit @ unit.T, -1.0, 1.0) angles = np.arccos(cos) triu = np.triu_indices(needed, k=1) dev = np.abs(angles[triu] - ideal_angle) if np.max(dev) <= angle_tol: for j_ in js: good_pairs.add((min(i_, j_), max(i_, j_))) if good_pairs: pair_arr = np.asarray(sorted(good_pairs), dtype=np.int32) bi = pair_arr[:, 0] bj = pair_arr[:, 1] else: bi = np.zeros(0, dtype=np.int32) bj = np.zeros(0, dtype=np.int32) tet_vertices_flat: list[float] = [] num_tets = 0 per_poly_faces: list[list[list[int]]] = [] per_poly_edges: list[list[list[int]]] = [] else: # Polyhedra mode: skip bonds, detect tets / octa / cubocta. effective_bond_length = tetra_cfg["bond_length"] if effective_bond_length is None: effective_bond_length = _bond_length_from_shell_target( cell, tetra_cfg["center_symbol"], tetra_cfg["vertex_symbol"], ) polys = tetra_cfg["detector"]( atoms, center_symbol=tetra_cfg["center_symbol"], vertex_symbol=tetra_cfg["vertex_symbol"], bond_length=effective_bond_length, bond_length_tol=tetra_cfg["bond_length_tol"], ideal_angle_deg=tetra_cfg["ideal_angle_deg"], angle_tol_deg=tetra_cfg["angle_tol_deg"], ) tet_vertices_flat = _polyhedra_vertex_coords( polys, atoms.positions, atoms.cell.array, scale=tetra_cfg["scale"], ) tet_vertices_flat = [round(v, 3) for v in tet_vertices_flat] num_tets = len(polys) if tetra_cfg["per_polyhedron_topology"]: per_poly_faces = [p["faces"] for p in polys] per_poly_edges = [p["edges"] for p in polys] else: per_poly_faces = [] per_poly_edges = [] bi = np.zeros(0, dtype=np.int32) bj = np.zeros(0, dtype=np.int32) # --- multi-group polyhedra payload (optional) --- if groups_cfg: groups_payload = [ _render_polyhedra_group( atoms, gcfg, cell_obj=cell, default_color=(0.25, 0.65, 0.95), default_opacity=0.35, ) for gcfg in groups_cfg ] # Clear the legacy single-group fields: viewer will read # polyhedra_groups in preference. if groups_payload: tet_vertices_flat = [] num_tets = 0 else: groups_payload = [] structures.append({ "label": label, "num_atoms": int(len(atoms)), "num_bonds": int(len(bi)), "positions": pos.ravel().tolist(), "atom_colors": colors.ravel().tolist(), "atom_radii": np.round(radii, 3).tolist(), "bond_i": bi.tolist(), "bond_j": bj.tolist(), "cell_matrix": np.round(cell_mat, 3).ravel().tolist(), "tetrahedra_vertices": tet_vertices_flat, "num_tetrahedra": int(num_tets), "polyhedra_faces_per_poly": per_poly_faces, "polyhedra_edges_per_poly": per_poly_edges, "polyhedra_groups": groups_payload, }) data = { "structures": structures, "grid_cols": int(grid_cols), "grid_rows": int(-(-len(structures) // grid_cols)), "atom_scale": float(atom_scale), "bond_radius": float(bond_radius), "bond_color": list(bond_color), "background_color": background_color, "title": title, "subtitle": subtitle, "tetrahedra_mode": tetra_cfg is not None, "tetrahedra_color": list(poly_color), "tetrahedra_opacity": float(poly_opacity), # Generic polyhedra topology so the viewer can render either # tets (n=4) or octahedra (n=6) without hard-coded tables. "polyhedra_n_vertices": int(tetra_cfg["n_vertices"]) if tetra_cfg else 4, "polyhedra_faces": ( tetra_cfg["faces"] if (tetra_cfg and tetra_cfg["faces"] is not None) else _TET_FACES ), "polyhedra_edges": ( tetra_cfg["edges"] if (tetra_cfg and tetra_cfg["edges"] is not None) else _TET_EDGES ), "polyhedra_per_polyhedron_topology": bool( tetra_cfg and tetra_cfg["per_polyhedron_topology"] ), "polyhedra_scale": float(tetra_cfg["scale"]) if tetra_cfg else 1.0, "polyhedra_multi_mode": bool(groups_cfg), } html = _OVERVIEW_HTML_TEMPLATE.replace( "__TRICOR_DATA_PLACEHOLDER__", json.dumps(data), ) output_path = str(output_path) with open(output_path, "w") as f: f.write(html) return output_path
def _detect_shell_mask( triplet_data: np.ndarray, r: np.ndarray, *, pair_peak: float | None = None, smooth_sigma_r: float = 0.25, dip_fraction: float = 0.5, hi_cap_factor: float = 1.25, ) -> np.ndarray: """Auto-detect the first NN shell window over r for one triplet of g3. Simple + robust variant: smooth the root-bond profile, seed the peak on ``pair_peak`` (or a low-r-biased argmax otherwise), pin the left edge at the first non-zero bin, walk right until the smoothed profile drops below ``dip_fraction × peak``. Always capped on the right at ``hi_cap_factor × pair_peak`` so a near-flat valley (common for crystalline nanocrystalline cells) can't runaway into the second shell. Parameters ---------- triplet_data One channel of the raw g3 histogram, shape ``(num_r, num_r, num_phi)``. r Bin-centre radii (Å), shape ``(num_r,)``. pair_peak Optional hint - reference first-neighbour distance (Å), typically ``shell_target.pair_peak[center, neighbour]``. smooth_sigma_r Gaussian standard deviation applied to g(r) before detection, in Å. dip_fraction The smoothed profile has to drop below ``dip_fraction × peak_val`` to count as "past the first shell". 0.5 works well for both crystalline (narrow peak) and disordered (broad peak) cases. hi_cap_factor Absolute right-boundary cap, in units of ``pair_peak`` (or the detected peak radius when pair_peak isn't given). 1.25 sits comfortably before the second NN shell for all tested materials. """ # Pair profile: collapse both angular-partner and phi dimensions. profile = triplet_data.sum(axis=(1, 2)) + triplet_data.sum(axis=(0, 2)) profile = profile / np.maximum(r * r, _EPS) finite = np.nan_to_num(profile, nan=0.0, posinf=0.0, neginf=0.0) mask = np.zeros_like(r, dtype=bool) positive = np.flatnonzero(finite > 0) if positive.size == 0: return mask # Smooth. r_step = float(r[1] - r[0]) if r.size > 1 else 1.0 sigma_bins = max(0.0, float(smooth_sigma_r) / max(r_step, _EPS)) if sigma_bins > 0.0: radius = max(1, int(np.ceil(3.0 * sigma_bins))) xk = np.arange(-radius, radius + 1, dtype=np.float64) kernel = np.exp(-0.5 * (xk / sigma_bins) ** 2) kernel /= float(kernel.sum()) pad = np.pad(finite, (radius, radius), mode="edge") smoothed = np.convolve(pad, kernel, mode="valid") else: smoothed = finite # --- Pick the peak --- if pair_peak is not None and np.isfinite(pair_peak) and pair_peak > 0: seed = int(np.argmin(np.abs(r - float(pair_peak)))) # Search a ±15 % window around the reference peak so a weak # low-r bump (atoms still close-packed from a random start) # can't out-vote the true first-neighbour peak. window_half = max( 3, int(np.ceil(0.15 * float(pair_peak) / max(r_step, _EPS))) ) lo_search = max(0, seed - window_half) hi_search = min(smoothed.size, seed + window_half + 1) peak_bin = int(lo_search + int(np.argmax(smoothed[lo_search:hi_search]))) else: # Without pair_peak, bias the peak search toward low r so the # second NN shell can't win the argmax in a clean crystal. The # exp(-r / 2 Å) weight falls off roughly 1.6x between first and # second NN for typical materials, enough to pick the right peak. start = int(positive[0]) biased = smoothed * np.exp(-np.asarray(r, dtype=np.float64) / 2.0) peak_bin = int(start + int(np.argmax(biased[start:]))) peak_val = float(smoothed[peak_bin]) peak_r = float(r[peak_bin]) if peak_val <= _EPS: return mask # nothing above the floor # --- Right edge: first bin where smoothed dips below threshold --- if pair_peak is not None and np.isfinite(pair_peak) and pair_peak > 0: hi_cap_r = float(hi_cap_factor) * float(pair_peak) else: hi_cap_r = float(hi_cap_factor) * peak_r hi_cap_bin = int(min(r.size - 1, np.searchsorted(r, hi_cap_r))) dip_thresh = peak_val * float(dip_fraction) right_bin = hi_cap_bin for idx in range(peak_bin + 1, hi_cap_bin + 1): if smoothed[idx] <= dip_thresh: right_bin = idx break # --- Left edge: walk LEFT from the peak until the smoothed # profile dips below ``dip_thresh``, mirroring the right-edge # logic. Floor at ``0.75 × pair_peak`` so a broad low-r tail # (e.g. grain-boundary atoms smearing the rising edge in a # nanocrystalline g(r), or Gaussian-smoothing bleed) can't drag # the mask down past the true first-NN peak. # # The previous heuristic pinned the left edge at the first # non-zero bin (capped at the 0.75 × pair_peak floor). That # captured the FULL rising edge for clean crystals, but for # nanocrystalline cells where GB atoms fill the [0.75 × peak, # peak] region with low-density counts, it left a wide "bleed" # band of low-count data inside the shell window. The dip-walk # mirrors the right edge and gives a clean, symmetric cut. if pair_peak is not None and np.isfinite(pair_peak) and pair_peak > 0: lo_floor_r = 0.75 * float(pair_peak) else: lo_floor_r = 0.75 * peak_r lo_floor_bin = int(max(0, np.searchsorted(r, lo_floor_r) - 1)) left_bin = lo_floor_bin for idx in range(peak_bin - 1, lo_floor_bin - 1, -1): if smoothed[idx] <= dip_thresh: left_bin = idx break # Never start past the peak. left_bin = min(left_bin, peak_bin) mask[left_bin : right_bin + 1] = True return mask def _g3_pair_profile( dist: "Any", triplet_idx: int, r: np.ndarray, ) -> np.ndarray: """Return the per-triplet ROOT-bond pair profile g(r), normalised so the tail -> 1.0. For a triplet ``A | B C`` the root bond is ``A-B`` (center to first neighbour). The heatmap above the profile integrates over this same A-B shell to expose how the third atom C is distributed about the A-B backbone. """ g2 = getattr(dist, "g2", None) g3_index = getattr(dist, "g3_index", None) if g2 is not None and g3_index is not None: center_ind, neigh1_ind, _neigh2_ind = g3_index[triplet_idx] profile = np.asarray(g2[center_ind, neigh1_ind], dtype=np.float64).copy() else: # Fall back to integrating r2 and phi out of g3, which leaves # the r1 profile (the root bond). triplet_data = np.asarray(dist.g3[triplet_idx], dtype=np.float64) profile = triplet_data.sum(axis=(1, 2)) profile = profile / np.maximum(r * r, _EPS) # Scale so the tail converges to 1.0 tail_start = 0.7 * float(r[-1]) tail_mask = r >= tail_start if not np.any(tail_mask): tail_mask = np.zeros_like(r, dtype=bool) tail_mask[-max(1, r.size // 4) :] = True tail = profile[tail_mask] finite = tail[np.isfinite(tail)] scale = float(np.mean(finite)) if finite.size else 1.0 if scale <= _EPS: scale = 1.0 return (profile / scale).astype(np.float32) def _g3_slice_image( triplet_data: np.ndarray, shell_mask: np.ndarray, r: np.ndarray, phi_deg: np.ndarray, ) -> np.ndarray: """Compute the (num_phi, num_r) reduced-density slice for one triplet. Integrates the root bond (axis 0, ``r1``) over ``shell_mask``; the remaining ``(r2, phi)`` plane shows where the *third* atom sits relative to the root bond. Normalised so that the uniform far-field tends to 1.0. """ image = triplet_data[shell_mask, :, :].sum(axis=0) image = image.T # (num_phi, num_r) phi_rad = np.deg2rad(phi_deg) phi_factor = np.maximum(np.sin(phi_rad), 1e-3)[:, None] radial_factor = np.maximum(r * r, _EPS)[None, :] image = image / (phi_factor * radial_factor) tail_start = 0.7 * float(r[-1]) tail_mask = r >= tail_start if not np.any(tail_mask): tail_mask = np.zeros_like(r, dtype=bool) tail_mask[-max(1, r.size // 4) :] = True tail = image[:, tail_mask] finite = tail[np.isfinite(tail)] scale = float(np.mean(finite)) if finite.size else 1.0 if scale <= _EPS: scale = 1.0 return (image / scale).astype(np.float32) def _nice_round_up(v: float) -> float: """Round *v* up to a ``nice`` number whose half is also clean. Picks the smallest value in [1, 2, 4, 5, 10] x 10**k that is >= v. E.g. 3.85 -> 4.0, 0.87 -> 1.0, 12.5 -> 20.0, 2.1 -> 4.0. """ import math if v <= 0 or not math.isfinite(v): return 1.0 exp = math.floor(math.log10(v)) magnitude = 10.0 ** exp mantissa = v / magnitude for m in (1.0, 2.0, 4.0, 5.0, 10.0): if mantissa <= m + 1e-9: return m * magnitude return 10.0 * magnitude class _PlottingMixin: """Plotting methods extracted from Supercell.""" def view_structure( self: "Supercell", shell_target: "CoordinationShellTarget | None" = None, *, polyhedra: "dict | bool | None" = True, **kwargs, ): """Return an interactive 3D structure viewer widget. Renders atoms as spheres (coloured by element). Two overlay modes, independently toggle-able in the widget: * Bonds - cylinders between atoms within a radial cutoff. * Polyhedra - translucent tetrahedra / octahedra / cuboctahedra around atoms that pass a distance + angle tolerance check (see :meth:`export_trajectory_html` / :func:`_detect_tetrahedra` for the underlying algorithm). Enabled by default for materials whose coordination polyhedron we can auto-detect (Si / C tetrahedra at half-scale, Cu cuboctahedra, SiO2 tetrahedra, SrTiO3 octahedra). Sliders in the side panel let you tune the radial tolerance, angular tolerance, centre-vertex bond length, and polyhedra scale (0.5 places vertices at bond midpoints, 1.0 at atoms). Parameters ---------- shell_target Sets the default bond cutoff and bond_length from ``shell_target.max_pair_outer`` / ``pair_peak``. If ``None``, uses the shell_target stored from the last :meth:`generate` call. polyhedra Polyhedra config. ``True`` / ``None`` (default) auto-pick kind + settings from species; ``False`` disables polyhedra; a ``dict`` overrides individual settings - e.g. ``{'kind': 'octahedra', 'center_symbol': 'Ti', 'vertex_symbol': 'O', 'bond_length': 1.96}``. **kwargs Forwarded to :class:`StructureWidget` (e.g. ``atom_scale``, ``bond_cutoff``, ``show_bonds``, ``slab_x``, etc.). Returns ------- StructureWidget An anywidget instance for display in Jupyter. """ from .structure_widget import StructureWidget if shell_target is None: shell_target = getattr(self, "_shell_target", None) return StructureWidget( self.atoms, shell_target=shell_target, grain_ids=self._grain_ids, polyhedra=polyhedra, **kwargs, ) def plot_g3( self: "Supercell", pair: int | str = 0, *, normalize: bool = True, ): """Return an interactive explorer for the supercell's measured g3. Requires :meth:`measure_g3` to have been called first. Parameters ---------- pair Triplet index or label (e.g. ``0`` or ``"Si-Si-Si"``). normalize If ``True``, display the reduced (density-normalised) g3. """ if self.current_distribution is None: raise ValueError("Call measure_g3() before plot_g3().") dist = self.current_distribution return dist.plot_g3(pair=pair, normalize=normalize) def plot_g3_compare( self: "Supercell", pair: int | str = 0, *, normalize: bool = True, ): """Interactive side-by-side comparison of the current supercell's g3 against its target g3. Renders an anywidget-based two-panel viewer in Jupyter: left panel is the supercell's measured g3 for the chosen species-pair triplet channel, right panel is the corresponding target g3 (set via ``Supercell.target_distribution``). Drag the radial-shell slider below either panel to inspect a g3 slice at fixed root-bond radius. Parameters ---------- pair : int or str, optional Which triplet channel to display. Either an integer index into ``target_distribution.angle_index`` or a string label like ``"Si-Si-Si"`` resolved by :meth:`G3Distribution._resolve_pair_index`. Default ``0`` (first channel). normalize : bool, optional If ``True`` (default), normalise both g3 values by the uniform-random reference so bins read as enhancements (>1) or depletions (<1). If ``False``, raw counts. Returns ------- G3CompareWidget Interactive anywidget instance. Display it inline in Jupyter (returning the widget from a cell auto-renders). """ current = self.measure_g3() pair_index = self.target_distribution._resolve_pair_index(pair) from .g3_compare_widget import G3CompareWidget return G3CompareWidget( current_distribution=current, target_distribution=self.target_distribution, triplet_index=pair_index, normalize=normalize, supercell_title=f"{self.label} g3 slice", target_title=f"{self.target_distribution.label} g3 slice", status_prefix=( f"density {self.relative_density:.3f} | " "cell_dim_angstroms " f"{self.cell_dim_angstroms[0]:.2f}x" f"{self.cell_dim_angstroms[1]:.2f}x" f"{self.cell_dim_angstroms[2]:.2f} A" ), ) def _display_compare_widget(self: "Supercell") -> None: """Display the comparison widget immediately when running in IPython.""" try: from IPython.display import display except Exception: return display(self.plot_g3_compare()) def plot_monte_carlo( self: "Supercell", *, log_y: bool = False, show_run_boundaries: bool = True, ): """Plot the Monte-Carlo cost history captured by the most recent :meth:`monte_carlo` call. Two curves are drawn on a single matplotlib axis: instantaneous cost (current MC state) and best-so-far cost (envelope). Vertical dashed markers separate consecutive ``monte_carlo`` runs when ``show_run_boundaries=True``. Parameters ---------- log_y : bool, optional Display the cost on a logarithmic y-axis. Useful for anneal schedules that span several orders of magnitude. Default ``False``. show_run_boundaries : bool, optional If ``True``, draw vertical dashed lines at the start of each new MC run when ``mc_history["run_index"]`` increments (e.g. when chained calls extend the history). Default ``True``. Returns ------- matplotlib.figure.Figure The created figure. Raises ------ ValueError If :meth:`monte_carlo` has not been run yet (``self.mc_history is None``). """ if self.mc_history is None: raise ValueError("Run monte_carlo() before plotting the history.") import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(6.8, 4.2)) ax.plot(self.mc_history["step"], self.mc_history["cost"], lw=1.8, label="cost") ax.plot(self.mc_history["step"], self.mc_history["best_cost"], lw=1.4, label="best") if show_run_boundaries and "run_index" in self.mc_history: run_index = np.asarray(self.mc_history["run_index"], dtype=np.int32) step = np.asarray(self.mc_history["step"], dtype=np.int32) change_points = np.flatnonzero(np.diff(run_index) > 0) for point_index in change_points: ax.axvline( float(step[point_index + 1]), color="0.65", lw=0.9, ls="--", alpha=0.7, ) if log_y: positive_cost = np.asarray(self.mc_history["cost"], dtype=np.float64) positive_cost = positive_cost[positive_cost > 0.0] if positive_cost.size: ax.set_yscale("log") ax.set_xlabel("step") ax.set_ylabel("cost") ax.set_title("Monte Carlo cost history") ax.legend() ax.grid(alpha=0.25) fig.tight_layout() return fig, ax def export_trajectory_html( self: "Supercell", output_path: str, *, bond_cutoff: float | None = None, atom_scale: float = 0.17, bond_radius: float = 0.06, background_color: str = "#f7f8f5", title: str = "", tetrahedra: dict | None = None, tetrahedra_color=(0.35, 0.45, 0.95), tetrahedra_opacity: float = 0.45, octahedra: dict | None = None, octahedra_color=(0.95, 0.55, 0.25), octahedra_opacity: float = 0.4, cuboctahedra: dict | None = None, cuboctahedra_color=(0.55, 0.35, 0.85), cuboctahedra_opacity: float = 0.4, polyhedra_groups: "list[dict] | None" = None, show_bonds: bool | None = None, history: "dict | str | None" = None, ) -> str: """Export an interactive 3D trajectory viewer as a self-contained HTML file. Requires :meth:`shell_relax` or :meth:`generate` to have been run with ``capture_trajectory=True``. The resulting HTML embeds the full position trajectory, uses Three.js (from CDN) for rendering, and provides play/pause/slider controls. Parameters ---------- output_path Path to write the HTML file. bond_cutoff Maximum bond length in Angstrom. If ``None`` and the supercell was generated with a shell_target, uses ``shell_target.max_pair_outer * 1.2``. Otherwise 3.0. atom_scale Radius scale for atom spheres (multiplied by covalent radii). bond_radius Radius of bond cylinders in Angstrom. background_color CSS colour for the viewer background. title Optional title displayed above the viewer. show_bonds Whether to emit bond cylinders. ``None`` (default) auto-picks: ``True`` when no ``tetrahedra`` are requested, ``False`` when they are (tetrahedra supersede bonds). Pass ``True`` / ``False`` explicitly to override. history Which trajectory history to render. ``None`` (default) uses ``self.shell_relax_history``. Pass ``"thermal_relax"`` to render ``self.thermal_relax_history``, or pass an explicit history dict. Both shell-relax and thermal-relax dicts have a ``"trajectory"`` key when their parent call was run with ``capture_trajectory=True``. Returns ------- str The output path. """ import json from ase.data import covalent_radii from ase.data.colors import jmol_colors if history is None: resolved_history = self.shell_relax_history history_label = "shell_relax_history" elif isinstance(history, str): attr = ( "thermal_relax_history" if history in ("thermal", "thermal_relax") else "shell_relax_history" if history in ("shell", "shell_relax") else "refine_grains_history" if history in ("refine", "refine_grains") else None ) if attr is None: raise ValueError( f"history string must be one of " f"'shell_relax' / 'thermal_relax' / 'refine_grains'; " f"got {history!r}" ) resolved_history = getattr(self, attr, None) history_label = attr else: resolved_history = history history_label = "(provided)" if resolved_history is None or "trajectory" not in resolved_history: raise ValueError( f"No trajectory data available on {history_label}. Run " "shell_relax() / generate() / thermal_relax() with " "capture_trajectory=True first." ) history = resolved_history trajectory = np.asarray(history["trajectory"], dtype=np.float32) n_frames, n_atoms, _ = trajectory.shape atom_cost = history.get("atom_cost") if atom_cost is not None: atom_cost = np.asarray(atom_cost, dtype=np.float32) # Global colour range (constant across frames). We scale # to the STEADY-STATE cost - the 99th percentile of the # last quarter of frames - rather than percentile-of-all, # which is dominated by the initial random-position chaos # for liquid-path runs (Cu liquid starting at random # positions has early per-atom costs of ~100 even with # bond_weight=0.05, which then saturates the colour scale # long after those atoms have relaxed). cost_min = 0.0 n_frames_cost = atom_cost.shape[0] tail_start = max(0, int(n_frames_cost * 0.75)) tail = atom_cost[tail_start:].ravel() positive = tail[tail > 0.0] if positive.size == 0: positive = atom_cost.ravel()[atom_cost.ravel() > 0] if positive.size: raw_max = float(np.percentile(positive, 99.0)) if raw_max <= 0.0: raw_max = float(positive.max()) else: raw_max = 1.0 cost_max = _nice_round_up(raw_max) # Bond cutoff shell_target = getattr(self, "_shell_target", None) if bond_cutoff is None: if shell_target is not None: pair_peak_max = float(np.max( np.asarray(shell_target.pair_peak, dtype=np.float64), )) bond_cutoff = pair_peak_max * 1.2 else: bond_cutoff = 3.0 # Polyhedra topology (from final frame) if requested; disables bonds. tetra_cfg = _resolve_polyhedra_cfg(tetrahedra, octahedra, cuboctahedra) if cuboctahedra is not None: poly_color = cuboctahedra_color poly_opacity = cuboctahedra_opacity elif octahedra is not None: poly_color = octahedra_color poly_opacity = octahedra_opacity else: poly_color = tetrahedra_color poly_opacity = tetrahedra_opacity tet_centers: list[int] = [] tet_vertex_idx: list[int] = [] per_poly_faces: list[list[list[int]]] = [] per_poly_edges: list[list[list[int]]] = [] # Resolve show_bonds auto-default: bonds are NEVER drawn by # default — they're an opt-in render mode now. Polyhedra are # the preferred visualisation when configured; otherwise the # viewer shows atoms only. Pass ``show_bonds=True`` to bring # back the old bond cylinders. if show_bonds is None: show_bonds_eff = False else: show_bonds_eff = bool(show_bonds) if tetra_cfg is not None: effective_bond_length = tetra_cfg["bond_length"] if effective_bond_length is None: from ase.data import atomic_numbers _st = getattr(self, "_shell_target", None) if _st is not None: _sp = np.asarray(_st.species, dtype=np.int64) try: ci = int(np.where(_sp == atomic_numbers[tetra_cfg["center_symbol"]])[0][0]) vi = int(np.where(_sp == atomic_numbers[tetra_cfg["vertex_symbol"]])[0][0]) v = float(np.asarray(_st.pair_peak, dtype=np.float64)[ci, vi]) if v > 0: effective_bond_length = v except (IndexError, KeyError): pass polys = tetra_cfg["detector"]( self.atoms, center_symbol=tetra_cfg["center_symbol"], vertex_symbol=tetra_cfg["vertex_symbol"], bond_length=effective_bond_length, bond_length_tol=tetra_cfg["bond_length_tol"], ideal_angle_deg=tetra_cfg["ideal_angle_deg"], angle_tol_deg=tetra_cfg["angle_tol_deg"], ) tet_centers = [t["center"] for t in polys] tet_vertex_idx = [v for t in polys for v in t["vertices"]] if tetra_cfg["per_polyhedron_topology"]: per_poly_faces = [p["faces"] for p in polys] per_poly_edges = [p["edges"] for p in polys] if show_bonds_eff: bi_all, bj_all, bd_all = neighbor_list( "ijd", self.atoms, float(bond_cutoff), ) mask = bi_all < bj_all bi = bi_all[mask].astype(np.int32) bj = bj_all[mask].astype(np.int32) else: bi = np.zeros(0, dtype=np.int32) bj = np.zeros(0, dtype=np.int32) # Atom metadata numbers = self.atoms.numbers colors = np.array([jmol_colors[z] for z in numbers], dtype=np.float32) radii = np.array([covalent_radii[z] for z in numbers], dtype=np.float32) # Cell matrix (for min-image bond wrapping in JS) cell_mat = np.asarray(self.atoms.cell.array, dtype=np.float32) # Centre all frames at the cell centre centre = 0.5 * np.sum(cell_mat, axis=0) trajectory_centered = trajectory - centre # Pack data data = { "num_frames": int(n_frames), "num_atoms": int(n_atoms), "num_bonds": int(len(bi)), "atom_colors": colors.ravel().tolist(), "atom_radii": radii.tolist(), "atom_scale": float(atom_scale), "bond_radius": float(bond_radius), "bond_i": bi.tolist(), "bond_j": bj.tolist(), "cell_matrix": cell_mat.ravel().tolist(), "positions": trajectory_centered.ravel().tolist(), "background_color": background_color, "title": title, } if atom_cost is not None: data["atom_cost"] = atom_cost.ravel().tolist() data["cost_min"] = float(cost_min) data["cost_max"] = float(cost_max) data["cost_label"] = "per-atom cost" data["tetrahedra_mode"] = tetra_cfg is not None data["tetrahedra_centers"] = list(tet_centers) data["tetrahedra_vertex_indices"] = list(tet_vertex_idx) data["num_tetrahedra"] = len(tet_centers) data["tetrahedra_color"] = list(poly_color) data["tetrahedra_opacity"] = float(poly_opacity) # Generic polyhedra topology so the viewer can render tets # (n=4), octahedra (n=6), or cuboctahedra (n=12). data["polyhedra_n_vertices"] = ( int(tetra_cfg["n_vertices"]) if tetra_cfg else 4 ) data["polyhedra_faces"] = ( tetra_cfg["faces"] if (tetra_cfg and tetra_cfg["faces"] is not None) else _TET_FACES ) data["polyhedra_edges"] = ( tetra_cfg["edges"] if (tetra_cfg and tetra_cfg["edges"] is not None) else _TET_EDGES ) data["polyhedra_per_polyhedron_topology"] = bool( tetra_cfg and tetra_cfg["per_polyhedron_topology"] ) data["polyhedra_faces_per_poly"] = per_poly_faces data["polyhedra_edges_per_poly"] = per_poly_edges data["polyhedra_scale"] = float(tetra_cfg["scale"]) if tetra_cfg else 1.0 # Multi-group polyhedra: emit a parallel ``polyhedra_groups`` # list, where each entry carries its own centers / # vertex_indices / faces / edges / color / opacity. Viewer # prefers this when present. groups_payload_traj: list[dict] = [] if polyhedra_groups: from ase.data import atomic_numbers as _an_traj for g in polyhedra_groups: gcfg = _resolve_group_cfg(g) # Resolve bond_length from shell_target if absent. eff_bl = gcfg["bond_length"] if eff_bl is None: _st = getattr(self, "_shell_target", None) if _st is not None: _sp = np.asarray(_st.species, dtype=np.int64) try: ci = int(np.where(_sp == _an_traj[gcfg["center_symbol"]])[0][0]) vi = int(np.where(_sp == _an_traj[gcfg["vertex_symbol"]])[0][0]) _v = float(np.asarray(_st.pair_peak, dtype=np.float64)[ci, vi]) if _v > 0: eff_bl = _v except (IndexError, KeyError): pass # Species filter. species_filter = None vsp = gcfg.get("virtual_species") if vsp is not None: asi = getattr(self, "_atom_shell_species_index", None) if asi is not None: species_filter = ( np.asarray(asi, dtype=np.intp) == int(vsp) ) detector_kwargs = dict( center_symbol=gcfg["center_symbol"], vertex_symbol=gcfg["vertex_symbol"], bond_length=eff_bl, bond_length_tol=gcfg["bond_length_tol"], ideal_angle_deg=gcfg["ideal_angle_deg"], angle_tol_deg=gcfg["angle_tol_deg"], ) try: polys_g = gcfg["detector"]( self.atoms, center_species_filter=species_filter, **detector_kwargs, ) except TypeError: polys_g = gcfg["detector"](self.atoms, **detector_kwargs) g_centers = [t["center"] for t in polys_g] g_vertex_idx = [v for t in polys_g for v in t["vertices"]] if gcfg["per_polyhedron_topology"]: g_per_poly_faces = [p["faces"] for p in polys_g] g_per_poly_edges = [p["edges"] for p in polys_g] else: g_per_poly_faces = [] g_per_poly_edges = [] groups_payload_traj.append( { "n_vertices": int(gcfg["n_vertices"]), "centers": g_centers, "vertex_indices": g_vertex_idx, "num": len(polys_g), "faces": gcfg["faces"] if gcfg["faces"] is not None else _TET_FACES, "edges": gcfg["edges"] if gcfg["edges"] is not None else _TET_EDGES, "per_polyhedron_topology": bool(gcfg["per_polyhedron_topology"]), "faces_per_poly": g_per_poly_faces, "edges_per_poly": g_per_poly_edges, "scale": float(gcfg["scale"]), "color": list(gcfg.get("color") or (0.25, 0.65, 0.95)), "opacity": float( gcfg.get("opacity") if gcfg.get("opacity") is not None else 0.35 ), } ) data["polyhedra_groups"] = groups_payload_traj data["polyhedra_multi_mode"] = bool(polyhedra_groups) html = _TRAJECTORY_HTML_TEMPLATE.replace( "__TRICOR_DATA_PLACEHOLDER__", json.dumps(data), ) output_path = str(output_path) with open(output_path, "w") as f: f.write(html) return output_path def export_g3_html( self: "Supercell", output_path: str, *, r_max: float = 10.0, r_step: float = 0.1, phi_num_bins: int = 45, background_color: str = "#f7f8f5", title: str = "", show_progress: bool = False, show_all_triplets: bool = False, ) -> str: """Export a static 2D g3 viewer as a self-contained HTML file. Renders one heatmap per species-triplet of the reduced three-body distribution (density / uniform, where ``1.0`` = white). The viewer uses a diverging RdBu_r colormap centred at ``1.0`` and lets the user pick the triplet and adjust the upper colour limit. A fresh :class:`G3Distribution` is measured from the current atoms on the coarse export grid (by default 50 x 45 bins per triplet) so the embedded JSON stays small (~500 KB) without affecting anything the supercell itself has cached. Parameters ---------- output_path Path to write the HTML file. r_max, r_step, phi_num_bins Measurement grid for the exported distribution. background_color, title Cosmetic. show_progress Forwarded to the g3 measurement call. show_all_triplets If True, the viewer renders a grid of all triplet heatmaps simultaneously (sharing one legend and colour scale) instead of the interactive single-panel view. Useful for multi- species cases like SiO\u2082 where it's otherwise unclear which triplet channel is being displayed. """ import json from .g3 import G3Distribution dist = G3Distribution(self.atoms, label=f"{self.label}-export-g3") dist.measure_g3( r_max=r_max, r_step=r_step, phi_num_bins=phi_num_bins, show_progress=show_progress, ) if dist.g3 is None: raise ValueError("Measured distribution has no g3 array.") r = np.asarray(dist.r, dtype=np.float64) r_edges = np.asarray(dist.bin_edges, dtype=np.float64) phi_deg = np.asarray(dist.phi_deg, dtype=np.float64) phi_edges = np.asarray(dist.phi_edges, dtype=np.float64) phi_edges_deg = np.rad2deg(phi_edges) labels = list(dist.pair_labels) num_triplets = int(dist.g3.shape[0]) # Per-triplet pair_peak from the supercell's shell target (if present). # Used to seed the shell-mask peak search; much more robust than # falling back to "first local maximum" on a noisy 20x20x20 g(r). shell_target = getattr(self, "_shell_target", None) pair_peak_matrix = None if shell_target is not None: pair_peak_matrix = np.asarray(shell_target.pair_peak, dtype=np.float64) g3_index = getattr(dist, "g3_index", None) # Per-triplet root-bond label ("Si-O" etc.) for the g(r) panel # below each heatmap, derived from g3_index + species_labels so it # matches however the user labels each species. species_labels = [str(s) for s in getattr(dist, "species_labels", None) or []] if not species_labels: from ase.data import chemical_symbols as _chemical_symbols species_labels = [_chemical_symbols[int(z)] for z in dist.species] profile_labels: list[str] = [] for ti in range(num_triplets): if g3_index is not None: center_ind, neigh1_ind, _neigh2_ind = g3_index[ti] profile_labels.append( f"g(r) {species_labels[int(center_ind)]}-" f"{species_labels[int(neigh1_ind)]}" ) else: profile_labels.append("g(r)") slices = [] pair_profiles = [] shell_ranges = [] for ti in range(num_triplets): triplet_data = np.asarray(dist.g3[ti], dtype=np.float64) triplet_pair_peak = None if pair_peak_matrix is not None and g3_index is not None: center_ind, neigh1_ind, _neigh2_ind = g3_index[ti] val = float(pair_peak_matrix[int(center_ind), int(neigh1_ind)]) if np.isfinite(val) and val > 0: triplet_pair_peak = val shell_mask = _detect_shell_mask( triplet_data, r, pair_peak=triplet_pair_peak, ) if not np.any(shell_mask): shell_mask = np.zeros_like(r, dtype=bool) shell_mask[: max(1, r.size // 4)] = True img = _g3_slice_image(triplet_data, shell_mask, r, phi_deg) slices.append(img.ravel().astype(np.float32).tolist()) profile = _g3_pair_profile(dist, ti, r) pair_profiles.append(profile.tolist()) idx = np.flatnonzero(shell_mask) shell_min = float(r_edges[int(idx[0])]) shell_max = float(r_edges[int(idx[-1]) + 1]) shell_ranges.append([shell_min, shell_max]) # Default triplet: prefer a canonical tetrahedral channel when # present (Si|O O for silica, Si|Si Si for silicon, etc.). # Labels are like "Si | O O"; compare the center/neighbour # symbols after normalising whitespace. default_triplet = 0 preferred = [ "Si | O O", "Si | Si Si", "C | C C", "Cu | Cu Cu", "Ti | O O", "Sr | O O", ] norm_labels = [str(lab).replace(" ", " ").strip() for lab in labels] for candidate in preferred: if candidate in norm_labels: default_triplet = norm_labels.index(candidate) break data = { "num_r": int(r.size), "num_phi": int(phi_deg.size), "num_triplets": num_triplets, "r_centers": r.tolist(), "r_edges": r_edges.tolist(), "phi_centers_deg": phi_deg.tolist(), "phi_edges_deg": phi_edges_deg.tolist(), "triplet_labels": labels, "profile_labels": profile_labels, "default_triplet": int(default_triplet), "slices": slices, "pair_profiles": pair_profiles, "shell_ranges": shell_ranges, "background_color": background_color, "title": title, "show_all": bool(show_all_triplets), } html = _G3_HTML_TEMPLATE.replace( "__TRICOR_DATA_PLACEHOLDER__", json.dumps(data), ) output_path = str(output_path) with open(output_path, "w") as f: f.write(html) return output_path def export_g2_html( self: "Supercell", output_path: "str | None" = None, *, r_max: float = 10.0, r_step: float = 0.05, background_color: str = "#f7f8f5", title: str = "", show_progress: bool = False, ) -> str: """Export a standalone interactive 2D g(r) viewer. Shows the per-species-pair pair-correlation function g_{AB}(r) as a single g(r) plot with a dropdown to switch species pair, and an "overlay all" checkbox to compare all pairs on one axis. Essentially the bottom panel of :meth:`export_g3_html` lifted out on its own with a species-pair selector - useful as a quick 2-body PDF viewer without any angular content. Parameters ---------- output_path Path to write the HTML file. When ``None`` the HTML is not written to disk; the raw HTML string is still returned so the caller can embed it directly (for example via :func:`IPython.display.HTML`) - see also :meth:`plot_g2` for a ready-made Jupyter wrapper. r_max, r_step Radial grid for the measurement. A finer ``r_step`` (default 0.05 \u00c5) gives smoother curves than the coarse 0.1 \u00c5 grid used by the g3 viewer. background_color, title Cosmetic. show_progress Forwarded to the underlying :meth:`measure_g3` call (this exporter reuses the g3 machinery because the g2 array is a by-product; ``phi_num_bins`` is set low for speed). Returns ------- str The resolved output path when ``output_path`` was provided, otherwise the HTML source string. """ import json from .g3 import G3Distribution from ase.data import chemical_symbols dist = G3Distribution(self.atoms, label=f"{self.label}-export-g2") # Low phi_num_bins because we don't use the angular data here; # the speedup is significant for large boxes. dist.measure_g3( r_max=r_max, r_step=r_step, phi_num_bins=12, show_progress=show_progress, ) if dist.g2 is None: raise ValueError("Measured distribution has no g2 array.") r = np.asarray(dist.r, dtype=np.float64) r_edges = np.asarray(dist.bin_edges, dtype=np.float64) species = np.asarray(dist.species, dtype=np.int64) sp_labels = [chemical_symbols[int(z)] for z in species] # g2[ci, vi] is the pair-specific RDF; (ci, vi) and (vi, ci) # differ by a constant prefactor but are otherwise identical. # Emit one curve per unique (ci <= vi) pair and drop the # symmetric duplicate. num_species = len(species) g2 = np.asarray(dist.g2, dtype=np.float64) # Normalise each pair profile so the tail converges to 1.0 # (consistent with the g3 viewer's profile panel). def _norm_profile(y: np.ndarray, r_arr: np.ndarray) -> np.ndarray: y = y / np.maximum(r_arr * r_arr, _EPS) tail_start = 0.7 * float(r_arr[-1]) tail_mask = r_arr >= tail_start if not np.any(tail_mask): tail_mask = np.zeros_like(r_arr, dtype=bool) tail_mask[-max(1, r_arr.size // 4):] = True tail = y[tail_mask] finite = tail[np.isfinite(tail)] scale = float(np.mean(finite)) if finite.size else 1.0 if scale <= _EPS: scale = 1.0 return (y / scale).astype(np.float32) pair_labels: list[str] = [] profiles: list[list[float]] = [] pair_peaks: list[float] = [] shell_target = getattr(self, "_shell_target", None) pair_peak_matrix = None if shell_target is not None: pair_peak_matrix = np.asarray( shell_target.pair_peak, dtype=np.float64 ) # Resolve a shell-target species index for each dist species. st_species_index: dict[int, int] = {} if shell_target is not None: st_species = np.asarray(shell_target.species, dtype=np.int64) for ci_dist, z in enumerate(species): match = np.where(st_species == int(z))[0] if match.size: st_species_index[ci_dist] = int(match[0]) for ci in range(num_species): for vi in range(ci, num_species): prof = _norm_profile(g2[ci, vi], r) pair_labels.append(f"{sp_labels[ci]}-{sp_labels[vi]}") profiles.append(prof.tolist()) peak = 0.0 if pair_peak_matrix is not None: ci_st = st_species_index.get(ci) vi_st = st_species_index.get(vi) if ci_st is not None and vi_st is not None: v = float(pair_peak_matrix[ci_st, vi_st]) if np.isfinite(v) and v > 0: peak = v pair_peaks.append(peak) # Default pair: prefer shortest-bond cross-species if available, # else first pair. default_pair = 0 if pair_peaks: # Exclude self-pairs with zero peak; pick the smallest positive peak. positive = [(i, p) for i, p in enumerate(pair_peaks) if p > 0] if positive: default_pair = min(positive, key=lambda x: x[1])[0] data = { "num_r": int(r.size), "num_pairs": len(pair_labels), "r_centers": r.tolist(), "r_edges": r_edges.tolist(), "pair_labels": pair_labels, "pair_peaks": pair_peaks, "profiles": profiles, "default_pair": int(default_pair), "background_color": background_color, "title": title, } html = _G2_HTML_TEMPLATE.replace( "__TRICOR_DATA_PLACEHOLDER__", json.dumps(data), ) if output_path is None: return html output_path = str(output_path) with open(output_path, "w") as f: f.write(html) return output_path def plot_g2( self: "Supercell", *, r_max: float = 10.0, r_step: float = 0.05, title: str = "", height: int = 420, show_progress: bool = False, ): """Return an inline Jupyter display of the g(r) pair-correlation viewer. Convenience wrapper around :meth:`export_g2_html` that packages the HTML as an :class:`IPython.display.HTML` object so you can just do ``cells['MRO'].plot_g2()`` in a notebook cell. The viewer is embedded via a ``srcdoc`` iframe so it renders isolated from the surrounding notebook CSS / JS. Parameters ---------- r_max, r_step, title, show_progress Forwarded to :meth:`export_g2_html`. height Iframe height in pixels. """ from IPython.display import HTML html = self.export_g2_html( None, r_max=r_max, r_step=r_step, title=title, show_progress=show_progress, ) import html as _html escaped = _html.escape(html, quote=True) # Wrap the iframe in a <div> so IPython's HTML helper doesn't # trigger its "Consider using IPython.display.IFrame instead" # warning - IFrame requires a URL or file path and can't embed # our self-contained HTML directly, so srcdoc is what we want. return HTML( f'<div class="tricor-g2-wrapper" style="width:100%">' f'<iframe srcdoc="{escaped}" ' f'style="width:100%; height:{int(height)}px; ' f'border:1px solid rgba(0,0,0,0.1); border-radius:6px;"></iframe>' f'</div>' ) def plot_structure( self: "Supercell", shell_target: "CoordinationShellTarget | None" = None, *, output: str | None = None, width: int = 1024, height: int = 1024, fps: int = 60, duration: float = 6.0, elevation: float = 15.0, atom_size: float = 10.0, bond_cutoff: float | None = None, show_cell: bool = True, show_atoms: bool = True, background: str = "white", colormap: str = "Reds", tetrahedral_thresh: float = 0.4, show_progress: bool = True, ): """Render a bond-centric rotating 3D view of the atomic structure. Bonds are the primary visual: crystalline (tetrahedral) bonds are drawn thick and coloured by depth; boundary / amorphous bonds are drawn faint. Atoms are optional small dots. The animation performs a full periodic 360-degree rotation. Classification follows the MATLAB ``plotAtoms02`` convention: an atom is *crystalline* if it has exactly K nearest neighbours within *bond_cutoff* **and** the mean displacement of those neighbours is less than *tetrahedral_thresh* (i.e. the local coordination is symmetric / tetrahedral). Parameters ---------- shell_target First-shell targets. Used to set *bond_cutoff* and the coordination number K automatically. output File path for a ``.mp4`` (recommended) or ``.gif``. ``None`` shows a static figure. width, height Frame size in pixels. fps Frames per second (GIF only). duration Total GIF length in seconds. Rotation is always exactly 360 degrees so the loop is seamless. elevation Camera elevation in degrees. atom_size Matplotlib scatter marker size. Set to 0 to hide atoms. bond_cutoff NN bond length cutoff in Angstrom. show_cell Draw the periodic cell outline. show_atoms Draw atom dots. background Figure background colour. colormap Matplotlib colormap for crystalline bonds (coloured by depth / y-coordinate after rotation, like MATLAB ``bone``). tetrahedral_thresh Maximum norm of mean NN displacement vector for an atom to be classified as crystalline. Smaller = stricter. show_progress Print frame counter during GIF rendering. """ import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Line3DCollection pos = self.atoms.positions.copy() cell_mat = np.asarray(self.atoms.cell.array, dtype=np.float64) cell_inv = np.linalg.inv(cell_mat) num_atoms = len(self.atoms) # --- bond cutoff --- if bond_cutoff is None: if shell_target is not None: # Use a generous cutoff (pair_peak + 3*sigma or 1.2x) pair_peak_max = float(np.max( np.asarray(shell_target.pair_peak, dtype=np.float64), )) bond_cutoff = pair_peak_max * 1.2 else: bond_cutoff = 3.0 if shell_target is not None: coord_target = np.asarray(shell_target.coordination_target, dtype=np.float64) species_idx = ( self._atom_shell_species_index if getattr(self, "_atom_shell_species_index", None) is not None else self._atom_species_index ) k_per_atom = np.array([ int(np.round(coord_target[species_idx[a]].sum())) for a in range(num_atoms) ], dtype=np.intp) else: k_per_atom = np.full(num_atoms, 4, dtype=np.intp) # --- find bonds --- bi_all, bj_all, bd_all = neighbor_list("ijd", self.atoms, bond_cutoff) def _min_image(delta: np.ndarray) -> np.ndarray: frac = delta @ cell_inv frac -= np.rint(frac) return frac @ cell_mat # --- classify atoms as crystalline --- # Two criteria (either makes an atom crystalline): # 1) Tetrahedral check (MATLAB style): K NN within cutoff and # symmetric coordination (mean displacement < thresh). # Allow K-1 to K+1 neighbors for tolerance. # 2) Grain interior: deep inside a crystalline Voronoi cell. is_crystalline_atom = np.zeros(num_atoms, dtype=bool) # Criterion 1: tetrahedral / symmetric coordination for a in range(num_atoms): mask = bi_all == a nn_count = int(np.sum(mask)) k_target = int(k_per_atom[a]) if nn_count < max(k_target - 1, 1) or nn_count > k_target + 1: continue # Use K nearest for the displacement check dists_a = bd_all[mask] js_a = bj_all[mask] order = np.argsort(dists_a)[:k_target] dxyz = _min_image(pos[js_a[order]] - pos[a]) mean_disp = np.linalg.norm(np.mean(dxyz, axis=0)) if mean_disp < tetrahedral_thresh: is_crystalline_atom[a] = True # Criterion 2: grain interior atoms (always crystalline) grain_ids = self._grain_ids grain_seeds = self._grain_seeds if grain_ids is not None and grain_seeds is not None: pp_max = float(np.max( np.asarray(shell_target.pair_peak, dtype=np.float64), )) if shell_target is not None else 2.5 bw = pp_max * 0.5 delta_seeds = pos[:, None, :] - grain_seeds[None, :, :] frac_ds = delta_seeds @ cell_inv frac_ds -= np.rint(frac_ds) cart_ds = frac_ds @ cell_mat dist_to_seeds = np.sqrt(np.sum(cart_ds ** 2, axis=2)) for ia in range(num_atoms): gid = grain_ids[ia] if gid < 0: continue d_own = dist_to_seeds[ia, gid] dists_copy = dist_to_seeds[ia].copy() dists_copy[gid] = np.inf d_other = float(np.min(dists_copy)) if (d_other - d_own) * 0.5 > bw: is_crystalline_atom[ia] = True # Keep i < j for unique bonds mask_ij = bi_all < bj_all bi, bj = bi_all[mask_ij], bj_all[mask_ij] # Bond is crystalline if BOTH endpoints are crystalline bond_is_cryst = is_crystalline_atom[bi] & is_crystalline_atom[bj] # Bond segment endpoints (minimum-image) bond_vecs = _min_image(pos[bj] - pos[bi]) bond_starts = pos[bi] bond_ends = bond_starts + bond_vecs # --- center everything --- a_vec, b_vec, c_vec = cell_mat[0], cell_mat[1], cell_mat[2] center = 0.5 * (a_vec + b_vec + c_vec) pos_c = pos - center bstart_c = bond_starts - center bend_c = bond_ends - center # Cell outline edges o = -center cell_corners = [ o, o + a_vec, o + b_vec, o + c_vec, o + a_vec + b_vec, o + a_vec + c_vec, o + b_vec + c_vec, o + a_vec + b_vec + c_vec, ] cell_edge_pairs = [ (0, 1), (0, 2), (0, 3), (1, 4), (1, 5), (2, 4), (2, 6), (3, 5), (3, 6), (4, 7), (5, 7), (6, 7), ] cell_segs = [(cell_corners[i], cell_corners[j]) for i, j in cell_edge_pairs] # --- colormap for crystalline bonds (depth-coloured) --- cmap = plt.get_cmap(colormap) cryst_mask = bond_is_cryst bnd_mask = ~bond_is_cryst extent = float(np.max(np.abs(pos_c))) * 1.15 dpi = 100 figsize = (width / dpi, height / dpi) def _rotate_2d(pts: np.ndarray, theta: float) -> np.ndarray: """Rotate x-y columns by theta radians (in-place friendly).""" c, s = np.cos(theta), np.sin(theta) x_new = pts[:, 0] * c - pts[:, 1] * s y_new = pts[:, 0] * s + pts[:, 1] * c out = pts.copy() out[:, 0] = x_new out[:, 1] = y_new return out def _draw_frame(theta_rad: float) -> "plt.Figure": fig = plt.figure(figsize=figsize, dpi=dpi) ax = fig.add_subplot(111, projection="3d") ax.set_facecolor(background) fig.patch.set_facecolor(background) # Perspective projection (like MATLAB camproj('perspective')) try: ax.set_proj_type("persp", focal_length=0.25) except (TypeError, AttributeError): pass # older matplotlib # Rotate bond endpoints in x-y plane (like MATLAB) bs_r = _rotate_2d(bstart_c, theta_rad) be_r = _rotate_2d(bend_c, theta_rad) # --- boundary bonds: very faint --- if np.any(bnd_mask): segs_b = list(zip(bs_r[bnd_mask], be_r[bnd_mask])) lc_b = Line3DCollection( segs_b, linewidths=0.3, colors=(0.0, 0.0, 0.0, 0.05), ) ax.add_collection3d(lc_b) # --- crystalline bonds: depth-coloured + depth-width --- # Camera at +x (azim=0): larger rotated-x = closer. # Linear depth with floor so far bonds remain visible. if np.any(cryst_mask): segs_cr = list(zip(bs_r[cryst_mask], be_r[cryst_mask])) mid_x_rot = 0.5 * (bs_r[cryst_mask, 0] + be_r[cryst_mask, 0]) norm_depth = (mid_x_rot + extent) / max(2.0 * extent, _EPS) norm_depth = np.clip(norm_depth, 0, 1) # Colormap: 0.15 at back (faint but visible), 0.95 at front cryst_colors = cmap(0.15 + 0.8 * norm_depth) # Linewidth: 0.4 at back, 2.0 at front cryst_lw = 0.4 + 1.6 * norm_depth lc_c = Line3DCollection( segs_cr, linewidths=cryst_lw, colors=cryst_colors, ) ax.add_collection3d(lc_c) # --- cell outline --- if show_cell: cell_segs_r = [] for s, e in cell_segs: s_r = _rotate_2d(s.reshape(1, 3), theta_rad)[0] e_r = _rotate_2d(e.reshape(1, 3), theta_rad)[0] cell_segs_r.append((s_r, e_r)) lc_cell = Line3DCollection( cell_segs_r, linewidths=1.5, colors="k", alpha=0.5, ) ax.add_collection3d(lc_cell) # --- atoms (tiny dots) --- if show_atoms and atom_size > 0: pos_r = _rotate_2d(pos_c, theta_rad) ax.scatter( pos_r[:, 0], pos_r[:, 1], pos_r[:, 2], s=atom_size, c="k", alpha=0.15, edgecolors="none", depthshade=False, ) ax.set_xlim(-extent, extent) ax.set_ylim(-extent, extent) ax.set_zlim(-extent, extent) ax.set_box_aspect([1, 1, 1]) ax.view_init(elev=elevation, azim=0) # azim fixed; we rotate data ax.axis("off") fig.subplots_adjust(left=-0.05, right=1.05, bottom=-0.05, top=1.05) return fig # --- static display or animation --- if output is None: return _draw_frame(theta_rad=np.deg2rad(45.0)) # Full 360-degree periodic rotation n_frames = int(fps * duration) thetas = np.linspace(0, 2 * np.pi, n_frames, endpoint=False) if show_progress: progress = _TextProgressBar(n_frames, label="Rendering", width=28) else: progress = None is_mp4 = str(output).lower().endswith(".mp4") if is_mp4: # MP4 via ffmpeg subprocess - true 60fps import subprocess import io # Frame dimensions from figure size (no bbox_inches="tight" # for the raw pipe - keeps pixel count deterministic). fw, fh = width, height ffmpeg_cmd = [ "ffmpeg", "-y", "-loglevel", "error", "-f", "rawvideo", "-pix_fmt", "rgba", "-s", f"{fw}x{fh}", "-r", str(fps), "-i", "pipe:0", "-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "fast", "-crf", "18", str(output), ] proc = subprocess.Popen( ffmpeg_cmd, stdin=subprocess.PIPE, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, ) for i, th in enumerate(thetas): fig_i = _draw_frame(th) buf_i = io.BytesIO() fig_i.savefig(buf_i, format="raw", dpi=dpi, facecolor=background) plt.close(fig_i) proc.stdin.write(buf_i.getvalue()) if progress is not None: progress.update(i + 1) proc.stdin.close() proc.wait() else: # GIF fallback via Pillow from PIL import Image import io frames: list[Image.Image] = [] for i, th in enumerate(thetas): fig = _draw_frame(th) buf = io.BytesIO() fig.savefig( buf, format="png", dpi=dpi, bbox_inches="tight", pad_inches=0, facecolor=background, ) plt.close(fig) buf.seek(0) frames.append(Image.open(buf).copy()) buf.close() if progress is not None: progress.update(i + 1) frame_duration_ms = int(1000 / fps) frames[0].save( output, save_all=True, append_images=frames[1:], duration=frame_duration_ms, loop=0, optimize=True, ) if progress is not None: progress.update(n_frames) return output