In computer science, a trie, or prefix tree, is an ordered tree data structure that is used to store an associative array where the keys are usually strings. Read more about tries in wikipedia. In this post we will implement a trie in Java. You can visualize a trie as below:

To implement a trie, first of all, we need a representation for the nodes, the circles in the diagram above. Below is the class that we use to encapsulate these nodes:

  private static class Node {
    public char letter;
    public boolean ends;
    
    public Node(char letter, boolean ends) {
      this.letter = letter;
      this.ends = ends;
    }
    
    public Map<Character, Node> children = new HashMap<Character, Node>();
  }

In the diagram, we see that the partial word formed till each node is stored in the node. We will not be following this, we will just store the letter in each node. To indicate that a word ends at this particular node, each node will also be associated with a boolean ends which indicates as to whether a word terminates at this particular letter/node, for example, say we are storing the word “cat”, then the node containing ‘t’ will have the ends variable set to true. Also, each node has a children map, which contains all the letters extending from the particular node. In the diagram above, the node ‘t’ will have two children ‘o’ and ‘e’.

With this knowledge, let us write an empty trie class as below:

public class Trie {
  private Map<Character, Node> root = new HashMap<Character, Node>();
}

Like in the diagram above, we do not start with an empty node for the root, but our root is a map of all root elements. In our trie, if we have all the words in the English dictionary, then the root will have a maximum of 26 entries. Instead of using a map, we can optimize this by using an array of size 26 and using the character values as the indices to the appropriate letters.

Now, let us implement the add method, the method that will be used to add words to the trie.

  public void add(String word) {
    char[] chars = word.toCharArray();
    
    //Get the first letter in the word and see whether it exists as an entry in the root.
    char first = chars[0];

    Node current;
    if (root.containsKey(first)) {
      //Root already contains the letter, hence treat that as the current element.
      current = root.get(first); 
    } else {
      //Letter is not already present as an entry in the root, hence create it and add it to root.
      Node node = new Node(first, false);
      root.put(first, node);
      current = node;
    }

    //Iterate over the other letters.
    for (int i = 1; i < chars.length; ++i) {
      char aChar = chars[i];

      if (current.children.containsKey(aChar)) {
        //This letter is already present in the tree as a child. Hence, just update the current to point to the node
        //represented by this letter.
        current = current.children.get(aChar);
      } else {
        //This letter is not present as a child of the current element, hence add it.
        Node node = new Node(aChar, false);
        current.children.put(aChar, node);
        current = node;
      }
    }

    //Our word ends here, mark the current node as the terminal node.
    current.ends = true;
  }

Now, let us implement a method to check whether a sequence of letters exists in the trie.

  public boolean contains(String sequence) {
    char[] chars = sequence.toCharArray();

    char first = chars[0];

    //If the first letter is not present as a key in the root, there is no chance of this sequence being present in the trie.
    if (!root.containsKey(first)) {
      return false;
    }

    Node current = root.get(first);

    for (int i = 1; i < chars.length; ++i) {
      char aChar = chars[i];

      //This letter is not present in the trie, hence the trie does not have this sequence.
      if (!current.children.containsKey(aChar)) {
        return false;
      }

      current = current.children.get(aChar);
    }

    //We came out of the loop, it means that all the letters are there in the trie, in the appropriate order.
    return true;
  }

Now comes the function which fetches all words from the trie corresponding to the passed in prefix.

  public List<String> get(String prefix) {
    List<String> words = new LinkedList<String>();

    char[] chars = prefix.toCharArray();

    char first = chars[0];

    //First letter is not present as a key in the root, hence there cannot be any words in this trie corresponding to the
    //prefix.
    if (!root.containsKey(first)) {
      return words;
    }

    Node current = root.get(first);

    for (int i = 1; i < chars.length; ++i) {
      char aChar = chars[i];

      //No entry in the tire for this letter, hence there cannot be any words in the trie for the passed in prefix.
      if (!current.children.containsKey(aChar)) {
        return words;
      }

      current = current.children.get(aChar);
    }

    //At this point, we can be sure that there are words in the trie corresponding to the passed in prefix. Now plow down 
    //the depths of the tree and form all the words corresponding to the prefix.
    return formWords(current, prefix, words);
  }

  private List<String> formWords(Node node, String prefix, List<String> words) {
    //A word can be formed at this point, hence add it to the container.
    if (node.ends) {
      words.add(prefix);
      
      //This branch of the tree ends here as there are no more children. Hence return.
      if (node.children.size() == 0) {
        return words;
      }
    }
    
    //Recursively go through all the children of the node and form words.
    for (Map.Entry<Character, Node> entry : node.children.entrySet()) {
      Node _node = entry.getValue();
      formWords(entry.getValue(), prefix + _node.letter, words);
    }

    return words;
  }

That is all there is to it. The code does not do any error checking, only happy path is taken into account. Full source code is present in github.

Advertisements