Source code for ase_notebook.draw_elements

"""Module defining backend agnostic containers for visualisation elements."""
from collections import OrderedDict
from collections.abc import Mapping
from copy import deepcopy
from typing import List

import numpy as np


[docs]class Element(object): """Representation of a single element. Implemented as a frozen dictionary with attribute access. """ def __init__(self, **kwargs): """Initialise element.""" self._kwargs = kwargs def __dir__(self): """Get the attributes.""" return list(self._kwargs.keys()) + ["get"]
[docs] def get(self, key, default): """Return key or default.""" if key in self: return self[key] else: return default
def __repr__(self): """Represent object.""" sig = ", ".join([f"{k}={self._kwargs[k]}" for k in sorted(self._kwargs)]) return f"Element({sig})" def __getitem__(self, key): """Return key.""" return self._kwargs[key] def __iter__(self): """Iterate property keys.""" for key in self._kwargs: yield key def __getattr__(self, key): """Return key.""" if key not in self._kwargs: raise AttributeError(str(key)) return self._kwargs[key] def __setattr__(self, name, key): """Return key.""" if name != "_kwargs": raise AttributeError("Element attributes are frozen") return super().__setattr__(name, key) def __contains__(self, key): """Test if key in object.""" return key in self._kwargs
[docs]class DrawElementsBase: """Abstract base class to store a set of 3D-visualisation elements.""" etype = None _protected_keys = ("name", "type", "position", "get") def __init__( self, name, coordinates, element_properties=None, group_properties=None ): """Initialise the element group.""" self.name = name self._coordinates = coordinates self._positions = coordinates self._axes = np.identity(3) self._offset = np.zeros(3) self._el_props = {} self._grp_props = {} for key, val in (element_properties or {}).items(): self.set_property(key, val, element=True) for key, val in (group_properties or {}).items(): self.set_property(key, val, element=False) @property def element_properties(self): """Return per element properties.""" output = deepcopy(self._el_props) output["positions"] = np.array(self._positions) return output @property def group_properties(self): """Return element group properties.""" return deepcopy(self._grp_props)
[docs] def set_property(self, name, value, element=False): """Set a group or per element property.""" if name in self._protected_keys: raise KeyError(f"{name} is a protected key name") if element: if len(value) != len(self._coordinates): raise AssertionError( f"property '{name}' does not have the same length " "as the number of elements" ) assert ( name not in self._grp_props ), f"{name} is already set as a group property" self._el_props[name] = value else: assert ( name not in self._el_props ), f"{name} is already set as an element property" self._grp_props[name] = value
[docs] def set_property_many(self, properties, element=False): """Set multiple group or per element properties.""" for key, val in properties.items(): self.set_property(key, val, element=element)
[docs] def get_elements_property(self, name): """Return a single property.""" if name == "position": return np.array(self._positions) if name in self._el_props: return [i for i in self._el_props[name]] if name in self._grp_props: return [self._grp_props[name] for _ in range(len(self))] raise KeyError(f"{name} not in properties")
def __len__(self): """Return the number of elements.""" return len(self._coordinates) def __getitem__(self, index): """Return a single element.""" try: index = int(index) except ValueError: raise TypeError(f"index must be an integer: {index}") return Element( name=self.name, type=self.etype, position=self._positions[index], **dict( [(str(k), v[index]) for k, v in self._el_props.items()] + [(str(k), v) for k, v in self._grp_props.items()] ), ) def __iter__(self): """Iterate over elements.""" for i in range(len(self)): yield self[i] def __repr__(self) -> str: """Return representation string.""" return ( f"{self.__class__.__name__}(name={self.name}, elements={len(self)}, " f"el_properties=({', '.join(self._el_props.keys())}), " f"grp_properties=({', '.join(self._grp_props.keys())}))" )
[docs] def unstack_coordinates(self): """Return a list of all coordinates in the group.""" raise NotImplementedError
[docs] def unstack_positions(self): """Return a list of all coordinates in the group.""" raise NotImplementedError
[docs] def update_positions(self, axes, offset, **kwargs): """Update element positions, give a axes basis and centre offset.""" raise NotImplementedError
[docs] def get_max_zposition(self): """Return the maximum z-coordinate.""" raise NotImplementedError
[docs]class DrawElementsSphere(DrawElementsBase): """Store a set of 3D-visualisation sphere elements.""" etype = "sphere" _protected_keys = ("name", "type", "position", "sradius", "get") def __init__( self, name, coordinates, radii, element_properties=None, group_properties=None, radii_scale=1.0, ): """Initialise the element group.""" coordinates = np.array(coordinates) if coordinates.shape == (0,): coordinates = np.empty((0, 3)) shape = coordinates.shape if len(shape) != 2 or shape[1] != 3: raise ValueError(f"coordinates must be of the shape (N, 3) not {shape}") super().__init__(name, coordinates, element_properties, group_properties) self._radii = np.array(radii) self._radii_scale = radii_scale def __getitem__(self, index): """Return a single element.""" try: index = int(index) except ValueError: raise TypeError(f"index must be an integer: {index}") return Element( name=self.name, type=self.etype, position=self._positions[index], sradius=self.scaled_radii[index], **dict( [(str(k), v[index]) for k, v in self._el_props.items()] + [(str(k), v) for k, v in self._grp_props.items()] ), ) @property def scaled_radii(self): """Return the scaled radii, for each sphere.""" return self._radii * self._radii_scale
[docs] def unstack_coordinates(self): """Return a list of all coordinates in the group.""" return self._coordinates
[docs] def unstack_positions(self): """Return a list of all coordinates in the group.""" return self._positions
[docs] def update_positions(self, axes, offset, radii_scale, **kwargs): """Update element positions, give a axes basis and centre offset.""" self._positions = np.dot(self._coordinates, axes) - offset self._axes = axes self._offset = offset self._radii_scale = radii_scale
[docs] def get_max_zposition(self): """Return the maximum z-coordinate.""" return self._positions[:, 2] + self.scaled_radii
[docs]class DrawElementsLine(DrawElementsBase): """Store a set of 3D-visualisation line elements.""" etype = "line" def __init__( self, name, coordinates, element_properties=None, group_properties=None ): """Initialise the element group.""" coordinates = np.array(coordinates) if coordinates.shape == (0,): coordinates = np.empty((0, 2, 3)) shape = coordinates.shape if len(shape) != 3 or shape[1] != 2 or shape[2] != 3: raise ValueError(f"coordinates must be of the shape (N, 2, 3) not {shape}") super().__init__(name, coordinates, element_properties, group_properties)
[docs] def unstack_coordinates(self): """Return a list of all coordinates in the group.""" return np.concatenate((self._coordinates[:, 0, :], self._coordinates[:, 1, :]))
[docs] def unstack_positions(self): """Return a list of all coordinates in the group.""" return np.concatenate((self._positions[:, 0, :], self._positions[:, 1, :]))
[docs] def update_positions(self, axes, offset, **kwargs): """Update element positions, give a axes basis and centre offset.""" self._positions = np.einsum("ijk, km -> ijm", self._coordinates, axes) - offset self._axes = axes self._offset = offset
[docs] def get_max_zposition(self): """Return the maximum z-coordinate.""" return self._positions.max(axis=1)[:, 2]
[docs]class DrawElementsPoly(DrawElementsBase): """Store a set of 3D-visualisation polygon elements.""" etype = "poly" # TODO validate init coordinate shapes
[docs] def unstack_coordinates(self): """Return a list of all coordinates in the group.""" planes = [np.array(plane) for plane in self._coordinates] if not planes: return np.empty((0, 3)) return np.concatenate(planes)
[docs] def unstack_positions(self): """Return a list of all coordinates in the group.""" planes = [np.array(plane) for plane in self._positions] if not planes: return np.empty((0, 3)) return np.concatenate(planes)
[docs] def update_positions(self, axes, offset, **kwargs): """Update element positions, give a axes basis and centre offset.""" # TODO ideally would apply transform to all planes at once self._positions = [np.dot(plane, axes) - offset for plane in self._coordinates] self._axes = axes self._offset = offset
[docs] def get_max_zposition(self): """Return the maximum z-coordinate.""" return np.array([plane[:, 2].max() for plane in self._positions])
[docs]class DrawGroup(Mapping): """Store and manipulate 3-D visualisation element groups.""" def __init__(self, elements: List[DrawElementsBase]): """Store and manipulate 3-D visualisation element groups.""" self._elements = OrderedDict([(el.name, el) for el in elements]) def __getitem__(self, key): """Return an element group by name.""" return self._elements[key] def __iter__(self): """Iterate over the element group names.""" for key in self._elements: yield key def __len__(self): """Return the number of element groups.""" return len(self._elements) def __repr__(self) -> str: """Return representation string.""" return ( f"{self.__class__.__name__}(groups=({', '.join(str(n) for n in self)}), " f"elements=({', '.join(str(len(self[n])) for n in self)}))" )
[docs] def get_all_coordinates(self): """Return a list of all coordinates.""" coordinates = [el.unstack_coordinates() for el in self._elements.values()] return np.concatenate(coordinates)
[docs] def get_all_positions(self): """Return a list of all coordinates.""" positions = [el.unstack_positions() for el in self._elements.values()] return np.concatenate(positions)
[docs] def update_positions(self, axes=None, offset=None, radii_scale=1): """Update element positions, give a axes basis and centre offset.""" if axes is None: axes = np.identity(3) if offset is None: offset = np.zeros(3) for element in self._elements.values(): element.update_positions(axes, offset, radii_scale=radii_scale)
[docs] def get_position_range(self): """Return the (minimum, maximum) coordinates.""" min_positions = [] max_positions = [] for element in self._elements.values(): positions = element.unstack_positions() if isinstance(element, DrawElementsSphere): # type: DrawElementsSphere # TODO make more general min_positions.append(positions - element.scaled_radii[:, None]) max_positions.append(positions + element.scaled_radii[:, None]) else: min_positions.append(positions) max_positions.append(positions) return ( np.concatenate(min_positions).min(0), np.concatenate(max_positions).max(0), )
[docs] def yield_zorder(self): """Yield elements, in order of the z-coordinate.""" keys = [(el.name, i) for el in self.values() for i in range(len(el))] z_positions = np.concatenate([el.get_max_zposition() for el in self.values()]) for i in z_positions.argsort(): yield i, self[keys[i][0]][keys[i][1]]