Go WaitGroup 源码分析

概述

go语言sync库中的WaitGroup是用于等待一个协程或者一组携程。使用Add函数增加计数器,使用Done函数减少计数器。当使用Wait函数等待计数器归零之后则唤醒主携程。需要注意的是:

  • Add和Done函数一定要配对,否则可能发生死锁
  • WaitGroup结构体不能复制

源码分析

WaitGroup 对象

type WaitGroup struct {
    noCopy noCopy
    // 位值:高32位是计数器,低32位是goroution等待计数。
    state1 [12]byte
    // 信号量,用于唤醒goroution
    sema   uint32
}

func (wg *WaitGroup) state() *uint64 {
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1))
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[4]))
    }
}

Add,Done,Wait

func (wg *WaitGroup) Add(delta int) {
    // 获取状态码
    statep := wg.state()
    if race.Enabled {
        _ = *statep // trigger nil deref early
        if delta < 0 {
            // Synchronize decrements with Wait.
            race.ReleaseMerge(unsafe.Pointer(wg))
        }
        race.Disable()
        defer race.Enable()
    }
    // 把传入的delta用原子操作加入到statep,
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    // 获取计数器数值
    v := int32(state >> 32)
    // 获取等待数量
    w := uint32(state)
    if race.Enabled && delta > 0 && v == int32(delta) {
        // The first increment must be synchronized with Wait.
        // Need to model this as a read, because there can be
        // several concurrent wg.counter transitions from 0.
        race.Read(unsafe.Pointer(&wg.sema))
    }
    // 计数器小于0 报错
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 如果等待为0或者计数器大于0 意味着没有等待或者还有读锁 不需要唤醒goroutine则返回 add操作完毕
    if v > 0 || w == 0 {
        return
    }
    
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 
    // 唤醒所有等待的线程
    for ; w != 0; w-- {
        runtime_Semrelease(&wg.sema, false)
    }
}

// Done函数 调用了Add函数传入-1 相当于锁的数量减1
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

func (wg *WaitGroup) Wait() {
    // 获取waitGroup的状态码
    statep := wg.state()
    if race.Enabled {
        _ = *statep // trigger nil deref early
        race.Disable()
    }
    // 循环
    for {
        // 调用load获取状态
        state := atomic.LoadUint64(statep)
        // 获取计数器数值
        v := int32(state >> 32)
        // 获取等待数量
        w := uint32(state)
        
        if v == 0 {
            // Counter is 0, no need to wait.
            if race.Enabled {
                race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
        // 添加等待数量 如果cas失败则重新获取状态 避免计数有错
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            if race.Enabled && w == 0 {
                // Wait must be synchronized with the first Add.
                // Need to model this is as a write to race with the read in Add.
                // As a consequence, can do the write only for the first waiter,
                // otherwise concurrent Waits will race with each other.
                race.Write(unsafe.Pointer(&wg.sema))
            }
            // 阻塞goroutine 等待唤醒
            runtime_Semacquire(&wg.sema)
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            if race.Enabled {
                race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
    }
}

相关推荐