机器不学习:Contextual Bandit算法在推荐系统中的实现及应用
机器不学习 www.jqbxx.com : 深度聚合机器学习、深度学习算法及技术实战
推荐系统选择商品展现给用户,并期待用户的正向反馈(点击、成交)。然而推荐系统并不能提前知道用户在观察到商品之后如何反馈,也就是不能提前获得本次推荐的收益,唯一能做的就是不停地尝试,并实时收集反馈以便更新自己试错的策略。目的是使得整个过程损失的收益最小。这一过程就类似与一个赌徒在赌场里玩老虎机赌博。赌徒要去摇老虎机,走进赌场一看,一排老虎机,外表一模一样,但是每个老虎机吐钱的概率可不一样,他不知道每个老虎机吐钱的概率分布是什么,那么每次该选择哪个老虎机可以做到最大化收益呢?这就是多臂赌博机问题(Multi-armed bandit problem, MAB)。
MAB问题的难点是Exploitation-Exploration(E&E)两难的问题:对已知的吐钱概率比较高的老虎机,应该更多的去尝试(exploitation),以便获得一定的累计收益;对未知的或尝试次数较少的老虎机,还要分配一定的尝试机会(exploration),以免错失收益更高的选择,但同时较多的探索也意味着较高的风险(机会成本)。
Bandit算法是一类用来实现Exploitation-Exploration机制的策略。根据是否考虑上下文特征,Bandit算法分为context-free bandit和contextual bandit两大类。
1. UCB
Context-free Bandit算法有很多种,比如ϵ−greedy、softmax、Thompson Sampling、UCB(Upper Confidence Bound)等。
在此,重点介绍一下UCB方法的基本思想。在统计学中,对于一个未知量的估计,总能找到一种量化其置信度的方法。最普遍的分布正态分布(或曰高斯分布)N(μ,δ),其中的μ就是估计量的期望,而δ
则表示其不确定性(δ越大则表示越不可信)。比如你掷一个标准的6面色子,它的平均值是3.5,而如果你只掷一次,比如说到2,那你对平均值的估计只能是2,但是这个置信度应该很低,我们可以知道,这个色子的预估平均值是2,而以95%的置信区间在[1.4,5.2]。
UCB(Upper Confidence Bound - 置信上限)就是以收益(bonus)均值的置信区间上限代表对该arm未来收益的预估值:
其中μi^是对arm i期望收益的预估,n是总的选择次数,ni是对arm i的尝试次数,可以看到尝试越多,其预估值与置信上限的差值就越小,也就是越有置信度。
UCB在此时的决策是选择置信区间上界最大的arm。这个策略的好处是,能让没有机会尝试的arm得到更多尝试的机会,是骡子是马拉出来溜溜!
- 对于未知或较少尝试的arm,尽管其均值可能很低,但是由于其不确定性会导致置信区间的上界较大,从而有较大的概率触发exploration
- 对于已经很熟悉的arm(尝试过较多次),更多的是触发exploitation机制:如果其均值很高,会获得更多的利用机会;反之,则会减少对其尝试的机会
2. LinUCB
在推荐系统中,通常把待推荐的商品作为MAB问题的arm。UCB这样的context-free类算法,没有充分利用推荐场景的上下文信息,为所有用户的选择展现商品的策略都是相同的,忽略了用户作为一个个活生生的个性本身的兴趣点、偏好、购买力等因素都是不同的,因而,同一个商品在不同的用户、不同的情景下接受程度是不同的。故在实际的推荐系统中,context-free的MAB算法基本都不会被采用。
与context-free MAB算法对应的是Contextual Bandit算法,顾名思义,这类算法在实现E&E时考虑了上下文信息,因而更加适合实际的个性化推荐场景。
形式化地说,在时间步t,contextual-bandit算法观察到当前用户ut,以及每个可选择的商品(arm)a的特征向量Xt,a, Xt,a称之为上下文信息,它概况了用户和商品两方面的信息。算法根据之前观察到的反馈结果选择一个商品at展现给用户,并接受到用户的反馈收益rt,at的期望取决于用户和商品两个方面。接着,算法根据新的观察(xt,a,at,rt,at)改进自身选择商品展现的策略,目标是使得整个过程中损失的收益最小,即regret值RA(T)最小。RA(T)的定义如下:
其中,T为实验的总步数;at∗为在时间步t时有最大期望收益的arm,不能提前得知。
LinUCB是处理Contextual Bandit的一个方法,在LinUCB中,设定每个arm的期望收益为该arm的特征向量(context)的线性函数,如下:
θa是LinUCB模型的参数,维度为d。每个arm维护一个θa.对于单个arm a,以其前m个context向量为行向量组成的矩阵称为Da,维度为m×n。前m个收益(reward)组成的向量称为ca。采用平方损失函数
,其中λ为正则项系数。求损失函数的最小值,令损失函数对θa求导,结果为
令∇θaloss=0,λ=1,可得
使用岭回归(ridge regression)方法,可以得到θa的概率分布为高斯分布:
为了符号简洁,令
于是θa的概率分布可表示为θa∼N(θ^a,Aa−1)
于是在第t次时可以得到
根据高斯分布的性质,得到置信上界后就可以使用普通UCB规则了,即每次选择
最大的arm。
需要注意的是,Aa与DaTca可以增量异步更新,于是标准流程如下:
LinUCB算法的优势:
- 计算复杂度与arm的数量成线性关系
- 支持动态变化的候选arm集合
3. 业务场景介绍
在我们的电商平台App首页,有一个商品瀑布流推荐场景,每次大概展示30个商品左右。商品候选集都是运营人工精选的历史销售情况较好,在更多流量刺激下有可能成为爆款的商品,并且每天都会汰换掉一部分,加入一些新品进来。
用过实现LinUcb算法,系统会对每个商品做充分的exploration和exploitation,从而发掘出真正有销售潜力的商品,逐渐淘汰掉不够理想的商品,纠正运营人工选品的局限。经过考验的商品,说明在一段时间内销量还是不错的,这些商品运营可以深度控价,要求商家提供更多的优惠和让利给用户,从而形成良性循环,同时也给其他商家树立标杆,促进平台更加健康地发展。
4. 系统架构
在我们的系统中,LinUCB算法的实现分为两个部分:一部分实现在推荐引擎里,主要完成特征向量提取、获取每个商品的Aa矩阵和ba向量数据、完成置信区间上届的计算、并选择最终展现的商品;另一部分逻辑实现在实时计算平台Storm上,这部分任务实时解析场景的曝光、点击和购买行为日志,每条日志里包含了商品ID、时间戳和特征向量等信息,根据公式更新每个商品的Aa矩阵和ba向量,并把更新后的结果写到Redis缓存里,供推荐引擎获取。
5. 核心代码逻辑
每个商品都维护一个队列queue,用于临时存放接受到的行为事件。之所以不在接受到事情的时候立马处理掉,主要有两个原因:一是用户对每个展现商品的反馈并不是原子操作,比如反馈是曝光未点击,或是曝光且点击但未购买,还是即曝光又点击最后还购买了,这些操作不是一步就能完成的,而是保护了好几步,因此在仅仅收到曝光事情或者点击事情时,我们不知道用户的反馈序列操作有没有结束,也就不能准确设置反馈的收益值;二是由于客户端对曝光事情有缓存不是实时上报的,因而不能保证曝光事情一定在点击事情之前被上传到日志服务器。
具体实现时,设置一个时间间隔(比如,5分钟),一个曝光或者点击事情到达Storm计算节点时,先缓存在队列里,等待设置好的时间间隔后,再消费掉。如果在时间间隔内高优先级的事情到达,则会直接移除掉低优先级的事情。优先级顺序为购买大于点击、点击大于曝光。另外,为了防止日志重复上报,还会根据推荐引擎的请求ID(不同请求不同)对接受到是事情去重。当然,为了容错,如果预设的事情队列满了,则会动态把时间间隔缩短为原来的二分之一,把新的时间间隔之前的事情强制消费掉,这个过程会递归执行,直到队列的长度小于预设的大小为止。
Storm任务中,某个商品的参数(Aa矩阵和ba向量)更新逻辑:
public int consume(Event newEvent, Config config) {
if (newEvent.getPid() != pid)
return -1;
reset(config);
int showConsumedCount = 0;
int clickConsumedCount = 0;
long current = System.currentTimeMillis();
Iterator<Event> iter = queue.iterator();
while (iter.hasNext()) {
Event event = iter.next();
if (event.isSameFlow(newEvent)) {
if (newEvent.isShow()) {
newEvent = null; // 重复曝光,或者点击事件在曝光事件前到达
break;
}
if (newEvent.isClick() && event.isClick()) {
newEvent = null; // 重复点击事件
break;
}
iter.remove(); // 高优先级的事情会覆盖低优先级的事情
if (null == newEvent.getItemTrackData()) {
newEvent.setItemTrackData(event.getItemTrackData());
}
continue;
}
if (current - event.getTimestamp() < config.cacheTimeSpan)
continue;
INDArray xt = event.getFeature(config);
INDArray x = xt.transpose();
A.addi(x.mmul(xt));
double reward = event.getReward();
if (reward > 0.0)
b.addi(x.muli(reward));
iter.remove();
if (event.isShow())
showConsumedCount ++;
else if (event.isClick())
clickConsumedCount ++;
}
if (null == newEvent)
return 0;
if (newEvent.getType().equals(EventType.BUY)) {
// 购买事情立即消费
INDArray xt = newEvent.getFeature(config);
if (null == xt) {
LogPushUtil.push("BuyFeatureMissingEvent");
}
else if (xt.length() == b.length()) {
INDArray x = xt.transpose();
A.addi(x.mmul(xt));
double reward = newEvent.getReward();
b.addi(x.muli(reward));
LogPushUtil.push("BuyEventConsumed");
records += config.updateRecords;
}
} else if (StringUtils.isNotEmpty(newEvent.getItemTrackData())) {
queue.add(newEvent);
checkQueueIsFullOrNot(config.cacheTimeSpan / 2, config);
}
if (showConsumedCount > 0) {
records += showConsumedCount;
LogPushUtil.push("ShowEventConsumed", showConsumedCount);
}
if (clickConsumedCount > 0) {
records += clickConsumedCount;
LogPushUtil.push("ClickEventConsumed", clickConsumedCount);
}
if (records >= config.updateRecords) {
writeToRedis(config);
}
return records;
}
在线推荐引擎计算每个商品的ucb分数的代码如下:
private void parallelGetLinUcbScore(Session session, INDArray features, List<Item> items) throws InterruptedException {
int concurrency = Math.min(session.config.linUcbConcurrency, items.size());
int num = items.size() / concurrency;
ExecutorService es = session.context.getThreadExecutorService(session.reqId, session.reqId.getSceneId());
final CountDownLatch countDownLatch = new CountDownLatch(concurrency);
int left = items.size() - num * concurrency;
int start = 0, end = num;
for (int i = 0; i < concurrency; ++i) {
if (i < left)
end++;
Logger.debug("prepare to start linucb score sub thread: [" + start + ", " + end + "]");
int finalStart = start;
int finalEnd = end;
Runnable runnable = () -> {
long startTime = System.currentTimeMillis();
try {
getLinUcbScore(session, features, items, finalStart, finalEnd);
} catch (Exception e) {
String exception = ExceptionUtils.getFullStackTrace(e);
Logger.error("linucb worker exception:", exception);
}
finally {
countDownLatch.countDown();
}
long time = System.currentTimeMillis() - startTime;
Logger.info("[Timer] compute sub linUcb scores [", finalStart, ", ", finalEnd, "] taken", time, "ms");
};
es.execute(TtlRunnable.get(runnable));
start = end;
end = start + num;
}
countDownLatch.await(session.config.linUcbTimeOut, TimeUnit.MILLISECONDS);
}
private void getLinUcbScore(Session session, INDArray features, List<Item> items, int start, int end) throws IOException {
if (items.isEmpty() || null == features)
return;
StopWatch stopWatch = StopWatch.CreateStopWatchAndStart();
for (int i = start; i < end; ++i) {
stopWatch.restart();
Item item = items.get(i);
long pid = item.getProductId();
byte[] matrix = getLinUcbMatrix(session, pid);
INDArray feature = features.getRow(i);
INDArray featureT = feature.transpose();
Logger.detail("<", i, ">", pid, "item feature:", item.getTrackInfo(), "vector:", feature);
if (matrix == null) {
double p = session.config.linUcbAlpha * FastMath.sqrt(feature.mmul(featureT).getDouble(0));
item.setMatchScore(p);
Logger.info("<", i, "> linucb new item:", pid, "ucb score:", p);
continue;
}
ByteArrayInputStream in = new ByteArrayInputStream(matrix);
INDArray[] Ab = ByteSerde.fromByteArrayStream(in);
if (Ab.length < 2)
continue;
INDArray invertA = Ab[0];
INDArray b = Ab[1];
Logger.detail("<", i, ">", pid, "[Timer] fetch matrix taken", stopWatch.click(), "ms");
if (b.length() != feature.length()) {
Logger.error("the length of b is", b.length(), "should be", feature.length());
double p = session.config.linUcbAlpha * FastMath.sqrt(feature.mmuli(featureT).getFloat(0));
item.setMatchScore(p);
continue;
}
INDArray theta = invertA.mmul(b);
INDArray thetaX = theta.transposei().mmul(featureT);
INDArray temp = feature.mmuli(invertA).mmul(featureT);
double p = thetaX.getDouble(0);
item.setPreferScore(p);
p += session.config.linUcbAlpha * FastMath.sqrt(temp.getDouble(0));
item.setMatchScore(p);
Logger.detail("<", i, ">", pid, "[Timer] one iteration linucb score taken", stopWatch.click(), "ms. match score:", p);
}
}
实现中矩阵计算的部分,用了ND4J(https://nd4j.org/)的库。
6. 业务效果
经过线上充分的A/B测试,最终测得LinUCB算法的UV CTR相对基准桶提升25%+,UV价值提升20%+。并且算法能够很好地支持商品动态上下架。
7. 参考资料
[Lihong Li, et al, 2010] A Contextual-Bandit Approach to Personalized News Article Recommendation.
转自:https://yangxudong.github.io/linucb/