在实际项目开发中,标准库提供的错误类型往往不能满足需求。自定义错误可以:
package main
import "fmt"
// 定义错误类型
type NotFoundError struct {
Resource string
ID int
}
// 实现 error 接口
func (e *NotFoundError) Error() string {
return fmt.Sprintf("%s (ID: %d) 不存在", e.Resource, e.ID)
}
func findUser(id int) error {
if id <= 0 {
return &NotFoundError{
Resource: "用户",
ID: id,
}
}
return nil
}
func main() {
if err := findUser(0); err != nil {
fmt.Println(err)
}
}
package main
import (
"encoding/json"
"fmt"
)
type ErrorCode int
const (
CodeSuccess ErrorCode = 0
CodeInvalidParam ErrorCode = 1001
CodeNotFound ErrorCode = 1002
CodeUnauthorized ErrorCode = 1003
CodeInternalError ErrorCode = 1004
CodeServiceTimeout ErrorCode = 1005
)
type AppError struct {
Code ErrorCode `json:"code"`
Message string `json:"message"`
Detail string `json:"detail,omitempty"`
}
func (e *AppError) Error() string {
return fmt.Sprintf("[%d] %s", e.Code, e.Message)
}
func (e *AppError) ToJSON() string {
b, _ := json.Marshal(e)
return string(b)
}
func NewAppError(code ErrorCode, message string) *AppError {
return &AppError{
Code: code,
Message: message,
}
}
func (e *AppError) WithDetail(detail string) *AppError {
e.Detail = detail
return e
}
func validateUser(name string, age int) error {
if name == "" {
return NewAppError(CodeInvalidParam, "用户名不能为空").
WithDetail("name 字段是必填的")
}
if age < 0 || age > 150 {
return NewAppError(CodeInvalidParam, "年龄不合法").
WithDetail(fmt.Sprintf("年龄 %d 不在有效范围内", age))
}
return nil
}
func main() {
if err := validateUser("", 25); err != nil {
if appErr, ok := err.(*AppError); ok {
fmt.Println("错误码:", appErr.Code)
fmt.Println("错误信息:", appErr.Message)
fmt.Println("详细信息:", appErr.Detail)
fmt.Println("JSON 格式:", appErr.ToJSON())
}
}
}
package main
import (
"fmt"
"runtime"
"strings"
)
type StackError struct {
Message string
Stack []string
Cause error
}
func (e *StackError) Error() string {
var sb strings.Builder
sb.WriteString(e.Message)
sb.WriteString("\n堆栈信息:\n")
for _, s := range e.Stack {
sb.WriteString(" ")
sb.WriteString(s)
sb.WriteString("\n")
}
if e.Cause != nil {
sb.WriteString("原因: ")
sb.WriteString(e.Cause.Error())
}
return sb.String()
}
func (e *StackError) Unwrap() error {
return e.Cause
}
func NewStackError(message string) *StackError {
return &StackError{
Message: message,
Stack: getStack(),
}
}
func WrapStackError(err error, message string) *StackError {
return &StackError{
Message: message,
Stack: getStack(),
Cause: err,
}
}
func getStack() []string {
var stack []string
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if !ok {
break
}
stack = append(stack, fmt.Sprintf("%s:%d", file, line))
}
return stack
}
func level3() error {
return NewStackError("底层错误")
}
func level2() error {
err := level3()
if err != nil {
return WrapStackError(err, "level2 处理失败")
}
return nil
}
func level1() error {
err := level2()
if err != nil {
return WrapStackError(err, "level1 处理失败")
}
return nil
}
func main() {
if err := level1(); err != nil {
fmt.Println(err)
}
}
Go 1.13 引入了 Unwrap 接口,用于错误链的支持:
type Wrapper interface {
Unwrap() error
}
package main
import (
"errors"
"fmt"
)
type QueryError struct {
Query string
Err error
}
func (e *QueryError) Error() string {
return fmt.Sprintf("查询失败 [%s]: %v", e.Query, e.Err)
}
func (e *QueryError) Unwrap() error {
return e.Err
}
func queryUser(id int) error {
return &QueryError{
Query: "SELECT * FROM users WHERE id = ?",
Err: errors.New("连接超时"),
}
}
func main() {
err := queryUser(1)
if err != nil {
fmt.Println("完整错误:", err)
// 解包获取内部错误
unwrapped := errors.Unwrap(err)
if unwrapped != nil {
fmt.Println("内部错误:", unwrapped)
}
}
}
package main
import (
"errors"
"fmt"
)
type NetworkError struct {
Op string
Err error
}
func (e *NetworkError) Error() string {
return fmt.Sprintf("网络操作 %s 失败: %v", e.Op, e.Err)
}
func (e *NetworkError) Unwrap() error {
return e.Err
}
type DatabaseError struct {
Op string
Query string
Err error
}
func (e *DatabaseError) Error() string {
return fmt.Sprintf("数据库操作 %s 失败 [%s]: %v", e.Op, e.Query, e.Err)
}
func (e *DatabaseError) Unwrap() error {
return e.Err
}
func connectDB() error {
return &NetworkError{
Op: "dial",
Err: errors.New("connection refused"),
}
}
func queryDB() error {
err := connectDB()
if err != nil {
return &DatabaseError{
Op: "query",
Query: "SELECT * FROM users",
Err: err,
}
}
return nil
}
func main() {
err := queryDB()
if err != nil {
fmt.Println("=== 错误链遍历 ===")
for err != nil {
fmt.Printf("- %T: %v\n", err, err)
err = errors.Unwrap(err)
}
}
}
package main
import (
"errors"
"fmt"
)
var ErrNotFound = errors.New("记录不存在")
func findRecord(id int) error {
return ErrNotFound
}
func processRecord(id int) error {
err := findRecord(id)
if err != nil {
return fmt.Errorf("处理记录 %d 失败: %w", id, err)
}
return nil
}
func main() {
err := processRecord(1)
if err != nil {
fmt.Println("错误:", err)
if errors.Is(err, ErrNotFound) {
fmt.Println("检测到 ErrNotFound")
}
}
}
package main
import (
"fmt"
"time"
)
type ErrorContext struct {
Message string
Timestamp time.Time
Err error
}
func (e *ErrorContext) Error() string {
return fmt.Sprintf("[%s] %s: %v", e.Timestamp.Format(time.RFC3339), e.Message, e.Err)
}
func (e *ErrorContext) Unwrap() error {
return e.Err
}
func WrapWithContext(err error, message string) *ErrorContext {
return &ErrorContext{
Message: message,
Timestamp: time.Now(),
Err: err,
}
}
func main() {
err := fmt.Errorf("数据库连接失败")
wrapped := WrapWithContext(err, "用户服务初始化失败")
fmt.Println(wrapped)
}
package main
import (
"errors"
"fmt"
"time"
)
// 临时错误接口
type temporary interface {
Temporary() bool
}
type NetworkTimeoutError struct {
Duration time.Duration
}
func (e *NetworkTimeoutError) Error() string {
return fmt.Sprintf("网络超时 (%v)", e.Duration)
}
func (e *NetworkTimeoutError) Temporary() bool {
return true
}
func (e *NetworkTimeoutError) Timeout() bool {
return true
}
func isTemporary(err error) bool {
te, ok := err.(temporary)
return ok && te.Temporary()
}
func retry(fn func() error, maxRetries int) error {
var lastErr error
for i := 0; i < maxRetries; i++ {
err := fn()
if err == nil {
return nil
}
lastErr = err
if !isTemporary(err) {
return err
}
fmt.Printf("第 %d 次重试...\n", i+1)
time.Sleep(100 * time.Millisecond)
}
return lastErr
}
func main() {
attempts := 0
err := retry(func() error {
attempts++
if attempts < 3 {
return &NetworkTimeoutError{Duration: time.Second}
}
return nil
}, 5)
if err != nil {
fmt.Println("最终失败:", err)
} else {
fmt.Println("操作成功")
}
// 测试非临时错误
err = retry(func() error {
return errors.New("永久错误")
}, 5)
if err != nil {
fmt.Println("结果:", err)
}
}
package main
import (
"errors"
"fmt"
"time"
)
type timeout interface {
Timeout() bool
}
type TimeoutError struct {
Op string
Timeout time.Duration
}
func (e *TimeoutError) Error() string {
return fmt.Sprintf("操作 %s 超时 (%v)", e.Op, e.Timeout)
}
func (e *TimeoutError) Timeout() bool {
return true
}
func isTimeout(err error) bool {
te, ok := err.(timeout)
return ok && te.Timeout()
}
func main() {
err := &TimeoutError{Op: "查询", Timeout: 5 * time.Second}
if isTimeout(err) {
fmt.Println("检测到超时错误")
}
regularErr := errors.New("普通错误")
if !isTimeout(regularErr) {
fmt.Println("这不是超时错误")
}
}
package main
import (
"encoding/json"
"fmt"
"net/http"
"time"
)
type ErrorLevel string
const (
ErrorLevelInfo ErrorLevel = "info"
ErrorLevelWarn ErrorLevel = "warn"
ErrorLevelError ErrorLevel = "error"
ErrorLevelFatal ErrorLevel = "fatal"
)
type BusinessError struct {
Code string `json:"code"`
Message string `json:"message"`
Level ErrorLevel `json:"level"`
TraceID string `json:"trace_id"`
Timestamp time.Time `json:"timestamp"`
Stack string `json:"stack,omitempty"`
Details any `json:"details,omitempty"`
Cause error `json:"-"`
HTTPStatus int `json:"-"`
}
func (e *BusinessError) Error() string {
return fmt.Sprintf("[%s] %s: %s", e.Code, e.Level, e.Message)
}
func (e *BusinessError) Unwrap() error {
return e.Cause
}
func (e *BusinessError) ToJSON() string {
b, _ := json.Marshal(e)
return string(b)
}
func (e *BusinessError) Response() (int, string) {
return e.HTTPStatus, e.ToJSON()
}
type ErrorBuilder struct {
err *BusinessError
}
func NewErrorBuilder(code string) *ErrorBuilder {
return &ErrorBuilder{
err: &BusinessError{
Code: code,
Timestamp: time.Now(),
},
}
}
func (b *ErrorBuilder) Message(msg string) *ErrorBuilder {
b.err.Message = msg
return b
}
func (b *ErrorBuilder) Level(level ErrorLevel) *ErrorBuilder {
b.err.Level = level
return b
}
func (b *ErrorBuilder) TraceID(id string) *ErrorBuilder {
b.err.TraceID = id
return b
}
func (b *ErrorBuilder) Details(details any) *ErrorBuilder {
b.err.Details = details
return b
}
func (b *ErrorBuilder) Cause(cause error) *ErrorBuilder {
b.err.Cause = cause
return b
}
func (b *ErrorBuilder) HTTPStatus(status int) *ErrorBuilder {
b.err.HTTPStatus = status
return b
}
func (b *ErrorBuilder) Build() *BusinessError {
if b.err.Level == "" {
b.err.Level = ErrorLevelError
}
if b.err.HTTPStatus == 0 {
b.err.HTTPStatus = http.StatusInternalServerError
}
return b.err
}
func main() {
err := NewErrorBuilder("USER_NOT_FOUND").
Message("用户不存在").
Level(ErrorLevelWarn).
TraceID("trace-123").
Details(map[string]int{"user_id": 100}).
HTTPStatus(http.StatusNotFound).
Build()
fmt.Println("错误信息:", err)
fmt.Println("JSON 格式:", err.ToJSON())
status, body := err.Response()
fmt.Printf("HTTP 响应: 状态码=%d, 内容=%s\n", status, body)
}
package main
import (
"fmt"
"sync"
)
type ErrorDefinition struct {
Code string
Message string
HTTPStatus int
Level string
Description string
}
type ErrorRegistry struct {
mu sync.RWMutex
codes map[string]*ErrorDefinition
}
var registry = &ErrorRegistry{
codes: make(map[string]*ErrorDefinition),
}
func (r *ErrorRegistry) Register(def *ErrorDefinition) {
r.mu.Lock()
defer r.mu.Unlock()
r.codes[def.Code] = def
}
func (r *ErrorRegistry) Get(code string) (*ErrorDefinition, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
def, ok := r.codes[code]
return def, ok
}
func init() {
registry.Register(&ErrorDefinition{
Code: "USER_NOT_FOUND",
Message: "用户不存在",
HTTPStatus: 404,
Level: "warn",
Description: "请求的用户 ID 在系统中不存在",
})
registry.Register(&ErrorDefinition{
Code: "INVALID_PARAMETER",
Message: "参数错误",
HTTPStatus: 400,
Level: "warn",
Description: "请求参数不符合要求",
})
registry.Register(&ErrorDefinition{
Code: "INTERNAL_ERROR",
Message: "内部错误",
HTTPStatus: 500,
Level: "error",
Description: "服务器内部错误",
})
}
func NewError(code string, details ...any) *BusinessError {
def, ok := registry.Get(code)
if !ok {
def = &ErrorDefinition{
Code: "UNKNOWN",
Message: "未知错误",
HTTPStatus: 500,
Level: "error",
}
}
err := &BusinessError{
Code: def.Code,
Message: def.Message,
Level: ErrorLevel(def.Level),
Timestamp: time.Now(),
HTTPStatus: def.HTTPStatus,
}
if len(details) > 0 {
err.Details = details[0]
}
return err
}
func main() {
err := NewError("USER_NOT_FOUND", map[string]int{"user_id": 100})
fmt.Println(err)
fmt.Println("HTTP 状态码:", err.HTTPStatus)
}
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
)
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
func writeJSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func errorMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if recovered := recover(); recovered != nil {
log.Printf("Panic recovered: %v", recovered)
writeJSON(w, http.StatusInternalServerError, Response{
Code: 500,
Message: "服务器内部错误",
})
}
}()
next.ServeHTTP(w, r)
})
}
type ErrorHandler func(w http.ResponseWriter, r *http.Request) error
func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := h(w, r); err != nil {
if bizErr, ok := err.(*BusinessError); ok {
writeJSON(w, bizErr.HTTPStatus, Response{
Code: bizErr.HTTPStatus,
Message: bizErr.Message,
})
return
}
writeJSON(w, http.StatusInternalServerError, Response{
Code: 500,
Message: "服务器内部错误",
})
}
}
func getUserHandler(w http.ResponseWriter, r *http.Request) error {
userID := r.URL.Query().Get("id")
if userID == "" {
return NewError("INVALID_PARAMETER", "缺少用户 ID")
}
if userID == "0" {
return NewError("USER_NOT_FOUND", map[string]string{"user_id": userID})
}
writeJSON(w, http.StatusOK, Response{
Code: 0,
Message: "成功",
Data: map[string]any{
"id": userID,
"name": "用户" + userID,
},
})
return nil
}
func main() {
mux := http.NewServeMux()
mux.Handle("/user", ErrorHandler(getUserHandler))
wrappedMux := errorMiddleware(mux)
fmt.Println("服务器启动在 :8080")
http.ListenAndServe(":8080", wrappedMux)
}
本章深入讲解了 Go 语言的自定义错误和错误包装:
Error() 方法,携带更多上下文信息errors.Is 和 errors.AsTemporary()、Timeout() 等方法良好的错误设计是构建可靠系统的关键。在下一章中,我们将学习错误处理的更多高级技巧。