CountDownLatch源码解析

CountDownLatch

CountDownLatch基于AQS实现的同步器,允许一个或者多个线程通过await()方法进入阻塞等待,直到一个或者多个线程执行countDown()完成。CountDownLatch在创建时需要传入一个count值,一旦某个或者多个线程调用了await()方法,那么需要等待count值减为0,才能继续执行。

countDown()方法每执行一次,count(state)值减1,直到减为0。一个线程可以多次调用countDown()方法,每次调用都会造成count减1

CountDownLatch在RocketMQ底层通信被大量使用,实现远程调用异步转同步。Netty Client发送消息之前创建一个ResponseFutureReponseFuture中有一个CountDownLatch属性,发送消息之后调用await() ,等待response,当接收到响应之后,调用对应ResponseFutureCountDownLatch#countDown,唤醒阻塞线程。

内部类AQS实现

private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }

    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

构造函数

public CountDownLatch(int count) {
    // count不能为负数
    if (count < 0) throw new IllegalArgumentException("count < 0");
    // 创建同步器,设置state为count
    this.sync = new Sync(count);
}

await

public void await() throws InterruptedException {
    // AQS#acquireSharedInterruptibly -> Sync#tryAcquireShared(如果state=0 返回1,立即返回,线程继续向下执行,如果state != 0, 返回-1,线程进入同步队列,阻塞排队)
    sync.acquireSharedInterruptibly(1);
}

public final void acquireSharedInterruptibly(int arg)
    throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    // 如果state != 0,tryAcquireShared()方法返回-1,说明需要等待其他线程执行countDown()方法,线程进入同步队列阻塞
    // 如果state = 0,tryAcquireShared()方法返回1,线程立即返回,继续向下执行
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}


protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 进入同步队列阻塞
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        // 自旋等待state = 0,等待其他线程执行完毕
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                // 如果state = 0,表明其他同步线程执行完毕,线程阻塞结束
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    // 更新头节点为自己,并向后唤醒其他阻塞的线程
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

countDown

/**
 * count(state)值减1,当减为0时,由于await调用阻塞的线程将被唤醒继续执行
 */
public void countDown() {
    sync.releaseShared(1);
}

public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) { // 将count值减-1,如果count值减1后等于0,返回true,
        // count值减1后等于0,唤醒在同步队列上等待的第一个线程,第一个线程会向后传播,唤醒后驱节点(doAcquireSharedInterruptibly)
        doReleaseShared(); 
        return true;
    }
    return false;
}

/**
 * 自旋 + CAS完成更新
 */
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

 /**
  * count(state)值减1后等于0,会调用该方法,该方法唤醒在同步队列上等待的第一个线程
  */ 
 private void doReleaseShared() {
       
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

获取count

public long getCount() {
    return sync.getCount();
}

相关推荐