背景

TensorFlow大部分使用方案都是异步SGD,而且使用global_step做停止条件,这样会导致几个问题:

  • 不能保证所有worker负责的数据训练相同的轮数,速度快的worker所负责的数据将会获得更多step。
  • chief worker结束时会保存模型参数,但还存在其他worker没结束,所以不能等所有worker结束后才保存参数。

最优的方式应该是这样:

1
2
3
4
5
while epoch < max_epoch:
train_one_epoch
barrier
epoch++
save checkpoint

那怎样实现barrier机制呢?

Barrier机制实现

首先不能使用QueueRunner读数据,因为它无法实现按轮次读取,需要使用DataSet来读取数据。然后,可以在PS:0节点上添加一个变量counter_var,每当worker结束一轮训练后,将counter_var增加1,接着等待sess.run(counter_var) % num_worker == 0条件成立后才进入下一轮训练。具体代码为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def _parse_function(example_proto):
features = {}
features['label'] = tf.FixedLenFeature([], tf.float32)
features['feature'] = tf.FixedLenFeature([100], tf.int64)
instance = tf.parse_example(example_proto, features)
label = instance['label']
feature = instance['feature']
return label, feature

if job_name == 'ps':
with tf.device('/cpu:0'):
server.join()
elif job_name == 'worker':
with tf.device(param_server_device):

# dataset input
dataset = tf.data.TFRecordDataset(file_name)
dataset = dataset.prefetch(buffer_size=batch_size*100)
dataset = dataset.shuffle(buffer_size=batch_size*10)
dataset = dataset.batch(batch_size)
dataset = dataset.map(_parse_function, num_parallel_calls=4)
train_iterator = dataset.make_initializable_iterator()
train_label, train_feature = train_iterator.get_next()

# forward pass
model = ...
train_logits = model.forward(train_feature)

# loss
train_label = tf.to_int64(train_label)
train_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=train_logits, labels=train_label)
train_loss = tf.reduce_mean(train_cross_entropy, name='loss')

# optimizer
opt = tf.train.AdamOptimizer()
train_op = opt.minimize(train_loss)

# barrier counter
ps_device = '/job:ps/task:0/cpu:0'
with tf.device(ps_device):
counter_var = tf.get_variable(
'counter_barrier', (), tf.int32, initializer=tf.zeros_initializer)
counter_add_op = counter_var.assign_add(1, use_locking=True)

# job process
with tf.Session() as sess:
# barrier function
def barrier():
sess.run(counter_add_op)
while sess.run(counter_var) % self.num_worker != 0:
time.sleep(1)

# init
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

# training process
epoch_num = 0
while epoch_num < max_epoch:
epoch_num += 1
sess.run(train_iterator.initializer)
barrier() # wait for all workers be ready

while True:
try:
sess.run(train_op)
except tf.errors.OutOfRangeError:
break