Source code for ase_notebook.backend.threejs

"""A module for creating a pythreejs visualisation of a structure."""
from math import radians, sqrt, tan

import numpy as np

from ase_notebook.color import Color
from ase_notebook.draw_utils import triangle_normal


[docs]class RenderContainer(object): """Container for the renderer, with attribute access for key elements.""" def __init__(self, top_level, **kwargs): """Initialise container.""" self._kwargs = kwargs self.top_level = top_level def __dir__(self): """Get the attributes.""" return list(self._kwargs.keys()) def __iter__(self): """Iterate keys.""" for key in self._kwargs: yield key def __len__(self): """Return number of keys.""" return len(self._kwargs) def __getitem__(self, key): """Return key.""" return self._kwargs[key] def __setitem__(self, key, value): """Set key.""" if key == "top_level": self.top_level = value else: self._kwargs[key] = value def __getattr__(self, key): """Return attribute.""" if key not in self._kwargs: raise AttributeError(key) return self._kwargs[key] def __setattr__(self, name, value): """Set attribute.""" if name == "top_level": if not hasattr(value, "_ipython_display_"): raise ValueError("top_level must have an `_ipython_display_` method") self._kwargs["top_level"] = value return if name != "_kwargs": raise AttributeError("Attributes are frozen") return super().__setattr__(name, value) def __contains__(self, key): """Test if key in container.""" return key in self._kwargs
[docs] def _ipython_display_(self): """Display the top level rendered in the notebook.""" return self.top_level._ipython_display_()
[docs]def generate_3js_render( element_groups, canvas_size, zoom, camera_fov=30, background_color="white", background_opacity=1.0, reuse_objects=False, use_atom_arrays=False, use_label_arrays=False, ): """Create a pythreejs scene of the elements. Regarding initialisation performance, see: https://github.com/jupyter-widgets/pythreejs/issues/154 """ import pythreejs as pjs key_elements = {} group_elements = pjs.Group() key_elements["group_elements"] = group_elements unique_atom_sets = {} for el in element_groups["atoms"]: element_hash = ( ("radius", el.sradius), ("color", el.color), ("fill_opacity", el.fill_opacity), ("stroke_color", el.get("stroke_color", "black")), ("ghost", el.ghost), ) unique_atom_sets.setdefault(element_hash, []).append(el) group_atoms = pjs.Group() group_ghosts = pjs.Group() atom_geometries = {} atom_materials = {} outline_materials = {} for el_hash, els in unique_atom_sets.items(): el = els[0] data = dict(el_hash) if reuse_objects: atom_geometry = atom_geometries.setdefault( el.sradius, pjs.SphereBufferGeometry( radius=el.sradius, widthSegments=30, heightSegments=30 ), ) else: atom_geometry = pjs.SphereBufferGeometry( radius=el.sradius, widthSegments=30, heightSegments=30 ) if reuse_objects: atom_material = atom_materials.setdefault( (el.color, el.fill_opacity), pjs.MeshLambertMaterial( color=el.color, transparent=True, opacity=el.fill_opacity ), ) else: atom_material = pjs.MeshLambertMaterial( color=el.color, transparent=True, opacity=el.fill_opacity ) if use_atom_arrays: atom_mesh = pjs.Mesh(geometry=atom_geometry, material=atom_material) atom_array = pjs.CloneArray( original=atom_mesh, positions=[e.position.tolist() for e in els], merge=False, ) else: atom_array = [ pjs.Mesh( geometry=atom_geometry, material=atom_material, position=e.position.tolist(), name=e.info_string, ) for e in els ] data["geometry"] = atom_geometry data["material_body"] = atom_material if el.ghost: key_elements["group_ghosts"] = group_ghosts group_ghosts.add(atom_array) else: key_elements["group_atoms"] = group_atoms group_atoms.add(atom_array) if el.get("stroke_width", 1) > 0: if reuse_objects: outline_material = outline_materials.setdefault( el.get("stroke_color", "black"), pjs.MeshBasicMaterial( color=el.get("stroke_color", "black"), side="BackSide", transparent=True, opacity=el.get("stroke_opacity", 1.0), ), ) else: outline_material = pjs.MeshBasicMaterial( color=el.get("stroke_color", "black"), side="BackSide", transparent=True, opacity=el.get("stroke_opacity", 1.0), ) # TODO use stroke width to dictate scale if use_atom_arrays: outline_mesh = pjs.Mesh( geometry=atom_geometry, material=outline_material, scale=(1.05, 1.05, 1.05), ) outline_array = pjs.CloneArray( original=outline_mesh, positions=[e.position.tolist() for e in els], merge=False, ) else: outline_array = [ pjs.Mesh( geometry=atom_geometry, material=outline_material, position=e.position.tolist(), scale=(1.05, 1.05, 1.05), ) for e in els ] data["material_outline"] = outline_material if el.ghost: group_ghosts.add(outline_array) else: group_atoms.add(outline_array) key_elements.setdefault("atom_arrays", []).append(data) group_elements.add(group_atoms) group_elements.add(group_ghosts) group_labels = add_labels(element_groups, key_elements, use_label_arrays) group_elements.add(group_labels) if len(element_groups["cell_lines"]) > 0: cell_line_mat = pjs.LineMaterial( linewidth=1, color=element_groups["cell_lines"].group_properties["color"] ) cell_line_geo = pjs.LineSegmentsGeometry( positions=[el.position.tolist() for el in element_groups["cell_lines"]] ) cell_lines = pjs.LineSegments2(geometry=cell_line_geo, material=cell_line_mat) key_elements["cell_lines"] = cell_lines group_elements.add(cell_lines) if len(element_groups["bond_lines"]) > 0: bond_line_mat = pjs.LineMaterial( linewidth=element_groups["bond_lines"].group_properties["stroke_width"], vertexColors="VertexColors", ) bond_line_geo = pjs.LineSegmentsGeometry( positions=[el.position.tolist() for el in element_groups["bond_lines"]], colors=[ [Color(c).rgb for c in el.color] for el in element_groups["bond_lines"] ], ) bond_lines = pjs.LineSegments2(geometry=bond_line_geo, material=bond_line_mat) key_elements["bond_lines"] = bond_lines group_elements.add(bond_lines) group_millers = pjs.Group() if len(element_groups["miller_lines"]) or len(element_groups["miller_planes"]): key_elements["group_millers"] = group_millers if len(element_groups["miller_lines"]) > 0: miller_line_mat = pjs.LineMaterial( linewidth=3, vertexColors="VertexColors" # TODO use stroke_width ) miller_line_geo = pjs.LineSegmentsGeometry( positions=[el.position.tolist() for el in element_groups["miller_lines"]], colors=[ [Color(el.stroke_color).rgb] * 2 for el in element_groups["miller_lines"] ], ) miller_lines = pjs.LineSegments2( geometry=miller_line_geo, material=miller_line_mat ) group_millers.add(miller_lines) for el in element_groups["miller_planes"]: vertices = el.position.tolist() faces = [ ( 0, 1, 2, triangle_normal(vertices[0], vertices[1], vertices[2]), "black", 0, ) ] if len(vertices) == 4: faces.append( ( 2, 3, 0, triangle_normal(vertices[2], vertices[3], vertices[0]), "black", 0, ) ) elif len(vertices) != 3: raise NotImplementedError("polygons with more than 4 points") plane_geom = pjs.Geometry(vertices=vertices, faces=faces) plane_mat = pjs.MeshBasicMaterial( color=el.fill_color, transparent=True, opacity=el.fill_opacity, side="DoubleSide", ) plane_mesh = pjs.Mesh(geometry=plane_geom, material=plane_mat) group_millers.add(plane_mesh) group_elements.add(group_millers) scene = pjs.Scene(background=None) scene.add([group_elements]) view_width, view_height = canvas_size minp, maxp = element_groups.get_position_range() # compute a minimum camera distance, that is guaranteed to encapsulate all elements camera_dist = maxp[2] + sqrt(maxp[0] ** 2 + maxp[1] ** 2) / tan( radians(camera_fov / 2) ) camera = pjs.PerspectiveCamera( fov=camera_fov, position=[0, 0, camera_dist], aspect=view_width / view_height, zoom=zoom, ) scene.add([camera]) ambient_light = pjs.AmbientLight(color="lightgray") key_elements["ambient_light"] = ambient_light direct_light = pjs.DirectionalLight(position=(maxp * 2).tolist()) key_elements["direct_light"] = direct_light scene.add([camera, ambient_light, direct_light]) camera_control = pjs.OrbitControls(controlling=camera, screenSpacePanning=True) atom_picker = pjs.Picker(controlling=group_atoms, event="dblclick") key_elements["atom_picker"] = atom_picker material = pjs.SpriteMaterial( map=create_arrow_texture(right=False), transparent=True, depthWrite=False, depthTest=False, ) atom_pointer = pjs.Sprite(material=material, scale=(4, 3, 1), visible=False) scene.add(atom_pointer) key_elements["atom_pointer"] = atom_pointer renderer = pjs.Renderer( camera=camera, scene=scene, controls=[camera_control, atom_picker], width=view_width, height=view_height, alpha=True, clearOpacity=background_opacity, clearColor=background_color, ) return renderer, key_elements
[docs]def add_labels(element_groups, key_elements, use_label_arrays): """Create label elements for the scene.""" import pythreejs as pjs group_labels = pjs.Group() unique_label_sets = {} for el in element_groups["atoms"]: if "label" in el and el.label is not None: unique_label_sets.setdefault( (("label", el.label), ("color", el.get("font_color", "black"))), [] ).append(el) if unique_label_sets: key_elements["group_labels"] = group_labels for el_hash, els in unique_label_sets.items(): el = els[0] data = dict(el_hash) # depthWrite=depthTest=False is required, for the sprite to remain on top, # and not have the whitespace obscure objects behind, see: # https://stackoverflow.com/questions/11165345/three-js-webgl-transparent-planes-hiding-other-planes-behind-them # TODO can this be improved? text_material = pjs.SpriteMaterial( map=pjs.TextTexture( string=el.label, color=el.get("font_color", "black"), size=2000, # this texttexture size seems to work, not sure why? ), opacity=1.0, transparent=True, depthWrite=False, depthTest=False, ) data["material"] = text_material key_elements.setdefault("label_arrays", []).append(data) if use_label_arrays: text_sprite = pjs.Sprite(material=text_material) label_array = pjs.CloneArray( original=text_sprite, positions=[e.position.tolist() for e in els], merge=False, ) else: label_array = [ pjs.Sprite(material=text_material, position=e.position.tolist()) for e in els ] group_labels.add(label_array) return group_labels
[docs]def create_world_axes( camera, controls, initial_rotation=np.eye(3), length=30, width=3, camera_fov=10 ): """Create a renderer, containing an axes and camera that is synced to another camera. adapted from http://jsfiddle.net/aqnL1mx9/ Parameters ---------- camera : pythreejs.PerspectiveCamera controls : pythreejs.OrbitControls initial_rotation : list or numpy.array initial rotation of the axes length : int length of axes lines width : int line width of axes Returns ------- pythreejs.Renderer """ import pythreejs as pjs canvas_width = length * 2 canvas_height = length * 2 ax_scene = pjs.Scene() group_ax = pjs.Group() # NOTE: could use AxesHelper, but this does not allow for linewidth seletion # TODO: add arrow heads (ArrowHelper doesn't seem to work) ax_line_mat = pjs.LineMaterial(linewidth=width, vertexColors="VertexColors") ax_line_geo = pjs.LineSegmentsGeometry( positions=[ [[0, 0, 0], length * r / np.linalg.norm(r)] for r in initial_rotation ], colors=[[Color(c).rgb] * 2 for c in ("red", "green", "blue")], ) ax_lines = pjs.LineSegments2(geometry=ax_line_geo, material=ax_line_mat) group_ax.add(ax_lines) ax_scene.add([group_ax]) camera_dist = length / tan(radians(camera_fov / 2)) ax_camera = pjs.PerspectiveCamera( fov=camera_fov, aspect=canvas_width / canvas_height, near=1, far=1000 ) ax_camera.up = camera.up ax_renderer = pjs.Renderer( scene=ax_scene, camera=ax_camera, width=canvas_width, height=canvas_height, alpha=True, clearOpacity=0.0, clearColor="white", ) def align_axes(change=None): """Align axes to world.""" # TODO: this is not working correctly for TrackballControls, when rotated upside-down # (OrbitControls enforces the camera up direction, # so does not allow the camera to rotate upside-down). # TODO how could this be implemented on the client (js) side? new_position = np.array(camera.position) - np.array(controls.target) new_position = camera_dist * new_position / np.linalg.norm(new_position) ax_camera.position = new_position.tolist() ax_camera.lookAt(ax_scene.position) align_axes() camera.observe(align_axes, names="position") controls.observe(align_axes, names="target") ax_scene.observe(align_axes, names="position") return ax_renderer
[docs]def make_basic_gui(container): """Create a basic GUI layout. Parameters ---------- container : RenderContainer Returns ------- ipywidgets.GridspecLayout """ import ipywidgets as ipyw element_controls = [ ipyw.HTML(value="<b>Elements</b>", layout=ipyw.Layout(align_self="center")) ] for key, descript in [ ("group_atoms", "Atoms"), ("cell_lines", "Unit Cell"), ("group_labels", "Labels"), ("bond_lines", "Bonds"), ("group_millers", "Planes"), ("group_ghosts", "Ghosts"), ]: if key not in container: continue toggle = ipyw.ToggleButton( description=descript, icon="eye", button_style="primary", value=False if key == "group_labels" else container[key].visible, layout=ipyw.Layout(width="auto"), ) ipyw.jslink((toggle, "value"), (container[key], "visible")) element_controls.append(toggle) control_box_elements = ipyw.Box( element_controls, layout=ipyw.Layout(flex_flow="column") ) container["control_box_elements"] = control_box_elements background_controls = [ ipyw.HTML(value="<b>Background</b>", layout=ipyw.Layout(align_self="center")) ] background_color = ipyw.ColorPicker( concise=True, description="Color", description_tooltip="Background Color", value=container.element_renderer.clearColor, layout=ipyw.Layout(align_items="center"), ) background_color.style.description_width = "40px" ipyw.jslink((background_color, "value"), (container.element_renderer, "clearColor")) background_controls.append(background_color) background_opacity = ipyw.FloatSlider( value=container.element_renderer.clearOpacity, min=0, max=1, step=0.1, orientation="horizontal", readout=False, description_tooltip="Background Opacity", ) background_opacity.layout.max_width = "100px" ipyw.jslink( (background_opacity, "value"), (container.element_renderer, "clearOpacity") ) background_controls.append(background_opacity) # other_controls.append(ipyw.Label(value="Opacity", layout=ipyw.Layout(align_self="center"))) control_box_background = ipyw.Box( background_controls, layout=ipyw.Layout(flex_flow="column") ) container["control_box_background"] = control_box_background axes = [container.axes_renderer] if "axes_renderer" in container else [] info_box = ipyw.HTML( value="", # "Double-click atom for info (requires active kernel).", color="grey", layout=ipyw.Layout( max_height="10px", margin="0px 0px 0px 0px", align_self="flex-start" ), ) def on_click(change): obj = change["new"] if obj is None: container.atom_pointer.visible = False info_box.value = "" else: info_box.value = obj.name # container.atom_pointer.position = container.atom_picker.point container.atom_pointer.position = obj.position container.atom_pointer.visible = True container.atom_picker.observe(on_click, names=["object"]) if axes and container.element_renderer.height > 200: grid = ipyw.GridspecLayout( 2, 2, width=f"{container.element_renderer.width + 100}px", height=f"{container.element_renderer.height + 35}px", ) grid[0, 0] = container.element_renderer grid[1, 0] = info_box grid[:, 1] = ipyw.Box( axes + [control_box_elements, control_box_background], layout=ipyw.Layout(align_self="flex-start", flex_flow="column"), ) else: grid = ipyw.GridspecLayout( 2, 3, width=f"{container.element_renderer.width + 200}px", height=f"{container.element_renderer.height + 35}px", ) grid[:, 0] = ipyw.Box( axes, layout=ipyw.Layout(align_self="flex-end", flex_flow="column") ) grid[0, 1] = container.element_renderer grid[1, 1] = info_box grid[:, 2] = ipyw.Box( [control_box_elements, control_box_background], layout=ipyw.Layout(align_self="flex-start", flex_flow="column"), ) return grid
[docs]def gather_3d_objects(obj, objects=None): """Recurse through objects children, to gather the set of 3D objects.""" # TODO create more complete method import pythreejs as pjs if objects is None: objects = set() if isinstance(obj, pjs.Renderer): gather_3d_objects(obj.scene, objects) elif isinstance(obj, pjs.Scene): for child in obj.children: gather_3d_objects(child, objects) elif isinstance(obj, pjs.Object3DBase): objects.add(obj) for child in obj.children: gather_3d_objects(child, objects) if "geometry" in obj.trait_names(): objects.add(obj.geometry) if "material" in obj.trait_names(): objects.add(obj.material) if isinstance(obj, pjs.CloneArray): gather_3d_objects(obj.original, objects) return objects
[docs]def create_arrow_texture(width=2 ** 9, height=2 ** 9, color="red", right=True): """Create an array map of an arrow.""" import pythreejs as pjs color_rgba = list(Color(color).rgb) + [1.0] array = np.zeros((width, height, 4), dtype="float32") if right: # facing right for y in range(0, int(width / 4)): for x in range(int(height * 9 / 24), int(height * 15 / 24)): array[x, y, :] = color_rgba for y in range(int(width / 4), int(width / 2)): for x in range( int(0 + height * y / (width)), int(height - height * y / (width)) ): array[x, y, :] = color_rgba else: # facing left for y in range(int(width * 3 / 4), int(width)): for x in range(int(height * 9 / 24), int(height * 15 / 24)): array[x, y, :] = color_rgba for y in range(int(width / 2), int(width * 3 / 4)): for x in range( int(height - ((height / 2) * y / (width / 2))), int(0 + ((height / 2) * y / (width / 2))), ): array[x, y, :] = color_rgba return pjs.DataTexture(data=array, format="RGBAFormat", type="FloatType")