ConcurrentHashMap源码分析

ConcurrentHashMap解决了HashMap的线程不安全问题,在分析之前先介绍一个将HashMap线程安全的方法。利用Collections.synchronizedMAp方法调用内部类SynchronizedMap

HashMap<String, String> map = new HashMap<>();
Map m =  Collections.synchronizedMap(map);
private static class SynchronizedMap<K,V>
        implements Map<K,V>, Serializable {
        private static final long serialVersionUID = 1978198479659022715L;

        private final Map<K,V> m;     // Backing Map
        final Object      mutex;        // Object on which to synchronize

        SynchronizedMap(Map<K,V> m) {
            this.m = Objects.requireNonNull(m);
            mutex = this;
        }

        SynchronizedMap(Map<K,V> m, Object mutex) {
            this.m = m;
            this.mutex = mutex;
        }

内部主要有两个变量,一个普通变量Map,还有一个互斥锁mutex。通过构造方法将外部的Map传入进去,如果没有要传入的mutex,则将引用this赋值给mutex,就产生了一个对象实例锁。之后,要操作Map的时候只要再外部添加一个synchronized关键字即可,很简单,但有时会影响性能。

public V put(K key, V value) {
            synchronized (mutex) {return m.put(key, value);}
        }
        public V remove(Object key) {
            synchronized (mutex) {return m.remove(key);}
        }
        public void putAll(Map<? extends K, ? extends V> map) {
            synchronized (mutex) {m.putAll(map);}
        }
        public void clear() {
            synchronized (mutex) {m.clear();}
        }

ConcurrentHashMap JDK1.7版本

在JDK1.7中,ConcurrentHashMap是由一个Segment数组和多个HashEntry组成,每一个Segment元素存储的是HashEntry数组和链表。它采用的是分段锁技术。其中,Segment继承于ReentrantLock。
ConcurrentHashMap源码分析

//默认的数组大小16(HashMap里的那个数组)
static final int DEFAULT_INITIAL_CAPACITY = 16;

//扩容因子0.75
static final float DEFAULT_LOAD_FACTOR = 0.75f;
 
//ConcurrentHashMap中的数组
final Segment<K,V>[] segments

//默认并发标准16
static final int DEFAULT_CONCURRENCY_LEVEL = 16;

//Segment是ReentrantLock子类,因此拥有锁的操作
 static final class Segment<K,V> extends ReentrantLock implements Serializable {
  //分别是数组、键值对数量、阈值、负载因子
  transient volatile HashEntry<K,V>[] table;
  transient int count;
  transient int threshold;
  final float loadFactor;

  Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
            this.loadFactor = lf;
            this.threshold = threshold;
            this.table = tab;
        }
 }
 
 //HashEntry对象,存key、value、hash值以及下一个节点
 static final class HashEntry<K,V> {
        final int hash;
        final K key;
        volatile V value;
        volatile HashEntry<K,V> next;
 }
//segment中HashEntry[]数组最小长度
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

//用于定位在segments数组中的位置,下面介绍
final int segmentMask;
final int segmentShift;

变量concurrentLevel表示并发数,默认是16,理论上最多可以同时支持16个线程并发写,只要它们的操作分别分布在不同的Segment上。这个值可以在初始化的时候设置为其他值,但是一旦初始化后,它是不可以扩容的。

public ConcurrentHashMap(int initialCapacity,
                         float loadFactor, int concurrencyLevel) {
    if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
        throw new IllegalArgumentException();
    if (concurrencyLevel > MAX_SEGMENTS)
        concurrencyLevel = MAX_SEGMENTS;
    // Find power-of-two sizes best matching arguments
    int sshift = 0;
    int ssize = 1;
    // 计算并行级别 ssize,因为要保持并行级别是 2 的 n 次方
    while (ssize < concurrencyLevel) {
        ++sshift;
        ssize <<= 1;
    }
    // 我们这里先不要那么烧脑,用默认值,concurrencyLevel 为 16,sshift 为 4
    // 那么计算出 segmentShift 为 28,segmentMask 为 15,后面会用到这两个值
    this.segmentShift = 32 - sshift;
    this.segmentMask = ssize - 1;

    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;

    // initialCapacity 是设置整个 map 初始的大小,
    // 这里根据 initialCapacity 计算 Segment 数组中每个位置可以分到的大小
    // 如 initialCapacity 为 64,那么每个 Segment 或称之为"槽"可以分到 4 个
    int c = initialCapacity / ssize;
    if (c * ssize < initialCapacity)
        ++c;
    // 默认 MIN_SEGMENT_TABLE_CAPACITY 是 2,这个值也是有讲究的,因为这样的话,对于具体的槽上,
    // 插入一个元素不至于扩容,插入第二个的时候才会扩容
    int cap = MIN_SEGMENT_TABLE_CAPACITY; 
    while (cap < c)
        cap <<= 1;

    // 创建 Segment 数组,
    // 并创建数组的第一个元素 segment[0]
    Segment<K,V> s0 =
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                         (HashEntry<K,V>[])new HashEntry[cap]);
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
    // 往数组写入 segment[0]
    UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
    this.segments = ss;
}

