自定义错误与错误包装

28.1 为什么需要自定义错误

在实际项目开发中,标准库提供的错误类型往往不能满足需求。自定义错误可以:

  1. 携带更多上下文信息:如错误码、时间戳、堆栈信息
  2. 实现特定行为:如临时错误判断、超时判断
  3. 统一错误格式:便于 API 返回统一的错误结构
  4. 支持错误分类:便于错误统计和监控

28.2 自定义错误类型

基本自定义错误

package main

import "fmt"

// 定义错误类型
type NotFoundError struct {
    Resource string
    ID       int
}

// 实现 error 接口
func (e *NotFoundError) Error() string {
    return fmt.Sprintf("%s (ID: %d) 不存在", e.Resource, e.ID)
}

func findUser(id int) error {
    if id <= 0 {
        return &NotFoundError{
            Resource: "用户",
            ID:       id,
        }
    }
    return nil
}

func main() {
    if err := findUser(0); err != nil {
        fmt.Println(err)
    }
}

带错误码的错误类型

package main

import (
    "encoding/json"
    "fmt"
)

type ErrorCode int

const (
    CodeSuccess        ErrorCode = 0
    CodeInvalidParam   ErrorCode = 1001
    CodeNotFound       ErrorCode = 1002
    CodeUnauthorized   ErrorCode = 1003
    CodeInternalError  ErrorCode = 1004
    CodeServiceTimeout ErrorCode = 1005
)

type AppError struct {
    Code    ErrorCode `json:"code"`
    Message string    `json:"message"`
    Detail  string    `json:"detail,omitempty"`
}

func (e *AppError) Error() string {
    return fmt.Sprintf("[%d] %s", e.Code, e.Message)
}

func (e *AppError) ToJSON() string {
    b, _ := json.Marshal(e)
    return string(b)
}

func NewAppError(code ErrorCode, message string) *AppError {
    return &AppError{
        Code:    code,
        Message: message,
    }
}

func (e *AppError) WithDetail(detail string) *AppError {
    e.Detail = detail
    return e
}

func validateUser(name string, age int) error {
    if name == "" {
        return NewAppError(CodeInvalidParam, "用户名不能为空").
            WithDetail("name 字段是必填的")
    }
    if age < 0 || age > 150 {
        return NewAppError(CodeInvalidParam, "年龄不合法").
            WithDetail(fmt.Sprintf("年龄 %d 不在有效范围内", age))
    }
    return nil
}

func main() {
    if err := validateUser("", 25); err != nil {
        if appErr, ok := err.(*AppError); ok {
            fmt.Println("错误码:", appErr.Code)
            fmt.Println("错误信息:", appErr.Message)
            fmt.Println("详细信息:", appErr.Detail)
            fmt.Println("JSON 格式:", appErr.ToJSON())
        }
    }
}

带堆栈信息的错误

package main

import (
    "fmt"
    "runtime"
    "strings"
)

type StackError struct {
    Message string
    Stack   []string
    Cause   error
}

func (e *StackError) Error() string {
    var sb strings.Builder
    sb.WriteString(e.Message)
    sb.WriteString("\n堆栈信息:\n")
    for _, s := range e.Stack {
        sb.WriteString("  ")
        sb.WriteString(s)
        sb.WriteString("\n")
    }
    if e.Cause != nil {
        sb.WriteString("原因: ")
        sb.WriteString(e.Cause.Error())
    }
    return sb.String()
}

func (e *StackError) Unwrap() error {
    return e.Cause
}

func NewStackError(message string) *StackError {
    return &StackError{
        Message: message,
        Stack:   getStack(),
    }
}

func WrapStackError(err error, message string) *StackError {
    return &StackError{
        Message: message,
        Stack:   getStack(),
        Cause:   err,
    }
}

func getStack() []string {
    var stack []string
    for i := 2; i < 15; i++ {
        _, file, line, ok := runtime.Caller(i)
        if !ok {
            break
        }
        stack = append(stack, fmt.Sprintf("%s:%d", file, line))
    }
    return stack
}

func level3() error {
    return NewStackError("底层错误")
}

func level2() error {
    err := level3()
    if err != nil {
        return WrapStackError(err, "level2 处理失败")
    }
    return nil
}

func level1() error {
    err := level2()
    if err != nil {
        return WrapStackError(err, "level1 处理失败")
    }
    return nil
}

func main() {
    if err := level1(); err != nil {
        fmt.Println(err)
    }
}

28.3 实现 Unwrap 接口

Go 1.13 引入了 Unwrap 接口,用于错误链的支持:

type Wrapper interface {
    Unwrap() error
}

实现 Unwrap

