推荐系统时长建模的常用方案
背景
介绍几种工业界比较常用的时长建模方案。
Weighted Logloss
Weighted Logloss是YouTube论文1提出的一种方法。将时长
实际中一般将loss中分母的
线上预估时,需要将预估的
时长归一化
长视频的消费时长往往更高,所以Weighted Logloss中长视频的比重更大,更倾向于优化长视频的效果。所以,为了消除视频本身长短带来的偏差,可以先将训练样本按video duration分桶,比如0-1分钟的一档,1-5分钟的一档等等,然后每档内部按消费时长分段,最后将消费时长归一化到0到1内的值,用这个归一化后的值做label进行训练。
最优解
这里如果用Weighted Logloss是否可以呢?
那么,
另外,时长归一化在后面的几种方法里都可以叠加使用。
分桶归一化的时长变换逻辑: class TimeTrans(object):
def __init__(self):
self.videoDurationTable = [0.0, 17.03, 23.7, 31.33, 36.07, 42.25, 49.86, 61.5, 85.72, 142.83]
self.timePercentileTable = [
[0.0, 3.24, 3.49, 3.76, 4.04, 4.34, 4.66, 4.98, 5.33, 5.69, 6.06, 6.45, 6.84, 7.24, 7.65, 8.06, 8.48, 8.9, 9.32, 9.73, 10.15, 10.57, 10.98, 11.35, 11.71, 12.04, 12.39, 12.74, 13.09, 13.44, 13.79, 14.13, 14.47, 14.82, 15.17, 15.53, 15.89, 16.22, 16.55, 16.87, 17.21, 17.56, 17.94, 18.48, 19.34, 20.54, 22.06, 24.17, 27.17, 31.48],
[0.0, 3.25, 3.53, 3.83, 4.15, 4.5, 4.88, 5.28, 5.72, 6.19, 6.68, 7.19, 7.73, 8.29, 8.88, 9.47, 10.08, 10.71, 11.33, 11.94, 12.57, 13.22, 13.86, 14.48, 15.08, 15.69, 16.3, 16.89, 17.43, 17.93, 18.31, 18.65, 18.98, 19.31, 19.66, 20.03, 20.42, 20.82, 21.25, 21.7, 22.15, 22.63, 23.16, 23.75, 24.43, 25.39, 27.0, 29.48, 33.24, 39.08],
[0.0, 3.3, 3.64, 4.0, 4.4, 4.85, 5.35, 5.89, 6.46, 7.08, 7.74, 8.48, 9.31, 10.16, 11.06, 12.05, 13.05, 14.06, 15.08, 16.08, 17.08, 18.13, 19.2, 20.27, 21.32, 22.27, 23.15, 23.93, 24.62, 25.19, 25.71, 26.22, 26.73, 27.25, 27.75, 28.25, 28.74, 29.27, 29.8, 30.35, 30.86, 31.29, 31.68, 32.05, 32.51, 33.34, 34.77, 37.09, 41.32, 49.85],
[0.0, 3.31, 3.63, 3.99, 4.38, 4.83, 5.32, 5.85, 6.44, 7.13, 7.86, 8.66, 9.53, 10.49, 11.53, 12.65, 13.85, 15.11, 16.41, 17.68, 18.86, 20.01, 21.27, 22.61, 23.98, 25.42, 26.86, 28.26, 29.44, 30.49, 31.38, 32.06, 32.53, 32.9, 33.24, 33.57, 33.9, 34.24, 34.59, 34.95, 35.31, 35.67, 36.07, 36.49, 36.94, 37.58, 38.74, 40.76, 44.55, 53.51],
[0.0, 3.3, 3.63, 3.99, 4.4, 4.85, 5.34, 5.91, 6.55, 7.26, 8.04, 8.93, 9.89, 10.95, 12.1, 13.36, 14.71, 16.13, 17.65, 19.26, 20.74, 22.08, 23.42, 24.88, 26.45, 28.06, 29.69, 31.36, 32.91, 34.3, 35.47, 36.44, 37.17, 37.66, 38.06, 38.45, 38.86, 39.26, 39.67, 40.1, 40.55, 41.01, 41.45, 41.92, 42.45, 43.07, 43.99, 45.91, 49.7, 58.93],
[0.0, 3.31, 3.66, 4.05, 4.48, 4.95, 5.49, 6.09, 6.76, 7.51, 8.31, 9.19, 10.19, 11.33, 12.55, 13.9, 15.38, 17.02, 18.79, 20.53, 22.37, 24.08, 25.67, 27.27, 29.08, 30.94, 32.92, 34.98, 37.05, 38.98, 40.67, 42.04, 43.06, 43.69, 44.22, 44.73, 45.23, 45.74, 46.23, 46.68, 47.15, 47.65, 48.15, 48.67, 49.26, 49.92, 50.66, 52.08, 55.46, 64.44],
[0.0, 3.32, 3.68, 4.08, 4.53, 5.04, 5.61, 6.26, 7.0, 7.81, 8.73, 9.77, 10.95, 12.23, 13.64, 15.17, 16.85, 18.65, 20.67, 22.78, 25.1, 27.37, 29.4, 31.43, 33.41, 35.59, 37.93, 40.3, 42.77, 45.24, 47.51, 49.34, 50.78, 51.69, 52.44, 53.15, 53.83, 54.53, 55.23, 55.96, 56.68, 57.43, 58.21, 59.03, 59.93, 60.88, 61.85, 63.03, 66.14, 75.23],
[0.0, 3.3, 3.64, 4.01, 4.43, 4.9, 5.43, 6.03, 6.71, 7.45, 8.3, 9.25, 10.34, 11.59, 12.99, 14.53, 16.27, 18.22, 20.3, 22.63, 25.02, 27.66, 30.41, 33.04, 35.62, 38.13, 40.73, 43.38, 46.17, 49.16, 52.42, 55.62, 58.59, 61.22, 63.15, 64.39, 65.57, 66.68, 67.79, 69.01, 70.32, 71.94, 73.74, 75.51, 77.27, 79.06, 81.03, 83.48, 86.07, 93.3],
[0.0, 3.28, 3.6, 3.97, 4.38, 4.85, 5.38, 5.98, 6.64, 7.42, 8.29, 9.28, 10.41, 11.71, 13.21, 14.9, 16.82, 19.05, 21.56, 24.21, 27.18, 30.42, 33.93, 37.56, 41.35, 45.3, 49.08, 52.89, 56.73, 60.71, 64.86, 69.1, 73.47, 78.19, 82.89, 86.82, 89.24, 91.37, 93.9, 96.49, 99.29, 102.51, 106.25, 109.97, 114.23, 119.12, 123.95, 129.03, 135.51, 142.95],
[0.0, 3.3, 3.64, 4.03, 4.47, 4.97, 5.54, 6.19, 6.96, 7.85, 8.85, 10.02, 11.35, 12.91, 14.73, 16.85, 19.3, 22.08, 25.26, 28.72, 32.59, 36.7, 41.39, 46.35, 51.64, 57.39, 63.42, 69.59, 76.3, 83.1, 90.2, 97.96, 106.55, 115.28, 124.49, 134.38, 144.24, 150.26, 157.35, 165.06, 173.25, 182.13, 192.08, 202.93, 215.55, 229.49, 243.75, 266.22, 293.19, 338.96]
]
self.maxReadTime = 3 # 阅读时长不能超过视频时长的3倍
def getIdx(self, percentileList, score):
if score >= percentileList[-1]:
return len(percentileList) - 1
left = 0
right = len(percentileList)
while left <= right:
mid = (left + right) // 2
if percentileList[mid] < score:
left = mid + 1
elif percentileList[mid] > score:
right = mid - 1
else:
return mid
return right
def encode(self, videoTime, readTime):
if videoTime <= 0 or readTime <= 0:
videoTime = 1e-8
readTime = 1e-8
elif readTime > self.maxReadTime * videoTime:
readTime = self.maxReadTime * videoTime
durationBuk = self.getIdx(self.videoDurationTable, videoTime)
timePercentileList = self.timePercentileTable[durationBuk]
bukIdx = self.getIdx(timePercentileList, readTime)
if bukIdx < 0:
return 0.0
elif bukIdx == len(timePercentileList) - 1:
lowerLabel = 1.0 / len(timePercentileList) * bukIdx
upperLabel = 1.0
label = lowerLabel + (readTime - timePercentileList[bukIdx]) * (upperLabel - lowerLabel) / (self.maxReadTime * videoTime - timePercentileList[bukIdx])
else:
lowerLabel = 1.0 / len(timePercentileList) * bukIdx
upperLabel = 1.0 / len(timePercentileList) * (bukIdx + 1)
label = lowerLabel + (readTime - timePercentileList[bukIdx]) * (upperLabel - lowerLabel) / (timePercentileList[bukIdx+1] - timePercentileList[bukIdx])
return label
def decode(self, videoTime, pScore):
if videoTime <= 0 or pScore <= 0:
return 0.0
elif pScore >= 1:
return self.maxReadTime * videoTime
durationBuk = self.getIdx(self.videoDurationTable, videoTime)
timePercentileList = self.timePercentileTable[durationBuk]
bukIdx = int(pScore * len(timePercentileList))
if bukIdx == len(timePercentileList) - 1:
lowerLabel = 1.0 / len(timePercentileList) * bukIdx
upperLabel = 1.0
if self.maxReadTime * videoTime > timePercentileList[bukIdx]:
readTime = timePercentileList[bukIdx] + (pScore - lowerLabel) * (self.maxReadTime * videoTime - timePercentileList[bukIdx]) / (upperLabel - lowerLabel)
else:
readTime = self.maxReadTime * videoTime
else:
lowerLabel = 1.0 / len(timePercentileList) * bukIdx
upperLabel = 1.0 / len(timePercentileList) * (bukIdx + 1)
readTime = timePercentileList[bukIdx] + (pScore - lowerLabel) * (timePercentileList[bukIdx+1] - timePercentileList[bukIdx]) / (upperLabel - lowerLabel)
if readTime > self.maxReadTime * videoTime:
readTime = self.maxReadTime * videoTime
return readTime
# 测试
timeTrans = TimeTrans()
videoTime = 16
readTime = 8
timeLabel = timeTrans.encode(videoTime, readTime)
print(timeLabel)
readTimeRevocery = timeTrans.decode(videoTime, timeLabel)
print(readTimeRevocery)
CREAD
CREAD是快手的论文2中提出的方法。将消费时长从小到大排序,比如是
实际上,训练时除了m个分类任务,还增加了两个其他损失,一个是还原时长与真实时长的偏差(采用Huber
loss),另一个是
CREAD的时长变换逻辑: class TimeTrans(object):
def __init__(self):
self.videoDurationTable = [0.0, 17.03, 23.7, 31.33, 36.07, 42.25, 49.86, 61.5, 85.72, 142.83]
self.timePercentileTable = [
[0.0, 3.24, 3.49, 3.76, 4.04, 4.34, 4.66, 4.98, 5.33, 5.69, 6.06, 6.45, 6.84, 7.24, 7.65, 8.06, 8.48, 8.9, 9.32, 9.73, 10.15, 10.57, 10.98, 11.35, 11.71, 12.04, 12.39, 12.74, 13.09, 13.44, 13.79, 14.13, 14.47, 14.82, 15.17, 15.53, 15.89, 16.22, 16.55, 16.87, 17.21, 17.56, 17.94, 18.48, 19.34, 20.54, 22.06, 24.17, 27.17, 31.48],
[0.0, 3.25, 3.53, 3.83, 4.15, 4.5, 4.88, 5.28, 5.72, 6.19, 6.68, 7.19, 7.73, 8.29, 8.88, 9.47, 10.08, 10.71, 11.33, 11.94, 12.57, 13.22, 13.86, 14.48, 15.08, 15.69, 16.3, 16.89, 17.43, 17.93, 18.31, 18.65, 18.98, 19.31, 19.66, 20.03, 20.42, 20.82, 21.25, 21.7, 22.15, 22.63, 23.16, 23.75, 24.43, 25.39, 27.0, 29.48, 33.24, 39.08],
[0.0, 3.3, 3.64, 4.0, 4.4, 4.85, 5.35, 5.89, 6.46, 7.08, 7.74, 8.48, 9.31, 10.16, 11.06, 12.05, 13.05, 14.06, 15.08, 16.08, 17.08, 18.13, 19.2, 20.27, 21.32, 22.27, 23.15, 23.93, 24.62, 25.19, 25.71, 26.22, 26.73, 27.25, 27.75, 28.25, 28.74, 29.27, 29.8, 30.35, 30.86, 31.29, 31.68, 32.05, 32.51, 33.34, 34.77, 37.09, 41.32, 49.85],
[0.0, 3.31, 3.63, 3.99, 4.38, 4.83, 5.32, 5.85, 6.44, 7.13, 7.86, 8.66, 9.53, 10.49, 11.53, 12.65, 13.85, 15.11, 16.41, 17.68, 18.86, 20.01, 21.27, 22.61, 23.98, 25.42, 26.86, 28.26, 29.44, 30.49, 31.38, 32.06, 32.53, 32.9, 33.24, 33.57, 33.9, 34.24, 34.59, 34.95, 35.31, 35.67, 36.07, 36.49, 36.94, 37.58, 38.74, 40.76, 44.55, 53.51],
[0.0, 3.3, 3.63, 3.99, 4.4, 4.85, 5.34, 5.91, 6.55, 7.26, 8.04, 8.93, 9.89, 10.95, 12.1, 13.36, 14.71, 16.13, 17.65, 19.26, 20.74, 22.08, 23.42, 24.88, 26.45, 28.06, 29.69, 31.36, 32.91, 34.3, 35.47, 36.44, 37.17, 37.66, 38.06, 38.45, 38.86, 39.26, 39.67, 40.1, 40.55, 41.01, 41.45, 41.92, 42.45, 43.07, 43.99, 45.91, 49.7, 58.93],
[0.0, 3.31, 3.66, 4.05, 4.48, 4.95, 5.49, 6.09, 6.76, 7.51, 8.31, 9.19, 10.19, 11.33, 12.55, 13.9, 15.38, 17.02, 18.79, 20.53, 22.37, 24.08, 25.67, 27.27, 29.08, 30.94, 32.92, 34.98, 37.05, 38.98, 40.67, 42.04, 43.06, 43.69, 44.22, 44.73, 45.23, 45.74, 46.23, 46.68, 47.15, 47.65, 48.15, 48.67, 49.26, 49.92, 50.66, 52.08, 55.46, 64.44],
[0.0, 3.32, 3.68, 4.08, 4.53, 5.04, 5.61, 6.26, 7.0, 7.81, 8.73, 9.77, 10.95, 12.23, 13.64, 15.17, 16.85, 18.65, 20.67, 22.78, 25.1, 27.37, 29.4, 31.43, 33.41, 35.59, 37.93, 40.3, 42.77, 45.24, 47.51, 49.34, 50.78, 51.69, 52.44, 53.15, 53.83, 54.53, 55.23, 55.96, 56.68, 57.43, 58.21, 59.03, 59.93, 60.88, 61.85, 63.03, 66.14, 75.23],
[0.0, 3.3, 3.64, 4.01, 4.43, 4.9, 5.43, 6.03, 6.71, 7.45, 8.3, 9.25, 10.34, 11.59, 12.99, 14.53, 16.27, 18.22, 20.3, 22.63, 25.02, 27.66, 30.41, 33.04, 35.62, 38.13, 40.73, 43.38, 46.17, 49.16, 52.42, 55.62, 58.59, 61.22, 63.15, 64.39, 65.57, 66.68, 67.79, 69.01, 70.32, 71.94, 73.74, 75.51, 77.27, 79.06, 81.03, 83.48, 86.07, 93.3],
[0.0, 3.28, 3.6, 3.97, 4.38, 4.85, 5.38, 5.98, 6.64, 7.42, 8.29, 9.28, 10.41, 11.71, 13.21, 14.9, 16.82, 19.05, 21.56, 24.21, 27.18, 30.42, 33.93, 37.56, 41.35, 45.3, 49.08, 52.89, 56.73, 60.71, 64.86, 69.1, 73.47, 78.19, 82.89, 86.82, 89.24, 91.37, 93.9, 96.49, 99.29, 102.51, 106.25, 109.97, 114.23, 119.12, 123.95, 129.03, 135.51, 142.95],
[0.0, 3.3, 3.64, 4.03, 4.47, 4.97, 5.54, 6.19, 6.96, 7.85, 8.85, 10.02, 11.35, 12.91, 14.73, 16.85, 19.3, 22.08, 25.26, 28.72, 32.59, 36.7, 41.39, 46.35, 51.64, 57.39, 63.42, 69.59, 76.3, 83.1, 90.2, 97.96, 106.55, 115.28, 124.49, 134.38, 144.24, 150.26, 157.35, 165.06, 173.25, 182.13, 192.08, 202.93, 215.55, 229.49, 243.75, 266.22, 293.19, 338.96]
]
self.maxReadTime = 3
def getIdx(self, percentileList, score):
if score >= percentileList[-1]:
return len(percentileList) - 1
left = 0
right = len(percentileList)
while left <= right:
mid = (left + right) // 2
if percentileList[mid] < score:
left = mid + 1
elif percentileList[mid] > score:
right = mid - 1
else:
return mid
return right
def encode(self, videoTime, readTime):
if videoTime <= 0 or readTime <= 0:
videoTime = 1e-8
readTime = 1e-8
elif readTime > self.maxReadTime * videoTime:
readTime = self.maxReadTime * videoTime
durationBuk = self.getIdx(self.videoDurationTable, videoTime)
timePercentileList = self.timePercentileTable[durationBuk]
label = [1.0 if readTime > thr else 0.0 for thr in timePercentileList]
return label
def decode(self, videoTime, pScore):
if videoTime <= 0:
return 0.0
durationBuk = self.getIdx(self.videoDurationTable, videoTime)
timePercentileList = self.timePercentileTable[durationBuk]
readTime = timePercentileList[0]
for i in range(1, len(timePercentileList)):
readTime += (timePercentileList[i] - timePercentileList[i-1]) * pScore[i]
return readTime
# 测试
timeTrans = TimeTrans()
videoTime = 16
readTime = 8
timeLabel = timeTrans.encode(videoTime, readTime)
print(timeLabel)
readTimeRevocery = timeTrans.decode(videoTime, timeLabel)
print(readTimeRevocery)
Earth Mover’s Distance
传统的多分类任务,label中只有一个是正类,其他是负类,类别之前没有关系。但是在某些任务中,类别之间是有一定关系的。下图是收入预估的一个例子,adult是真实label,AB两种分布的loss是一样的,但明显B比A更合理一点。在时长预估上也有同样的道理,把时长分成10个区间,真实label在区间5,做多分类任务,我们希望5区间的预估概率最高,向左向右的概率平滑的降低。论文3使用的是Squared EMD
Loss:
论文3不是做时长预估的,我们这里为时长预估任务把损失改成分类的EMD
Loss。 和CREAD类似,将消费时长从小到大排序,比如是
线上预估时,可以通过以下公式还原成时长,和CREAD很像。
EMD和CREAD的实现其实本质一样,只是出发角度不同。
EMD的时长变换逻辑: class TimeTrans(object):
def __init__(self):
self.videoDurationTable = [0.0, 17.03, 23.7, 31.33, 36.07, 42.25, 49.86, 61.5, 85.72, 142.83]
self.timePercentileTable = [
[0.0, 3.24, 3.49, 3.76, 4.04, 4.34, 4.66, 4.98, 5.33, 5.69, 6.06, 6.45, 6.84, 7.24, 7.65, 8.06, 8.48, 8.9, 9.32, 9.73, 10.15, 10.57, 10.98, 11.35, 11.71, 12.04, 12.39, 12.74, 13.09, 13.44, 13.79, 14.13, 14.47, 14.82, 15.17, 15.53, 15.89, 16.22, 16.55, 16.87, 17.21, 17.56, 17.94, 18.48, 19.34, 20.54, 22.06, 24.17, 27.17, 31.48],
[0.0, 3.25, 3.53, 3.83, 4.15, 4.5, 4.88, 5.28, 5.72, 6.19, 6.68, 7.19, 7.73, 8.29, 8.88, 9.47, 10.08, 10.71, 11.33, 11.94, 12.57, 13.22, 13.86, 14.48, 15.08, 15.69, 16.3, 16.89, 17.43, 17.93, 18.31, 18.65, 18.98, 19.31, 19.66, 20.03, 20.42, 20.82, 21.25, 21.7, 22.15, 22.63, 23.16, 23.75, 24.43, 25.39, 27.0, 29.48, 33.24, 39.08],
[0.0, 3.3, 3.64, 4.0, 4.4, 4.85, 5.35, 5.89, 6.46, 7.08, 7.74, 8.48, 9.31, 10.16, 11.06, 12.05, 13.05, 14.06, 15.08, 16.08, 17.08, 18.13, 19.2, 20.27, 21.32, 22.27, 23.15, 23.93, 24.62, 25.19, 25.71, 26.22, 26.73, 27.25, 27.75, 28.25, 28.74, 29.27, 29.8, 30.35, 30.86, 31.29, 31.68, 32.05, 32.51, 33.34, 34.77, 37.09, 41.32, 49.85],
[0.0, 3.31, 3.63, 3.99, 4.38, 4.83, 5.32, 5.85, 6.44, 7.13, 7.86, 8.66, 9.53, 10.49, 11.53, 12.65, 13.85, 15.11, 16.41, 17.68, 18.86, 20.01, 21.27, 22.61, 23.98, 25.42, 26.86, 28.26, 29.44, 30.49, 31.38, 32.06, 32.53, 32.9, 33.24, 33.57, 33.9, 34.24, 34.59, 34.95, 35.31, 35.67, 36.07, 36.49, 36.94, 37.58, 38.74, 40.76, 44.55, 53.51],
[0.0, 3.3, 3.63, 3.99, 4.4, 4.85, 5.34, 5.91, 6.55, 7.26, 8.04, 8.93, 9.89, 10.95, 12.1, 13.36, 14.71, 16.13, 17.65, 19.26, 20.74, 22.08, 23.42, 24.88, 26.45, 28.06, 29.69, 31.36, 32.91, 34.3, 35.47, 36.44, 37.17, 37.66, 38.06, 38.45, 38.86, 39.26, 39.67, 40.1, 40.55, 41.01, 41.45, 41.92, 42.45, 43.07, 43.99, 45.91, 49.7, 58.93],
[0.0, 3.31, 3.66, 4.05, 4.48, 4.95, 5.49, 6.09, 6.76, 7.51, 8.31, 9.19, 10.19, 11.33, 12.55, 13.9, 15.38, 17.02, 18.79, 20.53, 22.37, 24.08, 25.67, 27.27, 29.08, 30.94, 32.92, 34.98, 37.05, 38.98, 40.67, 42.04, 43.06, 43.69, 44.22, 44.73, 45.23, 45.74, 46.23, 46.68, 47.15, 47.65, 48.15, 48.67, 49.26, 49.92, 50.66, 52.08, 55.46, 64.44],
[0.0, 3.32, 3.68, 4.08, 4.53, 5.04, 5.61, 6.26, 7.0, 7.81, 8.73, 9.77, 10.95, 12.23, 13.64, 15.17, 16.85, 18.65, 20.67, 22.78, 25.1, 27.37, 29.4, 31.43, 33.41, 35.59, 37.93, 40.3, 42.77, 45.24, 47.51, 49.34, 50.78, 51.69, 52.44, 53.15, 53.83, 54.53, 55.23, 55.96, 56.68, 57.43, 58.21, 59.03, 59.93, 60.88, 61.85, 63.03, 66.14, 75.23],
[0.0, 3.3, 3.64, 4.01, 4.43, 4.9, 5.43, 6.03, 6.71, 7.45, 8.3, 9.25, 10.34, 11.59, 12.99, 14.53, 16.27, 18.22, 20.3, 22.63, 25.02, 27.66, 30.41, 33.04, 35.62, 38.13, 40.73, 43.38, 46.17, 49.16, 52.42, 55.62, 58.59, 61.22, 63.15, 64.39, 65.57, 66.68, 67.79, 69.01, 70.32, 71.94, 73.74, 75.51, 77.27, 79.06, 81.03, 83.48, 86.07, 93.3],
[0.0, 3.28, 3.6, 3.97, 4.38, 4.85, 5.38, 5.98, 6.64, 7.42, 8.29, 9.28, 10.41, 11.71, 13.21, 14.9, 16.82, 19.05, 21.56, 24.21, 27.18, 30.42, 33.93, 37.56, 41.35, 45.3, 49.08, 52.89, 56.73, 60.71, 64.86, 69.1, 73.47, 78.19, 82.89, 86.82, 89.24, 91.37, 93.9, 96.49, 99.29, 102.51, 106.25, 109.97, 114.23, 119.12, 123.95, 129.03, 135.51, 142.95],
[0.0, 3.3, 3.64, 4.03, 4.47, 4.97, 5.54, 6.19, 6.96, 7.85, 8.85, 10.02, 11.35, 12.91, 14.73, 16.85, 19.3, 22.08, 25.26, 28.72, 32.59, 36.7, 41.39, 46.35, 51.64, 57.39, 63.42, 69.59, 76.3, 83.1, 90.2, 97.96, 106.55, 115.28, 124.49, 134.38, 144.24, 150.26, 157.35, 165.06, 173.25, 182.13, 192.08, 202.93, 215.55, 229.49, 243.75, 266.22, 293.19, 338.96]
]
self.maxReadTime = 3 # 阅读时长不能超过视频时长的3倍
def getIdx(self, percentileList, score):
if score >= percentileList[-1]:
return len(percentileList) - 1
left = 0
right = len(percentileList)
while left <= right:
mid = (left + right) // 2
if percentileList[mid] < score:
left = mid + 1
elif percentileList[mid] > score:
right = mid - 1
else:
return mid
return right
def cdf(self, scores):
result = []
cur_sum = 0.0
for score in scores:
cur_sum += score
result.append(cur_sum)
return result
def encode(self, videoTime, readTime):
if videoTime <= 0 or readTime <= 0:
videoTime = 1e-8
readTime = 1e-8
elif readTime > self.maxReadTime * videoTime:
readTime = self.maxReadTime * videoTime
durationBuk = self.getIdx(self.videoDurationTable, videoTime)
timePercentileList = self.timePercentileTable[durationBuk]
bukIdx = self.getIdx(timePercentileList, readTime)
label = [1.0 if i == bukIdx else 0.0 for i in range(len(timePercentileList))]
return label
def decode(self, videoTime, pScore):
if videoTime <= 0:
return 0.0
durationBuk = self.getIdx(self.videoDurationTable, videoTime)
timePercentileList = self.timePercentileTable[durationBuk]
cdfResult = self.cdf(pScore)
readTime = timePercentileList[0]
for i in range(1, len(timePercentileList)):
readTime += (timePercentileList[i] - timePercentileList[i-1]) * (1 - cdfResult[i-1])
if self.maxReadTime * videoTime > timePercentileList[-1]:
readTime += (self.maxReadTime * videoTime - timePercentileList[-1]) * pScore[-1]
if readTime > self.maxReadTime * videoTime:
readTime = self.maxReadTime * videoTime
return readTime
# 测试
timeTrans = TimeTrans()
videoTime = 16
readTime = 8
timeLabel = timeTrans.encode(videoTime, readTime)
print(timeLabel)
readTimeRevocery = timeTrans.decode(videoTime, timeLabel)
print(readTimeRevocery)
Distill Softmax
和EMD有点类似,只不过EMD从预估值分布角度考虑,而Distill Softmax是从label角度考虑,直接将时长做成了平滑的多分类soft-label。比如,把时长分成10个区间,真实label在区间5,本来的多分类label是[0,0,0,0,1,0,0,0,0,0],这里对时长做了平滑,希望5区间的预估概率最高,向左向右的概率平滑的降低,做成[0.001,0.005,0.03,0.1,0.7,0.1,0.05,0.008,0.005,0.001]。
具体做法,时长先做下laplace分布变换,得到每个区间的概率值,越靠近真实label的区间概率越大。
损失函数就是多分类交叉熵:
线上预估时,可以通过以下公式还原成时长,和EMD一样。
Distill Softmax也能和EMD一起用,既先做soft-label变换后,再使用EMD Loss。
Distill Softmax的时长变换逻辑: class TimeTrans(object):
def __init__(self):
self.videoDurationTable = [0.0, 17.03, 23.7, 31.33, 36.07, 42.25, 49.86, 61.5, 85.72, 142.83]
self.timePercentileTable = [
[0.0, 3.24, 3.49, 3.76, 4.04, 4.34, 4.66, 4.98, 5.33, 5.69, 6.06, 6.45, 6.84, 7.24, 7.65, 8.06, 8.48, 8.9, 9.32, 9.73, 10.15, 10.57, 10.98, 11.35, 11.71, 12.04, 12.39, 12.74, 13.09, 13.44, 13.79, 14.13, 14.47, 14.82, 15.17, 15.53, 15.89, 16.22, 16.55, 16.87, 17.21, 17.56, 17.94, 18.48, 19.34, 20.54, 22.06, 24.17, 27.17, 31.48],
[0.0, 3.25, 3.53, 3.83, 4.15, 4.5, 4.88, 5.28, 5.72, 6.19, 6.68, 7.19, 7.73, 8.29, 8.88, 9.47, 10.08, 10.71, 11.33, 11.94, 12.57, 13.22, 13.86, 14.48, 15.08, 15.69, 16.3, 16.89, 17.43, 17.93, 18.31, 18.65, 18.98, 19.31, 19.66, 20.03, 20.42, 20.82, 21.25, 21.7, 22.15, 22.63, 23.16, 23.75, 24.43, 25.39, 27.0, 29.48, 33.24, 39.08],
[0.0, 3.3, 3.64, 4.0, 4.4, 4.85, 5.35, 5.89, 6.46, 7.08, 7.74, 8.48, 9.31, 10.16, 11.06, 12.05, 13.05, 14.06, 15.08, 16.08, 17.08, 18.13, 19.2, 20.27, 21.32, 22.27, 23.15, 23.93, 24.62, 25.19, 25.71, 26.22, 26.73, 27.25, 27.75, 28.25, 28.74, 29.27, 29.8, 30.35, 30.86, 31.29, 31.68, 32.05, 32.51, 33.34, 34.77, 37.09, 41.32, 49.85],
[0.0, 3.31, 3.63, 3.99, 4.38, 4.83, 5.32, 5.85, 6.44, 7.13, 7.86, 8.66, 9.53, 10.49, 11.53, 12.65, 13.85, 15.11, 16.41, 17.68, 18.86, 20.01, 21.27, 22.61, 23.98, 25.42, 26.86, 28.26, 29.44, 30.49, 31.38, 32.06, 32.53, 32.9, 33.24, 33.57, 33.9, 34.24, 34.59, 34.95, 35.31, 35.67, 36.07, 36.49, 36.94, 37.58, 38.74, 40.76, 44.55, 53.51],
[0.0, 3.3, 3.63, 3.99, 4.4, 4.85, 5.34, 5.91, 6.55, 7.26, 8.04, 8.93, 9.89, 10.95, 12.1, 13.36, 14.71, 16.13, 17.65, 19.26, 20.74, 22.08, 23.42, 24.88, 26.45, 28.06, 29.69, 31.36, 32.91, 34.3, 35.47, 36.44, 37.17, 37.66, 38.06, 38.45, 38.86, 39.26, 39.67, 40.1, 40.55, 41.01, 41.45, 41.92, 42.45, 43.07, 43.99, 45.91, 49.7, 58.93],
[0.0, 3.31, 3.66, 4.05, 4.48, 4.95, 5.49, 6.09, 6.76, 7.51, 8.31, 9.19, 10.19, 11.33, 12.55, 13.9, 15.38, 17.02, 18.79, 20.53, 22.37, 24.08, 25.67, 27.27, 29.08, 30.94, 32.92, 34.98, 37.05, 38.98, 40.67, 42.04, 43.06, 43.69, 44.22, 44.73, 45.23, 45.74, 46.23, 46.68, 47.15, 47.65, 48.15, 48.67, 49.26, 49.92, 50.66, 52.08, 55.46, 64.44],
[0.0, 3.32, 3.68, 4.08, 4.53, 5.04, 5.61, 6.26, 7.0, 7.81, 8.73, 9.77, 10.95, 12.23, 13.64, 15.17, 16.85, 18.65, 20.67, 22.78, 25.1, 27.37, 29.4, 31.43, 33.41, 35.59, 37.93, 40.3, 42.77, 45.24, 47.51, 49.34, 50.78, 51.69, 52.44, 53.15, 53.83, 54.53, 55.23, 55.96, 56.68, 57.43, 58.21, 59.03, 59.93, 60.88, 61.85, 63.03, 66.14, 75.23],
[0.0, 3.3, 3.64, 4.01, 4.43, 4.9, 5.43, 6.03, 6.71, 7.45, 8.3, 9.25, 10.34, 11.59, 12.99, 14.53, 16.27, 18.22, 20.3, 22.63, 25.02, 27.66, 30.41, 33.04, 35.62, 38.13, 40.73, 43.38, 46.17, 49.16, 52.42, 55.62, 58.59, 61.22, 63.15, 64.39, 65.57, 66.68, 67.79, 69.01, 70.32, 71.94, 73.74, 75.51, 77.27, 79.06, 81.03, 83.48, 86.07, 93.3],
[0.0, 3.28, 3.6, 3.97, 4.38, 4.85, 5.38, 5.98, 6.64, 7.42, 8.29, 9.28, 10.41, 11.71, 13.21, 14.9, 16.82, 19.05, 21.56, 24.21, 27.18, 30.42, 33.93, 37.56, 41.35, 45.3, 49.08, 52.89, 56.73, 60.71, 64.86, 69.1, 73.47, 78.19, 82.89, 86.82, 89.24, 91.37, 93.9, 96.49, 99.29, 102.51, 106.25, 109.97, 114.23, 119.12, 123.95, 129.03, 135.51, 142.95],
[0.0, 3.3, 3.64, 4.03, 4.47, 4.97, 5.54, 6.19, 6.96, 7.85, 8.85, 10.02, 11.35, 12.91, 14.73, 16.85, 19.3, 22.08, 25.26, 28.72, 32.59, 36.7, 41.39, 46.35, 51.64, 57.39, 63.42, 69.59, 76.3, 83.1, 90.2, 97.96, 106.55, 115.28, 124.49, 134.38, 144.24, 150.26, 157.35, 165.06, 173.25, 182.13, 192.08, 202.93, 215.55, 229.49, 243.75, 266.22, 293.19, 338.96]
]
self.maxReadTime = 3 # 阅读时长不能超过视频时长的3倍
def getIdx(self, percentileList, score):
if score >= percentileList[-1]:
return len(percentileList) - 1
left = 0
right = len(percentileList)
while left <= right:
mid = (left + right) // 2
if percentileList[mid] < score:
left = mid + 1
elif percentileList[mid] > score:
right = mid - 1
else:
return mid
return right
def cdf(self, scores):
result = []
curSum = 0.0
for score in scores:
curSum += score
result.append(curSum)
return result
def laplacePdf(self, x, mean, variance=1.0):
return 1.0 / (2.0 * variance) * np.exp(-1.0 * np.abs(np.array(x) - np.array(mean)) / variance) + 1e-7
def distill(self, timeLabel, labelPercentileList):
probs = self.laplacePdf(timeLabel, labelPercentileList)
probsSum = np.sum(probs)
softLabel = probs / probsSum
return softLabel
def encode(self, videoTime, readTime):
if videoTime <= 0 or readTime <= 0:
videoTime = 1e-8
readTime = 1e-8
elif readTime > self.maxReadTime * videoTime:
readTime = self.maxReadTime * videoTime
durationBuk = self.getIdx(self.videoDurationTable, videoTime)
timePercentileList = self.timePercentileTable[durationBuk]
label = self.distill(readTime, timePercentileList)
return label
def decode(self, videoTime, pScore):
if videoTime <= 0:
return 0.0
durationBuk = self.getIdx(self.videoDurationTable, videoTime)
timePercentileList = self.timePercentileTable[durationBuk]
cdfResult = self.cdf(pScore)
readTime = timePercentileList[0]
for i in range(1, len(timePercentileList)):
readTime += (timePercentileList[i] - timePercentileList[i-1]) * (1 - cdfResult[i-1])
if self.maxReadTime * videoTime > timePercentileList[-1]:
readTime += (self.maxReadTime * videoTime - timePercentileList[-1]) * pScore[-1]
if readTime > self.maxReadTime * videoTime:
readTime = self.maxReadTime * videoTime
return readTime
# 测试
timeTrans = TimeTrans()
videoTime = 16
readTime = 8
timeLabel = timeTrans.encode(videoTime, readTime)
print(timeLabel)
readTimeRevocery = timeTrans.decode(videoTime, timeLabel)
print(readTimeRevocery)
参考
- [1][Recommending What Video to Watch Next: A Multitask Ranking
System](https://dl.acm.org/doi/10.1145/3298689.3346997)
- [2][CREAD: A Classification-Restoration Framework with Error
Adaptive Discretization for Watch Time Prediction in Video Recommender
Systems](http://arxiv.org/abs/2401.07521)
- [3][Squared Earth Mover’s Distance-based Loss for Training Deep Neural Networks](http://arxiv.org/abs/1611.05916)