Go 并发处理

Go 语言的并发特性非常适合处理多个 Ollama 请求,提高整体效率。

Goroutine 并发请求

package main

import (
    "bytes"
    "encoding/json"
    "fmt"
    "io"
    "net/http"
    "sync"
)

type Message struct {
    Role    string `json:"role"`
    Content string `json:"content"`
}

type ChatRequest struct {
    Model    string    `json:"model"`
    Messages []Message `json:"messages"`
    Stream   bool      `json:"stream"`
}

type ChatResponse struct {
    Message Message `json:"message"`
}

func chat(model string, messages []Message) (string, error) {
    req := ChatRequest{
        Model:    model,
        Messages: messages,
        Stream:   false,
    }
    
    body, _ := json.Marshal(req)
    resp, err := http.Post(
        "http://localhost:11434/api/chat",
        "application/json",
        bytes.NewReader(body),
    )
    if err != nil {
        return "", err
    }
    defer resp.Body.Close()
    
    data, _ := io.ReadAll(resp.Body)
    var result ChatResponse
    json.Unmarshal(data, &result)
    
    return result.Message.Content, nil
}

func main() {
    questions := []string{
        "什么是 Go?",
        "什么是 Python?",
        "什么是 Rust?",
    }
    
    var wg sync.WaitGroup
    results := make(map[int]string)
    var mu sync.Mutex
    
    for i, question := range questions {
        wg.Add(1)
        go func(idx int, q string) {
            defer wg.Done()
            
            reply, err := chat("llama3.2", []Message{
                {Role: "user", Content: q},
            })
            if err != nil {
                fmt.Printf("请求失败: %v\n", err)
                return
            }
            
            mu.Lock()
            results[idx] = reply
            mu.Unlock()
        }(i, question)
    }
    
    wg.Wait()
    
    for i := 0; i < len(questions); i++ {
        fmt.Printf("问题%d: %s\n", i+1, questions[i])
        fmt.Printf("回答: %s\n\n", results[i][:100])
    }
}

Worker Pool 模式

package main

import (
    "bytes"
    "encoding/json"
    "io"
    "net/http"
)

type Job struct {
    ID      int
    Message string
}

type Result struct {
    ID      int
    Reply   string
    Error   error
}

func worker(id int, jobs <-chan Job, results chan<- Result) {
    for job := range jobs {
        reply, err := chat("llama3.2", []Message{
            {Role: "user", Content: job.Message},
        })
        
        results <- Result{
            ID:    job.ID,
            Reply: reply,
            Error: err,
        }
    }
}

func main() {
    jobs := make(chan Job, 100)
    results := make(chan Result, 100)
    
    // 启动 3 个 worker
    for w := 1; w <= 3; w++ {
        go worker(w, jobs, results)
    }
    
    // 发送任务
    questions := []string{
        "什么是 Go?",
        "什么是 Python?",
        "什么是 Rust?",
        "什么是 Java?",
        "什么是 JavaScript?",
    }
    
    for i, q := range questions {
        jobs <- Job{ID: i, Message: q}
    }
    close(jobs)
    
    // 收集结果
    for i := 0; i < len(questions); i++ {
        result := <-results
        if result.Error != nil {
            fmt.Printf("任务 %d 失败: %v\n", result.ID, result.Error)
        } else {
            fmt.Printf("任务 %d: %s\n", result.ID, result.Reply[:50])
        }
    }
}

并发获取嵌入

func getEmbedding(model, text string) ([]float64, error) {
    req := map[string]string{
        "model":  model,
        "prompt": text,
    }
    
    body, _ := json.Marshal(req)
    resp, err := http.Post(
        "http://localhost:11434/api/embeddings",
        "application/json",
        bytes.NewReader(body),
    )
    if err != nil {
        return nil, err
    }
    defer resp.Body.Close()
    
    data, _ := io.ReadAll(resp.Body)
    var result struct {
        Embedding []float64 `json:"embedding"`
    }
    json.Unmarshal(data, &result)
    
    return result.Embedding, nil
}

