some renames and add __repr__ and pretty()

pull/114/head
Andy Port 2020-06-23 21:54:58 -07:00
parent 445899b2eb
commit 1f7503aabd
2 changed files with 51 additions and 38 deletions

View File

@ -15,7 +15,7 @@ Example:
>> from svgpathtools import * >> from svgpathtools import *
>> doc = Document('my_file.html') >> doc = Document('my_file.html')
>> for path in doc.flattened_paths(): >> for path in doc.paths():
>> # Do something with the transformed Path object. >> # Do something with the transformed Path object.
>> foo(path) >> foo(path)
>> # Inspect the raw SVG element, e.g. change its attributes >> # Inspect the raw SVG element, e.g. change its attributes
@ -39,6 +39,7 @@ import os
import collections import collections
import xml.etree.ElementTree as etree import xml.etree.ElementTree as etree
from xml.etree.ElementTree import Element, SubElement, register_namespace from xml.etree.ElementTree import Element, SubElement, register_namespace
from xml.dom.minidom import parseString
import warnings import warnings
from tempfile import gettempdir from tempfile import gettempdir
from time import time from time import time
@ -75,7 +76,7 @@ CONVERT_ONLY_PATHS = {'path': path2pathd}
SVG_GROUP_TAG = 'svg:g' SVG_GROUP_TAG = 'svg:g'
def flatten_all_paths(group, group_filter=lambda x: True, def flattened_paths(group, group_filter=lambda x: True,
path_filter=lambda x: True, path_conversions=CONVERSIONS, path_filter=lambda x: True, path_conversions=CONVERSIONS,
group_search_xpath=SVG_GROUP_TAG): group_search_xpath=SVG_GROUP_TAG):
"""Returns the paths inside a group (recursively), expressing the """Returns the paths inside a group (recursively), expressing the
@ -84,7 +85,7 @@ def flatten_all_paths(group, group_filter=lambda x: True,
Note that if the group being passed in is nested inside some parent Note that if the group being passed in is nested inside some parent
group(s), we cannot take the parent group(s) into account, because group(s), we cannot take the parent group(s) into account, because
xml.etree.Element has no pointer to its parent. You should use xml.etree.Element has no pointer to its parent. You should use
Document.flatten_group(group) to flatten a specific nested group into Document.flattened_paths_from_group(group) to flatten a specific nested group into
the root coordinates. the root coordinates.
Args: Args:
@ -149,8 +150,9 @@ def flatten_all_paths(group, group_filter=lambda x: True,
return paths return paths
def flatten_group(group_to_flatten, root, recursive=True, def flattened_paths_from_group(group_to_flatten, root, recursive=True,
group_filter=lambda x: True, path_filter=lambda x: True, group_filter=lambda x: True,
path_filter=lambda x: True,
path_conversions=CONVERSIONS, path_conversions=CONVERSIONS,
group_search_xpath=SVG_GROUP_TAG): group_search_xpath=SVG_GROUP_TAG):
"""Flatten all the paths in a specific group. """Flatten all the paths in a specific group.
@ -196,7 +198,7 @@ def flatten_group(group_to_flatten, root, recursive=True,
for group in route: for group in route:
# Add each group from the root to the parent of the desired group # Add each group from the root to the parent of the desired group
# to the list of groups that we should traverse. This makes sure # to the list of groups that we should traverse. This makes sure
# that flattened_paths will not stop before reaching the desired # that paths will not stop before reaching the desired
# group. # group.
desired_groups.add(id(group)) desired_groups.add(id(group))
for key in path_conversions.keys(): for key in path_conversions.keys():
@ -217,12 +219,12 @@ def flatten_group(group_to_flatten, root, recursive=True,
def desired_path_filter(x): def desired_path_filter(x):
return (id(x) not in ignore_paths) and path_filter(x) return (id(x) not in ignore_paths) and path_filter(x)
return flatten_all_paths(root, desired_group_filter, desired_path_filter, return flattened_paths(root, desired_group_filter, desired_path_filter,
path_conversions, group_search_xpath) path_conversions, group_search_xpath)
class Document: class Document:
def __init__(self, filepath): def __init__(self, filepath=None):
"""A container for a DOM-style SVG document. """A container for a DOM-style SVG document.
The `Document` class provides a simple interface to modify and analyze The `Document` class provides a simple interface to modify and analyze
@ -241,22 +243,21 @@ class Document:
if filepath is not None and os.path.dirname(filepath) == '': if filepath is not None and os.path.dirname(filepath) == '':
self.original_filepath = os.path.join(os.getcwd(), filepath) self.original_filepath = os.path.join(os.getcwd(), filepath)
if filepath is not None: if filepath is None:
self.tree = etree.ElementTree(Element('svg'))
else:
# parse svg to ElementTree object # parse svg to ElementTree object
self.tree = etree.parse(filepath) self.tree = etree.parse(filepath)
else:
self.tree = etree.ElementTree(Element('svg'))
self.root = self.tree.getroot() self.root = self.tree.getroot()
def flattened_paths(self, group_filter=lambda x: True, def paths(self, group_filter=lambda x: True,
path_filter=lambda x: True, path_conversions=CONVERSIONS): path_filter=lambda x: True, path_conversions=CONVERSIONS):
"""Forward the tree of this document into the more general """Returns a list of all paths in the document."""
flattened_paths function and return the result.""" return flattened_paths(self.tree.getroot(), group_filter,
return flatten_all_paths(self.tree.getroot(), group_filter,
path_filter, path_conversions) path_filter, path_conversions)
def flatten_group(self, group, recursive=True, group_filter=lambda x: True, def paths_from_group(self, group, recursive=True, group_filter=lambda x: True,
path_filter=lambda x: True, path_conversions=CONVERSIONS): path_filter=lambda x: True, path_conversions=CONVERSIONS):
if all(isinstance(s, str) for s in group): if all(isinstance(s, str) for s in group):
# If we're given a list of strings, assume it represents a # If we're given a list of strings, assume it represents a
@ -272,7 +273,7 @@ class Document:
warnings.warn("Could not find the requested group!") warnings.warn("Could not find the requested group!")
return [] return []
return flatten_group(group, self.tree.getroot(), recursive, return flattened_paths_from_group(group, self.tree.getroot(), recursive,
group_filter, path_filter, path_conversions) group_filter, path_filter, path_conversions)
def add_path(self, path, attribs=None, group=None): def add_path(self, path, attribs=None, group=None):
@ -410,21 +411,33 @@ class Document:
return SubElement(parent, '{{{0}}}g'.format( return SubElement(parent, '{{{0}}}g'.format(
SVG_NAMESPACE['svg']), group_attribs) SVG_NAMESPACE['svg']), group_attribs)
def save(self, filepath): def __repr__(self):
return etree.tostring(self.tree.getroot()).decode()
def pretty(self, **kwargs):
return parseString(repr(self)).toprettyxml(**kwargs)
def save(self, filepath, prettify=False, **kwargs):
with open(filepath, 'w') as output_svg: with open(filepath, 'w') as output_svg:
output_svg.write(etree.tostring(self.tree.getroot())) if prettify:
output_svg.write(self.pretty(**kwargs))
else:
output_svg.write(repr(self))
def display(self, filepath=None): def display(self, filepath=None):
"""Displays/opens the doc using the OS's default application.""" """Displays/opens the doc using the OS's default application."""
if filepath is None: if filepath is None:
if self.original_filepath is None: # created from empty Document
orig_name, ext = 'unnamed', '.svg'
else:
orig_name, ext = \ orig_name, ext = \
os.path.splitext(os.path.basename(self.original_filepath)) os.path.splitext(os.path.basename(self.original_filepath))
filepath = os.path.join(gettempdir(), tmp_name = orig_name + '_' + str(time()).replace('.', '-') + ext
orig_name + '_' + str(time()).replace('.', '-') + ext) filepath = os.path.join(gettempdir(), tmp_name)
# write to a (by default temporary) file # write to a (by default temporary) file
with open(filepath, 'w') as output_svg: with open(filepath, 'w') as output_svg:
output_svg.write(etree.tostring(self.tree.getroot()).decode()) output_svg.write(self.as_string())
open_in_browser(filepath) open_in_browser(filepath)

View File

@ -26,7 +26,7 @@ class TestGroups(unittest.TestCase):
# end point relative to the start point # end point relative to the start point
# * name is the path name (value of the test:name attribute in # * name is the path name (value of the test:name attribute in
# the SVG document) # the SVG document)
# * paths is the output of doc.flattened_paths() # * paths is the output of doc.paths()
v_s_vals.append(1.0) v_s_vals.append(1.0)
v_e_relative_vals.append(0.0) v_e_relative_vals.append(0.0)
v_s = np.array(v_s_vals) v_s = np.array(v_s_vals)
@ -38,7 +38,7 @@ class TestGroups(unittest.TestCase):
self.check_values(tf.dot(v_e), actual.path.end) self.check_values(tf.dot(v_e), actual.path.end)
def test_group_flatten(self): def test_group_flatten(self):
# Test the Document.flattened_paths() function against the # Test the Document.paths() function against the
# groups.svg test file. # groups.svg test file.
# There are 12 paths in that file, with various levels of being # There are 12 paths in that file, with various levels of being
# nested inside of group transforms. # nested inside of group transforms.
@ -48,7 +48,7 @@ class TestGroups(unittest.TestCase):
# that are specified by the SVG standard. # that are specified by the SVG standard.
doc = Document(join(dirname(__file__), 'groups.svg')) doc = Document(join(dirname(__file__), 'groups.svg'))
result = doc.flattened_paths() result = doc.paths()
self.assertEqual(12, len(result)) self.assertEqual(12, len(result))
tf_matrix_group = np.array([[1.5, 0.0, -40.0], tf_matrix_group = np.array([[1.5, 0.0, -40.0],
@ -166,11 +166,11 @@ class TestGroups(unittest.TestCase):
self.assertEqual(expected_count, count) self.assertEqual(expected_count, count)
def test_nested_group(self): def test_nested_group(self):
# A bug in the flatten_group() implementation made it so that only top-level # A bug in the flattened_paths_from_group() implementation made it so that only top-level
# groups could have their paths flattened. This is a regression test to make # groups could have their paths flattened. This is a regression test to make
# sure that when a nested group is requested, its paths can also be flattened. # sure that when a nested group is requested, its paths can also be flattened.
doc = Document(join(dirname(__file__), 'groups.svg')) doc = Document(join(dirname(__file__), 'groups.svg'))
result = doc.flatten_group(['matrix group', 'scale group']) result = doc.paths_from_group(['matrix group', 'scale group'])
self.assertEqual(len(result), 5) self.assertEqual(len(result), 5)
def test_add_group(self): def test_add_group(self):