并发同步

25.1 为什么需要并发同步

在并发编程中,多个 Goroutine 可能同时访问和修改共享资源,这会导致数据竞争(Data Race)问题。如果不进行适当的同步,程序可能会产生不可预测的结果。

数据竞争示例

package main

import (
    "fmt"
    "sync"
)

func main() {
    var counter int
    var wg sync.WaitGroup

    // 启动 1000 个 Goroutine 同时增加计数器
    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            counter++ // 多个 Goroutine 同时访问,存在数据竞争
        }()
    }

    wg.Wait()
    fmt.Println("计数器值:", counter) // 期望 1000,实际可能小于 1000
}

运行这个程序多次,你会发现结果往往不是期望的 1000。这是因为 counter++ 不是原子操作,它包含三个步骤:

  1. 读取 counter 的值
  2. 将值加 1
  3. 将新值写回 counter

当多个 Goroutine 同时执行这些步骤时,可能会出现覆盖的情况。

使用 go run -race 检测数据竞争

Go 提供了竞态检测工具,可以帮助发现数据竞争:

go run -race main.go

25.2 互斥锁 (Mutex)

互斥锁是最基本的同步原语,它保证同一时间只有一个 Goroutine 可以访问共享资源。

基本使用

package main

import (
    "fmt"
    "sync"
)

func main() {
    var counter int
    var mu sync.Mutex // 互斥锁
    var wg sync.WaitGroup

    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            mu.Lock()   // 加锁
            counter++   // 安全地修改共享资源
            mu.Unlock() // 解锁
        }()
    }

    wg.Wait()
    fmt.Println("计数器值:", counter) // 总是输出 1000
}

使用 defer 确保解锁

package main

import (
    "fmt"
    "sync"
)

func main() {
    var mu sync.Mutex
    data := make(map[string]int)

    // 安全地写入 map
    mu.Lock()
    defer mu.Unlock() // 使用 defer 确保解锁,即使发生 panic 也会执行

    data["key"] = 100
    fmt.Println("写入成功")
}

封装线程安全的计数器

package main

import (
    "fmt"
    "sync"
)

// SafeCounter 是一个线程安全的计数器
type SafeCounter struct {
    mu    sync.Mutex
    value int
}

// Increment 增加计数
func (c *SafeCounter) Increment() {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.value++
}

// Decrement 减少计数
func (c *SafeCounter) Decrement() {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.value--
}

// Value 获取当前值
func (c *SafeCounter) Value() int {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.value
}

func main() {
    counter := &SafeCounter{}
    var wg sync.WaitGroup

    // 启动 100 个 Goroutine 增加计数
    for i := 0; i < 100; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            counter.Increment()
        }()
    }

    // 启动 50 个 Goroutine 减少计数
    for i := 0; i < 50; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            counter.Decrement()
        }()
    }

    wg.Wait()
    fmt.Printf("最终计数: %d\n", counter.Value()) // 输出: 50
}

死锁问题

使用互斥锁时要注意避免死锁:

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var mu sync.Mutex

    // 死锁示例:重复加锁
    mu.Lock()
    fmt.Println("第一次加锁")

    // 尝试再次加锁会永久阻塞
    // mu.Lock() // 这行会导致死锁!
    // fmt.Println("第二次加锁")

    mu.Unlock()
    fmt.Println("解锁成功")

    // 正确做法:先解锁再加锁
    mu.Lock()
    fmt.Println("再次加锁成功")
    mu.Unlock()
}

25.3 读写锁 (RWMutex)

当读操作远多于写操作时,使用读写锁可以提高性能。读写锁允许多个 Goroutine 同时读取,但写操作是互斥的。

读写锁的规则

操作是否允许说明
读 + 读✅ 允许多个读操作可以并发执行
读 + 写❌ 禁止读写互斥
写 + 写❌ 禁止写写互斥

基本使用

package main

import (
    "fmt"
    "sync"
    "time"
)

type Cache struct {
    mu    sync.RWMutex
    data  map[string]string
}

func NewCache() *Cache {
    return &Cache{
        data: make(map[string]string),
    }
}

// Get 读取数据(使用读锁)
func (c *Cache) Get(key string) (string, bool) {
    c.mu.RLock()         // 读锁
    defer c.mu.RUnlock() // 解读锁
    value, ok := c.data[key]
    return value, ok
}