func batchEmbeddings(texts []string) ([][]float64, error) {
    results := make([][]float64, len(texts))
    var wg sync.WaitGroup
    var mu sync.Mutex
    var firstError error
    
    for i, text := range texts {
        wg.Add(1)
        go func(idx int, t string) {
            defer wg.Done()
            
            emb, err := getEmbedding("nomic-embed-text", t)
            if err != nil {
                mu.Lock()
                if firstError == nil {
                    firstError = err
                }
                mu.Unlock()
                return
            }
            
            mu.Lock()
            results[idx] = emb
            mu.Unlock()
        }(i, text)
    }
    
    wg.Wait()
    return results, firstError
}

func main() {
    texts := []string{
        "Go 是一种编程语言",
        "Python 是一种编程语言",
        "Rust 是一种编程语言",
    }
    
    embeddings, err := batchEmbeddings(texts)
    if err != nil {
        panic(err)
    }
    
    for i, emb := range embeddings {
        fmt.Printf("文本 %d: %d 维向量\n", i, len(emb))
    }
}

限流控制

import (
    "time"
    "golang.org/x/time/rate"
)

func rateLimitedChat(limiter *rate.Limiter, model string, messages []Message) (string, error) {
    if err := limiter.Wait(context.Background()); err != nil {
        return "", err
    }
    
    return chat(model, messages)
}

func main() {
    // 每秒最多 5 个请求
    limiter := rate.NewLimiter(5, 10)
    
    for i := 0; i < 20; i++ {
        go func(idx int) {
            reply, err := rateLimitedChat(limiter, "llama3.2", []Message{
                {Role: "user", Content: fmt.Sprintf("请求 %d", idx)},
            })
            if err != nil {
                fmt.Printf("请求 %d 失败: %v\n", idx, err)
                return
            }
            fmt.Printf("请求 %d: %s\n", idx, reply[:30])
        }(i)
    }
    
    time.Sleep(10 * time.Second)
}

超时控制

import (
    "context"
    "time"
)

func chatWithTimeout(model string, messages []Message, timeout time.Duration) (string, error) {
    ctx, cancel := context.WithTimeout(context.Background(), timeout)
    defer cancel()
    
    req := ChatRequest{
        Model:    model,
        Messages: messages,
        Stream:   false,
    }
    
    body, _ := json.Marshal(req)
    httpReq, _ := http.NewRequestWithContext(
        ctx,
        "POST",
        "http://localhost:11434/api/chat",
        bytes.NewReader(body),
    )
    httpReq.Header.Set("Content-Type", "application/json")
    
    client := &http.Client{}
    resp, err := client.Do(httpReq)
    if err != nil {
        return "", err
    }
    defer resp.Body.Close()
    
    data, _ := io.ReadAll(resp.Body)
    var result ChatResponse
    json.Unmarshal(data, &result)
    
    return result.Message.Content, nil
}

func main() {
    reply, err := chatWithTimeout("llama3.2", []Message{
        {Role: "user", Content: "你好"},
    }, 30*time.Second)
    
    if err != nil {
        fmt.Printf("请求超时或失败: %v\n", err)
        return
    }
    
    fmt.Println(reply)
}

连接池

type OllamaClient struct {
    client *http.Client
    host   string
}

func NewOllamaClient(host string) *OllamaClient {
    transport := &http.Transport{
        MaxIdleConns:        100,
        MaxIdleConnsPerHost: 10,
        IdleConnTimeout:     90 * time.Second,
    }
    
    return &OllamaClient{
        client: &http.Client{
            Transport: transport,
            Timeout:   120 * time.Second,
        },
        host: host,
    }
}

func (c *OllamaClient) Chat(model string, messages []Message) (string, error) {
    req := ChatRequest{
        Model:    model,
        Messages: messages,
        Stream:   false,
    }
    
    body, _ := json.Marshal(req)
    resp, err := c.client.Post(
        c.host+"/api/chat",
        "application/json",
        bytes.NewReader(body),
    )
    if err != nil {
        return "", err
    }
    defer resp.Body.Close()
    
    data, _ := io.ReadAll(resp.Body)
    var result ChatResponse
    json.Unmarshal(data, &result)
    
    return result.Message.Content, nil
}