单元测试

39.1 测试基础

Go 语言内置了轻量级的测试框架,通过 testing 包和 go test 命令提供支持。

测试文件命名规则

  • 测试文件以 _test.go 结尾
  • 测试函数以 Test 开头
  • 测试函数签名为 func TestXxx(t *testing.T)

第一个测试

// calculator.go
package calculator

func Add(a, b int) int {
    return a + b
}

func Subtract(a, b int) int {
    return a - b
}
// calculator_test.go
package calculator

import "testing"

func TestAdd(t *testing.T) {
    result := Add(2, 3)
    if result != 5 {
        t.Errorf("Add(2, 3) = %d; want 5", result)
    }
}

func TestSubtract(t *testing.T) {
    result := Subtract(5, 3)
    if result != 2 {
        t.Errorf("Subtract(5, 3) = %d; want 2", result)
    }
}

运行测试

# 运行当前包的测试
go test

# 运行所有包的测试
go test ./...

# 显示详细输出
go test -v

# 运行特定测试
go test -run TestAdd

# 运行匹配模式的测试
go test -run "TestAdd|TestSubtract"

39.2 测试组织

表格驱动测试

package calculator

import "testing"

func TestAddTableDriven(t *testing.T) {
    tests := []struct {
        name     string
        a, b     int
        expected int
    }{
        {"正数相加", 2, 3, 5},
        {"负数相加", -2, -3, -5},
        {"正负相加", 2, -3, -1},
        {"零值相加", 0, 0, 0},
    }

    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            result := Add(tt.a, tt.b)
            if result != tt.expected {
                t.Errorf("Add(%d, %d) = %d; want %d", tt.a, tt.b, result, tt.expected)
            }
        })
    }
}

子测试

package calculator

import "testing"

func TestMath(t *testing.T) {
    t.Run("Add", func(t *testing.T) {
        if Add(2, 3) != 5 {
            t.Error("Add failed")
        }
    })

    t.Run("Subtract", func(t *testing.T) {
        if Subtract(5, 3) != 2 {
            t.Error("Subtract failed")
        }
    })
}

测试辅助函数

package calculator

import "testing"

func TestDivide(t *testing.T) {
    result, err := Divide(10, 2)
    if err != nil {
        t.Fatalf("Divide returned error: %v", err)
    }
    if result != 5 {
        t.Errorf("Divide(10, 2) = %d; want 5", result)
    }
}

func TestDivideByZero(t *testing.T) {
    _, err := Divide(10, 0)
    if err == nil {
        t.Error("Divide by zero should return error")
    }
}

测试辅助方法

package main

import "testing"

func TestHelper(t *testing.T) {
    // t.Helper() 标记为辅助函数,错误报告时会跳过
    assertEqual := func(t testing.TB, got, want int) {
        t.Helper()
        if got != want {
            t.Errorf("got %d, want %d", got, want)
        }
    }

    assertEqual(t, Add(2, 3), 5)
}

39.3 测试覆盖率

生成覆盖率报告

# 查看覆盖率
go test -cover

# 生成覆盖率详情
go test -coverprofile=coverage.out

# 查看覆盖率详情
go tool cover -func=coverage.out

# 在浏览器中查看
go tool cover -html=coverage.out

覆盖率示例

// math.go
package math

func Abs(x int) int {
    if x < 0 {
        return -x
    }
    return x
}

func Max(a, b int) int {
    if a > b {
        return a
    }
    return b
}
// math_test.go
package math

import "testing"

func TestAbs(t *testing.T) {
    tests := []struct {
        input    int
        expected int
    }{
        {1, 1},
        {-1, 1},
        {0, 0},
    }

    for _, tt := range tests {
        result := Abs(tt.input)
        if result != tt.expected {
            t.Errorf("Abs(%d) = %d; want %d", tt.input, result, tt.expected)
        }
    }
}

39.4 基准测试

基本基准测试

package calculator

import "testing"

func BenchmarkAdd(b *testing.B) {
    for i := 0; i < b.N; i++ {
        Add(2, 3)
    }
}

运行基准测试

# 运行基准测试
go test -bench=.

# 运行特定基准测试
go test -bench=BenchmarkAdd

# 指定运行时间
go test -bench=. -benchtime=5s

# 显示内存分配
go test -bench=. -benchmem

基准测试示例

package strings

import (
    "strings"
    "testing"
)

func BenchmarkConcat(b *testing.B) {
    for i := 0; i < b.N; i++ {
        s := ""
        for j := 0; j < 100; j++ {
            s += "x"
        }
    }
}

func BenchmarkStringBuilder(b *testing.B) {
    for i := 0; i < b.N; i++ {
        var builder strings.Builder
        for j := 0; j < 100; j++ {
            builder.WriteString("x")
        }
        builder.String()
    }
}

并行基准测试

package calculator

import "testing"