// Set 写入数据(使用写锁)
func (c *Cache) Set(key, value string) {
    c.mu.Lock()         // 写锁
    defer c.mu.Unlock() // 解写锁
    c.data[key] = value
}

func main() {
    cache := NewCache()
    var wg sync.WaitGroup

    // 启动多个读协程
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            for j := 0; j < 3; j++ {
                if value, ok := cache.Get("name"); ok {
                    fmt.Printf("读者 %d: %s\n", id, value)
                } else {
                    fmt.Printf("读者 %d: 未找到数据\n", id)
                }
                time.Sleep(100 * time.Millisecond)
            }
        }(i)
    }

    // 启动写协程
    wg.Add(1)
    go func() {
        defer wg.Done()
        names := []string{"Alice", "Bob", "Charlie"}
        for _, name := range names {
            cache.Set("name", name)
            fmt.Printf("写入: %s\n", name)
            time.Sleep(200 * time.Millisecond)
        }
    }()

    wg.Wait()
}

Mutex vs RWMutex 性能对比

package main

import (
    "fmt"
    "sync"
    "time"
)

type Counter struct {
    mu    sync.Mutex
    value int
}

func (c *Counter) Read() int {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.value
}

func (c *Counter) Write(n int) {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.value = n
}

type RWCounter struct {
    mu    sync.RWMutex
    value int
}

func (c *RWCounter) Read() int {
    c.mu.RLock()
    defer c.mu.RUnlock()
    return c.value
}

func (c *RWCounter) Write(n int) {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.value = n
}

func main() {
    // 测试 Mutex
    counter := &Counter{}
    start := time.Now()
    var wg sync.WaitGroup

    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            for j := 0; j < 100; j++ {
                counter.Read()
            }
        }()
    }
    wg.Wait()
    fmt.Printf("Mutex 读耗时: %v\n", time.Since(start))

    // 测试 RWMutex
    rwCounter := &RWCounter{}
    start = time.Now()

    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            for j := 0; j < 100; j++ {
                rwCounter.Read()
            }
        }()
    }
    wg.Wait()
    fmt.Printf("RWMutex 读耗时: %v\n", time.Since(start))
}

25.4 WaitGroup

WaitGroup 用于等待一组 Goroutine 完成。它有三个方法:

  • Add(n):增加等待计数
  • Done():减少等待计数(等同于 Add(-1))
  • Wait():阻塞直到计数为 0

基本使用

package main

import (
    "fmt"
    "sync"
    "time"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done() // 完成时通知 WaitGroup

    fmt.Printf("Worker %d 开始工作\n", id)
    time.Sleep(time.Duration(id) * 100 * time.Millisecond)
    fmt.Printf("Worker %d 完成工作\n", id)
}

func main() {
    var wg sync.WaitGroup

    // 启动 5 个 worker
    for i := 1; i <= 5; i++ {
        wg.Add(1) // 在启动 Goroutine 前增加计数
        go worker(i, &wg)
    }

    fmt.Println("等待所有 worker 完成...")
    wg.Wait() // 阻塞直到所有 worker 完成
    fmt.Println("所有 worker 已完成")
}

常见错误示例

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup

    // 错误示例:在 Goroutine 内部调用 Add
    for i := 0; i < 5; i++ {
        go func() {
            wg.Add(1) // 错误!可能在 Wait 之后才执行
            defer wg.Done()
            fmt.Println("工作完成")
        }()
    }
    wg.Wait() // 可能在 Add 之前就返回了
    fmt.Println("结束")
}

正确做法:

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup

    for i := 0; i < 5; i++ {
        wg.Add(1) // 正确:在启动 Goroutine 前调用 Add
        go func() {
            defer wg.Done()
            fmt.Println("工作完成")
        }()
    }
    wg.Wait()
    fmt.Println("结束")
}

嵌套 WaitGroup

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup

    // 外层任务
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(taskID int) {
            defer wg.Done()
            fmt.Printf("外层任务 %d 开始\n", taskID)

            // 内层任务
            var innerWg sync.WaitGroup
            for j := 0; j < 2; j++ {
                innerWg.Add(1)
                go func(subTaskID int) {
                    defer innerWg.Done()
                    time.Sleep(100 * time.Millisecond)
                    fmt.Printf("  内层任务 %d-%d 完成\n", taskID, subTaskID)
                }(j)
            }
            innerWg.Wait()
            fmt.Printf("外层任务 %d 完成\n", taskID)
        }(i)
    }

    wg.Wait()
    fmt.Println("所有任务完成")
}

