Exploring Computation Graphs in Rust

17 Jun 2018
Paul Kernfeld dot com

Recently, I’ve been trying to figure out a good way to model computation graphs in Rust. In this post, I explore using a graph with vector indices. I’m not sure if this is the best approach, but writing it out has helped me to understand the advantages and disadvantages better.

When I say “computation graph,” I mean a representation of a mathematical expression like 2 * a + a * b. This example contains a constant (2), two variables (a and b), and two functions (addition and multiplication). This expression can be modeled as a directed acyclic graph:

2   a   b
 \ / \ /
  *   *
   \ /
    +

In my ASCII diagram above, all edges go downwards. In general, edges don’t really have any interesting information associated with them so I’m going to pretty much ignore them.

A homogenous graph

Below is a homogeneous DAG as a jumping-off point. This is strongly inspired by Modeling graphs in Rust using vector indices. This isn’t a computation graph, but it does let us implement an algorithm that traverses the graph and memoizes intermediate values. This is important because the number of paths through a DAG can grow exponentially with the number of nodes, meaning that recursive implementations can be very slow.

/// Deriving Copy reduces ownership headaches
#[derive(Copy, Clone)]
pub struct Idx(usize);

pub struct Node {
    children: Vec<Idx>,
}

#[derive(Default)]
pub struct Graph {
    nodes: Vec<Node>,
}

/// Graph maintains the invariant that nodes can only be added, never removed. This means that
/// a particular Idx will always be valid as long as it is used with the correct Graph.
impl Graph {
    pub fn push(&mut self, children: Vec<Idx>) -> Idx {
        self.nodes.push(Node { children });
        Idx(self.nodes.len() - 1)
    }

    /// This returns the number of paths between each leaf node and the final node. This
    /// implementation memoizes the number of paths from leaves to each node.
    ///
    /// Note that "final node" is only a meaningful concept in a DAG where there is one node
    /// that is the ancestor of every other node in the graph; I'm using it here for simplicity.
    pub fn count_paths(&self) -> usize {
        let mut path_counts = Vec::new();

        for node in &self.nodes {
            let paths_to_here = if node.children.is_empty() {
                1
            } else {
                node.children
                    .iter()
                    .map(|child_index| path_counts[child_index.0])
                    .sum()
            };
            path_counts.push(paths_to_here);
        }

        path_counts[path_counts.len() - 1]
    }
}

let mut g = Graph::default();
let a = g.push(vec![]);
let b = g.push(vec![a]);
let c = g.push(vec![a, b]);
let d = g.push(vec![a, b, c]);

// All paths are:
// a -> b -> c -> d
// a -> b ------> d
// a ------> c -> d
// a -----------> d
assert_eq!(4, g.count_paths())

What if Node is an enum?

Next, a graph that can actually do some computation. I’ve installed a few upgrades relative to the previous implementation:

use std::collections::{HashMap, HashSet};
use std::ops::{Add, Index};

/// To enable this to be used in HashMap and HashSet, this derives Eq, PartialEq, and Hash
#[derive(Copy, Clone, Eq, Hash, PartialEq)]
pub struct Idx(usize);

impl Add for Idx {
    type Output = Node;

    fn add(self, rhs: Idx) -> Node {
        Node::Sum { children: vec![self, rhs] }
    }
}

pub enum Node {
    Constant(f64),
    Variable,
    Sum { children: Vec<Idx> },
}

impl Node {
    fn get_value(&self, my_index: &Idx, values: &HashMap<Idx, f64>) -> f64 {
        match self {
            Node::Constant(value) => *value,
            Node::Variable => values[my_index],
            Node::Sum { children } => children.iter().map(|child| values[child]).sum(),
        }
    }

    fn derivative(
        &self,
        my_index: &Idx,
        wrt: &HashSet<Idx>,
        derivatives: &HashMap<Idx, Idx>,
    ) -> Node {
        match self {
            Node::Constant(_) => Node::Constant(0.0),
            Node::Variable => {
                if wrt.contains(my_index) {
                    Node::Constant(1.0)
                } else {
                    Node::Constant(0.0)
                }
            }
            Node::Sum { ref children } => {
                Node::Sum {
                    children: children.iter().map(|child| derivatives[child]).collect(),
                }
            }
        }
    }
}

