Saturday, March 11, 2017

Java 8: Tree Traversal Using Streams

It's quite common to traverse a tree using recursion, but with Java 8 it is now possible to lazily traverse a tree using streams.

The class below represents a tree. The stream() method streams the nodes contained in the tree. You can then do all the cool things that you can do with other streams, such as filtering, mapping and collecting!

public class TreeNode<E> {

  private final E data;
  private final List<TreeNode<E>> children;

  /**
   * Creates the tree node with the specified data and no children.
   * @param data
   */
  public TreeNode(final E data) {
    this.data = data;
    this.children = new ArrayList<>();
  }

  /**
   * @return the data contained in this node
   */
  public E getData() {
    return data;
  }

  /**
   * Adds a child to this tree node.
   * @param data the data to add
   */
  public TreeNode<E> addChild(final E data) {
    final TreeNode<E> toAdd = new TreeNode<>(data);
    children.add(toAdd);
    return toAdd;
  }

  /**
   * @return a stream of nodes in this tree in depth first order
   */
  public Stream<TreeNode<E>> stream() {
    return Stream.concat(Stream.of(this), children.stream().flatMap(TreeNode::stream));
  }
}

Example usage:

TreeNode<String> root = new TreeNode<>("Root");
TreeNode<String> a = root.addChild("A");
a.addChild("A1");
a.addChild("A2");
root.addChild("B");
TreeNode<String> c = root.addChild("C");
c.addChild("C1");

int count = root.stream().count(); // 7

String tree = root.stream().map(TreeNode::getData).collect(Collectors.joining(","));
// Root,A,A1,A2,B,C,C1