mirror of
https://github.com/Ladebeze66/ragflow_preprocess.git
synced 2026-02-04 06:00:27 +01:00
244 lines
8.2 KiB
Python
244 lines
8.2 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Module pour l'interaction avec l'API Ollama
|
|
"""
|
|
|
|
import json
|
|
import requests
|
|
import base64
|
|
from typing import List, Dict, Any, Optional, Union, Callable
|
|
|
|
|
|
class OllamaAPI:
|
|
"""
|
|
Classe pour interagir avec l'API Ollama
|
|
"""
|
|
|
|
def __init__(self, base_url: str = "http://217.182.105.173:11434"):
|
|
"""
|
|
Initialise la connexion à l'API Ollama
|
|
|
|
Args:
|
|
base_url (str): URL de base de l'API Ollama
|
|
"""
|
|
self.base_url = base_url.rstrip("/")
|
|
self.generate_endpoint = f"{self.base_url}/api/generate"
|
|
self.chat_endpoint = f"{self.base_url}/api/chat"
|
|
self.models_endpoint = f"{self.base_url}/api/tags"
|
|
|
|
def list_models(self) -> List[str]:
|
|
"""
|
|
Récupère la liste des modèles disponibles
|
|
|
|
Returns:
|
|
List[str]: Liste des noms de modèles disponibles
|
|
"""
|
|
try:
|
|
response = requests.get(self.models_endpoint)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Extraire les noms des modèles
|
|
models = [model['name'] for model in data.get('models', [])]
|
|
return models
|
|
|
|
except Exception as e:
|
|
print(f"Erreur lors de la récupération des modèles: {str(e)}")
|
|
return []
|
|
|
|
def generate(self, model: str, prompt: str, images: Optional[List[bytes]] = None,
|
|
options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Génère une réponse à partir d'un prompt
|
|
|
|
Args:
|
|
model (str): Nom du modèle à utiliser
|
|
prompt (str): Texte du prompt
|
|
images (List[bytes], optional): Liste d'images en bytes
|
|
options (Dict, optional): Options de génération
|
|
|
|
Returns:
|
|
Dict[str, Any]: Réponse du modèle
|
|
"""
|
|
# Options par défaut
|
|
default_options = {
|
|
"temperature": 0.2,
|
|
"top_p": 0.95,
|
|
"top_k": 40,
|
|
"num_predict": 1024
|
|
}
|
|
|
|
# Fusionner avec les options fournies
|
|
if options:
|
|
default_options.update(options)
|
|
|
|
# Construire la payload
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"options": default_options
|
|
}
|
|
|
|
# Ajouter les images si fournies (pour les modèles multimodaux)
|
|
if images:
|
|
base64_images = []
|
|
for img in images:
|
|
if isinstance(img, bytes):
|
|
base64_img = base64.b64encode(img).decode("utf-8")
|
|
base64_images.append(base64_img)
|
|
|
|
payload["images"] = base64_images
|
|
|
|
try:
|
|
# Envoyer la requête
|
|
response = requests.post(self.generate_endpoint, json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
except requests.exceptions.HTTPError as e:
|
|
print(f"Erreur HTTP: {e}")
|
|
return {"error": str(e)}
|
|
|
|
except Exception as e:
|
|
print(f"Erreur lors de la génération: {str(e)}")
|
|
return {"error": str(e)}
|
|
|
|
def chat(self, model: str, messages: List[Dict[str, Any]],
|
|
images: Optional[List[bytes]] = None,
|
|
options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Utilise l'API de chat pour une conversation
|
|
|
|
Args:
|
|
model (str): Nom du modèle à utiliser
|
|
messages (List[Dict]): Liste des messages de la conversation
|
|
Format: [{"role": "user", "content": "message"}, ...]
|
|
images (List[bytes], optional): Liste d'images en bytes (pour le dernier message)
|
|
options (Dict, optional): Options de génération
|
|
|
|
Returns:
|
|
Dict[str, Any]: Réponse du modèle
|
|
"""
|
|
# Options par défaut
|
|
default_options = {
|
|
"temperature": 0.2,
|
|
"top_p": 0.95,
|
|
"top_k": 40,
|
|
"num_predict": 1024
|
|
}
|
|
|
|
# Fusionner avec les options fournies
|
|
if options:
|
|
default_options.update(options)
|
|
|
|
# Construire la payload
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"options": default_options
|
|
}
|
|
|
|
# Ajouter les images au dernier message utilisateur si fournies
|
|
if images and messages and messages[-1]["role"] == "user":
|
|
base64_images = []
|
|
for img in images:
|
|
if isinstance(img, bytes):
|
|
base64_img = base64.b64encode(img).decode("utf-8")
|
|
base64_images.append(base64_img)
|
|
|
|
# Modifier le dernier message pour inclure les images
|
|
last_message = messages[-1].copy()
|
|
# Les images doivent être dans un champ distinct du modèle d'API d'Ollama
|
|
# Pas comme un champ texte standard mais dans un tableau d'images
|
|
if "images" not in last_message:
|
|
last_message["images"] = base64_images
|
|
|
|
# Remplacer le dernier message
|
|
payload["messages"] = messages[:-1] + [last_message]
|
|
|
|
try:
|
|
# Envoyer la requête
|
|
response = requests.post(self.chat_endpoint, json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
except requests.exceptions.HTTPError as e:
|
|
print(f"Erreur HTTP: {e}")
|
|
return {"error": str(e)}
|
|
|
|
except Exception as e:
|
|
print(f"Erreur lors du chat: {str(e)}")
|
|
return {"error": str(e)}
|
|
|
|
def stream_generate(self, model: str, prompt: str,
|
|
callback: Callable[[str], None],
|
|
options: Optional[Dict[str, Any]] = None) -> None:
|
|
"""
|
|
Génère une réponse en streaming et appelle le callback pour chaque morceau
|
|
|
|
Args:
|
|
model (str): Nom du modèle à utiliser
|
|
prompt (str): Texte du prompt
|
|
callback (Callable): Fonction à appeler pour chaque morceau de texte
|
|
options (Dict, optional): Options de génération
|
|
"""
|
|
# Options par défaut
|
|
default_options = {
|
|
"temperature": 0.2,
|
|
"top_p": 0.95,
|
|
"top_k": 40,
|
|
"num_predict": 1024,
|
|
"stream": True # Activer le streaming
|
|
}
|
|
|
|
# Fusionner avec les options fournies
|
|
if options:
|
|
default_options.update(options)
|
|
|
|
# S'assurer que stream est activé
|
|
default_options["stream"] = True
|
|
|
|
# Construire la payload
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"options": default_options
|
|
}
|
|
|
|
try:
|
|
# Envoyer la requête en streaming
|
|
with requests.post(self.generate_endpoint, json=payload, stream=True) as response:
|
|
response.raise_for_status()
|
|
|
|
# Traiter chaque ligne de la réponse
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
data = json.loads(line)
|
|
if "response" in data:
|
|
callback(data["response"])
|
|
except json.JSONDecodeError:
|
|
print(f"Erreur de décodage JSON: {line}")
|
|
|
|
except requests.exceptions.HTTPError as e:
|
|
print(f"Erreur HTTP: {e}")
|
|
callback(f"\nErreur: {str(e)}")
|
|
|
|
except Exception as e:
|
|
print(f"Erreur lors du streaming: {str(e)}")
|
|
callback(f"\nErreur: {str(e)}")
|
|
|
|
def check_connection(self) -> bool:
|
|
"""
|
|
Vérifie si la connexion à l'API Ollama est fonctionnelle
|
|
|
|
Returns:
|
|
bool: True si la connexion est établie, False sinon
|
|
"""
|
|
try:
|
|
response = requests.get(f"{self.base_url}/api/version")
|
|
return response.status_code == 200
|
|
except:
|
|
return False |