中间件传值

中间件之间、中间件和处理器之间经常需要传递数据。Gin 的 Context 提供了便捷的存取方法。

基本的 Set 和 Get

使用 c.Set() 存储数据,c.Get() 获取数据:

package main

import (
    "net/http"
    "github.com/gin-gonic/gin"
)

func Auth() gin.HandlerFunc {
    return func(c *gin.Context) {
        token := c.GetHeader("Authorization")
        
        user := parseToken(token)
        c.Set("currentUser", user)
        c.Set("userID", user.ID)
        c.Set("userRole", user.Role)
        
        c.Next()
    }
}

func main() {
    r := gin.Default()
    
    r.GET("/profile", Auth(), func(c *gin.Context) {
        user, exists := c.Get("currentUser")
        if !exists {
            c.JSON(http.StatusInternalServerError, gin.H{"error": "用户信息丢失"})
            return
        }
        
        c.JSON(http.StatusOK, user)
    })
    
    r.Run(":8080")
}

MustGet 方法

如果确定值一定存在,可以用 MustGet,不存在时会 panic:

r.GET("/me", Auth(), func(c *gin.Context) {
    userID := c.MustGet("userID").(int)
    userRole := c.MustGet("userRole").(string)
    
    c.JSON(http.StatusOK, gin.H{
        "id":   userID,
        "role": userRole,
    })
})

类型断言

Get 返回的是 interface{},需要类型断言:

r.GET("/info", Auth(), func(c *gin.Context) {
    if userID, exists := c.Get("userID"); exists {
        id := userID.(int)
        c.JSON(http.StatusOK, gin.H{"id": id})
    }
    
    if user, exists := c.Get("currentUser"); exists {
        u := user.(*User)
        c.JSON(http.StatusOK, gin.H{"name": u.Name})
    }
})

中间件之间传值

中间件可以读取前面中间件设置的值:

func Auth() gin.HandlerFunc {
    return func(c *gin.Context) {
        token := c.GetHeader("Authorization")
        user, _ := parseToken(token)
        c.Set("user", user)
        c.Next()
    }
}

func AdminOnly() gin.HandlerFunc {
    return func(c *gin.Context) {
        user, exists := c.Get("user")
        if !exists {
            c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "未登录"})
            return
        }
        
        u := user.(*User)
        if u.Role != "admin" {
            c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "需要管理员权限"})
            return
        }
        
        c.Next()
    }
}

r.DELETE("/users/:id", Auth(), AdminOnly(), deleteUser)

请求计时示例

func Timer() gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        c.Set("startTime", start)
        c.Next()
    }
}

func Logger() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Next()
        
        start, exists := c.Get("startTime")
        if exists {
            duration := time.Since(start.(time.Time))
            log.Printf("请求耗时: %v", duration)
        }
    }
}

r.Use(Timer())
r.Use(Logger())

请求追踪示例

func RequestID() gin.HandlerFunc {
    return func(c *gin.Context) {
        requestID := uuid.New().String()
        c.Set("requestID", requestID)
        c.Header("X-Request-ID", requestID)
        c.Next()
    }
}

func AuditLog() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Next()
        
        requestID, _ := c.Get("requestID")
        log.Printf("[%s] %s %s - %d",
            requestID,
            c.Request.Method,
            c.Request.URL.Path,
            c.Writer.Status(),
        )
    }
}

r.Use(RequestID())
r.Use(AuditLog())

存储请求信息

func RequestInfo() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Set("clientIP", c.ClientIP())
        c.Set("userAgent", c.GetHeader("User-Agent"))
        c.Set("referer", c.GetHeader("Referer"))
        c.Next()
    }
}

r.GET("/track", RequestInfo(), func(c *gin.Context) {
    ip := c.MustGet("clientIP").(string)
    agent := c.MustGet("userAgent").(string)
    
    saveToAnalytics(ip, agent)
    c.String(http.StatusOK, "OK")
})

封装获取方法

可以封装便捷的获取方法:

func GetCurrentUser(c *gin.Context) (*User, error) {
    user, exists := c.Get("currentUser")
    if !exists {
        return nil, errors.New("用户未登录")
    }
    return user.(*User), nil
}

func GetUserID(c *gin.Context) (int, error) {
    id, exists := c.Get("userID")
    if !exists {
        return 0, errors.New("用户ID不存在")
    }
    return id.(int), nil
}

r.GET("/me", Auth(), func(c *gin.Context) {
    user, err := GetCurrentUser(c)
    if err != nil {
        c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
        return
    }
    
    c.JSON(http.StatusOK, user)
})

小结

中间件传值是 Gin 中常用的模式,通过 c.Set()c.Get() 可以在中间件链中传递数据。注意类型断言的正确使用,必要时封装便捷方法。合理使用中间件传值可以让代码更加解耦。