diff --git a/model/map/src/main/java/org/keycloak/models/map/storage/tree/DefaultTreeNode.java b/model/map/src/main/java/org/keycloak/models/map/storage/tree/DefaultTreeNode.java index 23075a1693..80bead89e1 100644 --- a/model/map/src/main/java/org/keycloak/models/map/storage/tree/DefaultTreeNode.java +++ b/model/map/src/main/java/org/keycloak/models/map/storage/tree/DefaultTreeNode.java @@ -24,11 +24,15 @@ import java.util.LinkedList; import java.util.List; import java.util.ListIterator; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Queue; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; +import java.util.stream.Stream; +import java.util.stream.Stream.Builder; /** * Generic implementation of a node in a tree. @@ -39,26 +43,29 @@ import java.util.function.Predicate; */ public class DefaultTreeNode> implements TreeNode { + private static final AtomicInteger COUNTER = new AtomicInteger(); + private final Map nodeProperties; private final Map edgeProperties; private final Map treeProperties; private final LinkedList children = new LinkedList<>(); private String id; private Self parent; + private final int uniqueId = COUNTER.getAndIncrement(); /** * @param treeProperties Reference to tree properties map. Tree properties are maintained outside of this node. */ protected DefaultTreeNode(Map treeProperties) { - this.treeProperties = treeProperties; + this.treeProperties = treeProperties == null ? Collections.emptyMap() : treeProperties; this.edgeProperties = new HashMap<>(); this.nodeProperties = new HashMap<>(); } public DefaultTreeNode(Map nodeProperties, Map edgeProperties, Map treeProperties) { - this.nodeProperties = nodeProperties; - this.edgeProperties = edgeProperties; - this.treeProperties = treeProperties; + this.treeProperties = treeProperties == null ? Collections.emptyMap() : treeProperties; + this.edgeProperties = edgeProperties == null ? new HashMap<>() : edgeProperties; + this.nodeProperties = nodeProperties == null ? new HashMap<>() : nodeProperties; } @Override @@ -169,6 +176,37 @@ public class DefaultTreeNode> implements Tree return Optional.empty(); } + @Override + public void walkBfs(Consumer visitor) { + Queue queue = new LinkedList<>(); + queue.add(getThis()); + while (! queue.isEmpty()) { + Self node = queue.poll(); + visitor.accept(node); + queue.addAll(node.getChildren()); + } + } + + @Override + public void walkDfs(Consumer visitorUponEntry, Consumer visitorAfterChildrenVisited) { + if (visitorUponEntry != null) { + visitorUponEntry.accept(getThis()); + } + for (Self child : children) { + child.walkDfs(visitorUponEntry, visitorAfterChildrenVisited); + } + if (visitorAfterChildrenVisited != null) { + visitorAfterChildrenVisited.accept(getThis()); + } + } + + @Override + public void forEachParent(Consumer visitor) { + for (Optional p = getParent(); p.isPresent(); p = p.get().getParent()) { + visitor.accept(p.get()); + } + } + @Override public List getPathToRoot(PathOrientation orientation) { LinkedList res = new LinkedList<>(); @@ -303,4 +341,73 @@ public class DefaultTreeNode> implements Tree private Self getThis() { return (Self) this; } + + @Override + public int hashCode() { + return this.uniqueId; + } + + @Override + public boolean equals(Object obj) { + return this == obj; + } + + @Override + public Stream getParentsStream() { + Builder resBuilder = Stream.builder(); + for (Optional p = getParent(); p.isPresent(); p = p.get().getParent()) { + resBuilder.accept(p.get()); + } + return resBuilder.build(); + } + + private static final ThreadLocal TOSTRING_DETAILS = new ThreadLocal() { + @Override + protected Boolean initialValue() { + return Boolean.TRUE; + } + + }; + + /** + * Print a tree structure in a pretty ASCII format. + * Adopted from https://stackoverflow.com/a/53705889/6930869 + * + * @param prefix Current prefix. Use "" in initial call + * @param node The current node + * @param getChildrenFunc A {@link Function} that returns the children of a given node. + * @param isTail Is node the last of its siblings. Use true in initial call. (This is needed for pretty printing.) + * @param The type of your nodes. Anything that has a toString can be used. + */ + private static StringBuilder toString(StringBuilder output, String prefix, DefaultTreeNode node, boolean isTail) { + String nodeName = node.getLabel(); + if (Objects.equals(TOSTRING_DETAILS.get(), Boolean.FALSE)) { + return new StringBuilder("@").append(nodeName); + } + String nodeConnection = isTail ? (prefix.isEmpty() ? "O── " : "└── ") : "├── "; + output.append(prefix).append(nodeConnection).append(nodeName); + try { + TOSTRING_DETAILS.set(Boolean.FALSE); + output.append(node.getNodeProperties().isEmpty() ? "" : " " + node.getNodeProperties()); + } finally { + TOSTRING_DETAILS.set(Boolean.TRUE); + } + output.append(System.lineSeparator()); + List> ch = node.getChildren(); + for (int i = 0; i < ch.size(); i ++) { + String newPrefix = prefix + (isTail ? " " : "│ "); + toString(output, newPrefix, ch.get(i), i == ch.size() - 1); + } + return output; + } + + protected String getLabel() { + return getId(); + } + + @Override + public String toString() { + return toString(new StringBuilder(), "", getThis(), true).toString(); + } + } diff --git a/model/map/src/main/java/org/keycloak/models/map/storage/tree/TreeNode.java b/model/map/src/main/java/org/keycloak/models/map/storage/tree/TreeNode.java index 259edd13fe..430328632c 100644 --- a/model/map/src/main/java/org/keycloak/models/map/storage/tree/TreeNode.java +++ b/model/map/src/main/java/org/keycloak/models/map/storage/tree/TreeNode.java @@ -19,7 +19,9 @@ package org.keycloak.models.map.storage.tree; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; import java.util.function.Predicate; +import java.util.stream.Stream; /** * Interface representing a node in a tree that has ID. @@ -166,4 +168,29 @@ public interface TreeNode> { * @return */ List getPathToRoot(PathOrientation orientation); + + /** + * Returns a stream of the nodes laying on the path from this node (exclusive) to the root of the tree (inclusive). + * @return + */ + Stream getParentsStream(); + + /** + * Calls the given {@code visitor} on each node laying on the path from this node (exclusive) to the root of the tree (inclusive). + * @param visitor + */ + void forEachParent(Consumer visitor); + + /** + * Walks the tree with the given visitor in depth-first search manner. + * @param visitorUponEntry Visitor called upon entry of the node. May be {@code null}, in that case no action is performed. + * @param visitorAfterChildrenVisited Visitor called before exit of the node. May be {@code null}, in that case no action is performed. + */ + void walkDfs(Consumer visitorUponEntry, Consumer visitorAfterChildrenVisited); + + /** + * Walks the tree with the given visitor in breadth-first search manner. + * @param visitor + */ + void walkBfs(Consumer visitor); } diff --git a/model/map/src/test/java/org/keycloak/models/map/storage/tree/DefaultTreeNodeTest.java b/model/map/src/test/java/org/keycloak/models/map/storage/tree/DefaultTreeNodeTest.java index efc072c4aa..5b72a486ea 100644 --- a/model/map/src/test/java/org/keycloak/models/map/storage/tree/DefaultTreeNodeTest.java +++ b/model/map/src/test/java/org/keycloak/models/map/storage/tree/DefaultTreeNodeTest.java @@ -31,6 +31,7 @@ import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; @@ -41,13 +42,16 @@ import static org.hamcrest.Matchers.notNullValue; public class DefaultTreeNodeTest { private class Node extends DefaultTreeNode { + public Node() { super(treeProperties); } + public Node(String id) { super(treeProperties); setId(id); } + public Node(Node parent, String id) { super(treeProperties); setId(id); @@ -55,7 +59,7 @@ public class DefaultTreeNodeTest { } @Override - public String toString() { + public String getLabel() { return this.getId() == null ? "Node:" + System.identityHashCode(this) : this.getId(); } } @@ -68,6 +72,7 @@ public class DefaultTreeNodeTest { private static final Integer VALUE_3 = 12345; public Map treeProperties = new HashMap<>(); + { treeProperties.put(KEY_1, VALUE_1); treeProperties.put(KEY_2, VALUE_2); @@ -330,85 +335,168 @@ public class DefaultTreeNodeTest { @Test public void testDfs() { - Node root = new Node("1"); - Node child11 = new Node(root, "1.1"); - Node child12 = new Node(root, "1.2"); + Node root = new Node("1"); + Node child11 = new Node(root, "1.1"); + Node child12 = new Node(root, "1.2"); Node child111 = new Node(child11, "1.1.1"); - Node child112 = new Node(child11, "1.1.2"); - Node child121 = new Node(child12, "1.2.1"); - Node child122 = new Node(child12, "1.2.2"); - Node child123 = new Node(child12, "1.2.3"); - Node child1121 = new Node(child112, "1.1.2.1"); + Node child112 = new Node(child11, "1.1.2"); + Node child121 = new Node(child12, "1.2.1"); + Node child122 = new Node(child12, "1.2.2"); + Node child123 = new Node(child12, "1.2.3"); + Node child1121 = new Node(child112, "1.1.2.1"); List res = new LinkedList<>(); - assertThat(root.findFirstDfs(n -> { res.add(n); return false; }), is(Optional.empty())); + assertThat(root.findFirstDfs(n -> { + res.add(n); + return false; + }), is(Optional.empty())); assertThat(res, contains(root, child11, child111, child112, child1121, child12, child121, child122, child123)); res.clear(); - assertThat(root.findFirstDfs(n -> { res.add(n); return n == child12; }), is(Optional.of(child12))); + assertThat(root.findFirstDfs(n -> { + res.add(n); + return n == child12; + }), is(Optional.of(child12))); assertThat(res, contains(root, child11, child111, child112, child1121, child12)); } @Test public void testDfsBottommost() { - Node root = new Node("1"); - Node child11 = new Node(root, "1.1"); - Node child12 = new Node(root, "1.2"); - Node child13 = new Node(root, "1.3"); + Node root = new Node("1"); + Node child11 = new Node(root, "1.1"); + Node child12 = new Node(root, "1.2"); + Node child13 = new Node(root, "1.3"); Node child111 = new Node(child11, "1.1.1"); - Node child112 = new Node(child11, "1.1.2"); - Node child121 = new Node(child12, "1.2.1"); - Node child122 = new Node(child12, "1.2.2"); - Node child123 = new Node(child12, "1.2.3"); - Node child1121 = new Node(child112, "1.1.2.1"); - Node child131 = new Node(child13, "1.3.1"); - Node child132 = new Node(child13, "1.3.2"); + Node child112 = new Node(child11, "1.1.2"); + Node child121 = new Node(child12, "1.2.1"); + Node child122 = new Node(child12, "1.2.2"); + Node child123 = new Node(child12, "1.2.3"); + Node child1121 = new Node(child112, "1.1.2.1"); + Node child131 = new Node(child13, "1.3.1"); + Node child132 = new Node(child13, "1.3.2"); List res = new LinkedList<>(); - assertThat(root.findFirstBottommostDfs(n -> { res.add(n); return false; }), is(Optional.empty())); + assertThat(root.findFirstBottommostDfs(n -> { + res.add(n); + return false; + }), is(Optional.empty())); assertThat(res, contains(root, child11, child111, child112, child1121, child12, child121, child122, child123, child13, child131, child132)); res.clear(); - assertThat(root.findFirstBottommostDfs(n -> { res.add(n); return n == child12; }), is(Optional.of(child12))); + assertThat(root.findFirstBottommostDfs(n -> { + res.add(n); + return n == child12; + }), is(Optional.of(child12))); assertThat(res, contains(root, child11, child111, child112, child1121, child12, child121, child122, child123)); res.clear(); - assertThat(root.findFirstBottommostDfs(n -> { res.add(n); return n.getId().startsWith("1.1.2"); }), is(Optional.of(child1121))); + assertThat(root.findFirstBottommostDfs(n -> { + res.add(n); + return n.getId().startsWith("1.1.2"); + }), is(Optional.of(child1121))); assertThat(res, contains(root, child11, child111, child112, child1121)); } @Test public void testBfs() { - Node root = new Node("1"); - Node child11 = new Node(root, "1.1"); - Node child12 = new Node(root, "1.2"); + Node root = new Node("1"); + Node child11 = new Node(root, "1.1"); + Node child12 = new Node(root, "1.2"); Node child111 = new Node(child11, "1.1.1"); - Node child112 = new Node(child11, "1.1.2"); - Node child121 = new Node(child12, "1.2.1"); - Node child122 = new Node(child12, "1.2.2"); - Node child123 = new Node(child12, "1.2.3"); - Node child1121 = new Node(child112, "1.1.2.1"); + Node child112 = new Node(child11, "1.1.2"); + Node child121 = new Node(child12, "1.2.1"); + Node child122 = new Node(child12, "1.2.2"); + Node child123 = new Node(child12, "1.2.3"); + Node child1121 = new Node(child112, "1.1.2.1"); List res = new LinkedList<>(); - assertThat(root.findFirstBfs(n -> { res.add(n); return false; }), is(Optional.empty())); + assertThat(root.findFirstBfs(n -> { + res.add(n); + return false; + }), is(Optional.empty())); assertThat(res, contains(root, child11, child12, child111, child112, child121, child122, child123, child1121)); res.clear(); - assertThat(root.findFirstBfs(n -> { res.add(n); return n == child12; }), is(Optional.of(child12))); + assertThat(root.findFirstBfs(n -> { + res.add(n); + return n == child12; + }), is(Optional.of(child12))); assertThat(res, contains(root, child11, child12)); } @Test - public void testPathToRoot() { - Node root = new Node("1"); - Node child11 = new Node(root, "1.1"); - Node child12 = new Node(root, "1.2"); + public void testWalkBfs() { + Node root = new Node("1"); + Node child11 = new Node(root, "1.1"); + Node child12 = new Node(root, "1.2"); Node child111 = new Node(child11, "1.1.1"); - Node child112 = new Node(child11, "1.1.2"); - Node child121 = new Node(child12, "1.2.1"); - Node child122 = new Node(child12, "1.2.2"); - Node child123 = new Node(child12, "1.2.3"); - Node child1121 = new Node(child112, "1.1.2.1"); + Node child112 = new Node(child11, "1.1.2"); + Node child121 = new Node(child12, "1.2.1"); + Node child122 = new Node(child12, "1.2.2"); + Node child123 = new Node(child12, "1.2.3"); + Node child1121 = new Node(child112, "1.1.2.1"); + + List res = new LinkedList<>(); + root.walkBfs(res::add); + assertThat(res, contains(root, child11, child12, child111, child112, child121, child122, child123, child1121)); + } + + @Test + public void testWalkDfs() { + Node root = new Node("1"); + Node child11 = new Node(root, "1.1"); + Node child12 = new Node(root, "1.2"); + Node child111 = new Node(child11, "1.1.1"); + Node child112 = new Node(child11, "1.1.2"); + Node child121 = new Node(child12, "1.2.1"); + Node child122 = new Node(child12, "1.2.2"); + Node child123 = new Node(child12, "1.2.3"); + Node child1121 = new Node(child112, "1.1.2.1"); + + List uponEntry = new LinkedList<>(); + List afterChildren = new LinkedList<>(); + root.walkDfs(uponEntry::add, afterChildren::add); + assertThat(uponEntry, contains(root, child11, child111, child112, child1121, child12, child121, child122, child123)); + assertThat(afterChildren, contains(child111, child1121, child112, child11, child121, child122, child123, child12, root)); + } + + @Test + public void testForEachParent() { + Node root = new Node("1"); + Node child11 = new Node(root, "1.1"); + Node child12 = new Node(root, "1.2"); + Node child111 = new Node(child11, "1.1.1"); + Node child112 = new Node(child11, "1.1.2"); + Node child121 = new Node(child12, "1.2.1"); + Node child122 = new Node(child12, "1.2.2"); + Node child123 = new Node(child12, "1.2.3"); + Node child1121 = new Node(child112, "1.1.2.1"); + + List res = new LinkedList<>(); + res.clear(); + root.forEachParent(res::add); + assertThat(res, empty()); + + res.clear(); + child1121.forEachParent(res::add); + assertThat(res, contains(child112, child11, root)); + + res.clear(); + child123.forEachParent(res::add); + assertThat(res, contains(child12, root)); + } + + @Test + public void testPathToRoot() { + Node root = new Node("1"); + Node child11 = new Node(root, "1.1"); + Node child12 = new Node(root, "1.2"); + Node child111 = new Node(child11, "1.1.1"); + Node child112 = new Node(child11, "1.1.2"); + Node child121 = new Node(child12, "1.2.1"); + Node child122 = new Node(child12, "1.2.2"); + Node child123 = new Node(child12, "1.2.3"); + Node child1121 = new Node(child112, "1.1.2.1"); assertThat(child1121.getPathToRoot(PathOrientation.TOP_FIRST), contains(root, child11, child112, child1121)); assertThat(child123.getPathToRoot(PathOrientation.TOP_FIRST), contains(root, child12, child123)); @@ -419,6 +507,13 @@ public class DefaultTreeNodeTest { assertThat(root.getPathToRoot(PathOrientation.BOTTOM_FIRST), contains(root)); } + @Test + public void testToStringStackOverflow() { + Node n = new Node("1"); + n.setNodeProperty("prop", n); + assertThat(n.toString().length(), lessThan(255)); + } + private void assertTreeProperties(Node node) { assertThat(node.getTreeProperty(KEY_1, String.class), notNullValue()); assertThat(node.getTreeProperty(KEY_1, Date.class), notNullValue());