diff --git a/svgpathtools/document.py b/svgpathtools/document.py index ec04d8b..96406ca 100644 --- a/svgpathtools/document.py +++ b/svgpathtools/document.py @@ -15,7 +15,7 @@ Example: >> from svgpathtools import * >> doc = Document('my_file.html') - >> for path in doc.flattened_paths(): + >> for path in doc.paths(): >> # Do something with the transformed Path object. >> foo(path) >> # Inspect the raw SVG element, e.g. change its attributes @@ -39,6 +39,7 @@ import os import collections import xml.etree.ElementTree as etree from xml.etree.ElementTree import Element, SubElement, register_namespace +from xml.dom.minidom import parseString import warnings from tempfile import gettempdir from time import time @@ -75,16 +76,16 @@ CONVERT_ONLY_PATHS = {'path': path2pathd} SVG_GROUP_TAG = 'svg:g' -def flatten_all_paths(group, group_filter=lambda x: True, - path_filter=lambda x: True, path_conversions=CONVERSIONS, - group_search_xpath=SVG_GROUP_TAG): +def flattened_paths(group, group_filter=lambda x: True, + path_filter=lambda x: True, path_conversions=CONVERSIONS, + group_search_xpath=SVG_GROUP_TAG): """Returns the paths inside a group (recursively), expressing the paths in the base coordinates. Note that if the group being passed in is nested inside some parent group(s), we cannot take the parent group(s) into account, because xml.etree.Element has no pointer to its parent. You should use - Document.flatten_group(group) to flatten a specific nested group into + Document.flattened_paths_from_group(group) to flatten a specific nested group into the root coordinates. Args: @@ -149,10 +150,11 @@ def flatten_all_paths(group, group_filter=lambda x: True, return paths -def flatten_group(group_to_flatten, root, recursive=True, - group_filter=lambda x: True, path_filter=lambda x: True, - path_conversions=CONVERSIONS, - group_search_xpath=SVG_GROUP_TAG): +def flattened_paths_from_group(group_to_flatten, root, recursive=True, + group_filter=lambda x: True, + path_filter=lambda x: True, + path_conversions=CONVERSIONS, + group_search_xpath=SVG_GROUP_TAG): """Flatten all the paths in a specific group. The paths will be flattened into the 'root' frame. Note that root @@ -196,7 +198,7 @@ def flatten_group(group_to_flatten, root, recursive=True, for group in route: # Add each group from the root to the parent of the desired group # to the list of groups that we should traverse. This makes sure - # that flattened_paths will not stop before reaching the desired + # that paths will not stop before reaching the desired # group. desired_groups.add(id(group)) for key in path_conversions.keys(): @@ -217,12 +219,12 @@ def flatten_group(group_to_flatten, root, recursive=True, def desired_path_filter(x): return (id(x) not in ignore_paths) and path_filter(x) - return flatten_all_paths(root, desired_group_filter, desired_path_filter, - path_conversions, group_search_xpath) + return flattened_paths(root, desired_group_filter, desired_path_filter, + path_conversions, group_search_xpath) class Document: - def __init__(self, filepath): + def __init__(self, filepath=None): """A container for a DOM-style SVG document. The `Document` class provides a simple interface to modify and analyze @@ -241,23 +243,22 @@ class Document: if filepath is not None and os.path.dirname(filepath) == '': self.original_filepath = os.path.join(os.getcwd(), filepath) - if filepath is not None: + if filepath is None: + self.tree = etree.ElementTree(Element('svg')) + else: # parse svg to ElementTree object self.tree = etree.parse(filepath) - else: - self.tree = etree.ElementTree(Element('svg')) self.root = self.tree.getroot() - def flattened_paths(self, group_filter=lambda x: True, - path_filter=lambda x: True, path_conversions=CONVERSIONS): - """Forward the tree of this document into the more general - flattened_paths function and return the result.""" - return flatten_all_paths(self.tree.getroot(), group_filter, - path_filter, path_conversions) + def paths(self, group_filter=lambda x: True, + path_filter=lambda x: True, path_conversions=CONVERSIONS): + """Returns a list of all paths in the document.""" + return flattened_paths(self.tree.getroot(), group_filter, + path_filter, path_conversions) - def flatten_group(self, group, recursive=True, group_filter=lambda x: True, - path_filter=lambda x: True, path_conversions=CONVERSIONS): + def paths_from_group(self, group, recursive=True, group_filter=lambda x: True, + path_filter=lambda x: True, path_conversions=CONVERSIONS): if all(isinstance(s, str) for s in group): # If we're given a list of strings, assume it represents a # nested sequence @@ -272,8 +273,8 @@ class Document: warnings.warn("Could not find the requested group!") return [] - return flatten_group(group, self.tree.getroot(), recursive, - group_filter, path_filter, path_conversions) + return flattened_paths_from_group(group, self.tree.getroot(), recursive, + group_filter, path_filter, path_conversions) def add_path(self, path, attribs=None, group=None): """Add a new path to the SVG.""" @@ -410,21 +411,33 @@ class Document: return SubElement(parent, '{{{0}}}g'.format( SVG_NAMESPACE['svg']), group_attribs) - def save(self, filepath): + def __repr__(self): + return etree.tostring(self.tree.getroot()).decode() + + def pretty(self, **kwargs): + return parseString(repr(self)).toprettyxml(**kwargs) + + def save(self, filepath, prettify=False, **kwargs): with open(filepath, 'w') as output_svg: - output_svg.write(etree.tostring(self.tree.getroot())) + if prettify: + output_svg.write(self.pretty(**kwargs)) + else: + output_svg.write(repr(self)) def display(self, filepath=None): """Displays/opens the doc using the OS's default application.""" if filepath is None: - orig_name, ext = \ - os.path.splitext(os.path.basename(self.original_filepath)) - filepath = os.path.join(gettempdir(), - orig_name + '_' + str(time()).replace('.', '-') + ext) + if self.original_filepath is None: # created from empty Document + orig_name, ext = 'unnamed', '.svg' + else: + orig_name, ext = \ + os.path.splitext(os.path.basename(self.original_filepath)) + tmp_name = orig_name + '_' + str(time()).replace('.', '-') + ext + filepath = os.path.join(gettempdir(), tmp_name) # write to a (by default temporary) file with open(filepath, 'w') as output_svg: - output_svg.write(etree.tostring(self.tree.getroot()).decode()) + output_svg.write(self.as_string()) open_in_browser(filepath) diff --git a/test/test_groups.py b/test/test_groups.py index 3eec79c..8e39cee 100644 --- a/test/test_groups.py +++ b/test/test_groups.py @@ -26,7 +26,7 @@ class TestGroups(unittest.TestCase): # end point relative to the start point # * name is the path name (value of the test:name attribute in # the SVG document) - # * paths is the output of doc.flattened_paths() + # * paths is the output of doc.paths() v_s_vals.append(1.0) v_e_relative_vals.append(0.0) v_s = np.array(v_s_vals) @@ -38,7 +38,7 @@ class TestGroups(unittest.TestCase): self.check_values(tf.dot(v_e), actual.path.end) def test_group_flatten(self): - # Test the Document.flattened_paths() function against the + # Test the Document.paths() function against the # groups.svg test file. # There are 12 paths in that file, with various levels of being # nested inside of group transforms. @@ -48,7 +48,7 @@ class TestGroups(unittest.TestCase): # that are specified by the SVG standard. doc = Document(join(dirname(__file__), 'groups.svg')) - result = doc.flattened_paths() + result = doc.paths() self.assertEqual(12, len(result)) tf_matrix_group = np.array([[1.5, 0.0, -40.0], @@ -166,11 +166,11 @@ class TestGroups(unittest.TestCase): self.assertEqual(expected_count, count) def test_nested_group(self): - # A bug in the flatten_group() implementation made it so that only top-level + # A bug in the flattened_paths_from_group() implementation made it so that only top-level # groups could have their paths flattened. This is a regression test to make # sure that when a nested group is requested, its paths can also be flattened. doc = Document(join(dirname(__file__), 'groups.svg')) - result = doc.flatten_group(['matrix group', 'scale group']) + result = doc.paths_from_group(['matrix group', 'scale group']) self.assertEqual(len(result), 5) def test_add_group(self):