跳至主要内容

[Golang] waitgroup 與 mutex

此篇為各筆記之整理,非原創內容,資料來源可見文後參考資料。

WaitGroup:等待許多任務執行完後再繼續

WaitGroup 的用法適合用在需要將單一任務拆成許多次任務,待所有任務完成後才繼續執行的情境。

提示

這種做法適合用在單純等待任務完成,而不需要從 goroutine 中取得所需資料的情況,如果會需要從 goroutine 中返回資料,那麼比較好的做法是使用 channel。

使用 sync.WaitGroup package 提供的:

  • var wg sync.WaitGroup 可以建立 waitgroup,預設 counter 是 0
  • wg.Add(delta int) 增加要等待的次數(increment counter),也可以是負值,通常就是要等待完成的 goroutine 數目
  • wg.Done() 會把要等待的次數減 1(decrement counter),可以使用 defer wg.Done()
  • wg.Wait() 會阻塞在這,直到 counter 歸零,也就是所有 WaitGroup 都呼叫過 done 後才往後執行
// 程式來源:https://medium.com/rungo/anatomy-of-channels-in-go-concurrency-in-go-1ec336086adb
var start time.Time

func init() {
start = time.Now()
}

func service(wg *sync.WaitGroup, instance int) {
time.Sleep(time.Duration(instance) * 500 * time.Millisecond)
fmt.Println("Service called on instance", instance, time.Since(start))
wg.Done() // 4. 減少 counter
}

func main() {
fmt.Println("main() started ", time.Since(start))
var wg sync.WaitGroup // 1. 建立 waitgroup(empty struct)

for i := 1; i <= 3; i++ {
wg.Add(1) // 2. 增加 counter
go service(&wg, i) // 一共啟動了 3 個 goroutine
}

wg.Wait() // 3. blocking 直到 counter 為 0
fmt.Println("main() stopped ", time.Since(start))
}

這裡的 wg 需要把 pointer 傳進去 goroutine 中,如果不是傳 pointer 進去而是傳 value 的話,將沒辦法有效把 main goroutine 中的 waitGroup 的 counter 減 1。

// 程式碼修改自 Concurrency Patterns in Go: sync.WaitGroup @ https://www.calhoun.io/

func main() {
notify("Service-1", "Service-2", "Service-3")
}

func notifying(wg *sync.WaitGroup, s string) {
fmt.Printf("Starting to notifying %s...\n", s)
time.Sleep(time.Duration(rand.Intn(3)) * time.Second)
fmt.Printf("Finish notifying %s\n", s)
wg.Done() // 執行 done
}

func notify(services ...string) {

var wg sync.WaitGroup // 建立 WaitGroup

for _, service := range services {
wg.Add(1) // 添加 counter 的次數
go notifying(&wg, service)
}

wg.Wait() // block 在這,直到 counter 歸零後才繼續往後執行

fmt.Println("All service notified!")
}

如果我們需要使用到 goroutine 中回傳的資料,那個應該要使用 channel 而不是 waitGroup,例如:

// 程式碼修改自 Concurrency Patterns in Go: sync.WaitGroup @ https://www.calhoun.io/

func main() {
notify("Service-1", "Service-2", "Service-3")
}

func notifying(res chan string, s string) {
fmt.Printf("Starting to notifying %s...\n", s)
time.Sleep(time.Duration(rand.Intn(3)) * time.Second)
res <- fmt.Sprintf("Finish notifying %s", s)
}

func notify(services ...string) {
res := make(chan string)
var count int = 0

for _, service := range services {
count++
go notifying(res, service)
}

for i := 0; i < count; i++ {
fmt.Println(<-res)
}

fmt.Println("All service notified!")
}

Worker Pool

worker pool 指的是有許多的 goroutines 同步的進行一個工作。要建立 worker pool,會先建立許多的 worker goroutine,這些 goroutine 中會:

  • 進行相同的 job
  • 有兩個 channel,一個用來接受任務(task channel),一個用來回傳結果(result channel)
  • 都等待 task channel 傳來要進行的 tasks
  • 一但收到 tasks 就可以做事並透過 result channel 回傳結果