func BenchmarkAddParallel(b *testing.B) {
    b.RunParallel(func(pb *testing.PB) {
        for pb.Next() {
            Add(2, 3)
        }
    })
}

39.5 Example 测试

基本示例

package calculator

import "fmt"

func ExampleAdd() {
    result := Add(2, 3)
    fmt.Println(result)
    // Output: 5
}

func ExampleSubtract() {
    result := Subtract(5, 3)
    fmt.Println(result)
    // Output: 2
}

无序输出

func ExamplePrintUnordered() {
    fmt.Println("hello")
    fmt.Println("world")
    // Unordered output:
    // world
    // hello
}

运行示例测试

# 运行示例测试
go test -run Example

# 验证示例输出
go test -v

39.6 Mock 和 Stub

接口 Mock

// repository.go
package user

type Repository interface {
    FindByID(id int) (*User, error)
    Save(user *User) error
}

type User struct {
    ID   int
    Name string
}

type Service struct {
    repo Repository
}

func (s *Service) GetUser(id int) (*User, error) {
    return s.repo.FindByID(id)
}
// repository_test.go
package user

import (
    "errors"
    "testing"
)

type MockRepository struct {
    users map[int]*User
}

func (m *MockRepository) FindByID(id int) (*User, error) {
    if user, ok := m.users[id]; ok {
        return user, nil
    }
    return nil, errors.New("user not found")
}

func (m *MockRepository) Save(user *User) error {
    m.users[user.ID] = user
    return nil
}

func TestService_GetUser(t *testing.T) {
    mockRepo := &MockRepository{
        users: map[int]*User{
            1: {ID: 1, Name: "Alice"},
        },
    }

    service := &Service{repo: mockRepo}

    user, err := service.GetUser(1)
    if err != nil {
        t.Fatalf("GetUser failed: %v", err)
    }

    if user.Name != "Alice" {
        t.Errorf("expected Alice, got %s", user.Name)
    }
}

39.7 测试 Main

TestMain 函数

package main

import (
    "os"
    "testing"
)

func TestMain(m *testing.M) {
    // 测试前设置
    println("测试开始")

    // 运行测试
    code := m.Run()

    // 测试后清理
    println("测试结束")

    os.Exit(code)
}

数据库测试设置

package main

import (
    "database/sql"
    "os"
    "testing"
)

var db *sql.DB

func TestMain(m *testing.M) {
    // 设置测试数据库
    var err error
    db, err = sql.Open("sqlite3", ":memory:")
    if err != nil {
        panic(err)
    }

    // 创建表
    createTables()

    // 运行测试
    code := m.Run()

    // 清理
    db.Close()

    os.Exit(code)
}

func createTables() {
    db.Exec(`CREATE TABLE users (
        id INTEGER PRIMARY KEY,
        name TEXT
    )`)
}

func TestInsertUser(t *testing.T) {
    _, err := db.Exec("INSERT INTO users (name) VALUES (?)", "Alice")
    if err != nil {
        t.Fatal(err)
    }
}

39.8 跳过测试

条件跳过

package main

import (
    "runtime"
    "testing"
)

func TestLinuxOnly(t *testing.T) {
    if runtime.GOOS != "linux" {
        t.Skip("此测试只在 Linux 上运行")
    }
    // Linux 特定测试
}

func TestSlowOperation(t *testing.T) {
    if testing.Short() {
        t.Skip("跳过慢速测试")
    }
    // 耗时测试
}

运行短测试

# 只运行短测试
go test -short

39.9 测试工具库

testify 库

go get github.com/stretchr/testify
package main

import (
    "testing"

    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
)

func TestWithTestify(t *testing.T) {
    // assert: 失败后继续执行
    assert.Equal(t, 5, Add(2, 3))
    assert.NotNil(t, "hello")

    // require: 失败后立即停止
    require.Equal(t, 5, Add(2, 3))
}

func TestWithSuite(t *testing.T) {
    // 使用测试套件
    suite.Run(t, new(MyTestSuite))
}

type MyTestSuite struct {
    suite.Suite
}

func (s *MyTestSuite) SetupTest() {
    // 每个测试前执行
}

func (s *MyTestSuite) TearDownTest() {
    // 每个测试后执行
}

func (s *MyTestSuite) TestSomething() {
    s.Equal(5, Add(2, 3))
}

39.10 小结

本章详细介绍了 Go 语言的单元测试:

  1. 测试基础:测试文件命名、测试函数编写
  2. 测试组织:表格驱动测试、子测试
  3. 测试覆盖率:生成和分析覆盖率报告
  4. 基准测试:性能测试和内存分析
  5. Example 测试:文档示例测试
  6. Mock 和 Stub:接口模拟和依赖注入
  7. TestMain:全局测试设置和清理
  8. 跳过测试:条件跳过测试
  9. 测试工具:使用 testify 等第三方库

单元测试是保证代码质量的重要手段,掌握测试能让你写出更可靠的代码。在下一章中,我们将学习性能调优。