Go 语言的并发特性非常适合处理多个 Ollama 请求,提高整体效率。
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
}
type ChatResponse struct {
Message Message `json:"message"`
}
func chat(model string, messages []Message) (string, error) {
req := ChatRequest{
Model: model,
Messages: messages,
Stream: false,
}
body, _ := json.Marshal(req)
resp, err := http.Post(
"http://localhost:11434/api/chat",
"application/json",
bytes.NewReader(body),
)
if err != nil {
return "", err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var result ChatResponse
json.Unmarshal(data, &result)
return result.Message.Content, nil
}
func main() {
questions := []string{
"什么是 Go?",
"什么是 Python?",
"什么是 Rust?",
}
var wg sync.WaitGroup
results := make(map[int]string)
var mu sync.Mutex
for i, question := range questions {
wg.Add(1)
go func(idx int, q string) {
defer wg.Done()
reply, err := chat("llama3.2", []Message{
{Role: "user", Content: q},
})
if err != nil {
fmt.Printf("请求失败: %v\n", err)
return
}
mu.Lock()
results[idx] = reply
mu.Unlock()
}(i, question)
}
wg.Wait()
for i := 0; i < len(questions); i++ {
fmt.Printf("问题%d: %s\n", i+1, questions[i])
fmt.Printf("回答: %s\n\n", results[i][:100])
}
}
package main
import (
"bytes"
"encoding/json"
"io"
"net/http"
)
type Job struct {
ID int
Message string
}
type Result struct {
ID int
Reply string
Error error
}
func worker(id int, jobs <-chan Job, results chan<- Result) {
for job := range jobs {
reply, err := chat("llama3.2", []Message{
{Role: "user", Content: job.Message},
})
results <- Result{
ID: job.ID,
Reply: reply,
Error: err,
}
}
}
func main() {
jobs := make(chan Job, 100)
results := make(chan Result, 100)
// 启动 3 个 worker
for w := 1; w <= 3; w++ {
go worker(w, jobs, results)
}
// 发送任务
questions := []string{
"什么是 Go?",
"什么是 Python?",
"什么是 Rust?",
"什么是 Java?",
"什么是 JavaScript?",
}
for i, q := range questions {
jobs <- Job{ID: i, Message: q}
}
close(jobs)
// 收集结果
for i := 0; i < len(questions); i++ {
result := <-results
if result.Error != nil {
fmt.Printf("任务 %d 失败: %v\n", result.ID, result.Error)
} else {
fmt.Printf("任务 %d: %s\n", result.ID, result.Reply[:50])
}
}
}
func getEmbedding(model, text string) ([]float64, error) {
req := map[string]string{
"model": model,
"prompt": text,
}
body, _ := json.Marshal(req)
resp, err := http.Post(
"http://localhost:11434/api/embeddings",
"application/json",
bytes.NewReader(body),
)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var result struct {
Embedding []float64 `json:"embedding"`
}
json.Unmarshal(data, &result)
return result.Embedding, nil
}
func batchEmbeddings(texts []string) ([][]float64, error) {
results := make([][]float64, len(texts))
var wg sync.WaitGroup
var mu sync.Mutex
var firstError error
for i, text := range texts {
wg.Add(1)
go func(idx int, t string) {
defer wg.Done()
emb, err := getEmbedding("nomic-embed-text", t)
if err != nil {
mu.Lock()
if firstError == nil {
firstError = err
}
mu.Unlock()
return
}
mu.Lock()
results[idx] = emb
mu.Unlock()
}(i, text)
}
wg.Wait()
return results, firstError
}
func main() {
texts := []string{
"Go 是一种编程语言",
"Python 是一种编程语言",
"Rust 是一种编程语言",
}
embeddings, err := batchEmbeddings(texts)
if err != nil {
panic(err)
}
for i, emb := range embeddings {
fmt.Printf("文本 %d: %d 维向量\n", i, len(emb))
}
}
import (
"time"
"golang.org/x/time/rate"
)
func rateLimitedChat(limiter *rate.Limiter, model string, messages []Message) (string, error) {
if err := limiter.Wait(context.Background()); err != nil {
return "", err
}
return chat(model, messages)
}
func main() {
// 每秒最多 5 个请求
limiter := rate.NewLimiter(5, 10)
for i := 0; i < 20; i++ {
go func(idx int) {
reply, err := rateLimitedChat(limiter, "llama3.2", []Message{
{Role: "user", Content: fmt.Sprintf("请求 %d", idx)},
})
if err != nil {
fmt.Printf("请求 %d 失败: %v\n", idx, err)
return
}
fmt.Printf("请求 %d: %s\n", idx, reply[:30])
}(i)
}
time.Sleep(10 * time.Second)
}
import (
"context"
"time"
)
func chatWithTimeout(model string, messages []Message, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req := ChatRequest{
Model: model,
Messages: messages,
Stream: false,
}
body, _ := json.Marshal(req)
httpReq, _ := http.NewRequestWithContext(
ctx,
"POST",
"http://localhost:11434/api/chat",
bytes.NewReader(body),
)
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var result ChatResponse
json.Unmarshal(data, &result)
return result.Message.Content, nil
}
func main() {
reply, err := chatWithTimeout("llama3.2", []Message{
{Role: "user", Content: "你好"},
}, 30*time.Second)
if err != nil {
fmt.Printf("请求超时或失败: %v\n", err)
return
}
fmt.Println(reply)
}
type OllamaClient struct {
client *http.Client
host string
}
func NewOllamaClient(host string) *OllamaClient {
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
}
return &OllamaClient{
client: &http.Client{
Transport: transport,
Timeout: 120 * time.Second,
},
host: host,
}
}
func (c *OllamaClient) Chat(model string, messages []Message) (string, error) {
req := ChatRequest{
Model: model,
Messages: messages,
Stream: false,
}
body, _ := json.Marshal(req)
resp, err := c.client.Post(
c.host+"/api/chat",
"application/json",
bytes.NewReader(body),
)
if err != nil {
return "", err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var result ChatResponse
json.Unmarshal(data, &result)
return result.Message.Content, nil
}