// 程式來源:https://medium.com/rungo/anatomy-of-channels-in-go-concurrency-in-go-1ec336086adb
// STEP 3:在 worker goroutines 中會做相同的工作
// tasks is receive only channel
// results is send only channel
func sqrWorker(tasks <-chan int, results chan<- int, instance int) {
// 一旦收到 tasks channel 傳來資料,就可以動工並回傳結果
for num := range tasks {
time.Sleep(500 * time.Millisecond) // 模擬會阻塞的任務
fmt.Printf("[worker %v] Sending result of task %v \n", instance, num)
results <- num * num
}
}

func main() {
fmt.Println("[main] main() started")

// STEP 1:建立兩個 channel,一個用來傳送 tasks,一個用來接收 results
tasks := make(chan int, 10)
results := make(chan int, 10)

// STEP 2 啟動三個不同的 worker goroutines
for i := 1; i <= 3; i++ {
go sqrWorker(tasks, results, i)
}

// STEP 4:發送 5 個不同的任務
for i := 1; i <= 5; i++ {
tasks <- i // non-blocking(因為 buffered channel 的 capacity 是 10)
}

fmt.Println("[main] Wrote 5 tasks")

// STEP 5:發送完任務後把 channel 關閉(非必要,但可減少 bug)
close(tasks)

// STEP 6:等待各個 worker 從 result channel 回傳結果
for i := 1; i <= 5; i++ {
result := <-results // blocking(因為 buffer 是空的)
fmt.Println("[main] Result", i, ":", result)
}

fmt.Println("[main] main() stopped")
}

輸出的結果可能是:

[main] main() started
[main] Wrote 5 tasks
[worker 3] Sending result of task 1
[main] Result 1 : 1
[worker 2] Sending result of task 3
[worker 1] Sending result of task 2
[main] Result 2 : 9
[main] Result 3 : 4
[worker 2] Sending result of task 5
[worker 3] Sending result of task 4
[main] Result 4 : 25
[main] Result 5 : 16
[main] main() stopped

從上面的例子中,可以看到當所有 worker 都剛好 blocking 的時候,控制權就會交回 main goroutine,這時候就可以立即看到計算好的結果。

WorkerGroup 搭配 WaitGroup

但有些時候,我們希望所有的 tasks 都執行完後才讓 main goroutine 繼續往後做,這時候可以搭配 WaitGroup 使用:

func sqrWorker(wg *sync.WaitGroup, tasks <-chan int, results chan<- int, instance int) {
defer wg.Done()

// 一旦收到 tasks channel 傳來資料,就可以動工並回傳結果
for num := range tasks {
time.Sleep(500 * time.Millisecond) // 模擬會阻塞的任務
fmt.Printf("[worker %v] Sending result of task %v \n", instance, num)
results <- num * num
}
}

func main() {
fmt.Println("[main] main() started")

var wg sync.WaitGroup

tasks := make(chan int, 10)
results := make(chan int, 10)

for i := 1; i <= 3; i++ {
wg.Add(1)
go sqrWorker(&wg, tasks, results, i)
}

for i := 1; i <= 5; i++ {
tasks <- i // non-blocking(因為 buffered channel 的 capacity 是 10)
}

fmt.Println("[main] Wrote 5 tasks")

close(tasks) // 有用 waitGroup 的話這個 close 不能省略

// 直到所有的 worker goroutine 把所有 tasks 都做完後才繼續往後
wg.Wait()

for i := 1; i <= 5; i++ {
result := <-results // blocking(因為 buffer 是空的)
fmt.Println("[main] Result", i, ":", result)
}

fmt.Println("[main] main() stopped")
}

這時會等到所有的 worker 都完工下班後,才開始輸出計算好的結果。搭配 WaitGroup 的好處是可以等到所有 worker 都完工後還讓程式繼續,但相對的會需要花更長的時間在等待所有人完工:

[main] main() started
[main] Wrote 5 tasks
[worker 2] Sending result of task 3
[worker 1] Sending result of task 2
[worker 3] Sending result of task 1
[worker 1] Sending result of task 5
[worker 2] Sending result of task 4
[main] Result 1 : 9
[main] Result 2 : 4
[main] Result 3 : 1
[main] Result 4 : 25
[main] Result 5 : 16
[main] main() stopped

Mutex and Race Condition

在 goroutines 中,由於有獨立的 stack,因此並不會在彼此之間共享資料(也就是在 scope 中的變數);然而在 heap 中的資料是會在不同 goroutine 之間共享的(也就是 global 的變數),在這種情況下,許多的 goroutine 會試著操作相同記憶體位址的資料,導致未預期的結果發生。

