"""A module for creating visualisations of a structure."""
from copy import copy
import json
import subprocess
import sys
from time import sleep
from typing import Union
from ase.data import covalent_radii as ase_covalent_radii
from ase.data.colors import jmol_colors as ase_element_colors
import attr
import numpy as np
from .atom_info import create_info_lines
from .atoms_convert import convert_to_atoms, deserialize_atoms, serialize_atoms
from .backend.gui import AtomGui, AtomImages
from .backend.svg import (
create_axes_elements,
create_svg_document,
generate_svg_elements,
)
from .backend.threejs import (
create_world_axes,
generate_3js_render,
make_basic_gui,
RenderContainer,
)
from .color import lighten_webcolor
from .configuration import ViewConfig
from .data import load_data_file
from .draw_utils import (
compute_projection,
get_rotation_matrix,
initialise_element_groups,
)
[docs]class AseView:
"""Class for visualising ``ase.Atoms`` or ``pymatgen.Structure``."""
def __init__(self, config: Union[None, ViewConfig] = None, **kwargs):
"""This is replaced by ``SVGConfig`` docstring.""" # noqa: D401
if config is not None:
self._config = config
else:
self._config = ViewConfig(**kwargs)
def __copy__(self):
"""Return a copy of this instance."""
return self.__class__(copy(self._config))
[docs] def copy(self):
"""Return a copy of this instance."""
return self.__class__(copy(self._config))
@property
def config(self):
"""Return the visualisation configuration."""
return self._config
[docs] def get_config_as_dict(self):
"""Return the configuration as a JSONable dictionary."""
return attr.asdict(self._config)
[docs] def add_miller_plane(
self,
h,
k,
l,
*,
color="blue",
stroke_color=None,
stroke_width=1,
fill_opacity=0.5,
stroke_opacity=0.9,
reset=False,
):
"""Add a miller plane to the config.
Parameters
----------
h : int or float
k : int or float
l : int or float
color : str
color of plane
stroke_color : str or None
color of outline (if None, color is used)
stroke_width : int
width of outline
reset : bool
if True, remove any previously set miller planes
"""
plane = [
{
"h": h,
"k": k,
"l": l,
"fill_color": color,
"stroke_color": stroke_color or color,
"stroke_width": stroke_width,
"fill_opacity": fill_opacity,
"stroke_opacity": stroke_opacity,
}
]
if reset:
self.config.miller_planes = plane
else:
self.config.miller_planes = list(self.config.miller_planes) + plane
[docs] def get_element_colors(self):
"""Return mapping of element atomic number to (hex) color."""
if self.config.element_colors == "ase":
return [
"#{0:02X}{1:02X}{2:02X}".format(*(int(x * 255) for x in c))
for c in ase_element_colors
]
if self.config.element_colors == "vesta":
data = load_data_file("vesta_element_data.json")
return [
"#{0:02X}{1:02X}{2:02X}".format(*(int(x * 255) for x in (r, g, b)))
for r, g, b in zip(data["r"], data["g"], data["b"])
]
raise ValueError(self.config.element_colors)
[docs] def get_element_radii(self):
"""Return mapping of element atomic number to atom radii."""
if self.config.element_radii == "ase":
return ase_covalent_radii.copy()
if self.config.element_radii == "vesta":
data = load_data_file("vesta_element_data.json")
return data["radius"]
raise ValueError(self.config.element_radii)
[docs] def get_atom_colors(self, atoms):
"""Return mapping of atom index to (hex) color."""
if self.config.atom_color_by == "element":
element_colors = self.get_element_colors()
return [element_colors[z] for z in atoms.numbers]
if self.config.atom_color_by == "color_array":
return atoms.get_array(self.config.atom_color_array)
if self.config.atom_color_by == "index":
values = range(len(atoms))
elif self.config.atom_color_by == "tag":
values = atoms.get_tags()
elif self.config.atom_color_by == "magmom":
values = atoms.get_initial_magnetic_moments()
elif self.config.atom_color_by == "charge":
values = atoms.get_initial_charges()
elif self.config.atom_color_by == "velocity":
values = (atoms.get_velocities() ** 2).sum(1) ** 0.5
elif self.config.atom_color_by == "value_array":
values = atoms.get_array(self.config.atom_color_array)
else:
raise ValueError(self.config.atom_color_by)
return self.values_to_colors(
values, self.config.atom_colormap, self.config.atom_colormap_range
)
[docs] @staticmethod
def values_to_colors(values, cmap, cmap_range=(None, None)):
"""Map hex colors, to a list of values."""
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize, rgb2hex
cmap = get_cmap(cmap)
cmin, cmax = cmap_range
norm = Normalize(
vmin=min(values) if cmin is None else cmin,
vmax=max(values) if cmax is None else cmax,
)
return [rgb2hex(cmap(norm(v))[:3]) for v in values]
[docs] def get_atom_radii(self, atoms):
"""Return mapping of atom index to sphere radii."""
element_radii = self.get_element_radii()
radii = np.array([element_radii[z] for z in atoms.numbers])
radii *= self.config.radii_scale
return radii
[docs] def get_atom_labels(self, atoms):
"""Return mapping of atom index to text label."""
labels = None
if self.config.atom_label_by == "element":
if "occupancy" in atoms.info:
labels = [
",".join(atoms.info["occupancy"][t].keys())
for t in atoms.get_tags()
]
else:
labels = atoms.get_chemical_symbols()
elif self.config.atom_label_by == "index":
labels = list(range(len(atoms)))
elif self.config.atom_label_by == "tag":
labels = atoms.get_tags()
elif self.config.atom_label_by == "magmom":
labels = atoms.get_initial_magnetic_moments()
elif self.config.atom_label_by == "charge":
labels = atoms.get_initial_charges()
elif self.config.atom_label_by == "array":
labels = atoms.get_array(self.config.atom_label_array)
if labels is None:
raise ValueError(self.config.atom_label_by)
return [str(l) for l in labels]
[docs] def _initialise_elements(self, atoms, center_in_uc=False, repeat_uc=(1, 1, 1)):
"""Prepare visualisation elements, in a backend agnostic manner."""
config = self._config
atoms = convert_to_atoms(atoms)
if center_in_uc:
atoms.center()
atoms = atoms.repeat(repeat_uc)
if config.show_uc_repeats:
atoms.info["unit_cell_repeat"] = (
repeat_uc
if isinstance(config.show_uc_repeats, bool)
else config.show_uc_repeats
)
element_groups = initialise_element_groups(
atoms,
atom_radii=self.get_atom_radii(atoms),
show_unit_cell=config.show_unit_cell,
uc_dash_pattern=config.uc_dash_pattern,
show_bonds=config.show_bonds,
bond_radii_scale=config.bond_radii_scale,
bond_array_name=config.bond_array_name,
bond_pairs_filter=config.bond_pairs_filter,
miller_planes=config.miller_planes if config.show_miller_planes else None,
miller_planes_as_lines=config.miller_as_lines,
)
return atoms, element_groups
[docs] def _add_element_properties(
self, atoms, element_groups, bond_thickness, lighten_by_depth=True
):
"""Add initial properties to the element groups."""
config = self._config
atom_colors = self.get_atom_colors(atoms)
atom_labels = self.get_atom_labels(atoms)
ghost_atoms = (
atoms.get_array("ghost")
if "ghost" in atoms.arrays
else [False for _ in atoms]
)
if config.atom_lighten_by_depth and lighten_by_depth:
z_positions = element_groups["atoms"].get_max_zposition()
zmin, zmax = z_positions.min(), z_positions.max()
new_atom_colors = []
for atom_color, z_position in zip(atom_colors, z_positions):
atom_depth = (zmax - z_position) / (zmax - zmin)
atom_color = lighten_webcolor(
atom_color, atom_depth * config.atom_lighten_by_depth
)
new_atom_colors.append(atom_color)
atom_colors = new_atom_colors
element_groups["atoms"].set_property_many(
{
"color": atom_colors,
"label": [
None
if (g and not config.ghost_show_label)
or (not config.atom_show_label)
else l
for l, g in zip(atom_labels, ghost_atoms)
],
"ghost": ghost_atoms,
"fill_opacity": [
config.ghost_opacity if g else config.atom_opacity
for g in ghost_atoms
],
"occupancy": [
atoms.info["occupancy"][t] if "occupancy" in atoms.info else None
for t in atoms.get_tags()
],
"stroke_width": [
config.ghost_stroke_width if g else config.atom_stroke_width
for g in ghost_atoms
],
"stroke_opacity": [
config.ghost_stroke_opacity if g else config.atom_stroke_opacity
for g in ghost_atoms
],
"info_string": [
"; ".join(create_info_lines(atoms, [i])) for i in range(len(atoms))
],
},
element=True,
)
element_groups["atoms"].set_property_many(
{"font_size": config.atom_font_size, "font_color": config.atom_font_color},
element=False,
)
element_groups["cell_lines"].set_property_many(
{"color": config.uc_color}, element=False
)
if config.bond_color_by == "atoms":
element_groups["bond_lines"].set_property(
"color",
[
(atom_colors[i], atom_colors[j])
for i, j in element_groups["bond_lines"].get_elements_property(
"atom_index"
)
],
element=True,
)
elif config.bond_color_by == "length":
bond_colors = self.values_to_colors(
element_groups["bond_lines"].get_elements_property("bond_length"),
self.config.bond_colormap,
self.config.bond_colormap_range,
)
element_groups["bond_lines"].set_property(
"color", [(c, c) for c in bond_colors], element=True
)
element_groups["bond_lines"].set_property_many(
{"stroke_width": bond_thickness, "stroke_opacity": config.bond_opacity},
element=False,
)
for miller_type in ["miller_lines", "miller_planes"]:
element_groups[miller_type].set_property_many(
{
"fill_color": [
config.miller_planes[i].get("fill_color", "blue")
for i in element_groups[miller_type].get_elements_property(
"index"
)
],
"stroke_color": [
config.miller_planes[i].get("stroke_color", "blue")
for i in element_groups[miller_type].get_elements_property(
"index"
)
],
"stroke_width": [
config.miller_planes[i].get("stroke_width", 1)
for i in element_groups[miller_type].get_elements_property(
"index"
)
],
"fill_opacity": [
config.miller_planes[i].get("fill_opacity", 1)
for i in element_groups[miller_type].get_elements_property(
"index"
)
],
"stroke_opacity": [
config.miller_planes[i].get("stroke_opacity", 1)
for i in element_groups[miller_type].get_elements_property(
"index"
)
],
},
element=True,
)
[docs] def make_svg(self, atoms, center_in_uc=False, repeat_uc=(1, 1, 1)):
"""Create an SVG of the atoms or structure."""
config = self.config
atoms, element_groups = self._initialise_elements(
atoms, center_in_uc=center_in_uc, repeat_uc=repeat_uc
)
rotation_matrix = get_rotation_matrix(config.rotations)
center, scale = compute_projection(
element_groups, config.canvas_size, rotation_matrix
)
scale *= config.zoom
axes = scale * rotation_matrix * (1, -1, 1)
offset = np.dot(center, axes)
offset[:2] -= 0.5 * np.array(config.canvas_size)
element_groups.update_positions(
axes, offset, radii_scale=scale * 0.65 if config.show_bonds else scale
)
self._add_element_properties(atoms, element_groups, bond_thickness=scale * 0.15)
svg_elements = generate_svg_elements(
element_groups,
element_colors=self.get_element_colors(),
background_color=config.canvas_color_background,
)
if config.canvas_crop:
left, right, top, bottom = config.canvas_crop
# (left, right, top, bottom) -> (minx, miny, width, height)
viewbox = (
left,
top,
config.canvas_size[0] - left - right,
config.canvas_size[1] - top - bottom,
)
else:
left = right = top = bottom = 0
viewbox = (0, 0, config.canvas_size[0], config.canvas_size[1])
if config.show_axes:
rmatrix = axes = rotation_matrix * (1, -1, 1)
labels = ("X", "Y", "Z")
if config.axes_uc:
# TODO add config.axes_uc to threejs render
axes = np.einsum("...jk,...k->...j", rmatrix.T, atoms.cell)
axes = np.divide(axes.T, np.linalg.norm(axes, axis=1)).T
labels = ("a", "b", "c")
svg_elements.extend(
create_axes_elements(
axes,
config.canvas_size,
# TODO compute offset based on axes xs and ys
inset=(
config.axes_offset[0] + left,
config.axes_offset[1] + bottom,
),
length=config.axes_length,
font_size=config.axes_font_size,
line_color=config.axes_line_color,
labels=labels,
)
)
return create_svg_document(
svg_elements,
config.canvas_size,
viewbox if config.canvas_crop else None,
background_color=config.canvas_color_background,
background_opacity=config.canvas_background_opacity,
)
[docs] def make_gui(
self,
atoms,
center_in_uc=False,
repeat_uc=(1, 1, 1),
bring_to_top=True,
launch=True,
):
"""Launch a (blocking) GUI to view the atoms or structure."""
atoms, element_groups = self._initialise_elements(
atoms, center_in_uc=center_in_uc, repeat_uc=repeat_uc
)
images = AtomImages(
[atoms],
element_radii=np.array(self.get_element_radii()).tolist(),
radii_scale=self.config.radii_scale,
)
gui = AtomGui(
config=self.config, images=images, element_colors=self.get_element_colors()
)
if bring_to_top:
tk_window = gui.window.win # tkinter.Toplevel
tk_window.attributes("-topmost", 1)
tk_window.attributes("-topmost", 0)
if launch:
gui.run()
else:
return gui
[docs] def launch_gui_subprocess(
self, atoms, center_in_uc=False, repeat_uc=(1, 1, 1), test_init=2
):
"""Launch a GUI to view the atoms or structure, in a (non-blocking) subprocess.
We encode all the data into a json object,
then parse this to a console executable via stdin.
:param test_init: wait for a x seconds, then test whether the process initialized without error.
"""
atoms = convert_to_atoms(atoms)
data_str = json.dumps(
{
"atoms": serialize_atoms(atoms),
"config": self.get_config_as_dict(),
"kwargs": {"center_in_uc": center_in_uc, "repeat_uc": repeat_uc},
}
)
process = subprocess.Popen(
"ase-notebook.view_atoms", stdin=subprocess.PIPE, stderr=subprocess.PIPE
)
process.stdin.write(data_str.encode())
process.stdin.close()
sleep(test_init)
if process.poll():
raise RuntimeError(process.stderr.read().decode())
return process
[docs] def make_render(
self,
atoms,
center_in_uc=False,
repeat_uc=(1, 1, 1),
reuse_objects=True,
use_atom_arrays=False,
use_label_arrays=True,
create_gui=True,
):
"""Create a pythreejs render of the atoms or structure."""
config = self.config
atoms, element_groups = self._initialise_elements(
atoms, center_in_uc=center_in_uc, repeat_uc=repeat_uc
)
rotation_matrix = get_rotation_matrix(config.rotations)
element_groups.update_positions(axes=rotation_matrix)
pos_min, pos_max = element_groups.get_position_range()
element_groups.update_positions(
axes=rotation_matrix,
offset=pos_min + (pos_max - pos_min) / 2,
radii_scale=0.65 if config.show_bonds else 1.0,
)
self._add_element_properties(
atoms, element_groups, bond_thickness=5, lighten_by_depth=False
)
renderer, key_elements = generate_3js_render(
element_groups,
canvas_size=config.canvas_size,
zoom=config.zoom,
background_color=config.canvas_color_background,
background_opacity=config.canvas_background_opacity,
camera_fov=config.camera_fov,
reuse_objects=reuse_objects,
use_atom_arrays=use_atom_arrays,
use_label_arrays=use_label_arrays,
)
if config.show_axes:
axes_renderer = create_world_axes(
renderer.camera, renderer.controls[0], initial_rotation=rotation_matrix
)
key_elements["axes_renderer"] = axes_renderer
container = RenderContainer(renderer, element_renderer=renderer, **key_elements)
if create_gui:
gui = make_basic_gui(container)
container.top_level = gui
return container
AseView.__init__.__doc__ = (
"kwargs are used to initialise ViewConfig:" f"\n\n{ViewConfig.__doc__}"
)
[docs]def launch_gui_exec(json_string=None):
"""Launch a GUI, with a json string as input.
Parameters
----------
json_string : str or None
A json string containing all data required for running AseView.makegui.
If None, the string is read from ``stdin``.
"""
if json_string is None:
if sys.stdin.isatty():
raise IOError("stdin is empty")
json_string = sys.stdin.read()
data = json.loads(json_string)
atoms_json = data.pop("atoms", {})
config_dict = data.pop("config", {})
kwargs = data.pop("kwargs", {})
atoms = deserialize_atoms(atoms_json)
ase_view = AseView(**config_dict)
return ase_view.make_gui(atoms, **kwargs)
# Note: original commands (when creating SVG via tkinter postscript)
# gui.window.win.withdraw() # hide window
# canvas = gui.window.canvas
# canvas.config(width=100, height=100); gui.draw() # resize canvas
# gui.scale *= zoom; gui.draw() # zoom
# canvas.postscript(file=fname) # save canvas