曹工说JDK源码(2)--ConcurrentHashMap的多线程扩容,说白了,就是分段取任务
前言
先预先说明,我这边jdk的代码版本为1.8.0_11,同时,因为我直接在本地jdk源码上进行了部分修改、调试,所以,导致大家看到的我这边贴的代码,和大家的不太一样。
不过,我对源码进行修改、重构时,会保证和原始代码的功能、逻辑严格一致,更多时候,可能只是修改变量名,方便理解。
大家也知道,jdk代码写得实在是比较深奥,变量名经常都是单字符,i,j,k啥的,实在是很难理解,所以,我一般会根据自己的理解,去重命名,为了减轻我们的头脑负担。
至于怎么去修改代码并调试,可以参考我之前的文章:
曹工力荐:调试 jdk 中 rt.jar 包部分的源码(可自由增加注释,修改代码并debug)
文章中,我改过的代码放在:
https://gitee.com/ckl111/jdk-debug
sizeCtl field的初始化
大家知道,concurrentHashMap底层是数组+链表+红黑树,数组的长度假设为n,在hashmap初始化的时候,这个n除了作为数组长度,还会作为另一个关键field的值。
/** * Table initialization and resizing control. When negative, the * table is being initialized or resized: -1 for initialization, * else -(1 + the number of active resizing threads). Otherwise, * when table is null, holds the initial table size to use upon * creation, or 0 for default. After initialization, holds the * next element count value upon which to resize the table. */ private transient volatile int sizeCtl;
该字段非常关键,根据取值不同,有不同的功能。
使用默认构造函数时
public ConcurrentHashMap() { }
此时,sizeCtl被初始化为0.
使用带初始容量的构造函数时
此时,sizeCtl也是32,和容量一致。
使用另一个map来初始化时
public ConcurrentHashMap(Map<? extends K, ? extends V> m) { this.sizeCtl = DEFAULT_CAPACITY; putAll(m); }
此时,sizeCtl,直接使用了默认值,16.
使用初始容量、负载因子来初始化时
public ConcurrentHashMap(int initialCapacity, float loadFactor) { this(initialCapacity, loadFactor, 1); }
这里重载了:
这里,我们传入的负载因子为0.75,这也是默认的负载因子,传入的初始容量为14.
这里面会根据: 1 + 14/0.75 = 19,拿到真正的size,然后根据size,获取到第一个大于19的2的n次方,即32,来作为数组容量,然后sizeCtl也被设置为32.
initTable时,对sizeCtl field的修改
实际上,new一个hashmap的时候,我们并没有创建支撑数组,那,什么时候创建数组呢?是在真正往里面放数据的时候,比如put的时候。
/** Implementation for put and putIfAbsent */ final V putVal(K key, V value, boolean onlyIfAbsent) { if (key == null || value == null) throw new NullPointerException(); int hash = spread(key.hashCode()); int binCount = 0; ConcurrentHashMapPutResultVO vo = new ConcurrentHashMapPutResultVO(); vo.setBinCount(0); for (Node<K,V>[] tab = table;;) { int tableLength; // 1 if (tab == null) { tab = initTable(); continue; } ... }
1处,即会去初始化table。
/** * Initializes table, using the size recorded in sizeCtl. * 初始化hashmap,使用sizeCtl作为容量 */ private final Node<K,V>[] initTable() { Node<K,V>[] tab; int sc; while ((tab = table) == null || tab.length == 0) { sc = sizeCtl; if (sc < 0){ Thread.yield(); // lost initialization race; just spin continue; } /** * 走到这里,说明sizeCtl大于0,大于0,代表什么,可以去看下其构造函数,此时,sizeCtl表示 * capacity的大小。 * {@link #ConcurrentHashMap(int)} * * cas修改为-1,如果成功修改为-1,则表示抢到了锁,可以进行初始化 * */ // 1 boolean bGotChanceToInit = U.compareAndSwapInt(this, SIZECTL, sc, -1); if (bGotChanceToInit) { try { tab = table; /** * 如果当前表为空,尚未初始化,则进行初始化,分配空间 */ if (tab == null || tab.length == 0) { /** * sc大于0,则以sc为准,否则使用默认的容量 */ int n = (sc > 0) ? sc : DEFAULT_CAPACITY; Node<K, V>[] nt = (Node<K, V>[]) new Node<?, ?>[n]; table = tab = nt; /** * n >>> 2,无符号右移2位,则是n的四分之一。 * n- n/4,结果为3/4 * n * 则,这里修改sc为 3/4 * n * 比如,默认容量为16,则修改sc为12 */ // 2 sc = n - (n >>> 2); } } finally { /** * 修改sizeCtl到field */ // 3 sizeCtl = sc; } break; } } return tab; }
- 1处,cas修改sizeCtl为-1,成功了的,获得初始化table的权利
- 2处,修改局部变量sc为: n - (n >>> 2),也就是修改为 0.75n,假设此时的数组容量为16,那么sc就是16 * 0.75 = 12.
- 3处,将sc赋值到field: sizeCtl
经过上面的分析,initTable时,这个字段可能有两种取值:
- -1,有线程正在对该table进行初始化
- 0.75*数组长度,此时,已经初始化完成
上面说的是,在put的时候去initTable,实际上,这个initTable,也会在以下函数中被调用,其共同点就是,都是往里面放数据的操作:
扩容时机
上面说了很多,目前,我们知道的是,在initTable后,sizeCtl的值,是旧的数组的长度 * 0.75。
接下来,我们看看扩容时机,在put时,会调用putVal,这个函数的大体步骤:
final V putVal(K key, V value, boolean onlyIfAbsent) { if (key == null || value == null) throw new NullPointerException(); // 1 int hash = spread(key.hashCode()); int binCount = 0; System.out.println("binCount:" + binCount); // 2 ConcurrentHashMapPutResultVO vo = new ConcurrentHashMapPutResultVO(); vo.setBinCount(0); for (Node<K,V>[] tab = table;;) { int tableLength; // 3 if (tab == null) { tab = initTable(); continue; } tableLength = tab.length; if (tableLength == 0) { tab = initTable(); continue; } int entryNodeHashCode; // 4 int entryNodeIndex = (tableLength - 1) & hash; Node<K,V> entryNode = tabAt(tab,entryNodeIndex); /** * 5 如果我们要放的桶,还是个空的,则直接cas放进去 */ if (entryNode == null) { Node<K, V> node = new Node<>(hash, key, value, null); // no lock when adding to empty bin boolean bSuccess = casTabAt(tab, entryNodeIndex, null, node); if (bSuccess) { break; } else { /** * 如果没成功,则继续下一轮循环 */ continue; } } entryNodeHashCode = entryNode.hash; /** * 6 如果要放的这个桶,正在迁移,则帮助迁移 */ if (entryNodeHashCode == MOVED){ tab = helpTransfer(tab, entryNode); continue; } /** * 7 对entryNode加锁 */ V oldVal = null; System.out.println("sync"); synchronized (entryNode) { /** * 这一行是判断,在我们执行前面的一堆方法的时候,看看entryNodeIndex处的node是否变化 */ if (tabAt(tab, entryNodeIndex) != entryNode) { continue; } /** * 8 hashCode大于0,说明不是处于迁移状态 */ if (entryNodeHashCode >= 0) { /** * 9 链表中找到合适的位置并放入 */ findPositionAndPut(key, value, onlyIfAbsent, hash, vo, entryNode); binCount = vo.getBinCount(); oldVal = (V) vo.getOldValue(); } else if (entryNode instanceof TreeBin) { ... } } System.out.println("binCount:" + binCount); // 10 if (binCount != 0) { if (binCount >= TREEIFY_THRESHOLD) treeifyBin(tab, entryNodeIndex); if (oldVal != null) return oldVal; break; } } // 11 addCount(1L, binCount); return null; }
1处,计算key的hashcode
2处,我这边new了一个对象,里面两个字段:
public class ConcurrentHashMapPutResultVO<V> { int binCount; V oldValue; }
其中,oldValue用来存放,如果put进去的key/value,其中key已经存在的话,一般会直接覆盖之前的旧值,这里主要存放之前的旧值,因为我们需要返回旧值。
binCount,则存放:在找到对应的hash桶之后,在链表中,遍历了多少个元素,该值后面会使用,作为一个标志,当该标志大于0的时候,才去进一步检查,看看是否扩容。
3处,如果table为null,说明table里没有任何一个键值对,数组也还没创建,则初始化table
4处,根据hashcode,和(数组长度 - 1)相与,计算出应该存放的哈希桶在数组中的索引
5处,如果要放的哈希桶,还是空的,则直接cas设置进去,成功则跳出循环,否则重试
6处,如果要放的这个桶,该节点的hashcode为MOVED(一个常量,值为-1),说明有其他线程正在扩容该hashmap,则帮助扩容
7处,对要存放的hash桶的头节点加锁
8处,如果头节点的hashcode大于0,说明是拉了一条链表,则调用子方法(我这边自己抽的),去找到合适的位置并插入到链表
9处,findPositionAndPut,在链表中,找到合适的位置,并插入
10处,在findPositionAndPut函数中,会返回:为了找到合适的位置,遍历了多少个元素,这个值,就是binCount。
如果这个binCount大于8,则说明遍历了8个元素,则需要转红黑树了。
11处,因为我们新增了一个元素,总数自然要加1,这里面会去增加总数,和检查是否需要扩容。
其中,第9步,因为是自己抽的函数,所以这里贴出来给大家看下:
/** * 遍历链表,找到应该放的位置;如果遍历完了还没找到,则放到最后 * @param key * @param value * @param onlyIfAbsent * @param hash * @param vo * @param entryNode */ private void findPositionAndPut(K key, V value, boolean onlyIfAbsent, int hash, ConcurrentHashMapPutResultVO vo, Node<K, V> entryNode) { vo.setBinCount(1); for (Node<K,V> currentIterateNode = entryNode; ; vo.setBinCount(vo.getBinCount() + 1)) { /** * 如果当前遍历指向的节点的hash值,与参数中的key的hash值相等,则, * 继续判断 */ K currentIterateNodeKey = currentIterateNode.key; boolean bKeyEqualOrNot = Objects.equals(currentIterateNodeKey, key); /** * key的hash值相等,且equals比较也相等,则就是我们要找的 */ if (currentIterateNode.hash == hash && bKeyEqualOrNot) { /** * 获取旧的值 */ vo.setOldValue(currentIterateNode.val); /** * 覆盖旧的node的val */ if (!onlyIfAbsent) currentIterateNode.val = value; // 这里直接break跳出循环 break; } /** * 把当前节点保存起来 */ Node<K,V> pred = currentIterateNode; /** * 获取下一个节点 */ currentIterateNode = currentIterateNode.next; /** * 如果下一个节点为null,说明当前已经是链表的最后一个node了 */ if ( currentIterateNode == null) { /** * 则在当前节点后面,挂上新的节点 */ pred.next = new Node<K,V>(hash, key, value, null); break; } } }
第11步,也是我们要看的重点:
private final void addCount(long delta, int check) { CounterCell[] counterCellsArray = counterCells; // 1 long b = baseCount; // 2 long newBaseCount = b + delta; /** * 3 直接cas在baseCount上增加 */ boolean bSuccess = U.compareAndSwapLong(this, BASECOUNT, b, newBaseCount); if ( counterCellsArray != null || !bSuccess) { ... newBaseCount = sumCount(); } // 4 if (check >= 0) { while (true) { Node<K,V>[] tab = table; Node<K,V>[] nt; int n = 0; // 5 int sc = sizeCtl; // 6 boolean bSumExteedSizeControl = newBaseCount >= (long) sc; // 7 boolean bContinue = bSumExteedSizeControl && tab != null && (n = tab.length) < MAXIMUM_CAPACITY; if (bContinue) { int rs = resizeStamp(n); if (sc < 0) { if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 || sc == rs + MAX_RESIZERS || (nt = nextTable) == null || transferIndex <= 0) break; if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1)) transfer(tab, nt); } else if (U.compareAndSwapInt(this, SIZECTL, sc, (rs << RESIZE_STAMP_SHIFT) + 2)) // 8 transfer(tab, null); newBaseCount = sumCount(); } else { break; } } } }
1处,baseCount是一个field,存储当前hashmap中,有多少个键值对,你put一次,就一个;remove一次,就减一个。
2处,b + delta,其中,b就是baseCount,是旧的数量;dalta,我们传入的是1,就是要增加的元素数量
所以,b + delta,得到的,就是经过这次put后,预期的数量
3处,直接cas,修改baseCount这个field为 新值,也就是第二步拿到的值。
4处,这里检查check是否大于0,check,是第二个形参;这个参数,我们外边怎么传的?
addCount(1L, binCount);
不就是bincount吗,也就是说,这里检查:我们在put过程中,在链表中遍历了几个元素,如果遍历了至少1个元素,这里要进入下面的逻辑:检查是否要扩容,因为,你binCount大于0,说明可能已经开始出现哈希冲突了。
5处,取field:sizeCtl的值,给局部变量sc
6处,判断当前的新的键值对总数,是否大于sc了;比如容量是16,那么sizeCtl是12,如果此时,hashmap中存放的键值对已经大于等于12了,则要检查是否扩容了
7处,几个组合条件,查看是否要扩容,其中,主要的条件就是第6步的那个。
8处,调用transfer,进行扩容
总结一下,经过前面的第6处,我们知道,如果存放的键值对总数,已经大于等于0.75*哈希桶(也就是底层数组的长度)的数量了,那么,就基本要扩容了。
扩容的大体过程
扩容也是一个相对复杂的过程,这里只说大概,详细的放下讲。
假设,现在底层数组长度,128,也就是128个哈希桶,当存放的键值对数量,大于等于 128 * 0.75的时候,就会开始扩容,扩容的过程,大概是:
- 申请一个256(容量翻倍)的数组
- 现在有128个桶,相当于,需要对128个桶进行遍历,遍历每个桶拉出去的链表或红黑树,查看每个键值对,是需要放到新数组的什么位置
这个过程,昨天的博文,画了个图,这里再贴一下。
扩容后:
可是,如果我们要一个个去遍历所有哈希桶,然后遍历对应的链表/红黑树,会不会太慢了?完全是单线程工作啊。
换个思路,我们能不能加快点呢?比如,线程1可以去处理数组的 0 -15这16个桶,16- 31这16个桶,完全可以让线程2去做啊,这样的话,不就多线程了吗,不是就快了吗?
没错,jdk就是这么干的。
jdk维护了一个field,这个field,专门用来存当前可以获取的任务的索引,举个例子:
大家看上图就懂了,一开始,这里假设我们有128个桶,每次每个线程,去拿16个桶来处理。
刚开始的时候,field:transferIndex就等于127,也就是最后一个桶的位置,然后我们要从后往前取,那么,127 到112,刚好就是16个桶,所以,申请任务的时候,就会用cas去更新field为112,则表示,自己取到了112 到127这一个区间的hash桶迁移任务。
如果自始至终,只有一个线程呢,它处理完了112 - 127这一批hash桶后,会继续取下一波任务,96 - 112;以此类推。
如果多线程的话呢,也是类似的,反正都是去尝试cas更新transferIndex的值为任务区间的开始下标的值,成功了,就算任务认领成功了。
多线程,怎么知道需要去帮助扩容呢? 发起扩容的线程,在处理完bucket[k]时,会把老的table中的对应的bucket[k]的头节点,修改为下面这种类型的节点:
static final class ForwardingNode<K,V> extends Node<K,V> { final Node<K,V>[] nextTable; ForwardingNode(Node<K,V>[] tab) { super(MOVED, null, null, null); this.nextTable = tab; } }
其他线程,在put或者其他操作时,发现头结点变成了这个,就会去协助扩容了。
多线程扩容,和分段取任务的差别?
我个人感觉,差别不大,多线程扩容,就是多线程去获取自己的那一段任务,然后来完成。我这边写了简单的demo,不过感觉还是很有用的,可以帮助我们理解。
import sun.misc.Unsafe; import java.lang.reflect.Field; import java.util.concurrent.*; import java.util.concurrent.locks.LockSupport; public class ConcurrentTaskFetch { /** * 空闲任务索引,获取任务时,从该下标开始,往前获取。 * 比如当前下标为10,表示tasks数组中,0-10这个区间的任务,没人领取 */ // 0 private volatile int freeTaskIndexForFetch; // 1 private static final int TASK_COUNT_PER_FETCH = 16; // 2 private String[] tasks = new String[128]; public static void main(String[] args) { ConcurrentTaskFetch fetch = new ConcurrentTaskFetch(); // 3 fetch.init(); ThreadPoolExecutor executor = new ThreadPoolExecutor(10, 10, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(100)); executor.prestartAllCoreThreads(); CyclicBarrier cyclicBarrier = new CyclicBarrier(10); // 4 for (int i = 0; i < 10; i++) { executor.execute(new Runnable() { @Override public void run() { try { cyclicBarrier.await(); } catch (InterruptedException | BrokenBarrierException e) { e.printStackTrace(); } // 5 FetchedTaskInfo fetchedTaskInfo = fetch.fetchTask(); if (fetchedTaskInfo != null) { System.out.println("thread:" + Thread.currentThread().getName() + ",get task success:" + fetchedTaskInfo); try { TimeUnit.SECONDS.sleep(3); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println("thread:" + Thread.currentThread().getName() + ", process task finished"); } } }); } LockSupport.park(); } public void init() { for (int i = 0; i < 128; i++) { tasks[i] = "task" + i; } freeTaskIndexForFetch = tasks.length; } // 6 public FetchedTaskInfo fetchTask() { System.out.println("Thread start fetch task:"+Thread.currentThread().getName()+",time: "+System.currentTimeMillis()); while (true){ // 6.1 if (freeTaskIndexForFetch == 0) { System.out.println("thread:" + Thread.currentThread().getName() + ",get task failed,there is no task"); return null; } /** * 6.2 获取当前任务的集合的上界 */ int subTaskListEndIndex = this.freeTaskIndexForFetch; /** * 6.3 获取当前任务的集合的下界 */ int subTaskListStartIndex = subTaskListEndIndex > TASK_COUNT_PER_FETCH ? subTaskListEndIndex - TASK_COUNT_PER_FETCH : 0; /** * 6.4 * 现在,我们拿到了集合的上下界,即[subTaskListStartIndex,subTaskListEndIndex) * 该区间为前开后闭,所以,实际的区间为: * [subTaskListStartIndex,subTaskListEndIndex - 1] */ /** * 6.5 使用cas,尝试更新{@link freeTaskIndexForFetch} 为 subTaskListStartIndex */ if (U.compareAndSwapInt(this, FREE_TASK_INDEX_FOR_FETCH, subTaskListEndIndex, subTaskListStartIndex)) { // 6.6 FetchedTaskInfo info = new FetchedTaskInfo(); info.setStartIndex(subTaskListStartIndex); info.setEndIndex(subTaskListEndIndex - 1); return info; } } } // Unsafe mechanics private static final sun.misc.Unsafe U; private static final long FREE_TASK_INDEX_FOR_FETCH; static { try { // U = sun.misc.Unsafe.getUnsafe(); Field f = Unsafe.class.getDeclaredField("theUnsafe"); f.setAccessible(true); U = (Unsafe) f.get(null); Class<?> k = ConcurrentTaskFetch.class; FREE_TASK_INDEX_FOR_FETCH = U.objectFieldOffset (k.getDeclaredField("freeTaskIndexForFetch")); } catch (Exception e) { throw new Error(e); } } static class FetchedTaskInfo{ int startIndex; int endIndex; public int getStartIndex() { return startIndex; } public void setStartIndex(int startIndex) { this.startIndex = startIndex; } public int getEndIndex() { return endIndex; } public void setEndIndex(int endIndex) { this.endIndex = endIndex; } @Override public String toString() { return "FetchedTaskInfo{" + "startIndex=" + startIndex + ", endIndex=" + endIndex + ‘}‘; } } }
0处,定义了一个field,类似于前面的transferIndex
/** * 空闲任务索引,获取任务时,从该下标开始,往前获取。 * 比如当前下标为10,表示tasks数组中,0-10这个区间的任务,没人领取 */ // 0 private volatile int freeTaskIndexForFetch;
1,定义了每次取多少个任务,这里也是16个
private static final int TASK_COUNT_PER_FETCH = 16;
2,定义任务列表,共128个任务
3,main函数中,进行任务初始化
public void init() { for (int i = 0; i < 128; i++) { tasks[i] = "task" + i; } freeTaskIndexForFetch = tasks.length; }
主要初始化任务列表,其次,将freeTaskIndexForFetch 赋值为128,后续取任务,从这个下标开始
4处,启动10个线程,每个线程去执行取任务,按理说,我们128个任务,每个线程取16个,只能有8个线程取到任务,2个线程取不到
5处,线程逻辑里,去获取任务
6处,获取任务的方法定义
6.1 ,如果可获取的任务索引为0了,说明没任务了,直接返回
6.2,获取当前任务的集合的上界
6.3,获取当前任务的集合的下界,减去16就行了
6.4,拿到了集合的上下界,即[subTaskListStartIndex,subTaskListEndIndex)
6.5, 使用cas,更新field为:6.4中的任务下界。
执行效果演示:
可以看到,8个线程取到任务,2个线程没取到。
该思想在内存分配时的应用
其实jvm内存分配时,也是类似的思路,比如,设置堆内存为200m,那这200m是启动时立马从操作系统分配了的。
接下来,就是每次new对象的时候,去这个大内存里,找个小空间,这个过程,也是需要cas去竞争的,比如肯定也有个全局的字段,来表示当前可用内存的索引,比如该索引为100,表示,第100个字节后的空间是可以用的,那我要new个对象,这个对象有3个字段,需要大概30个字节,那我是不是需要把这个索引更新为130。
这中间是多线程的,所以也是要cas操作。
道理都是类似的。
总结
时间仓促,有问题在所难免,欢迎及时指出或加群讨论。