HTTP 方法

限流是保护服务的重要手段,防止恶意请求或流量洪峰压垮系统。

简单的内存限流

基于内存的简单限流实现:

package main

import (
    "net/http"
    "sync"
    "time"
    
    "github.com/gin-gonic/gin"
)

type RateLimiter struct {
    mu       sync.Mutex
    requests map[string][]time.Time
    limit    int
    window   time.Duration
}

func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
    return &RateLimiter{
        requests: make(map[string][]time.Time),
        limit:    limit,
        window:   window,
    }
}

func (rl *RateLimiter) Allow(ip string) bool {
    rl.mu.Lock()
    defer rl.mu.Unlock()
    
    now := time.Now()
    windowStart := now.Add(-rl.window)
    
    requests := rl.requests[ip]
    validRequests := []time.Time{}
    
    for _, t := range requests {
        if t.After(windowStart) {
            validRequests = append(validRequests, t)
        }
    }
    
    if len(validRequests) >= rl.limit {
        rl.requests[ip] = validRequests
        return false
    }
    
    validRequests = append(validRequests, now)
    rl.requests[ip] = validRequests
    return true
}

func RateLimit(limiter *RateLimiter) gin.HandlerFunc {
    return func(c *gin.Context) {
        ip := c.ClientIP()
        
        if !limiter.Allow(ip) {
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
                "error": "请求过于频繁,请稍后再试",
            })
            return
        }
        
        c.Next()
    }
}

func main() {
    r := gin.Default()
    
    limiter := NewRateLimiter(100, time.Minute)
    
    r.Use(RateLimit(limiter))
    
    r.GET("/api", func(c *gin.Context) {
        c.JSON(http.StatusOK, gin.H{"message": "OK"})
    })
    
    go func() {
        for {
            time.Sleep(time.Minute)
            limiter.mu.Lock()
            limiter.requests = make(map[string][]time.Time)
            limiter.mu.Unlock()
        }
    }()
    
    r.Run(":8080")
}

令牌桶限流

更精确的令牌桶算法:

import (
    "golang.org/x/time/rate"
)

func TokenBucketLimiter(rps int, burst int) gin.HandlerFunc {
    limiters := make(map[string]*rate.Limiter)
    mu := sync.Mutex{}
    
    return func(c *gin.Context) {
        ip := c.ClientIP()
        
        mu.Lock()
        limiter, exists := limiters[ip]
        if !exists {
            limiter = rate.NewLimiter(rate.Limit(rps), burst)
            limiters[ip] = limiter
        }
        mu.Unlock()
        
        if !limiter.Allow() {
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
                "error": "请求过于频繁",
            })
            return
        }
        
        c.Next()
    }
}

r.Use(TokenBucketLimiter(100, 200))

滑动窗口限流

type SlidingWindowLimiter struct {
    mu        sync.Mutex
    windows   map[string]*Window
    limit     int
    windowSize time.Duration
}

type Window struct {
    timestamps []time.Time
}

func NewSlidingWindowLimiter(limit int, windowSize time.Duration) *SlidingWindowLimiter {
    return &SlidingWindowLimiter{
        windows:   make(map[string]*Window),
        limit:     limit,
        windowSize: windowSize,
    }
}

func (l *SlidingWindowLimiter) Allow(key string) bool {
    l.mu.Lock()
    defer l.mu.Unlock()
    
    now := time.Now()
    windowStart := now.Add(-l.windowSize)
    
    window, exists := l.windows[key]
    if !exists {
        window = &Window{timestamps: []time.Time{}}
        l.windows[key] = window
    }
    
    validTimestamps := []time.Time{}
    for _, t := range window.timestamps {
        if t.After(windowStart) {
            validTimestamps = append(validTimestamps, t)
        }
    }
    
    if len(validTimestamps) >= l.limit {
        window.timestamps = validTimestamps
        return false
    }
    
    validTimestamps = append(validTimestamps, now)
    window.timestamps = validTimestamps
    return true
}

func SlidingWindowRateLimit(limiter *SlidingWindowLimiter) gin.HandlerFunc {
    return func(c *gin.Context) {
        key := c.ClientIP()
        
        if !limiter.Allow(key) {
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
                "error": "请求过于频繁",
            })
            return
        }
        
        c.Next()
    }
}

分布式限流

使用 Redis 实现分布式限流:

import (
    "context"
    "github.com/go-redis/redis/v8"
)

func RedisRateLimit(rdb *redis.Client, limit int, window time.Duration) gin.HandlerFunc {
    return func(c *gin.Context) {
        ip := c.ClientIP()
        key := fmt.Sprintf("rate_limit:%s", ip)
        
        ctx := context.Background()
        count, err := rdb.Incr(ctx, key).Result()
        if err != nil {
            c.Next()
            return
        }
        
        if count == 1 {
            rdb.Expire(ctx, key, window)
        }
        
        if count > int64(limit) {
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
                "error": "请求过于频繁",
            })
            return
        }
        
        c.Next()
    }
}

不同路由不同限制

func RouteRateLimit(defaultLimit int, routeLimits map[string]int) gin.HandlerFunc {
    limiters := make(map[string]*rate.Limiter)
    mu := sync.Mutex{}
    
    return func(c *gin.Context) {
        route := c.FullPath()
        
        limit := defaultLimit
        if l, ok := routeLimits[route]; ok {
            limit = l
        }
        
        ip := c.ClientIP()
        key := fmt.Sprintf("%s:%s", route, ip)
        
        mu.Lock()
        limiter, exists := limiters[key]
        if !exists {
            limiter = rate.NewLimiter(rate.Limit(limit), limit*2)
            limiters[key] = limiter
        }
        mu.Unlock()
        
        if !limiter.Allow() {
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
                "error": "请求过于频繁",
            })
            return
        }
        
        c.Next()
    }
}

r.Use(RouteRateLimit(100, map[string]int{
    "/api/login":    5,
    "/api/register": 3,
    "/api/upload":   10,
}))

返回限流信息

func RateLimitWithHeaders(limiter *RateLimiter) gin.HandlerFunc {
    return func(c *gin.Context) {
        ip := c.ClientIP()
        
        remaining := limiter.GetRemaining(ip)
        resetTime := limiter.GetResetTime(ip)
        
        c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", limiter.limit))
        c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
        c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
        
        if !limiter.Allow(ip) {
            c.Header("Retry-After", fmt.Sprintf("%d", int(limiter.window.Seconds())))
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
                "error":       "请求过于频繁",
                "retry_after": limiter.window.Seconds(),
            })
            return
        }
        
        c.Next()
    }
}

小结

限流是保护服务的重要手段。单机限流可以用内存实现,分布式环境需要借助 Redis。令牌桶算法适合允许一定突发流量的场景,滑动窗口更精确但实现复杂。记得在响应头中返回限流信息,方便客户端处理。