缓存策略

缓存可以减少重复请求,提高响应速度,降低服务器负载。

内存缓存

基本缓存

import ollama
from typing import Dict, Optional, Callable
import hashlib
import json

class SimpleCache:
    def __init__(self):
        self.cache: Dict[str, any] = {}
    
    def _make_key(self, model: str, messages: list) -> str:
        data = {'model': model, 'messages': messages}
        return hashlib.md5(json.dumps(data, sort_keys=True).encode()).hexdigest()
    
    def get(self, key: str) -> Optional[any]:
        return self.cache.get(key)
    
    def set(self, key: str, value: any):
        self.cache[key] = value
    
    def clear(self):
        self.cache.clear()

class CachedOllamaClient:
    def __init__(self, model: str = 'llama3.2'):
        self.model = model
        self.cache = SimpleCache()
    
    def chat(self, messages: list, use_cache: bool = True) -> str:
        key = self.cache._make_key(self.model, messages)
        
        if use_cache:
            cached = self.cache.get(key)
            if cached is not None:
                print("使用缓存")
                return cached
        
        response = ollama.chat(model=self.model, messages=messages)
        result = response['message']['content']
        
        self.cache.set(key, result)
        return result

# 使用
client = CachedOllamaClient()

result1 = client.chat([{'role': 'user', 'content': '你好'}])
result2 = client.chat([{'role': 'user', 'content': '你好'}])

LRU 缓存

from collections import OrderedDict
from typing import Any

class LRUCache:
    def __init__(self, max_size: int = 100):
        self.max_size = max_size
        self.cache: OrderedDict[str, Any] = OrderedDict()
    
    def get(self, key: str) -> Optional[Any]:
        if key in self.cache:
            self.cache.move_to_end(key)
            return self.cache[key]
        return None
    
    def set(self, key: str, value: Any):
        if key in self.cache:
            self.cache.move_to_end(key)
        else:
            if len(self.cache) >= self.max_size:
                self.cache.popitem(last=False)
            self.cache[key] = value
    
    def clear(self):
        self.cache.clear()
    
    def size(self) -> int:
        return len(self.cache)

# 使用
cache = LRUCache(max_size=10)

for i in range(15):
    cache.set(f'key_{i}', f'value_{i}')

print(f"缓存大小: {cache.size()}")  # 10

带过期时间的缓存

import time
from typing import Any, Optional

class TTLCache:
    def __init__(self, ttl: int = 3600):
        self.ttl = ttl
        self.cache: Dict[str, tuple] = {}
    
    def get(self, key: str) -> Optional[Any]:
        if key in self.cache:
            value, timestamp = self.cache[key]
            
            if time.time() - timestamp < self.ttl:
                return value
            else:
                del self.cache[key]
        
        return None
    
    def set(self, key: str, value: Any):
        self.cache[key] = (value, time.time())
    
    def clear_expired(self):
        now = time.time()
        expired_keys = [
            k for k, (_, t) in self.cache.items()
            if now - t >= self.ttl
        ]
        
        for key in expired_keys:
            del self.cache[key]

# 使用
cache = TTLCache(ttl=60)

cache.set('test', 'value')
print(cache.get('test'))  # value

time.sleep(61)
print(cache.get('test'))  # None

持久化缓存

import pickle
import os
from pathlib import Path
from typing import Optional, Any

class DiskCache:
    def __init__(self, cache_dir: str = '.ollama_cache'):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
    
    def _get_cache_path(self, key: str) -> Path:
        return self.cache_dir / f"{key}.cache"
    
    def get(self, key: str) -> Optional[Any]:
        cache_path = self._get_cache_path(key)
        
        if cache_path.exists():
            try:
                with open(cache_path, 'rb') as f:
                    return data = pickle.load(f)
            except Exception:
                return None
        
        return None
    
    def set(self, key: str, value: Any):
        cache_path = self._get_cache_path(key)
        
        with open(cache_path, 'wb') as f:
            pickle.dump(value, f)
    
    def clear(self):
        for cache_file in self.cache_dir.glob('*.cache'):
            cache_file.unlink()

# 使用
cache = DiskCache(cache_dir='./cache')

cache.set('test_key', {'data': 'test'})
result = cache.get('test_key')
print(result)

智能缓存

import ollama
import hashlib
import json
from typing import List, Dict, Optional
import time

class SmartCache:
    def __init__(
        self,
        max_size: int = 100,
        ttl: int = 3600,
        persist: bool = False
    ):
        self.max_size = max_size
        self.ttl = ttl
        self.persist = persist
        self.cache: Dict[str, Dict] = {}
        self.access_order: List[str] = []
    
    def _make_key(self, model: str, messages: list, options: dict = None) -> str:
        data = {
            'model': model,
            'messages': messages,
            'options': options or {}
        }
        return hashlib.sha256(json.dumps(data, sort_keys=True).encode()).hexdigest()
    
    def get(self, key: str) -> Optional[any]:
        if key not in self.cache:
            return None
        
        entry = self.cache[key]
        
        if time.time() - entry['timestamp'] > self.ttl:
            del self.cache[key]
            self.access_order.remove(key)
            return None
        
        self.access_order.remove(key)
        self.access_order.append(key)
        
        return entry['value']
    
    def set(self, key: str, value: any):
        if key in self.cache:
            self.access_order.remove(key)
        elif len(self.cache) >= self.max_size:
            oldest_key = self.access_order.pop(0)
            del self.cache[oldest_key]
        
        self.cache[key] = {
            'value': value,
            'timestamp': time.time()
        }
        self.access_order.append(key)
    
    def get_stats(self) -> dict:
        return {
            'size': len(self.cache),
            'max_size': self.max_size,
            'hit_rate': self._calculate_hit_rate()
        }
    
    def _calculate_hit_rate(self) -> float:
        total = getattr(self, '_total_requests', 0)
        hits = getattr(self, '_hits', 0)
        return hits / total if total > 0 else 0

class SmartCachedClient:
    def __init__(
        self,
        model: str = 'llama3.2',
        cache_size: int = 100,
        cache_ttl: int = 3600
    ):
        self.model = model
        self.cache = SmartCache(max_size=cache_size, ttl=cache_ttl)
    
    def chat(
        self,
        messages: list,
        options: dict = None,
        use_cache: bool = True
    ) -> str:
        key = self.cache._make_key(self.model, messages, options)
        
        if use_cache:
            cached = self.cache.get(key)
            if cached is:
                return cached
        
        response = ollama.chat(
            model=self.model,
            messages=messages,
            **(options or {})
        )
        result = response['message']['content']
        
        self.cache.set(key, result)
        return result

# 使用
client = SmartCachedClient(model='llama3.2', cache_size=50)

result = client.chat([{'role': 'user', 'content': '你好'}])
print(result)