|
| 1 | +# Design for Horovod Based AllReduce |
| 2 | + |
| 3 | +This document describes the design for supporting AllReduce-based distributed |
| 4 | +training based on [Horovod](https://github.com/horovod/horovod) in ElasticDL. |
| 5 | + |
| 6 | +## Motivation |
| 7 | + |
| 8 | +We have developed elastic AllReduce based on FTlib in ElasticDL. |
| 9 | +From the [benchmark report](../benchmark/ftlib_benchmark.md), we can |
| 10 | +find that the performance of FTlib for ResNet50 is worse than Horovod. |
| 11 | + |
| 12 | +FTlib uses the gossip protocol to build consensus, which is not stable |
| 13 | +as described in this [issue](https://github.com/sql-machine-learning/elasticdl/issues/2192#issuecomment-664096185). |
| 14 | + |
| 15 | +FTlib uses Gloo to implement elastic AllReduce because the worker process |
| 16 | +can catch the exception from Gloo and not exit if any collective communications |
| 17 | +fail. Horovod also can use Gloo as the backend. |
| 18 | +There are many small parameter tensors in the ResNet50 model. We have to |
| 19 | +launch an AllReduce operator to synchronize each tensor. It brings a lot |
| 20 | +of overhead. There are many optimizations like [Tensor Fusion](https://horovod.readthedocs.io/en/latest/tensor-fusion_include.html) |
| 21 | +in Horovod to reduce the overhead. So, the performance of Horovod for ResNet50 |
| 22 | +is better than FTlib. |
| 23 | + |
| 24 | +On the Kubernetes cluster, we usually use [kubeflow/mpi-operator](https://github.com/kubeflow/mpi-operator) |
| 25 | +to submit a Horovod job. kubeflow/mpi-operator is not fault-tolerant and |
| 26 | +elastic. Horovod has supported [elastic training](https://horovod.readthedocs.io/en/latest/elastic_include.html) |
| 27 | +which can scale up and down the number of workers dynamically at runtime. |
| 28 | +Elastic Horovod needs a shell script to discover worker hosts on the cluster. |
| 29 | +However, it is difficult for users to use Kubernetes API to discover |
| 30 | +worker pod hosts. What's more, data access is a problem for data parallel |
| 31 | +training it the number of workers changes. Random sampling is a solution |
| 32 | +that may affect the training accuracy. There is a master process in ElasticDL. |
| 33 | +The master can get all worker hosts by Kubernetes API and dynamically |
| 34 | +assign data shards for workers to solve data access for elastic training. |
| 35 | +So, it is more user-friendly to run an elastic AllReduce-based training |
| 36 | +job using ElasticDL with Horovod. |
| 37 | + |
| 38 | +## ElasticDL Re-initialize Horovod When the Number of Workers Changes |
| 39 | + |
| 40 | +When using Horovod with Gloo backend, we need to create a `RendezvousServer`, |
| 41 | +which has a KVStore. Also, we need to put a host plan into this KVStore. |
| 42 | +A host plan includes worker hosts and their assigned ranks, |
| 43 | +which are required by Gloo. The master in ElasticDL is responsible for creating |
| 44 | +the RendezvousServer. To support elastic training, when the master detects |
| 45 | +the number of workers changes, it will create a new host plan and put it |
| 46 | +in the KVStore. |
| 47 | + |
| 48 | +```python |
| 49 | +import horovod |
| 50 | +from horovod.run.http.http_server import RendezvousServer |
| 51 | +from horovod.runner.common.util.hosts import get_host_assignments |
| 52 | + |
| 53 | +hosts = get_worker_hosts() |
| 54 | + |
| 55 | +host_alloc_plan = get_host_assignments(hosts, num_proc) |
| 56 | +global_rendezv_port = rendezvous.start() |
| 57 | + |
| 58 | +# Set hosts into KVStore for Gloo |
| 59 | +rendezvous.init(host_alloc_plan) |
| 60 | +``` |
| 61 | + |
| 62 | +Then, the worker can call `hvd.init` to initialize the Gloo context for |
| 63 | +AllReduce. |
| 64 | + |
| 65 | +When the master finds the number of workers changes, it can re-create a new |
| 66 | +`RendezvousServer` and notify workers to re-initialize Horovod. |
| 67 | +In the Kubernetes cluster, the number of workers may change for the |
| 68 | +following reasons: |
| 69 | + |
| 70 | +1. Some workers fail because of preemption. |
| 71 | +1. A worker pod status becomes running. |
| 72 | + |
| 73 | +In the first case, the Horovod AllReduce operator will raise an exception |
| 74 | +and the worker can catch the exception and re-initialize. |
| 75 | + |
| 76 | +In the second case, the worker will query the master periodically to see |
| 77 | +if there are new workers and re-initialization of the AllReduce process |
| 78 | +the group is needed. |
| 79 | + |
| 80 | +## The Worker Averages Gradients Using Horovod |
| 81 | + |
| 82 | +Using TensorFlow eager execution, we can use `hvd.DistributedGradientTape` |
| 83 | +to wrap `tf.GradientTape` to average gradients. |
| 84 | + |
| 85 | +```python |
| 86 | +@tf.function |
| 87 | +def training_process_with_horovod(self, features, labels): |
| 88 | + with tf.GradientTape() as tape: |
| 89 | + outputs = self._model.call(features, training=True) |
| 90 | + loss = self._loss(labels, outputs) |
| 91 | + tape = hvd.DistributedGradientTape(tape) |
| 92 | + grads = tape.gradient(loss, mnist_model.trainable_variables) |
| 93 | + return loss, grads |
| 94 | +``` |
| 95 | + |
| 96 | +If some workers fail, the `hvd.DistributedGradientTape` will raise |
| 97 | +a `tensorflow.python.framework.errors_impl.UnknownError`. We can catch |
| 98 | +the error and re-initialize the Horovod context if the error contains |
| 99 | +`HorovodAllreduce`, `HorovodAllgather`, or `HorovodBroadcast`. |
| 100 | + |
| 101 | +```python |
| 102 | +def training_process_horovod_fault_tolerance(self, freature, labels) |
| 103 | + from tensorflow.python.framework.errors_impl import UnknownError |
| 104 | + initialize_horovod = False |
| 105 | + |
| 106 | + hosts_update = query_worker_hosts_updated(master) |
| 107 | + if hosts_updated: |
| 108 | + initialize_horovod = True |
| 109 | + |
| 110 | + if not initialize_horovod: |
| 111 | + try: |
| 112 | + loss, grads = self.training_process_with_horovod(features, labels) |
| 113 | + except UnknownError as e: |
| 114 | + if ('HorovodAllreduce' in e.message or |
| 115 | + 'HorovodAllgather' in e.message or |
| 116 | + 'HorovodBroadcast' in e.message): |
| 117 | + initialize_horovod = True |
| 118 | + |
| 119 | + if initialize_horovod: |
| 120 | + hvd.shutdown() |
| 121 | + hvd.init() |
| 122 | +``` |
| 123 | + |
| 124 | +After initializing Horovod, we should broadcast the model in alive workers to |
| 125 | +the new workers. The master can assign rank 0 to the oldest worker, as it will |
| 126 | +be used as the broadcast source to synchronize models among workers. |
| 127 | + |
| 128 | +```python |
| 129 | +from horovod.tensorflow.functions import broadcast_variables |
| 130 | +def _broadcast_model(model, optimizer, backend): |
| 131 | + broadcast_variables(model.variables, root_rank=0) |
| 132 | + broadcast_variables(optimizer.variables(), root_rank=0) |
| 133 | +``` |
0 commit comments