/// This helps us to represent the idea that only a subset of the nodes in a graph might be
/// relevant for a particular computation. The indices in a Subgraph are ordered such that a
/// child always comes before one of its parents.
pub struct Subgraph {
    indices: Vec<Idx>,
}

impl Subgraph {
    fn new(indices_unsorted: impl Iterator<Item = Idx>) -> Self {
        let mut indices: Vec<Idx> = indices_unsorted.collect();

        // This is an easy way to enforce the order condition
        indices.sort_unstable_by_key(|index| index.0);
        Self { indices: indices }
    }
}

#[derive(Default)]
pub struct Graph {
    nodes: Vec<Node>,
}

impl Graph {
    pub fn push(&mut self, node: Node) -> Idx {
        self.nodes.push(node);
        Idx(self.nodes.len() - 1)
    }

    pub fn as_subgraph(&self) -> Subgraph {
        Subgraph { indices: self.nodes.iter().enumerate().map(|(i, _)| Idx(i)).collect() }
    }

    /// Given values for each relevant variable, this computes the value for each node in the
    /// graph.
    pub fn evaluate_subgraph(
        &self,
        subgraph: Subgraph,
        variable_to_value: HashMap<Idx, f64>,
    ) -> HashMap<Idx, f64> {
        let mut result = variable_to_value;

        for index in subgraph.indices.iter() {
            let value = self[*index].get_value(index, &result);
            result.insert(*index, value);
        }

        result
    }

    pub fn evaluate(&self, variable_to_value: HashMap<Idx, f64>) -> HashMap<Idx, f64> {
        self.evaluate_subgraph(self.as_subgraph(), variable_to_value)
    }

    /// This transforms the graph by taking the derivative
    pub fn derivative(&mut self, of: Idx, wrt: HashSet<Idx>) -> (Idx, Subgraph) {
        // Memoize the derivative of each node
        let mut derivatives: HashMap<Idx, Idx> = HashMap::new();

        for old_index in 0..self.nodes.len() {
            let old_index = Idx(old_index);
            let new_node = self[old_index].derivative(&old_index, &wrt, &derivatives);
            let new_index = self.push(new_node);
            derivatives.insert(old_index, new_index);
        }

        // The subgraph contains all the new nodes we just created
        (
            derivatives[&of],
            Subgraph::new(derivatives.values().cloned()),
        )
    }
}

impl Index<Idx> for Graph {
    type Output = Node;

    fn index(&self, index: Idx) -> &Node {
        &self.nodes[index.0]
    }
}

// c = 1 + b
let mut g = Graph::default();
let a = g.push(Node::Constant(1.0));
let b = g.push(Node::Variable);
let c = g.push(a + b);

// 1 + 2 = 3
let variable_to_value = {
    let mut result = HashMap::new();
    result.insert(b, 2.0);
    result
};
assert_eq!(3.0, g.evaluate(variable_to_value)[&c]);

// The derivative of c wrt b is just 1
let wrt = {
    let mut result = HashSet::new();
    result.insert(b);
    result
};
let (d_c_b, subgraph) = g.derivative(c, wrt);
assert_eq!(1.0, g.evaluate_subgraph(subgraph, HashMap::new())[&d_c_b]);

Overall, I’m pretty happy with this implementation. However, adding many different types of node will cause the match statements to balloon. Additionally, I would rather see all the code for one type of Node in one place.

What if Node is a trait?

To solve this, I’m going to make Node a trait instead of an enum. I have hidden aspects of the implementation that are the same as in the previous implementation; the complete implementation is in the source code for this post.

pub trait Node: 'static {
    /// The input must include values for all variables and for all children of this node.
    fn get_value(&self, my_index: &Idx, values: &HashMap<Idx, f64>) -> f64;

    fn derivative(
        &self,
        my_index: &Idx,
        wrt: &HashSet<Idx>,
        derivatives: &HashMap<Idx, Idx>,
    ) -> Box<Node>;
}