初始化可以得到:
1)Segment数组长度为16,不可以扩容
2)每个Segment元素的默认长度大小为2,负载因子为0.75,得出初始阈值为1.5,当插入第二个值时会进行第一次扩容
3)初始化了segment[0],其他位置还是null。
4)当前segmentShift的值为32-4=28,segmentMask为16-1=15,先简单分别称它们为移位码和掩码。

JDK1.7的put操作

ConcurrentHashMap源码分析

public V put(K key, V value) {
        Segment<K,V> s;
        //注意valus不能为空!!!
        if (value == null)
            throw new NullPointerException();
        //根据key计算hash值,key也不能为null,否则hash(key)报空指针
        int hash = hash(key);
        //根据hash值计算在segments数组中的位置
        int j = (hash >>> segmentShift) & segmentMask;
        //查看当前数组中指定位置Segment是否为空
        //若为空,先创建初始化Segment再put值,不为空,直接put值。
        if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
            s = ensureSegment(j);
        return s.put(key, hash, value, false);
    }

当key为空时,会抛出异常。根据hash来找到对应的Segment,然后执行Segment内部的put操作。
如果定位到的segment[j] 是空的,没有初始化,需要在该位置初始化一个Segmentj,调用方法ensureSegment, 要初始化的下标j是大于0的,因为segment[0]在构造函数中已经初始化了,不会为空。

private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        // 这里看到为什么之前要初始化 segment[0] 了,
        // 使用当前 segment[0] 处的数组长度和负载因子来初始化 segment[k]
        // 为什么要用“当前”,因为 segment[0] 可能早就扩容过了
        Segment<K,V> proto = ss[0];
        int cap = proto.table.length;
        float lf = proto.loadFactor;
        int threshold = (int)(cap * lf);

        // 初始化 segment[k] 内部的数组
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
            == null) { // 再次检查一遍该槽是否被其他线程初始化了。

            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            // 使用 while 循环,内部用 CAS,当前线程成功设值或其他线程成功设值后,退出
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                   == null) {
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}

考虑到并发,会利用CAS机制来进行初始化,加载因子和数组长度和Segment[0]一致。之后,就进入这个segment进行put操作。

final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            //步骤① start
            HashEntry<K,V> node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);
            //步骤① end
            V oldValue;
            try {
                //步骤② start
                //获取Segment中的HashEntry[]
                HashEntry<K,V>[] tab = table;
                //算出在HashEntry[]中的位置
                int index = (tab.length - 1) & hash;
                //找到HashEntry[]中的指定位置的第一个节点
                HashEntry<K,V> first = entryAt(tab, index);
                for (HashEntry<K,V> e = first;;) {
                    //如果不为空,遍历这条链
                    if (e != null) {
                        K k;
                        //情况① 之前已存过,则替换原值
                        if ((k = e.key) == key ||
                            (e.hash == hash && key.equals(k))) {
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;
                                ++modCount;
                            }
                            break;
                        }
                        e = e.next;
                    }
                    else {
                        //情况② 另一个线程的准备工作
                        if (node != null)
                            //链表头插入方式
                            node.setNext(first);
                        else //情况③ 该位置为空,则新建一个节点(注意这里采用链表头插入方式)
                            node = new HashEntry<K,V>(hash, key, value, first);
                        //键值对数量+1
                        int c = count + 1;
                        //如果键值对数量超过阈值
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            //扩容
                            rehash(node);
                        else //未超过阈值,直接放在指定位置
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        //插入成功返回null
                        oldValue = null;
                        break;
                    }
                }
            //步骤② end
            } finally {
                //步骤③
                //解锁
                unlock();
            }
            //修改成功,返回原值
            return oldValue;
        }

