#!/usr/bin/env python
"""Standalone reproducer: Cu Amorphous → orient → MACE+wall → movies.

Generated by scripts/regen_mace_examples.py.  Requires tricor
(install from GitHub), torch and mace-torch:
    pip install torch "mace-torch>=0.3"
Structure CIFs (for oxide / carbon cases) live in the tricor-docs
docs/structures/ directory — run from there or edit the path.
    python copper_amorphous_generate.py
"""
import numpy as np
from ase.optimize import LBFGS
from ase.neighborlist import neighbor_list
from ase.calculators.calculator import Calculator, all_changes
from mace.calculators import mace_mp
import tricor as tc
from tricor.shells import CoordinationShellTarget

BOX = 40.0
RNG_SEED = 2026
KW = {'grain_size': None, 'displacement_sigma': 0.03, 'bond_weight': 3.0, 'angle_weight': 0.0, 'repulsion_weight': 1.5, 'hard_core_scale': 0.95, 'nonbond_push_scale': 1.0}
VIZ = {'cuboctahedra': {'center_symbol': 'Cu', 'vertex_symbol': 'Cu'}, 'cuboctahedra_color': (0.85, 0.45, 0.2), 'cuboctahedra_opacity': 0.22}
TITLE = 'Cu Amorphous'
OUT_PREFIX = "copper_amorphous"
STEPS = 50
BOND_RELAX_N_ITER = 80
HARD_CORE_N_ITER = 40
MOVIE_FRAMES = 24
MACE_MODEL = 'medium-mpa-0'
WALL_K = 1000.0
WALL_EXPONENT = 4

# ---- per-pair soft wall (silent above each per-pair minimum) ----
def per_pair_min_from_atoms(atoms, margin=0.0):
    i, j, d = neighbor_list("ijd", atoms, 8.0, self_interaction=False)
    z = atoms.numbers
    out = {}
    lo = np.minimum(z[i], z[j])
    hi = np.maximum(z[i], z[j])
    for za, zb in np.unique(np.stack([lo, hi], 1), axis=0):
        m = (lo == za) & (hi == zb)
        if m.any():
            out[(int(za), int(zb))] = float(d[m].min()) - margin
    return out


class MinDistanceWallCalculator(Calculator):
    implemented_properties = ["energy", "free_energy", "forces", "stress"]

    def __init__(self, base, rmin, k=1000.0, exponent=4):
        super().__init__()
        self.base = base
        self.rmin = {(min(a, b), max(a, b)): float(r) for (a, b), r in rmin.items()}
        self.k = k
        self.n = exponent
        self.maxr = max(self.rmin.values()) if self.rmin else 0.0

    def calculate(self, atoms=None, properties=("energy",),
                  system_changes=all_changes):
        super().calculate(atoms, properties, system_changes)
        self.base.calculate(atoms, properties, system_changes)
        E = float(self.base.results["energy"])
        F = np.asarray(self.base.results["forces"], float).copy()
        if self.maxr > 0:
            i, j, d, D = neighbor_list("ijdD", atoms, self.maxr,
                                       self_interaction=False)
            if len(d):
                z = atoms.numbers
                lo = np.minimum(z[i], z[j])
                hi = np.maximum(z[i], z[j])
                rmp = np.zeros(len(d))
                for (za, zb), r in self.rmin.items():
                    rmp[(lo == za) & (hi == zb)] = r
                act = (rmp > 0) & (d < rmp)
                if act.any():
                    gap = rmp[act] - d[act]
                    E += float(0.5 * np.sum((self.k / self.n) * gap ** self.n))
                    f = self.k * gap ** (self.n - 1)
                    np.add.at(F, i[act], -(f[:, None]) * (D[act] / d[act, None]))
        self.results = dict(energy=E, free_energy=E, forces=F,
                            stress=np.zeros(6))

# ---- build cell ----
from ase.build import bulk
ref = bulk("Cu", "fcc", a=3.615)
shell = CoordinationShellTarget.from_atoms(ref, phi_num_bins=36)
cell = tc.Supercell.from_atoms(ref, cell_dim_angstroms=(BOX,) * 3,
    r_max=10.0, r_step=0.1, phi_num_bins=36, rng_seed=RNG_SEED)
gen = dict(KW)
cell.generate(shell, num_steps=0, show_progress=False, **gen)

# ---- orientation refinement (captures the orient movie) ----
orient_traj = None
if KW.get("grain_size") is not None:
    cell.refine_initial_orientations(
        shell, bond_weight=KW["bond_weight"],
        angle_weight=KW["angle_weight"],
        repulsion_weight=KW["repulsion_weight"],
        hard_core_scale=KW["hard_core_scale"],
        nonbond_push_scale=KW["nonbond_push_scale"],
        capture_trajectory=True, show_progress=False)
    orient_traj = cell.refine_initial_orientations_history.get("trajectory")

# ---- cleanup ----
cell.bond_relax(shell, n_iter=BOND_RELAX_N_ITER, max_step=0.1)
cell.enforce_hard_core(shell, n_iter=HARD_CORE_N_ITER)

# ---- MACE + wall ----
calc = mace_mp(model=MACE_MODEL, device="cpu", default_dtype="float32")
atoms = cell.atoms.copy()
atoms.calc = MinDistanceWallCalculator(calc, per_pair_min_from_atoms(atoms),
                                       k=WALL_K, exponent=WALL_EXPONENT)
mace_traj = [atoms.positions.copy()]
opt = LBFGS(atoms, maxstep=0.1, logfile="-")
opt.attach(lambda: mace_traj.append(atoms.positions.copy()), interval=1)
opt.run(fmax=1e-8, steps=STEPS)


def _subsample(seq, n=MOVIE_FRAMES):
    if len(seq) <= n:
        return list(seq)
    idx = np.unique(np.linspace(0, len(seq) - 1, n).round().astype(int))
    return [seq[i] for i in idx]


if orient_traj is not None and len(orient_traj):
    fr = _subsample(list(orient_traj))
    cell.atoms.positions = np.asarray(fr[0])
    cell.export_trajectory_html(
        OUT_PREFIX + "_orient_movie.html",
        title=TITLE + " — orientation refinement",
        history={"trajectory": np.asarray(fr, np.float32)}, **VIZ)

fr = _subsample(mace_traj)
cell.atoms.positions = np.asarray(fr[0])
cell.export_trajectory_html(
    OUT_PREFIX + "_mace_movie.html",
    title=TITLE + " — MACE+wall LBFGS",
    history={"trajectory": np.asarray(fr, np.float32)}, **VIZ)
print("wrote movies for", TITLE)
