Go 语言内置了轻量级的测试框架,通过 testing 包和 go test 命令提供支持。
_test.go 结尾Test 开头func TestXxx(t *testing.T)// calculator.go
package calculator
func Add(a, b int) int {
return a + b
}
func Subtract(a, b int) int {
return a - b
}
// calculator_test.go
package calculator
import "testing"
func TestAdd(t *testing.T) {
result := Add(2, 3)
if result != 5 {
t.Errorf("Add(2, 3) = %d; want 5", result)
}
}
func TestSubtract(t *testing.T) {
result := Subtract(5, 3)
if result != 2 {
t.Errorf("Subtract(5, 3) = %d; want 2", result)
}
}
# 运行当前包的测试
go test
# 运行所有包的测试
go test ./...
# 显示详细输出
go test -v
# 运行特定测试
go test -run TestAdd
# 运行匹配模式的测试
go test -run "TestAdd|TestSubtract"
package calculator
import "testing"
func TestAddTableDriven(t *testing.T) {
tests := []struct {
name string
a, b int
expected int
}{
{"正数相加", 2, 3, 5},
{"负数相加", -2, -3, -5},
{"正负相加", 2, -3, -1},
{"零值相加", 0, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := Add(tt.a, tt.b)
if result != tt.expected {
t.Errorf("Add(%d, %d) = %d; want %d", tt.a, tt.b, result, tt.expected)
}
})
}
}
package calculator
import "testing"
func TestMath(t *testing.T) {
t.Run("Add", func(t *testing.T) {
if Add(2, 3) != 5 {
t.Error("Add failed")
}
})
t.Run("Subtract", func(t *testing.T) {
if Subtract(5, 3) != 2 {
t.Error("Subtract failed")
}
})
}
package calculator
import "testing"
func TestDivide(t *testing.T) {
result, err := Divide(10, 2)
if err != nil {
t.Fatalf("Divide returned error: %v", err)
}
if result != 5 {
t.Errorf("Divide(10, 2) = %d; want 5", result)
}
}
func TestDivideByZero(t *testing.T) {
_, err := Divide(10, 0)
if err == nil {
t.Error("Divide by zero should return error")
}
}
package main
import "testing"
func TestHelper(t *testing.T) {
// t.Helper() 标记为辅助函数,错误报告时会跳过
assertEqual := func(t testing.TB, got, want int) {
t.Helper()
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
assertEqual(t, Add(2, 3), 5)
}
# 查看覆盖率
go test -cover
# 生成覆盖率详情
go test -coverprofile=coverage.out
# 查看覆盖率详情
go tool cover -func=coverage.out
# 在浏览器中查看
go tool cover -html=coverage.out
// math.go
package math
func Abs(x int) int {
if x < 0 {
return -x
}
return x
}
func Max(a, b int) int {
if a > b {
return a
}
return b
}
// math_test.go
package math
import "testing"
func TestAbs(t *testing.T) {
tests := []struct {
input int
expected int
}{
{1, 1},
{-1, 1},
{0, 0},
}
for _, tt := range tests {
result := Abs(tt.input)
if result != tt.expected {
t.Errorf("Abs(%d) = %d; want %d", tt.input, result, tt.expected)
}
}
}
package calculator
import "testing"
func BenchmarkAdd(b *testing.B) {
for i := 0; i < b.N; i++ {
Add(2, 3)
}
}
# 运行基准测试
go test -bench=.
# 运行特定基准测试
go test -bench=BenchmarkAdd
# 指定运行时间
go test -bench=. -benchtime=5s
# 显示内存分配
go test -bench=. -benchmem
package strings
import (
"strings"
"testing"
)
func BenchmarkConcat(b *testing.B) {
for i := 0; i < b.N; i++ {
s := ""
for j := 0; j < 100; j++ {
s += "x"
}
}
}
func BenchmarkStringBuilder(b *testing.B) {
for i := 0; i < b.N; i++ {
var builder strings.Builder
for j := 0; j < 100; j++ {
builder.WriteString("x")
}
builder.String()
}
}
package calculator
import "testing"
func BenchmarkAddParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
Add(2, 3)
}
})
}
package calculator
import "fmt"
func ExampleAdd() {
result := Add(2, 3)
fmt.Println(result)
// Output: 5
}
func ExampleSubtract() {
result := Subtract(5, 3)
fmt.Println(result)
// Output: 2
}
func ExamplePrintUnordered() {
fmt.Println("hello")
fmt.Println("world")
// Unordered output:
// world
// hello
}
# 运行示例测试
go test -run Example
# 验证示例输出
go test -v
// repository.go
package user
type Repository interface {
FindByID(id int) (*User, error)
Save(user *User) error
}
type User struct {
ID int
Name string
}
type Service struct {
repo Repository
}
func (s *Service) GetUser(id int) (*User, error) {
return s.repo.FindByID(id)
}
// repository_test.go
package user
import (
"errors"
"testing"
)
type MockRepository struct {
users map[int]*User
}
func (m *MockRepository) FindByID(id int) (*User, error) {
if user, ok := m.users[id]; ok {
return user, nil
}
return nil, errors.New("user not found")
}
func (m *MockRepository) Save(user *User) error {
m.users[user.ID] = user
return nil
}
func TestService_GetUser(t *testing.T) {
mockRepo := &MockRepository{
users: map[int]*User{
1: {ID: 1, Name: "Alice"},
},
}
service := &Service{repo: mockRepo}
user, err := service.GetUser(1)
if err != nil {
t.Fatalf("GetUser failed: %v", err)
}
if user.Name != "Alice" {
t.Errorf("expected Alice, got %s", user.Name)
}
}
package main
import (
"os"
"testing"
)
func TestMain(m *testing.M) {
// 测试前设置
println("测试开始")
// 运行测试
code := m.Run()
// 测试后清理
println("测试结束")
os.Exit(code)
}
package main
import (
"database/sql"
"os"
"testing"
)
var db *sql.DB
func TestMain(m *testing.M) {
// 设置测试数据库
var err error
db, err = sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
// 创建表
createTables()
// 运行测试
code := m.Run()
// 清理
db.Close()
os.Exit(code)
}
func createTables() {
db.Exec(`CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT
)`)
}
func TestInsertUser(t *testing.T) {
_, err := db.Exec("INSERT INTO users (name) VALUES (?)", "Alice")
if err != nil {
t.Fatal(err)
}
}
package main
import (
"runtime"
"testing"
)
func TestLinuxOnly(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("此测试只在 Linux 上运行")
}
// Linux 特定测试
}
func TestSlowOperation(t *testing.T) {
if testing.Short() {
t.Skip("跳过慢速测试")
}
// 耗时测试
}
# 只运行短测试
go test -short
go get github.com/stretchr/testify
package main
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWithTestify(t *testing.T) {
// assert: 失败后继续执行
assert.Equal(t, 5, Add(2, 3))
assert.NotNil(t, "hello")
// require: 失败后立即停止
require.Equal(t, 5, Add(2, 3))
}
func TestWithSuite(t *testing.T) {
// 使用测试套件
suite.Run(t, new(MyTestSuite))
}
type MyTestSuite struct {
suite.Suite
}
func (s *MyTestSuite) SetupTest() {
// 每个测试前执行
}
func (s *MyTestSuite) TearDownTest() {
// 每个测试后执行
}
func (s *MyTestSuite) TestSomething() {
s.Equal(5, Add(2, 3))
}
本章详细介绍了 Go 语言的单元测试:
单元测试是保证代码质量的重要手段,掌握测试能让你写出更可靠的代码。在下一章中,我们将学习性能调优。