25.5 Once

Once 确保某个操作只执行一次,常用于单例模式和初始化操作。

基本使用

package main

import (
    "fmt"
    "sync"
)

var (
    instance *Database
    once     sync.Once
)

type Database struct {
    name string
}

func GetDatabase() *Database {
    once.Do(func() {
        fmt.Println("初始化数据库连接...")
        instance = &Database{name: "MySQL"}
    })
    return instance
}

func main() {
    var wg sync.WaitGroup

    // 多个 Goroutine 同时调用 GetDatabase
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            db := GetDatabase()
            fmt.Printf("协程 %d 获取数据库实例: %s\n", id, db.name)
        }(i)
    }

    wg.Wait()
}

实现单例模式

package main

import (
    "fmt"
    "sync"
)

type Singleton struct {
    data string
}

var (
    singletonInstance *Singleton
    singletonOnce     sync.Once
)

func GetSingleton() *Singleton {
    singletonOnce.Do(func() {
        fmt.Println("创建单例实例")
        singletonInstance = &Singleton{data: "我是单例"}
    })
    return singletonInstance
}

func main() {
    // 多次调用只会创建一次实例
    s1 := GetSingleton()
    s2 := GetSingleton()
    s3 := GetSingleton()

    fmt.Printf("s1: %p\n", s1)
    fmt.Printf("s2: %p\n", s2)
    fmt.Printf("s3: %p\n", s3)
    fmt.Printf("是否相同: %v\n", s1 == s2 && s2 == s3)
}

25.6 Cond

Cond(条件变量)用于在特定条件下等待和通知 Goroutine。

基本方法

  • Wait():等待条件满足,会自动解锁并阻塞
  • Signal():唤醒一个等待的 Goroutine
  • Broadcast():唤醒所有等待的 Goroutine

生产者-消费者模式

package main

import (
    "fmt"
    "sync"
    "time"
)

type Queue struct {
    items []int
    mu    sync.Mutex
    cond  *sync.Cond
}

func NewQueue() *Queue {
    q := &Queue{
        items: make([]int, 0),
    }
    q.cond = sync.NewCond(&q.mu)
    return q
}

func (q *Queue) Put(item int) {
    q.mu.Lock()
    defer q.mu.Unlock()

    q.items = append(q.items, item)
    fmt.Printf("生产: %d (队列长度: %d)\n", item, len(q.items))
    q.cond.Signal() // 通知消费者
}

func (q *Queue) Get() int {
    q.mu.Lock()
    defer q.mu.Unlock()

    // 等待队列不为空
    for len(q.items) == 0 {
        fmt.Println("队列为空,等待...")
        q.cond.Wait()
    }

    item := q.items[0]
    q.items = q.items[1:]
    fmt.Printf("消费: %d (队列长度: %d)\n", item, len(q.items))
    return item
}

func main() {
    queue := NewQueue()
    var wg sync.WaitGroup

    // 启动消费者
    for i := 0; i < 2; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            for j := 0; j < 3; j++ {
                item := queue.Get()
                _ = item
                time.Sleep(100 * time.Millisecond)
            }
        }(i)
    }

    // 启动生产者
    go func() {
        for i := 1; i <= 6; i++ {
            queue.Put(i)
            time.Sleep(50 * time.Millisecond)
        }
    }()

    wg.Wait()
    fmt.Println("所有任务完成")
}

使用 Broadcast 唤醒所有等待者

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var mu sync.Mutex
    cond := sync.NewCond(&mu)
    ready := false

    var wg sync.WaitGroup

    // 启动多个等待者
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            mu.Lock()
            for !ready {
                fmt.Printf("等待者 %d 正在等待...\n", id)
                cond.Wait()
            }
            fmt.Printf("等待者 %d 被唤醒\n", id)
            mu.Unlock()
        }(i)
    }

    // 等待所有等待者就位
    time.Sleep(1 * time.Second)

    // 改变条件并广播
    mu.Lock()
    ready = true
    cond.Broadcast() // 唤醒所有等待者
    mu.Unlock()

    wg.Wait()
    fmt.Println("所有等待者已处理完毕")
}

