ragflow_preprocess/utils/api_ollama.py
2025-03-27 14:08:10 +01:00

244 lines
8.2 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Module pour l'interaction avec l'API Ollama
"""
import json
import requests
import base64
from typing import List, Dict, Any, Optional, Union, Callable
class OllamaAPI:
"""
Classe pour interagir avec l'API Ollama
"""
def __init__(self, base_url: str = "http://217.182.105.173:11434"):
"""
Initialise la connexion à l'API Ollama
Args:
base_url (str): URL de base de l'API Ollama
"""
self.base_url = base_url.rstrip("/")
self.generate_endpoint = f"{self.base_url}/api/generate"
self.chat_endpoint = f"{self.base_url}/api/chat"
self.models_endpoint = f"{self.base_url}/api/tags"
def list_models(self) -> List[str]:
"""
Récupère la liste des modèles disponibles
Returns:
List[str]: Liste des noms de modèles disponibles
"""
try:
response = requests.get(self.models_endpoint)
response.raise_for_status()
data = response.json()
# Extraire les noms des modèles
models = [model['name'] for model in data.get('models', [])]
return models
except Exception as e:
print(f"Erreur lors de la récupération des modèles: {str(e)}")
return []
def generate(self, model: str, prompt: str, images: Optional[List[bytes]] = None,
options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Génère une réponse à partir d'un prompt
Args:
model (str): Nom du modèle à utiliser
prompt (str): Texte du prompt
images (List[bytes], optional): Liste d'images en bytes
options (Dict, optional): Options de génération
Returns:
Dict[str, Any]: Réponse du modèle
"""
# Options par défaut
default_options = {
"temperature": 0.2,
"top_p": 0.95,
"top_k": 40,
"num_predict": 1024
}
# Fusionner avec les options fournies
if options:
default_options.update(options)
# Construire la payload
payload = {
"model": model,
"prompt": prompt,
"options": default_options
}
# Ajouter les images si fournies (pour les modèles multimodaux)
if images:
base64_images = []
for img in images:
if isinstance(img, bytes):
base64_img = base64.b64encode(img).decode("utf-8")
base64_images.append(base64_img)
payload["images"] = base64_images
try:
# Envoyer la requête
response = requests.post(self.generate_endpoint, json=payload)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
print(f"Erreur HTTP: {e}")
return {"error": str(e)}
except Exception as e:
print(f"Erreur lors de la génération: {str(e)}")
return {"error": str(e)}
def chat(self, model: str, messages: List[Dict[str, Any]],
images: Optional[List[bytes]] = None,
options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Utilise l'API de chat pour une conversation
Args:
model (str): Nom du modèle à utiliser
messages (List[Dict]): Liste des messages de la conversation
Format: [{"role": "user", "content": "message"}, ...]
images (List[bytes], optional): Liste d'images en bytes (pour le dernier message)
options (Dict, optional): Options de génération
Returns:
Dict[str, Any]: Réponse du modèle
"""
# Options par défaut
default_options = {
"temperature": 0.2,
"top_p": 0.95,
"top_k": 40,
"num_predict": 1024
}
# Fusionner avec les options fournies
if options:
default_options.update(options)
# Construire la payload
payload = {
"model": model,
"messages": messages,
"options": default_options
}
# Ajouter les images au dernier message utilisateur si fournies
if images and messages and messages[-1]["role"] == "user":
base64_images = []
for img in images:
if isinstance(img, bytes):
base64_img = base64.b64encode(img).decode("utf-8")
base64_images.append(base64_img)
# Modifier le dernier message pour inclure les images
last_message = messages[-1].copy()
# Les images doivent être dans un champ distinct du modèle d'API d'Ollama
# Pas comme un champ texte standard mais dans un tableau d'images
if "images" not in last_message:
last_message["images"] = base64_images
# Remplacer le dernier message
payload["messages"] = messages[:-1] + [last_message]
try:
# Envoyer la requête
response = requests.post(self.chat_endpoint, json=payload)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
print(f"Erreur HTTP: {e}")
return {"error": str(e)}
except Exception as e:
print(f"Erreur lors du chat: {str(e)}")
return {"error": str(e)}
def stream_generate(self, model: str, prompt: str,
callback: Callable[[str], None],
options: Optional[Dict[str, Any]] = None) -> None:
"""
Génère une réponse en streaming et appelle le callback pour chaque morceau
Args:
model (str): Nom du modèle à utiliser
prompt (str): Texte du prompt
callback (Callable): Fonction à appeler pour chaque morceau de texte
options (Dict, optional): Options de génération
"""
# Options par défaut
default_options = {
"temperature": 0.2,
"top_p": 0.95,
"top_k": 40,
"num_predict": 1024,
"stream": True # Activer le streaming
}
# Fusionner avec les options fournies
if options:
default_options.update(options)
# S'assurer que stream est activé
default_options["stream"] = True
# Construire la payload
payload = {
"model": model,
"prompt": prompt,
"options": default_options
}
try:
# Envoyer la requête en streaming
with requests.post(self.generate_endpoint, json=payload, stream=True) as response:
response.raise_for_status()
# Traiter chaque ligne de la réponse
for line in response.iter_lines():
if line:
try:
data = json.loads(line)
if "response" in data:
callback(data["response"])
except json.JSONDecodeError:
print(f"Erreur de décodage JSON: {line}")
except requests.exceptions.HTTPError as e:
print(f"Erreur HTTP: {e}")
callback(f"\nErreur: {str(e)}")
except Exception as e:
print(f"Erreur lors du streaming: {str(e)}")
callback(f"\nErreur: {str(e)}")
def check_connection(self) -> bool:
"""
Vérifie si la connexion à l'API Ollama est fonctionnelle
Returns:
bool: True si la connexion est établie, False sinon
"""
try:
response = requests.get(f"{self.base_url}/api/version")
return response.status_code == 200
except:
return False