package main

import (
    "errors"
    "fmt"
)

type QueryError struct {
    Query string
    Err   error
}

func (e *QueryError) Error() string {
    return fmt.Sprintf("查询失败 [%s]: %v", e.Query, e.Err)
}

func (e *QueryError) Unwrap() error {
    return e.Err
}

func queryUser(id int) error {
    return &QueryError{
        Query: "SELECT * FROM users WHERE id = ?",
        Err:   errors.New("连接超时"),
    }
}

func main() {
    err := queryUser(1)
    if err != nil {
        fmt.Println("完整错误:", err)

        // 解包获取内部错误
        unwrapped := errors.Unwrap(err)
        if unwrapped != nil {
            fmt.Println("内部错误:", unwrapped)
        }
    }
}

多层错误链

package main

import (
    "errors"
    "fmt"
)

type NetworkError struct {
    Op  string
    Err error
}

func (e *NetworkError) Error() string {
    return fmt.Sprintf("网络操作 %s 失败: %v", e.Op, e.Err)
}

func (e *NetworkError) Unwrap() error {
    return e.Err
}

type DatabaseError struct {
    Op    string
    Query string
    Err   error
}

func (e *DatabaseError) Error() string {
    return fmt.Sprintf("数据库操作 %s 失败 [%s]: %v", e.Op, e.Query, e.Err)
}

func (e *DatabaseError) Unwrap() error {
    return e.Err
}

func connectDB() error {
    return &NetworkError{
        Op:  "dial",
        Err: errors.New("connection refused"),
    }
}

func queryDB() error {
    err := connectDB()
    if err != nil {
        return &DatabaseError{
            Op:    "query",
            Query: "SELECT * FROM users",
            Err:   err,
        }
    }
    return nil
}

func main() {
    err := queryDB()
    if err != nil {
        fmt.Println("=== 错误链遍历 ===")
        for err != nil {
            fmt.Printf("- %T: %v\n", err, err)
            err = errors.Unwrap(err)
        }
    }
}

28.4 错误包装最佳实践

使用 fmt.Errorf 包装

package main

import (
    "errors"
    "fmt"
)

var ErrNotFound = errors.New("记录不存在")

func findRecord(id int) error {
    return ErrNotFound
}

func processRecord(id int) error {
    err := findRecord(id)
    if err != nil {
        return fmt.Errorf("处理记录 %d 失败: %w", id, err)
    }
    return nil
}

func main() {
    err := processRecord(1)
    if err != nil {
        fmt.Println("错误:", err)

        if errors.Is(err, ErrNotFound) {
            fmt.Println("检测到 ErrNotFound")
        }
    }
}

自定义包装函数

package main

import (
    "fmt"
    "time"
)

type ErrorContext struct {
    Message   string
    Timestamp time.Time
    Err       error
}

func (e *ErrorContext) Error() string {
    return fmt.Sprintf("[%s] %s: %v", e.Timestamp.Format(time.RFC3339), e.Message, e.Err)
}

func (e *ErrorContext) Unwrap() error {
    return e.Err
}

func WrapWithContext(err error, message string) *ErrorContext {
    return &ErrorContext{
        Message:   message,
        Timestamp: time.Now(),
        Err:       err,
    }
}

func main() {
    err := fmt.Errorf("数据库连接失败")
    wrapped := WrapWithContext(err, "用户服务初始化失败")

    fmt.Println(wrapped)
}

28.5 实现特定行为的错误

临时错误判断

package main

import (
    "errors"
    "fmt"
    "time"
)

// 临时错误接口
type temporary interface {
    Temporary() bool
}

type NetworkTimeoutError struct {
    Duration time.Duration
}

func (e *NetworkTimeoutError) Error() string {
    return fmt.Sprintf("网络超时 (%v)", e.Duration)
}

func (e *NetworkTimeoutError) Temporary() bool {
    return true
}

func (e *NetworkTimeoutError) Timeout() bool {
    return true
}

func isTemporary(err error) bool {
    te, ok := err.(temporary)
    return ok && te.Temporary()
}

func retry(fn func() error, maxRetries int) error {
    var lastErr error
    for i := 0; i < maxRetries; i++ {
        err := fn()
        if err == nil {
            return nil
        }
        lastErr = err
        if !isTemporary(err) {
            return err
        }
        fmt.Printf("第 %d 次重试...\n", i+1)
        time.Sleep(100 * time.Millisecond)
    }
    return lastErr
}