從下面的例子中可以看到,我們預期 i 的結果會是 1000,但是因為 race condition 的情況,最終的結果並不會是 1000:

// 程式來源:https://medium.com/rungo/anatomy-of-channels-in-go-concurrency-in-go-1ec336086adb
var i int // i == 0

func worker(wg *sync.WaitGroup) {
i = i + 1
wg.Done()
}

func main() {
var wg sync.WaitGroup

for i := 0; i < 1000; i++ {
wg.Add(1)
go worker(&wg)
}

wg.Wait()
fmt.Println(i) // value i should be 1000 but it did not
}

之所以會有這個情況產生,是因為多個 goroutine 在執行時,在為 i 賦值前(即,i = i + 1)拿到的是同一個值的 i,因此雖然跑了多次 goroutine,但對於 i 來說,值並沒有增加。

為了要避免多個 goroutine 同時取用到一個 heap 中的變數,第一原則是應該要盡可能避免在多個 goroutine 中使用共享的資源(變數)。

如果無法避免會需要操作共用的變數的話,則可以使用 Mutex(mutual exclusion),也就是說在一個時間內只有一個 goroutine(thread)可以對該變數進行操作,在對該變數進行操作前,會先把它「上鎖」,操作完後再進行「解鎖」的動作,當一個變數被上鎖的時候,其他的 goroutine 都不能對該變數進行讀取和寫入

mutex 是 map 型別的方法,被放在 sync package 中,使用 mutex.Lock() 可以上鎖,使用 mutex.Unlock() 可以解鎖:

var i int // i == 0

func worker(wg *sync.WaitGroup, m *sync.Mutex) {
m.Lock() // 上鎖
i = i + 1
m.Unlock() // 解鎖
wg.Done()
}

func main() {
var wg sync.WaitGroup
var m sync.Mutex

for i := 0; i < 1000; i++ {
wg.Add(1)
go worker(&wg, &m) // 把 mutex 的記憶體位址傳入
}

wg.Wait()
fmt.Println(i) // 在使用 mutex 對 heap 中的變數進行上鎖和解鎖後,即可以確保最終的值是 1000
}

⚠️ mutex 和 waitgroup 一樣,都是把「記憶體位址」傳入 goroutine 中使用。

如同前面所說,第一原則應該是要避免 race condition 的方法,也就是不要在 goroutine 中對共用的變數進行操作,在 go 的 CLI 中可以透過下面的指令檢測程式中是否有 race condition 的情況:

# 檢查程式執行時會不會碰到 race condition
$ go run -race .

# 檢查執行的測試會不會碰到 race condition
$ go test -race .

範例程式碼

範例 WaitGroup 搭配 Channel

⚠️ 邏輯上,可以單獨使用 channel 就好,不需要使用 WaitGroup。

func worker(wg *sync.WaitGroup, c chan<- int, i int) {
fmt.Println("[worker] start i:", i)
time.Sleep(time.Second * 1)
defer wg.Done()
c <- i
fmt.Println("[worker] finish i:", i)
}

func main() {
numOfFacilities := 6
var wg sync.WaitGroup

c := make(chan int, numOfFacilities)

for i := 0; i < numOfFacilities; i++ {
fmt.Println("[main] add i: ", i)
wg.Add(1)
go worker(&wg, c, i)
}

wg.Wait()

var numbers []int
for i := 0; i < numOfFacilities; i++ {
numbers = append(numbers, <-c)
}
fmt.Println("[main] ---all finish---", numbers)

defer close(c)
}

範例

func controller(c chan string, wg *sync.WaitGroup) {
fmt.Println("controller() start and block")
wg.Wait()
fmt.Println("controller() unblock and close channel")
close(c)
fmt.Println("controller() end")
}

func printString(s string, c chan string, wg *sync.WaitGroup) {
fmt.Println(s)
wg.Done()
c <- "Done printing: " + s
}

func main() {
fmt.Println("main() start")
c := make(chan string)
var wg sync.WaitGroup
for i := 1; i <= 4; i++ {
go printString("Hello ~ "+strconv.Itoa(i), c, &wg)
wg.Add(1)
}

go controller(c, &wg)

for message := range c {
fmt.Println(message)
}

fmt.Println("main() end")
}