在执行put操作时首先调用tryLock尝试获得锁,如果获取失败就说明有其他线程竞争,则利用scanAndLockForPut()通过自旋获取锁。在里面如果重试的次数达到了max_scan_retries则改为阻塞锁获取,保证能获得成功。之后就是按照hashmap的1.7版本的put操作那样插入数据,即是按头插法插入的。最后是解锁。

// 方法参数上的 node 是这次扩容后,需要添加到新的数组中的数据。
private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table;
    int oldCapacity = oldTable.length;
    // 2 倍
    int newCapacity = oldCapacity << 1;
    threshold = (int)(newCapacity * loadFactor);
    // 创建新数组
    HashEntry<K,V>[] newTable =
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    // 新的掩码,如从 16 扩容到 32,那么 sizeMask 为 31,对应二进制 ‘000...00011111’
    int sizeMask = newCapacity - 1;

    // 遍历原数组,老套路,将原数组位置 i 处的链表拆分到 新数组位置 i 和 i+oldCap 两个位置
    for (int i = 0; i < oldCapacity ; i++) {
        // e 是链表的第一个元素
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            // 计算应该放置在新数组中的位置,
            // 假设原数组长度为 16,e 在 oldTable[3] 处,那么 idx 只可能是 3 或者是 3 + 16 = 19
            int idx = e.hash & sizeMask;
            if (next == null)   // 该位置处只有一个元素,那比较好办
                newTable[idx] = e;
            else { // Reuse consecutive sequence at same slot
                // e 是链表表头
                HashEntry<K,V> lastRun = e;
                // idx 是当前链表的头结点 e 的新位置
                int lastIdx = idx;

                // 下面这个 for 循环会找到一个 lastRun 节点,这个节点之后的所有元素是将要放到一起的
                for (HashEntry<K,V> last = next;
                     last != null;
                     last = last.next) {
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                // 将 lastRun 及其之后的所有节点组成的这个链表放到 lastIdx 这个位置
                newTable[lastIdx] = lastRun;
                // 下面的操作是处理 lastRun 之前的节点,
                //    这些节点可能分配在另一个链表中,也可能分配到上面的那个链表中
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    // 将新来的 node 放到新数组中刚刚的 两个链表之一 的 头部
    int nodeIndex = node.hash & sizeMask; // add the new node
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

上面的代码先找出扩容前后需要转移的节点,先执行转移,然后在把该条链上剩下的节点转移。整体的put流程图如下:

get方法分析
1)计算hash值,找到segment数组中的具体位置
2)槽中也是一个数组,根据hash找到数组中的具体位置
3)这个时候获取到的是链表了,顺着链表进行查找就可以了
get方法比较简单,它能够实现无锁化操作的主要原因是使用UNSAFE对象的getObjectVolatile()方法提供原子语义,来获取segment和头节点。

jDK 1.8版本

在jdk1.8版本中ConcurrentHashMap利用CAS+Sychronized来确保线程安全,它的底层数组结构依然是数组+链表+红黑树
重要属性

//存放node的数组,大小是2的幂次方
    transient volatile Node[] table;
    //扩容时用于存放数据的变量,平时为null
    private transient volatile Node[] nextTable;
    //通过CAS更新,记录容器的容量大小
    private transient volatile long baseCount;
    /**
     * 控制标志符
     * 负数: 代表正在进行初始化或扩容操作,其中-1表示正在初始化,-N 表示有N-1个线程正在进行扩容操作
     * 正数或0: 代表hash表还没有被初始化,这个数值表示初始化或下一次进行扩容的大小,类似于扩容阈值
     * 它的值始终是当前ConcurrentHashMap容量的0.75倍,这与loadfactor是对应的。
     * 实际容量 >= sizeCtl,则扩容
     */
    private transient volatile int sizeCtl;
    //下次transfer方法的起始下标index加上1之后的值
    private transient volatile int transferIndex;
    //CAS自旋锁标志位
    private transient volatile int cellsBusy;
    //counter cell表,长度总为2的幂次
    private transient volatile CounterCell[] counterCells;

重要内部类
Node节点类

static class Node implements Map.Entry {
        final int hash;
        final K key;
        volatile V val;
        volatile Node next;
        ...
   }

value和next属性用volatile修饰保证了内存可见性,没有setValue方法直接改变Node的value属性