func main() {
    attempts := 0
    err := retry(func() error {
        attempts++
        if attempts < 3 {
            return &NetworkTimeoutError{Duration: time.Second}
        }
        return nil
    }, 5)

    if err != nil {
        fmt.Println("最终失败:", err)
    } else {
        fmt.Println("操作成功")
    }

    // 测试非临时错误
    err = retry(func() error {
        return errors.New("永久错误")
    }, 5)

    if err != nil {
        fmt.Println("结果:", err)
    }
}

超时错误判断

package main

import (
    "errors"
    "fmt"
    "time"
)

type timeout interface {
    Timeout() bool
}

type TimeoutError struct {
    Op      string
    Timeout time.Duration
}

func (e *TimeoutError) Error() string {
    return fmt.Sprintf("操作 %s 超时 (%v)", e.Op, e.Timeout)
}

func (e *TimeoutError) Timeout() bool {
    return true
}

func isTimeout(err error) bool {
    te, ok := err.(timeout)
    return ok && te.Timeout()
}

func main() {
    err := &TimeoutError{Op: "查询", Timeout: 5 * time.Second}

    if isTimeout(err) {
        fmt.Println("检测到超时错误")
    }

    regularErr := errors.New("普通错误")
    if !isTimeout(regularErr) {
        fmt.Println("这不是超时错误")
    }
}

28.6 企业级错误设计

统一错误结构

package main

import (
    "encoding/json"
    "fmt"
    "net/http"
    "time"
)

type ErrorLevel string

const (
    ErrorLevelInfo    ErrorLevel = "info"
    ErrorLevelWarn    ErrorLevel = "warn"
    ErrorLevelError   ErrorLevel = "error"
    ErrorLevelFatal   ErrorLevel = "fatal"
)

type BusinessError struct {
    Code       string     `json:"code"`
    Message    string     `json:"message"`
    Level      ErrorLevel `json:"level"`
    TraceID    string     `json:"trace_id"`
    Timestamp  time.Time  `json:"timestamp"`
    Stack      string     `json:"stack,omitempty"`
    Details    any        `json:"details,omitempty"`
    Cause      error      `json:"-"`
    HTTPStatus int        `json:"-"`
}

func (e *BusinessError) Error() string {
    return fmt.Sprintf("[%s] %s: %s", e.Code, e.Level, e.Message)
}

func (e *BusinessError) Unwrap() error {
    return e.Cause
}

func (e *BusinessError) ToJSON() string {
    b, _ := json.Marshal(e)
    return string(b)
}

func (e *BusinessError) Response() (int, string) {
    return e.HTTPStatus, e.ToJSON()
}

type ErrorBuilder struct {
    err *BusinessError
}

func NewErrorBuilder(code string) *ErrorBuilder {
    return &ErrorBuilder{
        err: &BusinessError{
            Code:      code,
            Timestamp: time.Now(),
        },
    }
}

func (b *ErrorBuilder) Message(msg string) *ErrorBuilder {
    b.err.Message = msg
    return b
}

func (b *ErrorBuilder) Level(level ErrorLevel) *ErrorBuilder {
    b.err.Level = level
    return b
}

func (b *ErrorBuilder) TraceID(id string) *ErrorBuilder {
    b.err.TraceID = id
    return b
}

func (b *ErrorBuilder) Details(details any) *ErrorBuilder {
    b.err.Details = details
    return b
}

func (b *ErrorBuilder) Cause(cause error) *ErrorBuilder {
    b.err.Cause = cause
    return b
}

func (b *ErrorBuilder) HTTPStatus(status int) *ErrorBuilder {
    b.err.HTTPStatus = status
    return b
}

func (b *ErrorBuilder) Build() *BusinessError {
    if b.err.Level == "" {
        b.err.Level = ErrorLevelError
    }
    if b.err.HTTPStatus == 0 {
        b.err.HTTPStatus = http.StatusInternalServerError
    }
    return b.err
}

func main() {
    err := NewErrorBuilder("USER_NOT_FOUND").
        Message("用户不存在").
        Level(ErrorLevelWarn).
        TraceID("trace-123").
        Details(map[string]int{"user_id": 100}).
        HTTPStatus(http.StatusNotFound).
        Build()

    fmt.Println("错误信息:", err)
    fmt.Println("JSON 格式:", err.ToJSON())

    status, body := err.Response()
    fmt.Printf("HTTP 响应: 状态码=%d, 内容=%s\n", status, body)
}

错误码注册表

package main

import (
    "fmt"
    "sync"
)

type ErrorDefinition struct {
    Code        string
    Message     string
    HTTPStatus  int
    Level       string
    Description string
}

type ErrorRegistry struct {
    mu    sync.RWMutex
    codes map[string]*ErrorDefinition
}

var registry = &ErrorRegistry{
    codes: make(map[string]*ErrorDefinition),
}

