批量处理

批量处理是提高效率的关键,这一章我们学习如何高效处理大量请求。

基本批量处理

顺序处理

import ollama

def batch_generate(prompts: list, model: str = 'llama3.2'):
    results = []
    
    for prompt in prompts:
        response = ollama.generate(
            model=model,
            prompt=prompt
        )
        results.append(response['response'])
    
    return results

prompts = ['你好', '再见', '谢谢']
results = batch_generate(prompts)

for prompt, result in zip(prompts, results):
    print(f"{prompt} -> {result}")

并发处理

import ollama
from concurrent.futures import ThreadPoolExecutor, as_completed

def concurrent_generate(prompts: list, model: str = 'llama3.2', max_workers: int = 4):
    def generate(prompt):
        return ollama.generate(model=model, prompt=prompt)['response']
    
    results = {}
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(generate, p): p for p in prompts}
        
        for future in as_completed(futures):
            prompt = futures[future]
            try:
                results[prompt] = future.result()
            except Exception as e:
                results[prompt] = f"错误: {e}"
    
    return results

prompts = ['问题1', '问题2', '问题3', '问题4', '问题5']
results = concurrent_generate(prompts, max_workers=3)

批量处理器类

import ollama
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Callable
import time

class BatchProcessor:
    def __init__(self, model: str = 'llama3.2', max_workers: int = 4):
        self.model = model
        self.max_workers = max_workers
    
    def process_prompts(self, prompts: List[str]) -> List[Dict]:
        def process(prompt):
            start = time.time()
            response = ollama.generate(
                model=self.model,
                prompt=prompt
            )
            duration = time.time() - start
            
            return {
                'prompt': prompt,
                'response': response['response'],
                'duration': duration,
                'tokens': response.get('eval_count', 0)
            }
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            results = list(executor.map(process, prompts))
        
        return results
    
    def process_with_callback(
        self,
        items: List[str],
        callback: Callable,
        processor: Callable = None
    ):
        def default_processor(item):
            return ollama.generate(model=self.model, prompt=item)['response']
        
        proc = processor or default_processor
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = []
            
            for item in items:
                future = executor.submit(proc, item)
                future.add_done_callback(
                    lambda f, i=item: callback(i, f.result())
                )
                futures.append(future)
            
            for future in futures:
                future.result()
    
    def get_stats(self, results: List[Dict]) -> Dict:
        total_duration = sum(r['duration'] for r in results)
        total_tokens = sum(r['tokens'] for r in results)
        
        return {
            'total_items': len(results),
            'total_duration': total_duration,
            'total_tokens': total_tokens,
            'avg_duration': total_duration / len(results),
            'tokens_per_second': total_tokens / total_duration if total_duration > 0 else 0
        }

# 使用
processor = BatchProcessor(max_workers=3)

prompts = ['写一首诗', '讲个笑话', '解释量子力学']
results = processor.process_prompts(prompts)

stats = processor.get_stats(results)
print(f"处理 {stats['total_items']} 个任务")
print(f"总耗时: {stats['total_duration']:.2f}秒")
print(f"速度: {stats['tokens_per_second']:.1f} tokens/s")

异步批量处理

from ollama import AsyncClient
import asyncio
from typing import List

async def async_batch_generate(prompts: List[str], model: str = 'llama3.2'):
    client = AsyncClient()
    
    async def generate(prompt):
        response = await client.generate(model=model, prompt=prompt)
        return prompt, response['response']
    
    tasks = [generate(p) for p in prompts]
    results = await asyncio.gather(*tasks)
    
    return dict(results)

async def async_batch_chat(messages_list: List[List[dict]], model: str = 'llama3.2'):
    client = AsyncClient()
    
    async def chat(messages):
        response = await client.chat(model=model, messages=messages)
        return response['message']['content']
    
    tasks = [chat(m) for m in messages_list]
    results = await asyncio.gather(*tasks)
    
    return results

# 使用
async def main():
    prompts = ['问题1', '问题2', '问题3']
    results = await async_batch_generate(prompts)
    
    for prompt, response in results.items():
        print(f"{prompt}: {response[:50]}...")

asyncio.run(main())

文件批量处理

import ollama
import os
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

class FileBatchProcessor:
    def __init__(self, model: str = 'llama3.2', max_workers: int = 4):
        self.model = model
        self.max_workers = max_workers
    
    def process_files(
        self,
        input_dir: str,
        output_dir: str,
        prompt_template: str
    ):
        input_path = Path(input_dir)
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        files = list(input_path.glob('*.txt'))
        
        def process_file(filepath):
            with open(filepath, 'r', encoding='utf-8') as f:
                content = f.read()
            
            prompt = prompt_template.format(content=content)
            
            response = ollama.generate(
                model=self.model,
                prompt=prompt
            )
            
            output_file = output_path / f"{filepath.stem}_processed.txt"
            with open(output_file, 'w', encoding='utf-8') as f:
                f.write(response['response'])
            
            return filepath.name, True
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            results = list(executor.map(process_file, files))
        
        return results
    
    def summarize_files(self, input_dir: str, output_file: str):
        input_path = Path(input_dir)
        files = list(input_path.glob('*.txt'))
        
        summaries = []
        
        for filepath in files:
            with open(filepath, 'r', encoding='utf-8') as f:
                content = f.read()
            
            response = ollama.generate(
                model=self.model,
                prompt=f"总结以下内容:\n\n{content}"
            )
            
            summaries.append(f"## {filepath.name}\n\n{response['response']}\n")
        
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write('\n'.join(summaries))

# 使用
processor = FileBatchProcessor()

results = processor.process_files(
    input_dir='./documents',
    output_dir='./summaries',
    prompt_template='请总结以下内容:\n\n{content}'
)

processor.summarize_files('./documents', 'all_summaries.md')

限流控制

import ollama
import time
from threading import Semaphore
from concurrent.futures import ThreadPoolExecutor

class RateLimitedProcessor:
    def __init__(
        self,
        model: str = 'llama3.2',
        max_concurrent: int = 3,
        requests_per_second: float = 2.0
    ):
        self.model = model
        self.semaphore = Semaphore(max_concurrent)
        self.min_interval = 1.0 / requests_per_second
        self.last_request_time = 0
    
    def _rate_limit(self):
        with self.semaphore:
            elapsed = time.time() - self.last_request_time
            if elapsed < self.min_interval:
                time.sleep(self.min_interval - elapsed)
            
            self.last_request_time = time.time()
    
    def process(self, prompts: list):
        results = []
        
        def generate(prompt):
            self._rate_limit()
            return ollama.generate(model=self.model, prompt=prompt)['response']
        
        with ThreadPoolExecutor(max_workers=3) as executor:
            results = list(executor.map(generate, prompts))
        
        return results

# 使用
processor = RateLimitedProcessor(
    max_concurrent=2,
    requests_per_second=1.0
)

results = processor.process(['问题1', '问题2', '问题3'])

进度跟踪

import ollama
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

def batch_with_progress(prompts: list, model: str = 'llama3.2'):
    def generate(prompt):
        return ollama.generate(model=model, prompt=prompt)['response']
    
    results = []
    
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = list(tqdm(
            executor.map(generate, prompts),
            total=len(prompts),
            desc="处理中"
        ))
        results = list(futures)
    
    return results

prompts = [f"问题 {i}" for i in range(20)]
results = batch_with_progress(prompts)