理解Snowflake算法的实现原理
前提
Snowflake (雪花)是 Twitter 开源的高性能 ID 生成算法(服务)。
上图是 Snowflake 的 Github 仓库, master 分支中的 REAEMDE 文件中提示:初始版本于 2010 年发布,基于 Apache Thrift ,早于 Finagle (这里的 Finagle 是 Twitter 上用于 RPC 服务的构建模块)发布,而 Twitter 内部使用的 Snowflake 是一个完全重写的程序,在很大程度上依靠 Twitter 上的现有基础架构来运行。
而 2010 年发布的初版 Snowflake 源码是使用 Scala 语言编写的,归档于 scala_28 分支。换言之, 大家目前使用的 Snowflake 算法原版或者改良版已经是十年前(当前是 2020 年)的产物,不得不说这个算法确实比较厉害 。 scala_28 分支中有介绍该算法的动机和要求,这里简单摘录一下:
动机:
- Cassandra 中没有生成顺序 ID 的工具, Twitter 由使用 MySQL 转向使用 Cassandra 的时候需要一种新的方式来生成 ID (印证了架构不是设计出来,而是基于业务场景迭代出来)。
要求:
- 高性能:每秒每个进程至少产生 10K 个 ID ,加上网络延迟响应速度要在 2ms 内。
- 顺序性:具备按照时间的自增趋势,可以直接排序。
- 紧凑性:保持生成的 ID 的长度在 64 bit 或更短。
- 高可用: ID 生成方案需要和存储服务一样高可用。
- 下面就 Snowflake 的源码分析一下他的实现原理。
Snowflake方案简述
Snowflake 在初版设计方案是:
- 时间: 41 bit 长度,使用毫秒级别精度,带有一个自定义 epoch ,那么可以使用大概 69 年。
- 可配置的机器 ID : 10 bit 长度,可以满足 1024 个机器使用。
- 序列号: 12 bit 长度,可以在 4096 个数字中随机取值,从而避免单个机器在 1 ms 内生成重复的序列号。
但是在实际源码实现中, Snowflake 把 10 bit 的可配置的机器 ID 拆分为 5 bit 的 Worker ID (这个可以理解为原来的机器 ID )和 5 bit 的 Data Center ID (数据中心 ID ),详情见 IdWorker.scala :
也就是说,支持配置最多 32 个机器 ID 和最多 32 个数据中心 ID :
由于算法是 Scala 语言编写,是依赖于 JVM 的语言,返回的 ID 值为 Long 类型,也就是 64 bit 的整数,原来的算法生成序列中只使用了 63 bit 的长度,要返回的是无符号数,所以在高位补一个 0 (占用 1 bit ),那么加起来整个 ID 的长度就是 64 bit :
其中:
- 41 bit 毫秒级别时间戳的取值范围是: [0, 2^41 - 1] => 0 ~ 2199023255551 ,一共 2199023255552 个数字。
- 5 bit 机器 ID 的取值范围是: [0, 2^5 - 1] => 0 ~ 31 ,一共 32 个数字。
- 5 bit 数据中心 ID 的取值范围是: [0, 2^5 - 1] => 0 ~ 31 ,一共 32 个数字。
- 12 bit 序列号的取值范围是: [0, 2^12 - 1] => 0 ~ 4095 ,一共 4096 个数字。
那么理论上可以生成 2199023255552 * 32 * 32 * 4096 个完全不同的 ID 值。
Snowflake 算法还有一个明显的特征: 依赖于系统时钟 。 41 bit 长度毫秒级别的时间来源于系统时间戳,所以必须保证系统时间是向前递进,不能发生 时钟回拨 (通说来说就是不能在同一个时刻产生多个相同的时间戳或者产生了过去的时间戳)。一旦发生时钟回拨, Snowflake 会拒绝生成下一个 ID 。
位运算知识补充
Snowflake 算法中使用了大量的位运算。由于整数的补码才是在计算机中的存储形式, Java 或者 Scala 中的整型都使用补码表示,这里稍微提一下原码和补码的知识。
- 原码用于阅读,补码用于计算。
- 正数的补码与其原码相同。
- 负数的补码是除最高位其他所有位取反,然后加 1 (反码加 1 ),而负数的补码还原为原码也是使用这个方式。
- +0 的原码是 0000 0000 ,而 -0 的原码是 1000 0000 ,补码只有一个 0 值,用 0000 0000 表示,这一点很重要,补码的 0 没有二义性。
简单来看就是这样:
* [+ 11] 原码 = [0000 1011] 补码 = [0000 1011] * [- 11] 原码 = [1000 1011] 补码 = [1111 0101] * [- 11]的补码计算过程: 原码 1000 1011 除了最高位其他位取反 1111 0100 加1 1111 0101 (补码)
使用原码、反码在计算的时候得到的不一定是准确的值,而使用补码的时候计算结果才是正确的,记住这个结论即可,这里不在举例。由于 Snowflake 的 ID 生成方案中,除了最高位,其他四个部分都是无符号整数,所以四个部分的整数 使用补码进行位运算的效率会比较高,也只有这样才能满足Snowflake高性能设计的初衷 。 Snowflake 算法中使用了几种位运算:异或( ^ )、按位与( & )、按位或( | )和带符号左移( << )。
异或
异或的运算规则是: 0^0=0 0^1=1 1^0=1 1^1=0 ,也就是位不同则结果为1,位相同则结果为0。主要作用是:
- 特定位翻转,也就是一个数和 N 个位都为 1 的数进行异或操作,这对应的 N 个位都会翻转,例如 0100 & 1111 ,结果就是 1011 。
- 与 0 项异或,则结果和原来的值一致。
- 两数的值交互: a=a^b b=b^a a=a^b ,这三个操作完成之后, a 和 b 的值完成交换。
这里推演一下最后一条:
* [+ 11] 原码 = [0000 1011] 补码 = [0000 1011] a * [- 11] 原码 = [1000 1011] 补码 = [1111 0101] b a=a^b 0000 1011 1111 0101 ---------^ 1111 1110 b=b^a 1111 0101 ---------^ 0000 1011 (十进制数:11) b a=a^b 1111 1110 ---------^ 1111 0101 (十进制数:-11) a
按位与
按位与的运算规则是: 0^0=0 0^1=0 1^0=0 1^1=1 ,只有对应的位都为1的时候计算结果才是1,其他情况的计算结果都是0。主要作用是:
- 清零,如果想把一个数清零,那么和所有位为 0 的数进行按位与即可。
- 取一个数中的指定位,例如要取 X 中的低 4 位,只需要和 zzzz...1111 进行按位与即可,例如取 1111 0110 的低 4 位,则 11110110 & 00001111 即可得到 00000110 。
按位或
按位与的运算规则是: 0^0=0 0^1=1 1^0=1 1^1=1 ,只要有其中一个位存在1则计算结果是1,只有两个位同时为0的情况下计算结果才是0。主要作用是:
- 对一个数的部分位赋值为 1 ,只需要和对应位全为 0 的数做按位或操作就行,例如 1011 0000 如果低 4 位想全部赋值为 1 ,那么 10110000 | 00001111 即可得到 1011 1111 。
带符号左移
带符号左移的运算符是 << ,一般格式是: M << n 。作用如下:
- M 的二进制数(补码)向左移动 n 位。
- 左边(高位)移出部分直接舍弃,右边(低位)移入部分全部补 0 。
- 移位结果:相当于 M 的值乘以 2 的 n 次方,并且0、正、负数通用。
- 移动的位数超过了该类型的最大位数,那么编译器会对移动的位数取模,例如 int 移位 33 位,实际上只移动了 33 % 2 = 1 位。
推演过程如下(假设 n = 2 ):
* [+ 11] 原码 = [0000 1011] 补码 = [0000 1011] * [- 11] 原码 = [1000 1011] 补码 = [1111 0101] * [+ 11 << 2]的计算过程 补码 0000 1011 左移2位 0000 1011 舍高补低 0010 1100 十进制数 2^2 + 2^3 + 2^5 = 44 * [- 11 << 2]的计算过程 补码 1111 0101 左移2位 1111 0101 舍高补低 1101 0100 原码 1010 1100 (补码除最高位其他所有位取反再加1) 十进制数 - (2^2 + 2^3 + 2^5) = -44
可以写个 main 方法验证一下:
public static void main(String[] args) { System.out.println(-11 << 2); // -44 System.out.println(11 << 2); // 44 }
组合技巧
利用上面提到的三个位运算符,相互组合可以实现一些高效的计算方案。
计算n个bit能表示的最大数值:
Snowflake 算法中有这样的代码:
// 机器ID的位长度 private val workerIdBits = 5L; // 最大机器ID -> 31 private val maxWorkerId = -1L ^ (-1L << workerIdBits);
这里的算子是 -1L ^ (-1L << 5L) ,整理运算符的顺序,再使用 64 bit 的二进制数推演计算过程如下:
* [-1] 的补码 11111111 11111111 11111111 11111111 11111111 11111111 11111111 11111111 左移5位 11111111 11111111 11111111 11111111 11111111 11111111 11111111 11100000 [-1] 的补码 11111111 11111111 11111111 11111111 11111111 11111111 11111111 11111111 异或 ----------------------------------------------------------------------- ^ 结果的补码 00000000 00000000 00000000 00000000 00000000 00000000 00000000 00011111 (十进制数 2^0 + 2^1 + 2^2 + 2^3 + 2^4 = 31)
这样就能计算出 5 bit 能表示的最大数值 n , n 为整数并且 0 <= n <= 31 ,即 0、1、2、3...31 。 Worker ID 和 Data Center ID 部分的最大值就是使用这种组合运算得出的。
用固定位的最大值作为Mask避免溢出:
Snowflake 算法中有这样的代码:
var sequence = 0L ...... private val sequenceBits = 12L // 这里得到的是sequence的最大值4095 private val sequenceMask = -1L ^ (-1L << sequenceBits) ...... sequence = (sequence + 1) & sequenceMask
最后这个算子其实就是 sequence = (sequence + 1) & 4095 ,假设 sequence 当前值为 4095 ,推演一下计算过程:
* [4095] 的补码 00000000 00000000 00000000 00000000 00000000 00000000 00000111 11111111 [sequence + 1] 的补码 00000000 00000000 00000000 00000000 00000000 00000000 00001000 00000000 按位与 ----------------------------------------------------------------------- & 计算结果 00000000 00000000 00000000 00000000 00000000 00000000 00000000 00000000 (十进制数:0)
可以编写一个 main 方法验证一下:
public static void main(String[] args) { int mask = 4095; System.out.println(0 & mask); // 0 System.out.println(1 & mask); // 1 System.out.println(2 & mask); // 2 System.out.println(4095 & mask); // 4095 System.out.println(4096 & mask); // 0 System.out.println(4097 & mask); // 1 }
也就是 x = (x + 1) & (-1L ^ (-1L << N)) 能保证最终得到的 x 值不会超过 N ,这是利用了按位与中的"取指定位"的特性。
Snowflake算法实现源码分析
Snowflake 虽然用 Scala 语言编写,语法其实和 Java 差不多,当成 Java 代码这样阅读就行,下面阅读代码的时候会跳过一些日志记录和度量统计的逻辑。先看 IdWorker.scala 的属性值:
// 定义基准纪元值,这个值是北京时间2010-11-04 09:42:54,估计就是2010年初版提交代码时候定义的一个时间戳 val twepoch = 1288834974657L // 初始化序列号为0 var sequence = 0L //TODO after 2.8 make this a constructor param with a default of 0 // 机器ID的最大位长度为5 private val workerIdBits = 5L // 数据中心ID的最大位长度为5 private val datacenterIdBits = 5L // 最大的机器ID值,十进制数为为31 private val maxWorkerId = -1L ^ (-1L << workerIdBits) // 最大的数据中心ID值,十进制数为为31 private val maxDatacenterId = -1L ^ (-1L << datacenterIdBits) // 序列号的最大位长度为12 private val sequenceBits = 12L // 机器ID需要左移的位数12 private val workerIdShift = sequenceBits // 数据中心ID需要左移的位数 = 12 + 5 private val datacenterIdShift = sequenceBits + workerIdBits // 时间戳需要左移的位数 = 12 + 5 + 5 private val timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits // 序列号的掩码,十进制数为4095 private val sequenceMask = -1L ^ (-1L << sequenceBits) // 初始化上一个时间戳快照值为-1 private var lastTimestamp = -1L // 下面的代码块为参数校验和初始化日志打印,这里不做分析 if (workerId > maxWorkerId || workerId < 0) { exceptionCounter.incr(1) throw new IllegalArgumentException("worker Id can't be greater than %d or less than 0".format(maxWorkerId)) } if (datacenterId > maxDatacenterId || datacenterId < 0) { exceptionCounter.incr(1) throw new IllegalArgumentException("datacenter Id can't be greater than %d or less than 0".format(maxDatacenterId)) } log.info("worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d", timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId)
接着看算法的核心代码逻辑:
// 同步方法,其实就是protected synchronized long nextId(){ ...... } protected[snowflake] def nextId(): Long = synchronized { // 获取系统时间戳(毫秒) var timestamp = timeGen() // 高并发场景,同一毫秒内生成多个ID if (lastTimestamp == timestamp) { // 确保sequence + 1之后不会溢出,最大值为4095,其实也就是保证1毫秒内最多生成4096个ID值 sequence = (sequence + 1) & sequenceMask // 如果sequence溢出则变为0,说明1毫秒内并发生成的ID数量超过了4096个,这个时候同1毫秒的第4097个生成的ID必须等待下一毫秒 if (sequence == 0) { // 死循环等待下一个毫秒值,直到比lastTimestamp大 timestamp = tilNextMillis(lastTimestamp) } } else { // 低并发场景,不同毫秒中生成ID // 不同毫秒的情况下,由于外层方法保证了timestamp大于或者小于lastTimestamp,而小于的情况是发生了时钟回拨,下面会抛出异常,所以不用考虑 // 也就是只需要考虑一种情况:timestamp > lastTimestamp,也就是当前生成的ID所在的毫秒数比上一个ID大 // 所以如果时间戳部分增大,可以确定整数值一定变大,所以序列号其实可以不用计算,这里直接赋值为0 sequence = 0 } // 获取到的时间戳比上一个保存的时间戳小,说明时钟回拨,这种情况下直接抛出异常,拒绝生成ID // 个人认为,这个方法应该可以提前到var timestamp = timeGen()这段代码之后 if (timestamp < lastTimestamp) { exceptionCounter.incr(1) log.error("clock is moving backwards. Rejecting requests until %d.", lastTimestamp); throw new InvalidSystemClock("Clock moved backwards. Refusing to generate id for %d milliseconds".format(lastTimestamp - timestamp)); } // lastTimestamp保存当前时间戳,作为方法下次被调用的上一个时间戳的快照 lastTimestamp = timestamp // 度量统计,生成的ID计数器加1 genCounter.incr() // X = (系统时间戳 - 自定义的纪元值) 然后左移22位 // Y = (数据中心ID左移17位) // Z = (机器ID左移12位) // 最后ID = X | Y | Z | 计算出来的序列号sequence ((timestamp - twepoch) << timestampLeftShift) | (datacenterId << datacenterIdShift) | (workerId << workerIdShift) | sequence } // 辅助方法:获取系统当前的时间戳(毫秒) protected def timeGen(): Long = System.currentTimeMillis() // 辅助方法:获取系统当前的时间戳(毫秒),用死循环保证比传入的lastTimestamp大,也就是获取下一个比lastTimestamp大的毫秒数 protected def tilNextMillis(lastTimestamp: Long): Long = { var timestamp = timeGen() while (timestamp <= lastTimestamp) { timestamp = timeGen() } timestamp }
最后一段逻辑的位操作比较多,但是如果熟练使用位运算操作符,其实逻辑并不复杂,这里可以画个图推演一下:
四个部分的整数完成左移之后,由于空缺的低位都会补充了 0 ,基于按位或的特性,所有低位只要存在 1 ,那么对应的位就会填充为 1 ,由于四个部分的位不会越界分配,所以这里的本质就是: 四个部分左移完毕后最终的数字进行加法计算 。
Snowflake算法改良
Snowflake 算法有几个比较大的问题:
- 低并发场景会产生连续偶数,原因是低并发场景系统时钟总是走到下一个毫秒值,导致序列号重置为 0 。
- 依赖系统时钟,时钟回拨会拒绝生成新的 ID (直接抛出异常)。
- Woker ID 和 Data Center ID 的管理比较麻烦,特别是同一个服务的不同集群节点需要保证每个节点的 Woker ID 和 Data Center ID 组合唯一。
这三个问题美团开源的 Leaf 提供了解决思路,下图截取自 com.sankuai.inf.leaf.snowflake.SnowflakeIDGenImpl :
对应的解决思路是(不进行深入的源码分析,有兴趣可以阅读以下 Leaf 的源码):
- 序列号生成添加随机源,会稍微减少同一个毫秒内能产生的最大 ID 数量。
- 时钟回拨则进行一定期限的等待。
- 使用 Zookeeper 缓存和管理 Woker ID 和 Data Center ID 。
Woker ID 和 Data Center ID 的配置是极其重要的,对于同一个服务(例如支付服务)集群的多个节点,必须配置不同的机器 ID 和数据中心 ID 或者同样的数据中心 ID 和不同的机器 ID ( 简单说就是确保 Woker ID 和 Data Center ID 的组合全局唯一 ),否则在高并发的场景下,在系统时钟一致的情况下,很容易在多个节点产生相同的 ID 值,所以一般的部署架构如下:
管理这两个 ID 的方式有很多种,或者像 Leaf 这样的开源框架引入分布式缓存进行管理,再如笔者所在的创业小团队生产服务比较少,直接把 Woker ID 和 Data Center ID 硬编码在服务启动脚本中,然后把所有服务使用的 Woker ID 和 Data Center ID 统一登记在团队内部知识库中。
自实现简化版Snowflake
如果完全不考虑性能的话,也不考虑时钟回拨、序列号生成等等问题,其实可以把 Snowflake 的位运算和异常处理部分全部去掉,使用 Long.toBinaryString() 方法结合字符串按照 Snowflake 算法思路拼接出 64 bit 的二进制数,再通过 Long.parseLong() 方法转化为 Long 类型。编写一个 main 方法如下:
public class Main { private static final String HIGH = "0"; /** * 2020-08-01 00:00:00 */ private static final long EPOCH = 1596211200000L; public static void main(String[] args) { long workerId = 1L; long dataCenterId = 1L; long seq = 4095; String timestampString = leftPadding(Long.toBinaryString(System.currentTimeMillis() - EPOCH), 41); String workerIdString = leftPadding(Long.toBinaryString(workerId), 5); String dataCenterIdString = leftPadding(Long.toBinaryString(dataCenterId), 5); String seqString = leftPadding(Long.toBinaryString(seq), 12); String value = HIGH + timestampString + workerIdString + dataCenterIdString + seqString; long num = Long.parseLong(value, 2); System.out.println(num); // 某个时刻输出为3125927076831231 } private static String leftPadding(String value, int maxLength) { int diff = maxLength - value.length(); StringBuilder builder = new StringBuilder(); for (int i = 0; i < diff; i++) { builder.append("0"); } builder.append(value); return builder.toString(); } }
然后把代码规范一下,编写出一个简版 Snowflake 算法实现的工程化代码:
// 主键生成器接口 public interface PrimaryKeyGenerator { long generate(); } // 简易Snowflake实现 public class SimpleSnowflake implements PrimaryKeyGenerator { private static final String HIGH = "0"; private static final long MAX_WORKER_ID = 31; private static final long MIN_WORKER_ID = 0; private static final long MAX_DC_ID = 31; private static final long MIN_DC_ID = 0; private static final long MAX_SEQUENCE = 4095; /** * 机器ID */ private final long workerId; /** * 数据中心ID */ private final long dataCenterId; /** * 基准纪元值 */ private final long epoch; private long sequence = 0L; private long lastTimestamp = -1L; public SimpleSnowflake(long workerId, long dataCenterId, long epoch) { this.workerId = workerId; this.dataCenterId = dataCenterId; this.epoch = epoch; checkArgs(); } private void checkArgs() { if (!(MIN_WORKER_ID <= workerId && workerId <= MAX_WORKER_ID)) { throw new IllegalArgumentException("Worker id must be in [0,31]"); } if (!(MIN_DC_ID <= dataCenterId && dataCenterId <= MAX_DC_ID)) { throw new IllegalArgumentException("Data center id must be in [0,31]"); } } @Override public synchronized long generate() { long timestamp = System.currentTimeMillis(); // 时钟回拨 if (timestamp < lastTimestamp) { throw new IllegalStateException("Clock moved backwards"); } // 同一毫秒内并发 if (lastTimestamp == timestamp) { sequence = sequence + 1; if (sequence == MAX_SEQUENCE) { timestamp = untilNextMillis(lastTimestamp); sequence = 0L; } } else { // 下一毫秒重置sequence为0 sequence = 0L; } lastTimestamp = timestamp; // 41位时间戳字符串,不够位数左边补"0" String timestampString = leftPadding(Long.toBinaryString(timestamp - epoch), 41); // 5位机器ID字符串,不够位数左边补"0" String workerIdString = leftPadding(Long.toBinaryString(workerId), 5); // 5位数据中心ID字符串,不够位数左边补"0" String dataCenterIdString = leftPadding(Long.toBinaryString(dataCenterId), 5); // 12位序列号字符串,不够位数左边补"0" String seqString = leftPadding(Long.toBinaryString(sequence), 12); String value = HIGH + timestampString + workerIdString + dataCenterIdString + seqString; return Long.parseLong(value, 2); } private long untilNextMillis(long lastTimestamp) { long timestamp; do { timestamp = System.currentTimeMillis(); } while (timestamp <= lastTimestamp); return timestamp; } private static String leftPadding(String value, int maxLength) { int diff = maxLength - value.length(); StringBuilder builder = new StringBuilder(); for (int i = 0; i < diff; i++) { builder.append("0"); } builder.append(value); return builder.toString(); } public static void main(String[] args) { long epoch = LocalDateTime.of(1970, 1, 1, 0, 0, 0, 0) .toInstant(ZoneOffset.of("+8")).toEpochMilli(); PrimaryKeyGenerator generator = new SimpleSnowflake(1L, 1L, epoch); for (int i = 0; i < 5; i++) { System.out.println(String.format("第%s个生成的ID: %d", i + 1, generator.generate())); } } } // 某个时刻输出如下 第1个生成的ID: 6698247966366502912 第2个生成的ID: 6698248027448152064 第3个生成的ID: 6698248032162549760 第4个生成的ID: 6698248033076908032 第5个生成的ID: 6698248033827688448
通过字符串拼接的写法虽然运行效率低,但是可读性会比较高,工程化处理后的代码可以在实例化时候直接指定 Worker ID 和 Data Center ID 等值,并且这个简易的 Snowflake 实现没有第三方库依赖,拷贝下来可以直接运行。上面的方法使用字符串拼接看起来比较低端,其实最后那部分的按位或, 可以完全转化为加法 :
public class Main { /** * 2020-08-01 00:00:00 */ private static final long EPOCH = 1596211200000L; public static void main(String[] args) { long workerId = 1L; long dataCenterId = 1L; long seq = 4095; long timestampDiff = System.currentTimeMillis() - EPOCH; long num = (long) (timestampDiff * Math.pow(2, 22)) + (long) (dataCenterId * Math.pow(2, 17)) + (long) (workerId * Math.pow(2, 12)) + seq; System.out.println(num); // 某个时刻输出为3248473482862591 } }
这样看起来整个算法都变得简单,不过这里涉及到指数运算和加法运算,效率会比较低。
小结