static final class ForwardingNode extends Node {
        final Node[] nextTable;
        //ForwardingNode节点hash为-1,若操作中遇到此类型节点,表明有线程正在扩容
        ForwardingNode(Node[] tab) {
            super(MOVED, null, null, null);
            this.nextTable = tab;
        }
        ...
    }

ForwardingNode是一种临时节点只有扩容时使用,表明当前桶已做过处理。
initTable方法

private final Node[] initTable() {
        Node[] tab; int sc;
        while ((tab = table) == null || tab.length == 0) {
            //若sizeCtl<0,即存在其他线程正在初始化操作,确保只有一个线程进行初始化
            if ((sc = sizeCtl) < 0)
                Thread.yield(); // lost initialization race; just spin
            //利用CAS方法把sizectl的值置为-1,表明已有线程进行初始化
            else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                try {
                    if ((tab = table) == null || tab.length == 0) {
                        //获得桶容量
                        int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                        @SuppressWarnings("unchecked")
                        //初始化node数组
                        Node[] nt = (Node[])new Node[n];
                        table = tab = nt;
                        //计算扩容阈值0.75n
                        sc = n - (n >>> 2);
                    }
                } finally {
                    sizeCtl = sc;
                }
                break;
            }
        }
        return tab;
    }

只有一个线程参与初始化过程,其他线程必须挂起;构造函数不初始化过程,初始化真正是在put操作触发。
当sizeCtl为-1时表明已有一个线程正在执行初始化操作,当前线程要执行Thread.yield()操作让出CPU时间片。而正在进行初始化的线程会利用CAS操作将sizeCtl改为-1,创建出一个数组后,并将sizeCtl赋值为当前可用的数组大小。

整体流程:
1)首先对于每一个放入的值,首先利用spread方法对key的hashcode进行一次hash计算,获取在table数组的索引下标地址
2)如果当前table还没有初始化,先调用initTable()进行初始化
3)如果该位置为null,说明还没有数据放入,则利用CAS操作直接放入
4)如果不为空,说明存在哈希碰撞,当fh==MOED(-1),说明数组正在扩容
5)当数组不在扩容状态,对该节点利用sychronized加锁,然后再进行一次判断当前节点是否发生变化,没有变化执行下面的方法;发生了变化直接跳转到第8步
6)如果是链表节点(fh>0),开始遍历链表节点,如果key相等,则进行值覆盖;如果都没有到节点尾部插入新节点
7)如果这个节点类型是TreeBin,利用红黑树的方法插入新的节点。
8)如果链表长度大于8,则利用treeifyBin把这个链表转化为红黑树,但是不是大于8就转化为红黑树,当数组长度小于MIN_TREEIFY_CAPACITY(默认是64)时,进行扩容操作。
9)如果当前实际大小数量+1超过了临界值,就进行扩容
spread()方法
计算hash,主要是将key的hashcode的低16位和高16位进行异或运算,0x7fffffff主要是
用于和负数hash值进行 & 运算,将其转化为正数(绝对值不相等)。

static final int spread(int h) {
        return (h ^ (h >>> 16)) & HASH_BITS;
    }

transfer 扩容操作

总体流程:
1)计算每个线程可以处理的桶区间,默认16
2) 构建一个nextTable,容量是原来的两倍
3)死循环开始,根据一个finishing变量来判断,当为true时表示扩容结束,否则继续扩容
3.1)进入一个while循环,分配数组中一个桶的区间给线程,默认是16。从大到小进行分配。当拿到分配值后,进行i--递减。这个i是数组下标。其中,bound变量是指该线程此次可以处理的区间的最小下标,超过这个下标,就需要重新领取区间或者结束扩容;advance变量是值是否转移到下一个桶,如果为true,表明该桶已经处理好了,向下一个桶推进;如果为false,说明还没有处理好当前桶,不能推进。
3.2)判断扩容是否结束,如果扩容结束,清空临时变量,更新table变量,更新库容阈值
3.3)如果当前桶内没有节点,则通过CAS操作插入到ForwardingNode节点,用于告诉其他线程该桶已经处理了。

else if ((f = tabAt(tab, i)) == null)
                advance = casTabAt(tab, i, null, fwd);

3.4)如果当前桶已经被其他线程处理了,当前线程处理到这个节点时,获得的hash值应该为-1(MOVED),则直接跳过,向前一个桶处理。

else if ((fh = f.hash) == MOVED)
                advance = true; // already processed

