diff --git a/svgpathtools/document.py b/svgpathtools/document.py index 87b0521..e90c17d 100644 --- a/svgpathtools/document.py +++ b/svgpathtools/document.py @@ -101,6 +101,8 @@ def flatten_all_paths(group, group_filter=lambda x: True, # Stop right away if the group_selector rejects this group if not group_filter(group): + warnings.warn('The input group [{}] (id attribute: {}) was rejected by the group filter' + .format(group, group.get('id'))) return [] # To handle the transforms efficiently, we'll traverse the tree of @@ -174,10 +176,48 @@ def flatten_group(group_to_flatten, root, recursive=True, else: desired_groups.add(id(group_to_flatten)) + ignore_paths = set() + # Use breadth-first search to find the path to the group that we care about + if root is not group_to_flatten: + search = [[root]] + route = None + while search: + top = search.pop(0) + frontier = top[-1] + for child in frontier.iterfind(group_search_xpath, SVG_NAMESPACE): + if child is group_to_flatten: + route = top + break + future_top = list(top) + future_top.append(child) + search.append(future_top) + + if route is not None: + for group in route: + # Add each group from the root to the parent of the desired group + # to the list of groups that we should traverse. This makes sure + # that flatten_all_paths will not stop before reaching the desired + # group. + desired_groups.add(id(group)) + for key in path_conversions.keys(): + for path_elem in group.iterfind('svg:'+key, SVG_NAMESPACE): + # Add each path in the parent groups to the list of paths + # that should be ignored. The user has not requested to + # flatten the paths of the parent groups, so we should not + # include any of these in the result. + ignore_paths.add(id(path_elem)) + break + + if route is None: + raise ValueError('The group_to_flatten is not a descendant of the root!') + def desired_group_filter(x): return (id(x) in desired_groups) and group_filter(x) - return flatten_all_paths(root, desired_group_filter, path_filter, + def desired_path_filter(x): + return (id(x) not in ignore_paths) and path_filter(x) + + return flatten_all_paths(root, desired_group_filter, desired_path_filter, path_conversions, group_search_xpath) @@ -223,13 +263,17 @@ class Document: if all(isinstance(s, str) for s in group): # If we're given a list of strings, assume it represents a # nested sequence - group = self.get_or_add_group(group) + group = self.get_group(group) elif not isinstance(group, Element): raise TypeError( 'Must provide a list of strings that represent a nested ' 'group name, or provide an xml.etree.Element object. ' 'Instead you provided {0}'.format(group)) + if group is None: + warnings.warn("Could not find the requested group!") + return [] + return flatten_group(group, self.tree.getroot(), recursive, group_filter, path_filter, path_conversions) @@ -282,6 +326,37 @@ class Document: def contains_group(self, group): return any(group is owned for owned in self.tree.iter()) + def get_group(self, nested_names, name_attr='id'): + """Get a group from the tree, or None if the requested group + does not exist. Use get_or_add_group(~) if you want a new group + to be created if it did not already exist. + + `nested_names` is a list of strings which represent group names. + Each group name will be nested inside of the previous group name. + + `name_attr` is the group attribute that is being used to + represent the group's name. Default is 'id', but some SVGs may + contain custom name labels, like 'inkscape:label'. + + Returns the request group. If the requested group did not + exist, this function will return a None value. + """ + group = self.tree.getroot() + # Drill down through the names until we find the desired group + while len(nested_names): + prev_group = group + next_name = nested_names.pop(0) + for elem in group.iterfind(SVG_GROUP_TAG, SVG_NAMESPACE): + if elem.get(name_attr) == next_name: + group = elem + break + + if prev_group is group: + # The nested group could not be found, so we return None + return None + + return group + def get_or_add_group(self, nested_names, name_attr='id'): """Get a group from the tree, or add a new one with the given name structure. diff --git a/test/test_groups.py b/test/test_groups.py index f493524..d33aea2 100644 --- a/test/test_groups.py +++ b/test/test_groups.py @@ -165,6 +165,14 @@ class TestGroups(unittest.TestCase): self.assertEqual(expected_count, count) + def test_nested_group(self): + # A bug in the flatten_group() implementation made it so that only top-level + # groups could have their paths flattened. This is a regression test to make + # sure that when a nested group is requested, its paths can also be flattened. + doc = Document(join(dirname(__file__), 'groups.svg')) + result = doc.flatten_group(['matrix group', 'scale group']) + self.assertEqual(len(result), 5) + def test_add_group(self): # Test `Document.add_group()` function and related Document functions. doc = Document(None)