Source code for ase_notebook.backend.svg

"""A module for creating an SVG visualisation of a structure."""
from itertools import cycle
import os
import tempfile

import numpy as np
from svgwrite import Drawing, path, shapes, text
from svgwrite.container import Group
from svgwrite.filters import Filter


[docs]def generate_svg_elements(element_group, element_colors=None, background_color="white"): """Create the SVG elements, related to the 3D objects. Parameters ---------- element_group : ase_notebook.draw_elements.DrawGroup Container of all element groups to be created. background_color : str Returns ------- list[svgwrite.base.BaseElement] """ svg_elements = [] for _, element in element_group.yield_zorder(): if element.name == "atoms": if not element.get("visible", True): continue if element.occupancy is not None: from ase.data import atomic_numbers if (np.sum([o for o in element.occupancy.values()])) < 1.0: # first draw an empty circle if a site is not fully occupied svg_elements.append( shapes.Circle( element.position[:2], r=element.sradius, fill=background_color, fill_opacity=element.get("fill_opacity", 0.95), stroke=element.get("stroke", "black"), stroke_width=element.get("stroke_width", 1), ) ) angle_start = 0 # start with the dominant species for sym, occ in sorted( element.occupancy.items(), key=lambda x: x[1], reverse=True ): if np.round(occ, decimals=4) == 1.0: svg_elements.append( shapes.Circle( element.position[:2], r=element.sradius, fill=element_colors[atomic_numbers[sym]], fill_opacity=element.get("fill_opacity", 0.95), stroke=element.get("stroke_color", "black"), stroke_width=element.get("stroke_width", 1), ) ) else: angle_extent = 360.0 * occ svg_elements.append( create_arc_element( element.position[:2], angle_start, angle_start + angle_extent, element.sradius, fill=element_colors[atomic_numbers[sym]], fill_opacity=element.get("fill_opacity", 0.95), stroke=element.get("stroke_color", "black"), stroke_width=element.get("stroke_width", 1), ) ) angle_start += angle_extent else: svg_elements.append( shapes.Circle( element.position[:2], r=element.sradius, fill=element.color, fill_opacity=element.get("fill_opacity", 0.95), stroke=element.get("stroke_color", "black"), stroke_width=element.get("stroke_width", 1), ) ) if "label" in element and element.label is not None: svg_elements.append( text.Text( element.label, x=(int(element.position[0]),), y=(int(element.position[1]),), text_anchor="middle", dominant_baseline="middle", font_size=element.get("font_size", 20), fill=element.get("font_color", "black"), ) ) # TODO add force/velocity vectors # TODO add ghost crosses if element.name == "cell_lines": svg_elements.append( shapes.Line( element.position[0][:2], element.position[1][:2], stroke=element.get("color", "black"), # stroke_dasharray=f"{element.get('dashed', '6,4')}", ) ) if element.name == "bond_lines": start, end = element.position[0][:2], element.position[1][:2] svg_elements.append( shapes.Line( start, start + 0.5 * (end - start), stroke=element.color[0], stroke_width=element.get("stroke_width", 1), stroke_linecap="round", stroke_opacity=element.get("stroke_opacity", 0.8), ) ) svg_elements.append( shapes.Line( start + 0.5 * (end - start), end, stroke=element.color[1], stroke_width=element.get("stroke_width", 1), stroke_linecap="round", stroke_opacity=element.get("stroke_opacity", 0.8), ) ) if element.name == "miller_lines": svg_elements.append( shapes.Line( element.position[0][:2], element.position[1][:2], stroke=element.get("stroke_color", "blue"), stroke_width=element.get("stroke_width", 1), stroke_opacity=element.get("stroke_opacity", 0.8), ) ) if element.name == "miller_planes": svg_elements.append( shapes.Polygon( points=element.position[:, :2], fill=element.get("fill_color", "blue"), fill_opacity=element.get("fill_opacity", 0.5), stroke=element.get("stroke_color", "blue"), stroke_width=element.get("stroke_width", 0), stroke_opacity=element.get("stroke_opacity", 0.5), ) ) return svg_elements
[docs]def cart2polar(x, y): """Convert cartesian to polar coordinates.""" rho = np.sqrt(x ** 2 + y ** 2) phi = np.arctan2(y, x) return (rho, np.rad2deg(phi))
[docs]def polar2cart(radius, angle): """Convert polar to cartesian coordinates.""" x = radius * np.cos(np.radians(angle)) y = radius * np.sin(np.radians(angle)) return (x, y)
[docs]def create_arc_element(center, start, end, radius, **kwargs): """Create an arc (circle section) path element. Parameters ---------- center: tuple (x, y) start: float starting angle from x axis (in degrees) end: float final angle from x axis (in degrees) radius: float Returns ------- svgwrite.path.Path """ c = np.array(center) p1 = np.array(polar2cart(radius, start)) + c p2 = np.array(polar2cart(radius, end)) + c l1 = p1 - c l2 = p2 - p1 l3 = c - p2 if start < end: angle_dir = 1 large_arc = 1 if end - start >= 180 else 0 else: angle_dir = 0 large_arc = 1 if start - end >= 180 else 0 return path.Path( [ f"m{c[0]},{c[1]}", f"l{l1[0]},{l1[1]}", f"a{radius},{radius},{0},{large_arc},{angle_dir},{l2[0]},{l2[1]}", f"l{l3[0]},{l3[1]}", ], **kwargs, )
[docs]def create_axes_elements( axes, window_size, *, length=15, font_size=14, inset=(20, 20), font_offset=1.0, line_width=1, line_color="black", labels=("X", "Y", "Z"), colors=("red", "green", "blue"), ): """Create the SVG elements, related to the axes.""" svg_elements = [] for i in axes[:, 2].argsort(): a = inset[0] b = window_size[1] - inset[1] c = int(axes[i][0] * length + a) d = int(axes[i][1] * length + b) e = int(axes[i][0] * length * font_offset + a) f = int(axes[i][1] * length * font_offset + b) svg_elements.append( shapes.Line([a, b], [c, d], stroke=line_color, stroke_width=line_width) ) svg_elements.append( text.Text( labels[i], x=(e,), y=(f,), fill=colors[i], text_anchor="middle", dominant_baseline="middle", font_size=font_size, ) ) return svg_elements
[docs]def create_svg_document( elements, size, viewbox=None, background_color="white", background_opacity=1.0 ): """Create the full SVG document. :param viewbox: (minx, miny, width, height) """ dwg = Drawing("ase.svg", profile="tiny", size=size) root = Group(id="root") dwg.add(root) # if Color(background_color).web != "white": # apparently the best way, see: https://stackoverflow.com/a/11293812/5033292 root.add( shapes.Rect(size=size, fill=background_color, fill_opacity=background_opacity) ) for element in elements: root.add(element) if viewbox: dwg.viewbox(*viewbox) return dwg
[docs]def create_svg_document_with_light( elements, size, viewbox=None, background_color="white", background_opacity=1.0 ): """Create the full SVG document, with a lighting filter. Resources: - https://www.w3.org/TR/SVG11/filters.html#LightSourceDefinitions - https://svgwrite.readthedocs.io/en/master/classes/filters.html - http://www.svgbasics.com/filters2.html - https://css-tricks.com/look-svg-light-source-filters/ :param viewbox: (minx, miny, width, height) """ # TODO work in progress # TODO have a look at how threejs is converted to SVG: # https://github.com/mrdoob/three.js/blob/master/examples/jsm/renderers/SVGRenderer.js dwg = Drawing("ase.svg", profile="full", size=size) light_filter = dwg.defs.add(Filter(size=("100%", "100%"))) diffuse_lighting = light_filter.feDiffuseLighting( size=size, surfaceScale=10, diffuseConstant=1, kernelUnitLength=1, color="white" ) diffuse_lighting.fePointLight(source=(size[0], 0, 1000)) light_filter.feComposite(operator="arithmetic", k1=1) root = Group(id="root", filter=light_filter.get_funciri()) dwg.add(root) # if Color(background_color).web != "white": # apparently the best way, see: https://stackoverflow.com/a/11293812/5033292 root.add( shapes.Rect(size=size, fill=background_color, fill_opacity=background_opacity) ) for element in elements: root.add(element) if viewbox: dwg.viewbox(*viewbox) return dwg
[docs]def string_to_compose(string): """Convert an SVG string to a ``svgutils.compose.SVG``.""" from svgutils.compose import SVG from svgutils.transform import fromstring svg_figure = fromstring(string) element = SVG() element.root = svg_figure.getroot().root return element, list(map(float, svg_figure.get_size()))
[docs]def tessellate_rectangles(sizes, max_columns=None): """Compute the minimum size grid, required to fit a list of rectangles.""" original_length = len(sizes) sizes = np.array(sizes, dtype=float) max_columns = min(max_columns, len(sizes)) if max_columns else len(sizes) overflow = sizes.shape[0] % max_columns empty = max_columns - overflow if overflow else 0 sizes = np.concatenate((sizes, np.full([empty, 2], np.nan))) if len(sizes.shape) == 2: sizes = sizes.reshape((sizes.shape[0], 1, 2)) sizes = np.reshape(sizes, (int(len(sizes) / max_columns), max_columns, 2)) heights = sizes[:, :, 1] widths = sizes[:, :, 0] xv, yv = np.meshgrid(np.nanmax(widths, axis=0), np.nanmax(heights, axis=1)) max_width = np.nanmax(np.cumsum(xv, axis=1)) max_height = np.nanmax(np.cumsum(yv, axis=0)) wposition = (np.cumsum(xv, axis=1) - xv).flatten() hposition = (np.cumsum(yv, axis=0) - yv).flatten() return [max_width, max_height], list(zip(wposition, hposition))[:original_length]
[docs]def get_svg_string(svg): """Return the raw string of an SVG object with a ``tostring`` or ``to_str`` method.""" if isinstance(svg, str): return svg if hasattr(svg, "tostring"): # svgwrite.drawing.Drawing.tostring() return svg.tostring() if hasattr(svg, "to_str"): # svgutils.transform.SVGFigure.to_str() return svg.to_str() raise TypeError(f"SVG cannot be converted to a raw string: {svg}")
[docs]def concatenate_svgs( svgs, max_columns=None, scale=None, label=False, size=12, weight="bold", inset=(0.1, 0.1), ): """Create a grid of SVGs, with a maximum number of columns. Parameters ---------- svgs : list Items may be raw SVG strings, or any objects with a ``tostring`` or ``to_str`` method. max_columns : int or None max number of columns, or if None, only use one row scale : float or None scale the entire composition label : bool whether to add a label for each SVG (cycle through upper case letters) size : int label font size weight : str label font weight inset : tuple inset the label by x times the SVG width and y times the SVG height Returns ------- svgutils.compose.Figure """ # TODO could replace svgutils with use of lxml primitives from svgutils.compose import Figure, Text label_iter = cycle("ABCDEFGHIJKLMNOPQRSTUVWXYZ") svg_composes, dimensions = zip( *[string_to_compose(get_svg_string(svg)) for svg in svgs] ) if scale: [svg.scale(scale) for svg in svg_composes] dimensions = [(w * scale, h * scale) for (w, h) in dimensions] (width, height), positions = tessellate_rectangles(dimensions, max_columns) elements = [] for svg, (x, y), (w, h) in zip(svg_composes, positions, dimensions): elements.append(svg.move(x, y)) if label: elements.append( Text( next(label_iter), x=x + inset[0] * w, y=y + inset[1] * h, size=size, weight=weight, ) ) return Figure(width, height, *elements)
[docs]def svg_to_pdf(svg, file_name=None): """Convert SVG to PDF. To view in notebook:: from IPython.display import display_pdf rlg_drawing = svg_to_pdf(svg) display_pdf(rlg_drawing.asString("pdf"), raw=True) """ from svglib.svglib import svg2rlg from reportlab.graphics import renderPDF string = get_svg_string(svg) fd, fname = tempfile.mkstemp() try: with open(fname, "w") as handle: handle.write(string) rlg_drawing = svg2rlg(fname) finally: if os.path.exists(fname): os.remove(fname) if file_name: renderPDF.drawToFile(rlg_drawing, file_name) return rlg_drawing