4) 如果该桶没有被处理,则开始李勇sychronized加锁,然后再判断一下该桶的头节点是否发生了变化,没有发生变化继续执行。

synchronized (f) {
                    if (tabAt(tab, i) == f) {

4.1)如果该桶存储的是链表的话
4.1.1)因为扩容后与扩容前就增加了一位,只要比较新增的最高位是1还是0即可。int runBit = fh & n;是标识新增的位标志。然后开始对链表进行遍历。lastRun表示该节点及剩余的节点的新位置都是一样的,不需要再向下遍历,只要把这部分的头结点,即lastRun移动到新的位置,就能使剩余的部分都移到了新位置。此时的runBit表示该节点位置的标识,可能是1,也可能是0。

Node<K,V> lastRun = f;
for (Node<K,V> p = f.next; p != null; p = p.next) {
      int b = p.hash & n;
      if (b != runBit) {
         runBit = b;
         lastRun = p;
      }
}

4.1.2)如果最后一个需要移动的节点是到原来的索引下标下,则将低位置头结点ln=lastRun;如果是到新的索引下标下,则将高位置头节点设置为hn = lastRun;

if (runBit == 0) {
    ln = lastRun;
    hn = null;
}
else {
    hn = lastRun;
    ln = null;
}

4.1.3)然后对链表进行遍历,知道最后一个需要移动的节点就终止,将节点分别插入到lnhn,利用头插法插入。

for (Node<K,V> p = f; p != lastRun; p = p.next) {
    int ph = p.hash; K pk = p.key; V pv = p.val;
    if ((ph & n) == 0)
        ln = new Node<K,V>(ph, pk, pv, ln);
    else
        hn = new Node<K,V>(ph, pk, pv, hn);
    }

4.1.4 )分别将lnhn插入到新数组,并将旧数组的该位置的节点变成ForwardingNode类型。之后,设置advance为true,表明该桶处理完了。

setTabAt(nextTab, i, ln);
setTabAt(nextTab, i + n, hn);
setTabAt(tab, i, fwd);
advance = true;

4.2)如果桶存储的是红黑树类型
也是判断是最高位是0还是1,生成两个树lohi,然后判断这个树如果小于6,就转化为链表,如果不是,则处理成标准的红黑树。之后,设置advance为true,表明该桶处理完了。

在旧数组中节点设置为ForwardingNode,表明该节点已经被处理了,里面的nextTable执行新的数组。

get()方法

先通过hash值获取在哪个桶,如果头节点的key相等,则返回值。如果hash小于0,表明该节点是ForwardingNode类型,已经发生了移动,则调用该类型节点的find方法查找;其他情况就是遍历链表记行查询。

public V get(Object key) {
        Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
        int h = spread(key.hashCode());
        if ((tab = table) != null && (n = tab.length) > 0 &&
            (e = tabAt(tab, (n - 1) & h)) != null) {
            if ((eh = e.hash) == h) {
                if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                    return e.val;
            }
            else if (eh < 0)
                return (p = e.find(h, key)) != null ? p.val : null;
            while ((e = e.next) != null) {
                if (e.hash == h &&
                    ((ek = e.key) == key || (ek != null && key.equals(ek))))
                    return e.val;
            }
        }
        return null;
    }
Node<K,V> find(int h, Object k) {
            // loop to avoid arbitrarily deep recursion on forwarding nodes
            outer: for (Node<K,V>[] tab = nextTable;;) {
                Node<K,V> e; int n;
                if (k == null || tab == null || (n = tab.length) == 0 ||
                    (e = tabAt(tab, (n - 1) & h)) == null)
                    return null;
                for (;;) {
                    int eh; K ek;
                    if ((eh = e.hash) == h &&
                        ((ek = e.key) == k || (ek != null && k.equals(ek))))
                        return e;
                    if (eh < 0) {
                        if (e instanceof ForwardingNode) {
                            tab = ((ForwardingNode<K,V>)e).nextTable;
                            continue outer;
                        }
                        else
                            return e.find(h, k);
                    }
                    if ((e = e.next) == null)
                        return null;
                }
            }
        }

参考文章:

  1. 图解ConcurrentHashMap
  2. 第四天:ConcurrentHashMap全解析(上)
  3. Java并发——ConcurrentHashMap(JDK 1.8)
  4. 深入分析ConcurrentHashMap1.8的扩容实现