Suppose we want to train a machine learning model with a large dataset $D$ using $N$ computing units. We naturally split $D$ into $N$ chunks and feed each chunk to one computing unit. Doing so will leave us with $N$ gradients $\mathbf g_1, \mathbf g_2, \dots, \mathbf g_N$. How do we go about finding the combined gradient: $\mathbf G=\frac{1}{N}\sum_i \mathbf g_i$ and propagate this to all computing units? ## Parameter Server This is a straightforward solution. Each unit will send its gradient to a central server (hence the name "parameter server"). The central server will compute the combined gradient and send it back to each unit. The problem with this approach is it requires high bandwidth on the Parameter Server. Let $S$ be the size of a gradient. The Parameter Server needs bandwidth size $2NS$ (receiving $N$ streams and sending $N$ streams). ## Tree AllReduce Tree AllReduce eases the bandwidth problem of the Parameter Server. Instead of all units sending to one single unit, the units form a balanced binary tree. Starting from the leaves, each unit will send their gradients to their parents to be combined, all the way until the root unit. Then, the root unit will pass down the fully combined gradients back down all the way to the leaves. Let's look at the bandwidth requirement for each unit. If the unit is a leaf, it makes 1 send and 1 receive, making the bandwidth $2S$. If the unit is the root, it makes 2 receives and 2 sends, making the bandwidth $4S$. For all other units, it makes 3 receives and 3 sends, making the bandwidth $6S$. The maximum bandwidth of this algorithm is then $6S$. This is a great step forward because it means that the bandwidth requirement per node does not scale according to the number of computing units. If we compare latency, which is the number of message passes for a unit to get the final combined gradient, Parameter Server only needs 2 whereas Tree AllReduce requires $2\log N$. ### Double Binary Tree Optimization This technique provides a constant-level optimization to Tree AllReduce. Notice that the bandwidth usage across units is unbalanced. Leaf units have bandwidth $2S$ while middle units have bandwidth $6S$. The double binary tree technique aims to balance the bandwidth usage, thus reducing the peak bandwidth. We create 2 binary trees $A$ and $B$. In tree $A$, even-numbered units will be the leaves. In $B$, odd-numbered units will be the leaves. Then, we split the gradient $\mathbf g$ into 2 chunks: $\mathbf g=(\mathbf g_A\quad \mathbf g_B)$ The binary tree $A$ is responsible for $\mathbf g_A$ while binary tree $B$ is responsible for $\mathbf g_B$. Let's analyze the bandwidth of the units with this setup: | | Tree $A$ | Tree $B$ | Total | | ---------- | ----------- | ----------- | ----- | | Even units | $2(S/2)=S$ | $6(S/2)=3S$ | $4S$ | | Odd units | $6(S/2)=3S$ | $2(S/2)=S$ | $4S$ | We see that the bandwidth of each unit is more balanced, and thus reduced the overall peak bandwidth from $6S$ to $4S$. Double Binary Tree Optimization is implemented in [NVIDIA's NCCL](https://github.com/NVIDIA/nccl) for inter-GPU communication. ## Ring AllReduce Ring AllReduce addresses the bandwidth problem using a logical ring. High-level idea: 1. **Setup:** Breakdown each gradient vector into $N$ chunks:$\mathbf g_i=(\mathbf g_{i,1} \quad \mathbf g_{i,2} \quad \dots \quad \mathbf g_{i,N})$ 2. **Scatter-Reduce:** Computing unit $i$ is responsible for the combined gradient of the $i$-th chunk: $\mathbf G_i = \frac{1}{N}\sum_j\mathbf g_{j, i}$ 3. **All-Gather:** Computing unit $i$ sends out $\mathbf G_i$ to all other units to construct $\mathbf G$:$\mathbf G=(G_1\quad G_2 \quad \dots \quad G_N)$ Scatter-Reduce and All-Gather each take $N-1$ rounds, so each unit will have $2(N-1)$ rounds. In each round the bandwidth is $2S/N$ for send and receive altogether. In total the bandwidth is $4(N-1)/N \times S\leq 4S$. ## Gossip The mechanism of the Gossip algorithm is very simple. The user sets the bandwidth budget. Each unit will send its gradients to random neighbors as long as it satisfies the bandwidth. It only settles for **eventual consistency**. At any point in time each unit may hold a different gradient. It is useful for decentralized learning over the open internet, where devices (like mobile phones) drop in and out frequently. ## Comparison | | Bandwidth | Latency | Fault-tolerance | | -------------------------- | -------------- | ------------------------- | --------------- | | Parameter Server | $O(N)$[^1] | $O(1)$ | Robust | | Tree AllReduce | $6S=O(1)$ | $O(\log N)$ | Brittle | | Tree AllReduce (Optimized) | $4S=O(1)$ | $O(\log N)$ | Brittle | | Ring AllReduce | $4S=O(1)$ | $O(N)$ | Brittle | | Gossip | High (tunable) | $O(\log N)$ Probabilistic | Robust | [^1]: Only on Parameter Server (bottleneck). Workers are $O(1)$.