背景

同步机制在TensorFlow等分布式机器学习框架中非常重要,比如TensorFlow有以下场景需要做同步:

  • 当chief worker训练完一轮后,保存模型前需要等所有worker都完成再保存模型。
  • BSP方式的SGD训练,需要每个batch做同步。

如果不做同步可能会出现如下问题:

  • TensorFlow大部分使用方案都是异步SGD,而且使用global_step做停止条件,不能保证所有worker负责的数据训练相同的轮数,速度快的worker所负责的数据将会获得更多step。
  • chief worker结束时会保存模型参数,但还存在其他worker没结束,所以模型没有完全训练完整。

最优的方式应该是这样:

epoch = 0
while epoch < max_epoch:
train_one_epoch # 跑一轮数据
barrier # 卡在这里,等所有worker都跑完一轮数据
save checkpoint # 保存这一轮的模型
do evaluation # 跑一遍验证集数据
epoch++ # 进入下一轮

那怎样实现barrier机制呢?

Barrier机制实现

具体原理就是在PS:0节点上添加和worker数目一样的一组计数变量counter_vars,初始化时为0,每当worker结束一轮训练后,将自己的worker_index对应的counter_var增加1,然后依次遍历其他worker对应的counter_var,直到所有worker的counter_var都等于1说明所有worker都完成这一轮训练了,然后就可以进入下一轮训练。当然,barrier也可以用于其他任意同步的方式,比如退出时也可以加个barrier,等所有worker都结束后保存模型再退出。

class Barrier(object):
def __init__(self, worker_num, barrier_num, sleep_time_ms=10):
self._worker_num = worker_num
self._barrier_num = barrier_num
self._sleep_time_ms = sleep_time_ms
self._counter_vars = []
self._counter_add_ops = []
self._counter_reset_ops = []
ps_device = '/job:ps/task:0/cpu:0'
with tf.device(ps_device):
for i in range(self._barrier_num):
for j in range(self._worker_num):
counter_var = tf.get_variable(
'counter-{}_{}'.format(i, j),
(),
tf.int32,
initializer=tf.zeros_initializer
)
self._counter_vars.append(counter_var)
self._counter_add_ops.append(counter_var.assign_add(1, use_locking=True))
self._counter_reset_ops.append(counter_var.assign(0, use_locking=True))

def barrier_reset(self, session, worker_index, barrier_index):
index = barrier_index * self._worker_num + worker_index
session.run(self._counter_reset_ops[index])

def barrier(self, session, worker_index, barrier_index, epoch):
for task_index in range(self._worker_num):
if task_index == worker_index:
session.run(self._counter_add_ops[barrier_index * self._worker_num + worker_index])
index = barrier_index * self._worker_num + task_index
count = session.run(self._counter_vars[index])
retry_num = 0
while count < epoch:
time.sleep(self._sleep_time_ms)
retry_num += 1
count = session.run(self._counter_vars[index])
if retry_num == 1:
tf.logging.info("{} wait for {}_{} to be completed".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), task_index))

训练代码

首先不能使用QueueRunner读数据,因为它无法实现按轮次读取,需要使用DataSet来读取数据,保证worker知道每轮数据读完了。

def _parse_function(example_proto):
features = {}
features['label'] = tf.FixedLenFeature([], tf.float32)
features['feature'] = tf.FixedLenFeature([100], tf.int64)
instance = tf.parse_example(example_proto, features)
label = instance['label']
feature = instance['feature']
return label, feature

if job_name == 'ps':
with tf.device('/cpu:0'):
server.join()
elif job_name == 'worker':
with tf.device(param_server_device):
# dataset input
dataset = tf.data.TFRecordDataset(file_name)
dataset = dataset.prefetch(buffer_size=batch_size*100)
dataset = dataset.shuffle(buffer_size=batch_size*10)
dataset = dataset.batch(batch_size)
dataset = dataset.map(_parse_function, num_parallel_calls=4)
train_iterator = dataset.make_initializable_iterator()
train_label, train_feature = train_iterator.get_next()

# forward pass
model = ...
train_logits = model.forward(train_feature)

# loss
train_label = tf.to_int64(train_label)
train_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=train_logits, labels=train_label)
train_loss = tf.reduce_mean(train_cross_entropy, name='loss')

# optimizer
opt = tf.train.AdamOptimizer()
train_op = opt.minimize(train_loss)

# barrier
barrier_op = barrier.Barrier(self.num_worker, 2) # 下面需要两处做barrier,所以为2

# job process
with tf.Session() as sess:
# init
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

# training process
epoch_num = 0
barrier_op.barrier(sess, self.task_index, 0, epoch_num) # 等所有worker都启动再开始训练
while epoch_num < max_epoch:
sess.run(train_iterator.initializer) # 每轮开始先初始化数据
while True:
try:
sess.run(train_op)
except tf.errors.OutOfRangeError:
break
barrier_op.barrier(sess, self.task_index, 1, epoch_num) # 等所有worker结束这轮训练
#保存这一轮的checkpoint
epoch_num += 1 # 进入下一轮