背景

ChatGPT出现后,已经有许多开源项目尝试复现其效果,包括LLaMa、DeepSpeed-Chat、ColossalChat、ChatGLM等。其中DeepSpeed-Chat是微软Deep Speed团队的开源项目,其完整的提供了Supervised Fine-tuning、Reward Model Training、RLHF PPO Traing三阶段的代码,逻辑简单,模块划分清晰,另外也由于Deep Speed在大模型训练中的使用非常普遍,所以笔者近期正在研究DeepSpeed-Chat的代码。之前博客中已经介绍了全部三阶段的训练实战情况:

本文以DeepSpeed-Chat的实现为例,详细介绍下RLHF——基于人类反馈的强化学习策略,并与经典Off-Policy Actor-Critic策略做对比。

符号定义

Off-Policy Advantage Actor-Critic标准范式

Off-Policy策略,在老版本参数的模型下做出动作选择和环境交互,放入样本池,在后面训练过程中可以重复利用。样本格式可以为:

标准Advantage Actor-Critic强化学习中,和Critic的是可训练模型。

其中Actor模型的Loss为:

DeepSpeed-Chat强化学习策略

DeepSpeed-Chat和ColossalChat强化学习部分的策略借鉴了TRLX开源项目。从InstructGPT论文和一些开源复现中,可以推测出ChatGPT对于step和episode的定义。每次预估下一个token是一个step,完成一个完整response是一个episode。

Reward设计

每个episode获得一个收益,由Reward Model预估得到,Reward Model相当于强化学习中的环境。并且,所有step共享episode的reward。 Reward除了Reward Model预估值外,增加了当前Actor模型与SFT模型的KL散度,保证Actor模型不要改变的太远。因为Off-Policy理论中,采样模型和最新模型接近时才有效果保障,否则需要非常多的采样样本,因此这里增加KL保障是符合理论要求的。不过这里的KL计算逻辑和严格数学定义也不太一样。

def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, action_mask):
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
rewards = kl_divergence_estimate
start = prompts.shape[1] - 1 # 状态s_1在prompt最后一个token,动作a_1表示预测response的第一个token
ends = start + action_mask[:, start:].sum(1)
reward_clip = torch.clamp(reward_score, -self.clip_reward_value, self.clip_reward_value)
batch_size = log_probs.shape[0]
for j in range(batch_size):
rewards[j, start:ends[j]][-1] += reward_clip[j] # 在最后一个token加reward_score
return rewards

Advantage设计

在标准的Advantage Actor-Critic策略中,。与此不同,ChatGPT的reward加在了最后一个token,因此每一步依赖下一步,可以看到计算adv时是从后向前遍历,reward从后向前传。

def get_advantages_and_returns(self, values, rewards, start):
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
for t in reversed(range(start, length)): # 反向计算
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1) # 再反转
returns = advantages + values[:, start:] # adv(t) + value(t+1)更合理些
return advantages.detach(), returns

Actor Model

Actor模型以SFT模型初始化,其损失函数设计与标准Actor-Critic有个不同点,是PPO2策略,整体loss对原始PPO和PPO2进行了结合。 另外,代码中没有直接计算两个概率值相除,而是使用对数指数变换,应该是数值稳定性考虑。

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss

Critic Model

Critic模型以Reward Model初始化,其损失函数为模型预估值与回报的差别,为平方损失的回归任务。这里比较奇怪的一点是为何如此设计,感觉更合理一些。

def critic_loss_fn(self, values, old_values, returns, mask):
values_clipped = torch.clamp(values, old_values - self.cliprange_value, old_values + self.cliprange_value)
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
return vf_loss

DeepSpeed-Chat强化学习训练逻辑

训练逻辑是Off-Policy策略,外层循环读取prompt数据生成prompt+response数据放入样本池,内层循环从样本池中读取prompt+response数据进行Actor Model和Critic Model的训练。性能上,SFT模型放到CPU上,Actor模型通过DeepSpeed Hybrid Engine支持训练和推理两种模式的高效切换。 另外,Instruct论文中在Actor Loss中增加了一个SFT Loss和一个Unsupervised Loss,两个Loss也加到之前的Actor Loss上。

最终的Actor Loss为:

其中,SFT Loss部分保证和Actor模型和SFT模型偏离不远,Unsupervised Loss部分增加了一个自回归任务,整体Loss计算梯度做模型更新,而DeepSpeed-Chat只使用了Unsupervised,没有增加SFT部分(在reward计算时使用了,间接引入),并且先用Actor Loss更新,再用Unsupervised Loss更新。Actor模型参数都采用了Exponential Moving Averages策略。

def generate_experience(self, prompts, mask):
self.eval()
seq = self._generate_sequence(prompts, mask) # [batch_size, prompt_response_max_len]
self.train()

pad_token_id = self.tokenizer.pad_token_id
attention_mask = seq.not_equal(pad_token_id).long()

with torch.no_grad():
output = self.actor_model(seq, attention_mask=attention_mask)
output_ref = self.ref_model(seq, attention_mask=attention_mask)
reward_score = self.reward_model.forward_value(seq, attention_mask, prompt_length=self.prompt_length)['chosen_end_scores'].detach()
values = self.critic_model.forward_value(seq, attention_mask, return_value_only=True).detach()[:, :-1]

logits = output.logits
logits_ref = output_ref.logits

return {
'prompts': prompts, # [batch_size, prompt_max_len]
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]), # [batch_size, prompt_response_max_len]
'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:, 1:]), # [batch_size, prompt_response_max_len]
'value': values, # [batch_size, prompt_response_max_len]
'rewards': reward_score, # [batch_size, 1]
'input_ids': seq, # [batch_size, prompt_response_max_len]
"attention_mask": attention_mask # [batch_size, prompt_response_max_len]
}

def train_rlhf(self, inputs):
### process the old outputs
prompts = inputs['prompts']
log_probs = inputs['logprobs']
ref_log_probs = inputs['ref_logprobs']
reward_score = inputs['rewards']
values = inputs['value']
attention_mask = inputs['attention_mask']
seq = inputs['input_ids']

start = prompts.size()[-1] - 1
action_mask = attention_mask[:, 1:]

old_values = values
with torch.no_grad():
old_rewards = self.compute_rewards(prompts, log_probs, ref_log_probs, reward_score, action_mask)
advantages, returns = self.get_advantages_and_returns(old_values, old_rewards, start)

### process the new outputs
batch = {'input_ids': seq, "attention_mask": attention_mask}
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :], seq[:, 1:])
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:], log_probs[:, start:], advantages, action_mask[:, start:])
self.actor_model.backward(actor_loss)
self.actor_model.step()
value = self.critic_model.forward_value(**batch, return_value_only=True, use_cache=False)[:, :-1]
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:, start:], returns, action_mask[:, start:])
self.critic_model.backward(critic_loss)
self.critic_model.step()

return actor_loss, critic_loss

参考