25.7 原子操作

Go 的 sync/atomic 包提供了底层的原子操作,对于简单的计数器等场景,原子操作比互斥锁更高效。

基本原子操作

package main

import (
    "fmt"
    "sync"
    "sync/atomic"
)

func main() {
    var counter int64
    var wg sync.WaitGroup

    // 使用原子操作增加计数器
    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            atomic.AddInt64(&counter, 1) // 原子加
        }()
    }

    wg.Wait()
    fmt.Printf("计数器值: %d\n", atomic.LoadInt64(&counter))
}

常用原子操作函数

package main

import (
    "fmt"
    "sync/atomic"
)

func main() {
    var value int64 = 100

    // 加载
    fmt.Printf("当前值: %d\n", atomic.LoadInt64(&value))

    // 存储
    atomic.StoreInt64(&value, 200)
    fmt.Printf("存储后: %d\n", value)

    // 加法
    atomic.AddInt64(&value, 50)
    fmt.Printf("加法后: %d\n", value)

    // 比较并交换 (CAS)
    swapped := atomic.CompareAndSwapInt64(&value, 250, 300)
    fmt.Printf("CAS 结果: %v, 值: %d\n", swapped, value)

    // 交换
    old := atomic.SwapInt64(&value, 500)
    fmt.Printf("旧值: %d, 新值: %d\n", old, value)
}

实现无锁计数器

package main

import (
    "fmt"
    "sync"
    "sync/atomic"
)

type AtomicCounter struct {
    value int64
}

func (c *AtomicCounter) Increment() {
    atomic.AddInt64(&c.value, 1)
}

func (c *AtomicCounter) Decrement() {
    atomic.AddInt64(&c.value, -1)
}

func (c *AtomicCounter) Value() int64 {
    return atomic.LoadInt64(&c.value)
}

func main() {
    counter := &AtomicCounter{}
    var wg sync.WaitGroup

    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            counter.Increment()
        }()
    }

    wg.Wait()
    fmt.Printf("最终计数: %d\n", counter.Value())
}

原子值 (atomic.Value)

atomic.Value 可以原子地存储和加载任意类型的值:

package main

import (
    "fmt"
    "sync"
    "sync/atomic"
    "time"
)

type Config struct {
    Host    string
    Port    int
    Timeout time.Duration
}

func main() {
    var configValue atomic.Value

    // 初始化配置
    configValue.Store(&Config{
        Host:    "localhost",
        Port:    8080,
        Timeout: 30 * time.Second,
    })

    var wg sync.WaitGroup

    // 读取配置的 Goroutine
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            for j := 0; j < 3; j++ {
                cfg := configValue.Load().(*Config)
                fmt.Printf("读者 %d: %s:%d\n", id, cfg.Host, cfg.Port)
                time.Sleep(100 * time.Millisecond)
            }
        }(i)
    }

    // 更新配置的 Goroutine
    go func() {
        time.Sleep(150 * time.Millisecond)
        configValue.Store(&Config{
            Host:    "production.example.com",
            Port:    443,
            Timeout: 60 * time.Second,
        })
        fmt.Println("配置已更新")
    }()

    wg.Wait()
}

25.8 实战案例:线程安全的 Map

Go 的内置 map 不是线程安全的,我们可以使用 sync.RWMutex 封装一个线程安全的 Map:

package main

import (
    "fmt"
    "sync"
)

type SafeMap struct {
    mu   sync.RWMutex
    data map[string]interface{}
}

func NewSafeMap() *SafeMap {
    return &SafeMap{
        data: make(map[string]interface{}),
    }
}

func (m *SafeMap) Set(key string, value interface{}) {
    m.mu.Lock()
    defer m.mu.Unlock()
    m.data[key] = value
}

func (m *SafeMap) Get(key string) (interface{}, bool) {
    m.mu.RLock()
    defer m.mu.RUnlock()
    value, ok := m.data[key]
    return value, ok
}

func (m *SafeMap) Delete(key string) {
    m.mu.Lock()
    defer m.mu.Unlock()
    delete(m.data, key)
}

func (m *SafeMap) Keys() []string {
    m.mu.RLock()
    defer m.mu.RUnlock()
    keys := make([]string, 0, len(m.data))
    for k := range m.data {
        keys = append(keys, k)
    }
    return keys
}

