"""A module for visualising structures.
The module subclasses ase (v3.18.0) classes, to add additional functionality.
"""
from time import time
import tkinter
from tkinter.font import Font
from ase.data import atomic_numbers
from ase.data import covalent_radii as default_covalent_radii
from ase.gui import ui
from ase.gui.gui import GUI
from ase.gui.images import Images
from ase.gui.view import GREEN, PURPLE, View
import attr
import numpy as np
from ase_notebook.atom_info import create_info_lines
from ase_notebook.color import lighten_webcolor
from ase_notebook.draw_utils import initialise_element_groups
[docs]class AtomImages(Images):
"""A subclass of the ase ``Images``, but with additional functionality, for setting radii."""
def __init__(self, atoms_list, element_radii=None, radii_scale=0.89):
"""Initialise the atom images.
Parameters
----------
atoms_list : list[ase.Atoms]
element_radii : list[float]
radii for each atomic number (default to ase covalent)
radii_scale : float
scale all atomic_radii
"""
if element_radii:
self.covalent_radii = np.array(element_radii, dtype=float)
else:
self.covalent_radii = default_covalent_radii.copy()
# In the base class, self.config is set, but it is only used for radii scale
# self.config = get_default_settings()
# self.atom_scale = self.config["radii_scale"]
self.atom_scale = radii_scale
self.initialize(atoms_list)
[docs]class AtomGui(GUI):
"""A subclass of the ase ``GUI``, but with additional functionality."""
def __init__(self, config, images=None, element_colors=None):
"""Initialise the GUI.
Parameters
----------
config: ViewConfig
initial configuration settings
images : ase.gui.images.Images
list of ase.Atoms, with some settings for visualisations (mainly radii)
element_colors: list[tuple]
hex colour for each atomic number (defaults to 'jmol' scheme)
"""
if not isinstance(images, Images):
images = Images(images)
self.images = images
self.observers = []
self.config = attr.asdict(config)
# aliases required by ui.ASEGUIWindow
self.config["gui_foreground_color"] = self.config["canvas_color_foreground"]
self.config["gui_background_color"] = self.config["canvas_color_background"]
self.config["swap_mouse"] = self.config["gui_swap_mouse"]
menu = self.get_menu_data()
self.window = ui.ASEGUIWindow(
close=self.exit,
menu=menu,
config=self.config,
scroll=self.scroll,
scroll_event=self.scroll_event,
press=self.press,
move=self.move,
release=self.release,
resize=self.resize,
)
# used by ``View.update_labels``
label_sites = {"index": 1, "magmom": 2, "element": 3, "charge": 4}.get(
self.config["atom_label_by"], 0
)
if not self.config["atom_show_label"]:
label_sites = 0
self.window["show-labels"] = label_sites
View.__init__(self, self.config["rotations"])
if element_colors:
self.colors = dict(enumerate(element_colors))
self.subprocesses = [] # list of external processes
self.movie_window = None
self.vulnerable_windows = []
self.simulation = {} # Used by modules on Calculate menu.
self.module_state = {} # Used by modules to store their state.
self.arrowkey_mode = self.ARROWKEY_SCAN
self.move_atoms_mask = None
self.set_frame(len(self.images) - 1, focus=True)
# Used to move the structure with the mouse
self.prev_pos = None
self.last_scroll_time = self.t0 = time()
self.orig_scale = self.scale
self.xy = None
if len(self.images) > 1:
self.movie()
[docs] def release(self, event):
"""Handle release event."""
# fix an error raised in GUI class
if self.xy is None:
self.xy = (event.x, event.y)
super().release(event)
[docs] def move(self, event):
"""Handle move event."""
# fix an error raised in GUI class
if self.xy is None:
self.xy = (event.x, event.y)
super().move(event)
[docs] def showing_millers(self):
"""Return whether to display planes."""
return self.config["show_miller_planes"]
[docs] def set_atoms(self, atoms):
"""Set the atoms, unit cell(s) and bonds to draw.
This is overridden from ``View``, in order to
- set bond colors, specific to the atom types at each end of the bond.
- use a modified ``get_cell_coordinates`` function,
which returns cartesian coordinates, rather than fractional
(since they were just converted to cartesian anyway)
and can create multiple cells (for each repeat)
- compute miller index planes, by points that intercept with the unit cell
"""
elements = initialise_element_groups(
atoms,
atom_radii=self.get_covalent_radii(),
show_unit_cell=self.showing_cell(),
uc_dash_pattern=self.config["uc_dash_pattern"],
show_bonds=self.showing_bonds(),
bond_supercell=self.images.repeat,
miller_planes=self.config["miller_planes"]
if self.showing_millers()
else None,
miller_planes_as_lines=self.config["miller_as_lines"],
)
self.elements = elements
# record all positions (atoms first) with legacy array name, for use by View.focus
self.X = elements.get_all_coordinates()
# record atom positions with legacy array name, used by View.move
self.X_pos = self.elements["atoms"]._positions.copy()
[docs] def draw(self, status=True):
"""Draw all required objects on the canvas.
This is overridden from ``View``, in order to:
- set bond colors, specific to the atom types at each end of the bond.
- add a cross to 'ghost' atoms (specified by an array on the atoms)
- add a dash pattern to the unit cell lines
- allow miller index planes to be drawn
"""
self.window.clear()
# compute orientation, position and scale of axes
axes = self.scale * self.axes * (1, -1, 1)
offset = np.dot(self.center, axes)
offset[:2] -= 0.5 * self.window.size
element_groups = self.elements
element_groups.update_positions(
axes,
offset,
radii_scale=self.scale * 0.65
if self.window["toggle-show-bonds"]
else self.scale,
)
# required by View.release
self.P = element_groups["atoms"].unstack_positions()[:, :2].round().astype(int)
self.indices = np.array([i for i, _ in element_groups.yield_zorder()])
if "ghost" in self.atoms.arrays:
ghost_atoms = self.atoms.get_array("ghost")
else:
ghost_atoms = [False for _ in self.atoms]
self.update_labels() # set self.labels for atoms
# TODO use occuapancy keys for label, if showing symbol
atom_colors = self.get_colors()
if self.config["atom_lighten_by_depth"]:
z_positions = element_groups["atoms"].get_max_zposition()
zmin, zmax = z_positions.min(), z_positions.max()
new_atom_colors = []
for atom_color, z_position in zip(atom_colors, z_positions):
atom_depth = (zmax - z_position) / (zmax - zmin)
atom_color = lighten_webcolor(
atom_color, atom_depth * self.config["atom_lighten_by_depth"]
)
new_atom_colors.append(atom_color)
atom_colors = new_atom_colors
celldisp = (
(np.dot(self.atoms.get_celldisp().reshape((3,)), axes)).round().astype(int)
)
vector_arrays = []
if self.window["toggle-show-velocities"]:
# Scale ugly?
vector = self.atoms.get_velocities()
if vector is not None:
vector_arrays.append(vector * 10.0 * self.velocity_vector_scale)
if self.window["toggle-show-forces"]:
f = self.get_forces()
vector_arrays.append(f * self.force_vector_scale)
for array in vector_arrays:
array[:] = (
(np.dot(array, axes) + element_groups["atoms"].unstack_positions())
.round()
.astype(int)
)
element_groups["atoms"].set_property_many(
{
"lbound": (
element_groups["atoms"].unstack_positions()[:, :2]
- element_groups["atoms"].scaled_radii[:, None]
)
.round()
.astype(int),
"color": atom_colors,
"label": [
None if g and not self.config["ghost_show_label"] else l
for l, g in zip(
self.labels or [None for _ in ghost_atoms], ghost_atoms
)
],
"tag": self.atoms.get_tags(),
"ghost": ghost_atoms,
"selected": self.images.selected,
"visible": self.images.visible,
"constrained": ~self.images.get_dynamic(self.atoms),
"moving": [m if self.moving else False for m in self.move_atoms_mask]
if self.move_atoms_mask is not None
else [False for _ in self.atoms],
"occupancy": [
self.atoms.info["occupancy"][t]
if "occupancy" in self.atoms.info
else None
for t in self.atoms.get_tags()
],
"stroke_width": [
self.config["ghost_stroke_width"]
if g
else self.config["atom_stroke_width"]
for g in ghost_atoms
],
},
element=True,
)
element_groups["atoms"].set_property_many(
{
"font_size": self.config["atom_font_size"],
"font_color": self.config["atom_font_color"],
},
element=False,
)
element_groups["cell_lines"].set_property_many(
{"color": self.config["uc_color"]}, element=False
)
element_groups["bond_lines"].set_property(
"color",
[
(atom_colors[i], atom_colors[j])
for i, j in element_groups["bond_lines"].get_elements_property(
"atom_index"
)
],
element=True,
)
element_groups["bond_lines"].set_property_many(
{"stroke_width": self.scale * 0.15}, element=False
)
for miller_type in ["miller_lines", "miller_planes"]:
element_groups[miller_type].set_property_many(
{
"fill_color": [
self.config["miller_planes"][i].get("fill_color", "blue")
for i in element_groups[miller_type].get_elements_property(
"index"
)
],
"stroke_color": [
self.config["miller_planes"][i].get("stroke_color", "blue")
for i in element_groups[miller_type].get_elements_property(
"index"
)
],
"stroke_width": [
self.config["miller_planes"][i].get("stroke_width", 1)
for i in element_groups[miller_type].get_elements_property(
"index"
)
],
},
element=True,
)
if self.arrowkey_mode == self.ARROWKEY_MOVE:
movecolor = GREEN
elif self.arrowkey_mode == self.ARROWKEY_ROTATE:
movecolor = PURPLE
else:
movecolor = None
draw_elements(
element_groups,
canvas=self.window.canvas,
celldisp=celldisp,
vector_arrays=vector_arrays,
movecolor=movecolor,
scale=self.scale,
element_colors=self.colors,
ghost_cross_out=self.config["ghost_cross_out"],
background_color=self.config["gui_background_color"],
)
if self.window["toggle-show-axes"]:
draw_axes(
self.window.canvas,
self.axes,
self.window.size,
length=self.config["axes_length"],
font_size=self.config["axes_font_size"],
line_color=self.config["axes_line_color"],
)
if len(self.images) > 1:
self.draw_frame_number()
self.window.update()
if status:
num_atoms = len(self.atoms)
indices = np.arange(num_atoms)[self.images.selected[:num_atoms]]
ordered_indices = [i for i in self.images.selected_ordered if i < num_atoms]
status_lines = create_info_lines(self.atoms, indices, ordered_indices)
self.window.update_status_line(" " + "; ".join(status_lines))
[docs]def draw_arrow(canvas, coords, width, scale):
"""Draw an arrow element."""
begin = np.array((coords[0], coords[1]))
end = np.array((coords[2], coords[3]))
canvas.create_line(*tuple(int(x) for x in coords), width)
vec = end - begin
length = np.sqrt((vec[:2] ** 2).sum())
length = min(length, 0.3 * scale)
angle = np.arctan2(end[1] - begin[1], end[0] - begin[0]) + np.pi
x1 = (end[0] + length * np.cos(angle - 0.3)).round().astype(int)
y1 = (end[1] + length * np.sin(angle - 0.3)).round().astype(int)
x2 = (end[0] + length * np.cos(angle + 0.3)).round().astype(int)
y2 = (end[1] + length * np.sin(angle + 0.3)).round().astype(int)
canvas.create_line(x1, y1, int(end[0]), int(end[1]), width)
canvas.create_line(x2, y2, int(end[0]), int(end[1]), width)
[docs]def draw_circle(lbound, diameter, canvas, color, selected, tags=(), stroke_width=1):
"""Draw a circle element, given a lower bound and diameter."""
if selected:
outline = "#004500"
width = stroke_width * 3
else:
outline = "black"
width = stroke_width
bbox = (lbound[0], lbound[1], lbound[0] + diameter, lbound[1] + diameter)
canvas.create_oval(
*tuple(int(x) for x in bbox),
fill=color,
outline=outline,
width=width,
tags=tags,
)
[docs]def draw_arc(lbound, diameter, canvas, color, selected, start, extent):
"""Draw an arc element."""
if selected:
outline = "#004500"
width = 3
else:
outline = "black"
width = 1
bbox = (lbound[0], lbound[1], lbound[0] + diameter, lbound[1] + diameter)
canvas.create_arc(
*tuple(int(x) for x in bbox),
start=start,
extent=extent,
fill=color,
outline=outline,
width=width,
)
[docs]def draw_axes(
canvas,
axes,
window_size,
*,
length=15,
line_color="black",
line_width=1,
font_size=14,
):
"""Draw the axes element."""
rgb = ["red", "green", "blue"]
for i in axes[:, 2].argsort():
a = 20
b = window_size[1] - 20
c = int(axes[i][0] * length + a)
d = int(-axes[i][1] * length + b)
canvas.create_line(
*tuple(int(x) for x in (a, b, c, d)), width=line_width, fill=line_color
)
canvas.create_text(
(c, d),
text="XYZ"[i],
fill=rgb[i],
font=Font(size=20),
anchor=tkinter.CENTER,
)
[docs]def draw_elements(
element_groups,
canvas,
celldisp,
vector_arrays,
scale,
element_colors,
ghost_cross_out=False,
movecolor=None,
background_color="#ffffff",
):
"""Draw elements on a ``tkinter.Canvas``.
Parameters
----------
element_groups : ase_notebook.draw_elements.DrawGroup
canvas : tkinter.Canvas
celldisp : numpy.array
cell displacement
vector_arrays : list
scale : float
canvas scale (used for drawing vector arrows)
element_colors : dict
mapping of element colors (used for partial occupancies)
ghost_cross_out : bool
whether to cross out ghost atoms
movecolor : str or None
color for moving atom
background_color : str or None
color used for coloring partial atom occupancies
"""
for idx, element in element_groups.yield_zorder():
if element.name == "atoms":
if not element.visible:
continue
diameter = int(round(element.sradius * 2))
if element.occupancy is not None:
# first draw an empty circle if a site is not fully occupied
if (np.sum([o for o in element.occupancy.values()])) < 1.0:
draw_circle(
element.lbound,
diameter,
canvas,
background_color,
element.selected,
stroke_width=element.stroke_width,
)
start = 0
# start with the dominant species
for sym, occ in sorted(
element.occupancy.items(), key=lambda x: x[1], reverse=True
):
if np.round(occ, decimals=4) == 1.0:
draw_circle(
element.lbound,
diameter,
canvas,
element_colors[atomic_numbers[sym]],
element.selected,
stroke_width=element.stroke_width,
tags=("atom-circle",),
)
else:
# TODO alter for ghost
extent = 360.0 * occ
draw_arc(
element.lbound,
diameter,
canvas,
element_colors[atomic_numbers[sym]],
element.selected,
start,
extent,
)
start += extent
else:
if element.moving:
draw_circle(
(element.lbound[0] - 4, element.lbound[1] - 4),
diameter + 8,
canvas,
movecolor,
False,
)
draw_circle(
element.lbound,
diameter,
canvas,
element.color,
element.selected,
stroke_width=element.stroke_width,
tags=("atom-circle",),
)
if element.label is not None:
canvas.create_text(
(
element.lbound[0] + diameter / 2,
element.lbound[1] + diameter / 2,
),
text=str(element.label),
fill=element.font_color,
font=Font(size=element.font_size),
anchor=tkinter.CENTER,
)
# Draw cross on constrained or ghost atoms
if element.constrained or (element.ghost and ghost_cross_out):
rad1 = int(0.14644 * diameter)
rad2 = int(0.85355 * diameter)
canvas.create_line(
element.lbound[0] + rad1,
element.lbound[1] + rad1,
element.lbound[0] + rad2,
element.lbound[1] + rad2,
width=1,
)
canvas.create_line(
element.lbound[0] + rad2,
element.lbound[1] + rad1,
element.lbound[0] + rad1,
element.lbound[1] + rad2,
width=1,
)
# Draw velocities and/or forces
# TODO vector data should be added to element
for vector in vector_arrays:
assert not np.isnan(vector).any()
draw_arrow(
canvas(
element.position[0],
element.position[1],
vector[idx, 0],
vector[idx, 1],
),
width=2,
scale=scale,
)
if element.name == "cell_lines":
canvas.create_line(
(
element.position[0, 0] + celldisp[0],
element.position[0, 1] + celldisp[1],
element.position[1, 0] + celldisp[0],
element.position[1, 1] + celldisp[1],
),
fill=element.color,
width=1,
# dash=(6, 4), # dash pattern = (line length, gap length, ..)
tags=("cell-line",),
)
if element.name == "bond_lines":
canvas.create_line(
(
element.position[0, 0],
element.position[0, 1],
element.position[0, 0]
+ 0.5 * (element.position[1, 0] - element.position[0, 0]),
element.position[0, 1]
+ 0.5 * (element.position[1, 1] - element.position[0, 1]),
),
width=element.stroke_width,
fill=element.color[0],
tags=("bond-line",),
)
canvas.create_line(
(
element.position[0, 0]
+ 0.5 * (element.position[1, 0] - element.position[0, 0]),
element.position[0, 1]
+ 0.5 * (element.position[1, 1] - element.position[0, 1]),
element.position[1, 0],
element.position[1, 1],
),
width=element.stroke_width,
fill=element.color[1],
tags=("bond-line",),
)
if element.name == "miller_lines":
canvas.create_line(
(
element.position[0, 0] + celldisp[0],
element.position[0, 1] + celldisp[1],
element.position[1, 0] + celldisp[0],
element.position[1, 1] + celldisp[1],
),
width=element.stroke_width,
fill=element.stroke_color,
tags=("miller-line",),
)
if element.name == "miller_planes":
plane_pts = [
pt[i] + celldisp[i]
for pt in element.position.round().astype(int)
for i in [0, 1]
]
canvas.create_polygon(
plane_pts,
width=element.stroke_width,
outline=element.stroke_color,
fill=element.fill_color,
tags=("miller-plane",),
)