在微服务架构中集成 Ollama,构建可扩展的 AI 服务。
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Gateway │────▶│ AI Service │────▶│ Ollama │
└─────────────┘ └─────────────┘ └─────────────┘
│
▼
┌─────────────┐
│ Redis │
└─────────────┘
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
import ollama
import redis
import json
import hashlib
app = FastAPI()
redis_client = redis.Redis(host='localhost', port=6379, db=0)
class ChatRequest(BaseModel):
message: str
model: str = "llama3.2"
session_id: Optional[str] = None
class ChatResponse(BaseModel):
response: str
session_id: str
def get_cache_key(model: str, message: str) -> str:
data = f"{model}:{message}"
return hashlib.md5(data.encode()).hexdigest()
def get_session_history(session_id: str) -> list:
key = f"session:{session_id}"
history = redis_client.get(key)
return json.loads(history) if history else []
def save_session_history(session_id: str, history: list):
key = f"session:{session_id}"
redis_client.setex(key, 3600, json.dumps(history))
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
cache_key = get_cache_key(request.model, request.message)
cached = redis_client.get(cache_key)
if cached:
return ChatResponse(response=cached.decode(), session_id=request.session_id or "")
history = []
if request.session_id:
history = get_session_history(request.session_id)
messages = history + [{"role": "user", "content": request.message}]
response = ollama.chat(
model=request.model,
messages=messages
)
reply = response["message"]["content"]
redis_client.setex(cache_key, 3600, reply)
if request.session_id:
history.append({"role": "user", "content": request.message})
history.append({"role": "assistant", "content": reply})
save_session_history(request.session_id, history)
return ChatResponse(response=reply, session_id=request.session_id or "")
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
version: '3.8'
services:
ollama:
image: ollama/ollama
ports:
- "11434:11434"
volumes:
- ollama_data:/root/.ollama
redis:
image: redis:alpine
ports:
- "6379:6379"
ai-service:
build: .
ports:
- "8000:8000"
depends_on:
- ollama
- redis
environment:
- OLLAMA_HOST=http://ollama:11434
- REDIS_HOST=redis
volumes:
ollama_data:
import random
from typing import List
class OllamaLoadBalancer:
def __init__(self, hosts: List[str]):
self.hosts = hosts
self.current = 0
def get_next_host(self) -> str:
host = self.hosts[self.current]
self.current = (self.current + 1) % len(self.hosts)
return host
def get_random_host(self) -> str:
return random.choice(self.hosts)
load_balancer = OllamaLoadBalancer([
"http://ollama-1:11434",
"http://ollama-2:11434",
"http://ollama-3:11434"
])
def chat_with_load_balance(model: str, messages: list) -> str:
host = load_balancer.get_next_host()
import requests
response = requests.post(
f"{host}/api/chat",
json={"model": model, "messages": messages}
)
return response.json()["message"]["content"]
from fastapi import FastAPI
import requests
app = FastAPI()
OLLAMA_HOSTS = [
"http://ollama-1:11434",
"http://ollama-2:11434"
]
def check_ollama_health(host: str) -> bool:
try:
response = requests.get(f"{host}/api/tags", timeout=5)
return response.status_code == 200
except:
return False
@app.get("/health")
async def health():
ollama_status = {
host: check_ollama_health(host)
for host in OLLAMA_HOSTS
}
all_healthy = all(ollama_status.values())
return {
"status": "healthy" if all_healthy else "degraded",
"ollama": ollama_status
}
from fastapi import FastAPI
from prometheus_client import Counter, Histogram, generate_latest
import time
app = FastAPI()
REQUEST_COUNT = Counter('ollama_requests_total', 'Total requests')
REQUEST_LATENCY = Histogram('ollama_request_latency_seconds', 'Request latency')
@app.middleware("http")
async def add_metrics(request, call_next):
REQUEST_COUNT.inc()
start = time.time()
response = await call_next(request)
duration = time.time() - start
REQUEST_LATENCY.observe(duration)
return response
@app.get("/metrics")
async def metrics():
return generate_latest()