func (m *SafeMap) Len() int {
    m.mu.RLock()
    defer m.mu.RUnlock()
    return len(m.data)
}

func main() {
    sm := NewSafeMap()
    var wg sync.WaitGroup

    // 并发写入
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(n int) {
            defer wg.Done()
            key := fmt.Sprintf("key%d", n)
            sm.Set(key, n*10)
            fmt.Printf("写入 %s: %d\n", key, n*10)
        }(i)
    }

    // 并发读取
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(n int) {
            defer wg.Done()
            key := fmt.Sprintf("key%d", n)
            if value, ok := sm.Get(key); ok {
                fmt.Printf("读取 %s: %v\n", key, value)
            }
        }(i)
    }

    wg.Wait()
    fmt.Printf("Map 大小: %d\n", sm.Len())
    fmt.Printf("所有键: %v\n", sm.Keys())
}

25.9 sync.Map

Go 1.9 引入了 sync.Map,它是线程安全的 Map,适用于读多写少的场景。

基本使用

package main

import (
    "fmt"
    "sync"
)

func main() {
    var m sync.Map

    // 存储
    m.Store("name", "Alice")
    m.Store("age", 30)
    m.Store("city", "Beijing")

    // 加载
    if value, ok := m.Load("name"); ok {
        fmt.Printf("name: %v\n", value)
    }

    // 加载或存储
    value, loaded := m.LoadOrStore("name", "Bob")
    fmt.Printf("name: %v, 是否已存在: %v\n", value, loaded)

    // 遍历
    m.Range(func(key, value interface{}) bool {
        fmt.Printf("%v: %v\n", key, value)
        return true
    })

    // 删除
    m.Delete("city")

    // 加载并删除
    value, loaded = m.LoadAndDelete("age")
    fmt.Printf("删除的值: %v, 是否存在: %v\n", value, loaded)
}

sync.Map vs 加锁的 map

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    // 测试 sync.Map
    var syncMap sync.Map
    start := time.Now()
    var wg sync.WaitGroup

    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func(n int) {
            defer wg.Done()
            for j := 0; j < 100; j++ {
                syncMap.Store(fmt.Sprintf("key%d", n), n)
                syncMap.Load(fmt.Sprintf("key%d", n))
            }
        }(i)
    }
    wg.Wait()
    fmt.Printf("sync.Map 耗时: %v\n", time.Since(start))

    // 测试加锁 map
    lockedMap := NewSafeMap()
    start = time.Now()

    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func(n int) {
            defer wg.Done()
            for j := 0; j < 100; j++ {
                lockedMap.Set(fmt.Sprintf("key%d", n), n)
                lockedMap.Get(fmt.Sprintf("key%d", n))
            }
        }(i)
    }
    wg.Wait()
    fmt.Printf("加锁 Map 耗时: %v\n", time.Since(start))
}

type SafeMap struct {
    mu   sync.RWMutex
    data map[string]interface{}
}

func NewSafeMap() *SafeMap {
    return &SafeMap{
        data: make(map[string]interface{}),
    }
}

func (m *SafeMap) Set(key string, value interface{}) {
    m.mu.Lock()
    defer m.mu.Unlock()
    m.data[key] = value
}

func (m *SafeMap) Get(key string) (interface{}, bool) {
    m.mu.RLock()
    defer m.mu.RUnlock()
    value, ok := m.data[key]
    return value, ok
}

25.10 小结

本章详细介绍了 Go 语言中的并发同步机制:

同步原语用途适用场景
Mutex互斥锁保护共享资源,读写都需要互斥
RWMutex读写锁读多写少的场景
WaitGroup等待组等待一组 Goroutine 完成
Once单次执行初始化、单例模式
Cond条件变量条件等待和通知
atomic原子操作简单的计数器、标志位
sync.Map线程安全 Map读多写少的并发 Map 场景

选择合适的同步原语:

  1. 简单计数器:使用 atomic
  2. 保护共享资源:使用 Mutex
  3. 读多写少:使用 RWMutex
  4. 等待任务完成:使用 WaitGroup
  5. 只执行一次:使用 Once
  6. 条件等待:使用 Cond
  7. 并发 Map:使用 sync.Map