pub struct Constant(f64);

impl Node for Constant {
    fn get_value(&self, _my_index: &Idx, _values: &HashMap<Idx, f64>) -> f64 {
        self.0
    }

    fn derivative(
        &self,
        _my_index: &Idx,
        _wrt: &HashSet<Idx>,
        _derivatives: &HashMap<Idx, Idx>,
    ) -> Box<Node> {
        Box::from(Constant(0.0))
    }
}

pub struct Variable;

impl Node for Variable {
    fn get_value(&self, _my_index: &Idx, _values: &HashMap<Idx, f64>) -> f64 {
        _values[_my_index]
    }

    fn derivative(
        &self,
        my_index: &Idx,
        wrt: &HashSet<Idx>,
        _derivatives: &HashMap<Idx, Idx>,
    ) -> Box<Node> {
        if wrt.contains(my_index) {
            Box::from(Constant(1.0))
        } else {
            Box::from(Constant(0.0))
        }
    }
}

pub struct Sum {
    children: Vec<Idx>,
}

impl Node for Sum {
    fn get_value(&self, _my_index: &Idx, _values: &HashMap<Idx, f64>) -> f64 {
        self.children.iter().map(|child| _values[child]).sum()
    }

    fn derivative(
        &self,
        _my_index: &Idx,
        _wrt: &HashSet<Idx>,
        derivatives: &HashMap<Idx, Idx>,
    ) -> Box<Node> {
        Box::from(Sum {
            children: self.children
                .iter()
                .map(|child| derivatives[child])
                .collect(),
        })
    }
}

/// Since Node does not implement Sized, we need to box it so we can put it into a Vec.
#[derive(Default)]
pub struct Graph {
    nodes: Vec<Box<Node>>,
}

/// This is almost identical to the Graph implementation with the enum, except that the push
/// fn now accepts a Box<Node>, and I've added push_box.
impl Graph {
    pub fn push_box(&mut self, box_node: Box<Node>) -> Idx {
        self.nodes.push(box_node);
        Idx(self.nodes.len() - 1)
    }

    pub fn push<N: Node>(&mut self, node: N) -> Idx {
        self.push_box(Box::from(node))
    }
}

// c = 1 + b
let mut g = Graph::default();
let a = g.push(Constant(1.0));
let b = g.push(Variable);
let c = g.push_box(a + b);

// 1 + 2 = 3
let variable_to_value = {
    let mut result = HashMap::new();
    result.insert(b, 2.0);
    result
};
assert_eq!(3.0, g.evaluate(variable_to_value)[&c]);

// The derivative of c wrt b is just 1
let wrt = {
    let mut result = HashSet::new();
    result.insert(b);
    result
};
let (d_c_b, subgraph) = g.derivative(c, wrt);
assert_eq!(1.0, g.evaluate_subgraph(subgraph, HashMap::new())[&d_c_b]);

While implementing this, I noticed that making Node a trait enforces a clean separation of responsibility between the graph and the node. I actually used the implementation of this version to clean up the factorization of the enum version.

However, making Node a trait brings with it the ergonomic disadvantage that nodes often need to be passed around inside a Box, which is slightly annoying. Separately, there is a performance penalty because we are now using dynamic dispatch instead of static dispatch. I don’t think that I care too much about this, because I’m interested in using these graphs with large tensors where the cost of the actual computation will dwarf the cost of traversing the graph.

Advantages

I was happy about several aspects of this experiment:

Disadvantages

Future directions

I have an unhealthy obsession with building an elegant DSL in vanilla Rust. I would love to be able to create a graph by writing something like this:

let x = variable();
let y = variable();
let z = x * 2.0 + y;

One somewhat crazy direction of exploration would be to allow different nodes in the graph to implement different traits.

I would like to be able to save and load graphs.

It would be good to have a way to represent functions that are composed of smaller functions, like softmax.

These questions aside, the most obvious ways to make this more useful would be to implement many different functions and to allow computation on data such as tensors.

About

This blog post was produced using cargo-readme to ensure that all of the code actually works. The source code is here.