"""Implementation agnostics visualisation functions."""
from itertools import product
from math import ceil, cos, radians, sin
from typing import List
import numpy as np
import ase_notebook.draw_elements as draw
[docs]def triangle_normal(a, b, c):
"""Compute the normal of three points."""
a, b, c = [np.array(i) for i in (a, b, c)]
return np.cross(b - a, c - a).tolist()
[docs]def compute_projection(element_group, wsize, rotation, whitespace=1.3):
"""Compute the center and scale of the projection."""
element_group.update_positions(rotation)
min_coord, max_coord = element_group.get_position_range()
center = np.dot(rotation, (min_coord + max_coord) / 2)
s = whitespace * (max_coord - min_coord)
width, height = wsize
if s[0] * height < s[1] * width:
scale = height / s[1]
elif s[0] > 0.0001:
scale = width / s[0]
else:
scale = 1.0
return center, scale
[docs]def get_rotation_matrix(rotations, init_rotation=None):
"""Convert string of format '50x,-10y,120z' to a rotation matrix.
Note that the order of rotation matters, i.e. '50x,40z' is different
from '40z,50x'.
"""
if init_rotation is None:
rotation = np.identity(3)
else:
rotation = init_rotation
if rotations == "":
return rotation
for i, a in [
("xyz".index(s[-1]), radians(float(s[:-1]))) for s in rotations.split(",")
]:
s = sin(a)
c = cos(a)
if i == 0:
rotation = np.dot(rotation, [(1, 0, 0), (0, c, s), (0, -s, c)])
elif i == 1:
rotation = np.dot(rotation, [(c, 0, -s), (0, 1, 0), (s, 0, c)])
else:
rotation = np.dot(rotation, [(c, s, 0), (-s, c, 0), (0, 0, 1)])
return rotation
[docs]def get_cell_coordinates(
cell, origin=(0.0, 0.0, 0.0), show_repeats=None, dash_pattern=None
):
"""Get start and end points of lines segments used to draw unit cells.
We also add an origin option, to allow for different cells to be created.
"""
reps_a, reps_b, reps_c = show_repeats or (1, 1, 1)
vec_a, vec_b, vec_c = cell
has_a = np.linalg.norm(vec_a) > 1e-9
has_b = np.linalg.norm(vec_b) > 1e-9
has_c = np.linalg.norm(vec_c) > 1e-9
vec_a = vec_a / reps_a
vec_b = vec_b / reps_b
vec_c = vec_c / reps_c
lines = []
for rep_a, rep_b, rep_c in product(
*(range(1, reps_a + 1), range(1, reps_b + 1), range(1, reps_c + 1))
):
rep_origin = (
np.array(origin)
+ (rep_a - 1) * vec_a
+ (rep_b - 1) * vec_b
+ (rep_c - 1) * vec_c
)
if has_a:
lines.append([rep_origin, rep_origin + vec_a])
if has_b:
lines.append([rep_origin, rep_origin + vec_b])
if has_c:
lines.append([rep_origin, rep_origin + vec_c])
if has_a and has_b:
lines.extend(
[
[rep_origin + vec_a, rep_origin + vec_a + vec_b],
[rep_origin + vec_a + vec_b, rep_origin + vec_b],
]
)
if has_a and has_c:
lines.extend(
[
[rep_origin + vec_a, rep_origin + vec_a + vec_c],
[rep_origin + vec_c, rep_origin + vec_c + vec_a],
]
)
if has_b and has_c:
lines.extend(
[
[rep_origin + vec_b, rep_origin + vec_b + vec_c],
[rep_origin + vec_c, rep_origin + vec_c + vec_b],
]
)
if has_a and has_b and has_c:
lines.extend(
[
[rep_origin + vec_a + vec_b, rep_origin + vec_a + vec_b + vec_c],
[rep_origin + vec_c + vec_a, rep_origin + vec_c + vec_a + vec_b],
[rep_origin + vec_c + vec_a + vec_b, rep_origin + vec_c + vec_b],
]
)
lines = np.array(lines, dtype=float)
if dash_pattern:
# split lines into a dash pattern
dlength, dgap = dash_pattern
new_lines = []
for (start, end) in lines:
new_start = start
total_length = np.linalg.norm(end - start)
dash_fraction = (dlength + dgap) / total_length
length_fraction = dlength / total_length
ndashes = int(ceil(total_length / (dlength + dgap)))
for n in range(ndashes - 1):
dash_end = start + (end - start) * (
(dash_fraction * n) + length_fraction
)
new_lines.append([new_start, dash_end])
new_start = start + (end - start) * dash_fraction * (n + 1)
# TODO remove last gap fraction (if present)
# or, better, start with a fraction of dlength, so start/end are symmetric
new_lines.append([new_start, end])
lines = np.array(new_lines, dtype=float)
return lines[:, 0], lines[:, 1]
[docs]def get_miller_coordinates(cell, miller):
"""Compute the points at which a miller index intercepts with a unit cell boundary."""
vec_a, vec_b, vec_c = np.array(cell, dtype=float)
h_val, k_val, l_val = miller
if h_val < 0 or k_val < 0 or l_val < 0:
# TODO compute negative miller intercepts
# look at script in https://www.doitpoms.ac.uk/tlplib/miller_indices/printall.php
# they appear to use a transpose
raise NotImplementedError("h, k or l less than zero")
h_is_zero, k_is_zero, l_is_zero = np.isclose(miller, 0)
mod_a = np.inf if h_is_zero else vec_a / h_val
mod_b = np.inf if k_is_zero else vec_b / k_val
mod_c = np.inf if l_is_zero else vec_c / l_val
if h_is_zero and k_is_zero and l_is_zero:
raise ValueError("h, k, l all 0")
elif k_is_zero and l_is_zero:
points = [mod_a, mod_a + vec_b, mod_a + vec_b + vec_c, mod_a + vec_c]
elif h_is_zero and l_is_zero:
points = [mod_b, mod_b + vec_a, mod_b + vec_a + vec_c, mod_b + vec_c]
elif h_is_zero and k_is_zero:
points = [mod_c, mod_c + vec_a, mod_c + vec_a + vec_b, mod_c + vec_b]
elif h_is_zero:
points = [mod_b, mod_c, mod_c + vec_a, mod_b + vec_a]
elif k_is_zero:
points = [mod_a, mod_c, mod_c + vec_b, mod_a + vec_b]
elif l_is_zero:
points = [mod_a, mod_b, mod_b + vec_c, mod_a + vec_c]
else:
points = [mod_a, mod_b, mod_c]
return np.array(points)
[docs]def compute_bonds(atoms, atom_radii, scale_radii=1.5):
"""Compute bonds for atoms."""
from ase.neighborlist import NeighborList
nl = NeighborList(atom_radii * scale_radii, skin=0, self_interaction=False)
nl.update(atoms)
nbonds = nl.nneighbors + nl.npbcneighbors
bonds = np.empty((nbonds, 5), int)
if nbonds == 0:
return bonds
n1 = 0
for a in range(len(atoms)):
indices, offsets = nl.get_neighbors(a)
n2 = n1 + len(indices)
bonds[n1:n2, 0] = a
bonds[n1:n2, 1] = indices
bonds[n1:n2, 2:] = offsets
n1 = n2
i = bonds[:n2, 2:].any(1)
pbc_bonds = bonds[:n2][i]
bonds[n2:, 0] = pbc_bonds[:, 1]
bonds[n2:, 1] = pbc_bonds[:, 0]
bonds[n2:, 2:] = -pbc_bonds[:, 2:]
return bonds
[docs]def filter_bond_indices(bonds, to_keep: List[bool]):
"""Filter bonds by required indices."""
keep = [i for i, b in enumerate(to_keep) if b]
index1 = np.isin(bonds[:, 0], keep)
# second index, only if it is not-periodic image
index2 = np.logical_and(
np.isin(bonds[:, 1], keep), np.equal(bonds[:, 2:], [0, 0, 0]).all(axis=1)
)
return bonds[np.logical_or(index1, index2)]
[docs]def initialise_element_groups(
atoms,
atom_radii,
show_unit_cell=True,
uc_dash_pattern=None,
show_bonds=False,
bond_radii_scale=1.5,
bond_array_name=None,
bond_pairs_filter=None,
bond_supercell=(1, 1, 1),
miller_planes=None,
miller_planes_as_lines=False,
):
"""Compute (untransformed) coordinates, for elements in the visualisation.
Parameters
----------
atoms : ase.Atoms
atom_radii : list or None
mapping of atom index to atomic radii
show_unit_cell : bool
show the unit cell
uc_dash_pattern : tuple or None
split unit cell lines into dash pattern (line_length, gap_length)
show_bonds : bool
show the atomic bonds
bond_radii_scale : float
Factor to scale atomic radii by, when computing bonds (via overlapping radii)
bond_array_name : str
The name of a boolean array on the Atoms, specifying which atoms that bonds
should be drawn for (if None, then all bonds are drawn).
bond_pairs_filter : list
A list of bond element pairs to filter by, e.g. [("Fe", "O"), ("Fe", "Fe")]
bond_supercell : tuple
the supercell of unit cell used for computing bonds
miller_planes: list[dict] or None
list of miller planes to project onto the unit cell
miller_planes_as_lines: bool
whether to create miller planes as a group of lines or a solid plane
Returns
-------
elements: dict
all_coordinates: numpy.array
"""
if show_unit_cell:
cvec_starts, cvec_ends = get_cell_coordinates(
atoms.cell,
show_repeats=atoms.info.get("unit_cell_repeat", None),
dash_pattern=uc_dash_pattern,
)
else:
cvec_starts = cvec_ends = np.zeros((0, 3))
el_cell_lines = {"coordinates": np.stack((cvec_starts, cvec_ends), axis=1)}
el_miller_lines = {"starts": [], "ends": [], "index": []}
el_miller_planes = {"coordinates": [], "index": []}
if miller_planes is not None:
for i, plane in enumerate(miller_planes):
miller_points = get_miller_coordinates(
atoms.cell, [plane[n] for n in "hkl"]
).tolist()
if miller_planes_as_lines:
el_miller_lines["starts"].extend(miller_points)
el_miller_lines["ends"].extend(miller_points[1:] + [miller_points[0]])
el_miller_lines["index"].extend([i for _ in miller_points])
else:
el_miller_planes["coordinates"].append(miller_points)
el_miller_planes["index"].append(i)
el_miller_lines["coordinates"] = np.stack(
(
el_miller_lines.pop("starts") or np.zeros((0, 3)),
el_miller_lines.pop("ends") or np.zeros((0, 3)),
),
axis=1,
)
# el_miller_planes["coordinates"] = np.array(
# el_miller_planes["coordinates"] or np.zeros((0, 4, 3)), dtype=float
# )
if show_bonds:
atomscopy = atoms.copy()
atomscopy.cell *= np.array(bond_supercell)[:, np.newaxis]
bonds = compute_bonds(atomscopy, atom_radii, bond_radii_scale)
if bond_array_name is not None:
bonds = filter_bond_indices(
bonds, atoms.get_array(bond_array_name).tolist()
)
if bond_pairs_filter is not None:
# ensure bi-directional
allowed = set(
[(a, b) for a, b in bond_pairs_filter]
+ [(b, a) for a, b in bond_pairs_filter]
)
symbols = atoms.get_chemical_symbols()
bonds = bonds[
[(symbols[i], symbols[j]) in allowed for i, j in bonds[:, 0:2]]
]
bond_atom_indices = [(bond[0], bond[1]) for bond in bonds]
else:
bonds = np.empty((0, 5), int)
bond_atom_indices = []
if len(bonds) > 0:
positions = atoms.positions
cell = np.array(bond_supercell)[:, np.newaxis] * atoms.cell
a = positions[bonds[:, 0]]
b = positions[bonds[:, 1]] + np.dot(bonds[:, 2:], cell) - a
bond_lengths = (b ** 2).sum(1) ** 0.5
r = 0.65 * atom_radii
x0 = (r[bonds[:, 0]] / bond_lengths).reshape((-1, 1))
x1 = (r[bonds[:, 1]] / bond_lengths).reshape((-1, 1))
bond_starts = a + b * x0
b *= 1.0 - x0 - x1
# This halves bond lengths for periodic images, it is present in the core
# ase viewer, but is confusing when comparing bond lengths:
# b[bonds[:, 2:].any(1)] *= 0.5
bond_ends = bond_starts + b
else:
bond_lengths = np.empty((0,))
bond_starts = bond_ends = np.empty((0, 3))
el_bond_lines = {
"coordinates": np.stack((bond_starts, bond_ends), axis=1),
"atom_index": bond_atom_indices,
"bond_lengths": bond_lengths,
}
return draw.DrawGroup(
[
draw.DrawElementsSphere("atoms", atoms.positions[:], atom_radii),
draw.DrawElementsLine("cell_lines", el_cell_lines["coordinates"]),
draw.DrawElementsLine(
"bond_lines",
el_bond_lines["coordinates"],
element_properties={
"atom_index": el_bond_lines["atom_index"],
"bond_length": bond_lengths,
},
),
draw.DrawElementsLine(
"miller_lines",
el_miller_lines["coordinates"],
element_properties={"index": el_miller_lines["index"]},
),
draw.DrawElementsPoly(
"miller_planes",
el_miller_planes["coordinates"],
element_properties={"index": el_miller_planes["index"]},
),
]
)