func (r *ErrorRegistry) Register(def *ErrorDefinition) {
    r.mu.Lock()
    defer r.mu.Unlock()
    r.codes[def.Code] = def
}

func (r *ErrorRegistry) Get(code string) (*ErrorDefinition, bool) {
    r.mu.RLock()
    defer r.mu.RUnlock()
    def, ok := r.codes[code]
    return def, ok
}

func init() {
    registry.Register(&ErrorDefinition{
        Code:        "USER_NOT_FOUND",
        Message:     "用户不存在",
        HTTPStatus:  404,
        Level:       "warn",
        Description: "请求的用户 ID 在系统中不存在",
    })
    registry.Register(&ErrorDefinition{
        Code:        "INVALID_PARAMETER",
        Message:     "参数错误",
        HTTPStatus:  400,
        Level:       "warn",
        Description: "请求参数不符合要求",
    })
    registry.Register(&ErrorDefinition{
        Code:        "INTERNAL_ERROR",
        Message:     "内部错误",
        HTTPStatus:  500,
        Level:       "error",
        Description: "服务器内部错误",
    })
}

func NewError(code string, details ...any) *BusinessError {
    def, ok := registry.Get(code)
    if !ok {
        def = &ErrorDefinition{
            Code:       "UNKNOWN",
            Message:    "未知错误",
            HTTPStatus: 500,
            Level:      "error",
        }
    }

    err := &BusinessError{
        Code:       def.Code,
        Message:    def.Message,
        Level:      ErrorLevel(def.Level),
        Timestamp:  time.Now(),
        HTTPStatus: def.HTTPStatus,
    }

    if len(details) > 0 {
        err.Details = details[0]
    }

    return err
}

func main() {
    err := NewError("USER_NOT_FOUND", map[string]int{"user_id": 100})
    fmt.Println(err)
    fmt.Println("HTTP 状态码:", err.HTTPStatus)
}

28.7 错误处理中间件

package main

import (
    "encoding/json"
    "fmt"
    "log"
    "net/http"
)

type Response struct {
    Code    int    `json:"code"`
    Message string `json:"message"`
    Data    any    `json:"data,omitempty"`
}

func writeJSON(w http.ResponseWriter, status int, data any) {
    w.Header().Set("Content-Type", "application/json")
    w.WriteHeader(status)
    json.NewEncoder(w).Encode(data)
}

func errorMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        defer func() {
            if recovered := recover(); recovered != nil {
                log.Printf("Panic recovered: %v", recovered)
                writeJSON(w, http.StatusInternalServerError, Response{
                    Code:    500,
                    Message: "服务器内部错误",
                })
            }
        }()
        next.ServeHTTP(w, r)
    })
}

type ErrorHandler func(w http.ResponseWriter, r *http.Request) error

func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    if err := h(w, r); err != nil {
        if bizErr, ok := err.(*BusinessError); ok {
            writeJSON(w, bizErr.HTTPStatus, Response{
                Code:    bizErr.HTTPStatus,
                Message: bizErr.Message,
            })
            return
        }
        writeJSON(w, http.StatusInternalServerError, Response{
            Code:    500,
            Message: "服务器内部错误",
        })
    }
}

func getUserHandler(w http.ResponseWriter, r *http.Request) error {
    userID := r.URL.Query().Get("id")
    if userID == "" {
        return NewError("INVALID_PARAMETER", "缺少用户 ID")
    }

    if userID == "0" {
        return NewError("USER_NOT_FOUND", map[string]string{"user_id": userID})
    }

    writeJSON(w, http.StatusOK, Response{
        Code:    0,
        Message: "成功",
        Data: map[string]any{
            "id":   userID,
            "name": "用户" + userID,
        },
    })
    return nil
}

func main() {
    mux := http.NewServeMux()
    mux.Handle("/user", ErrorHandler(getUserHandler))

    wrappedMux := errorMiddleware(mux)

    fmt.Println("服务器启动在 :8080")
    http.ListenAndServe(":8080", wrappedMux)
}

28.8 小结

本章深入讲解了 Go 语言的自定义错误和错误包装:

  1. 自定义错误类型:实现 Error() 方法,携带更多上下文信息
  2. 错误码设计:统一错误码格式,便于 API 返回和错误统计
  3. Unwrap 接口:支持错误链,配合 errors.Iserrors.As
  4. 特定行为错误:实现 Temporary()Timeout() 等方法
  5. 企业级错误设计:统一错误结构、错误码注册表、错误处理中间件

良好的错误设计是构建可靠系统的关键。在下一章中,我们将学习错误处理的更多高级技巧。