llm_lab/core/mistral_api.py
2025-03-26 15:40:31 +01:00

134 lines
4.7 KiB
Python

from core.base_llm import BaseLLM
import requests
import json
import os
from typing import Dict, List, Any, Optional
class MistralAPI(BaseLLM):
"""Intégration avec l'API Mistral (similaire à la classe Mistral de l'ancien projet)"""
def __init__(self):
model_name = "mistral-large-latest"
engine = "MistralAPI"
self.api_url = "https://api.mistral.ai/v1/chat/completions"
# À remplacer par la clé réelle ou une variable d'environnement
self.api_key = os.environ.get("MISTRAL_API_KEY", "")
default_params = {
# Paramètres de génération
"temperature": 0.7,
"top_p": 0.9,
"max_tokens": 1024,
# Paramètres de contrôle
"presence_penalty": 0,
"frequency_penalty": 0,
"stop": [],
# Paramètres divers
"random_seed": None
}
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
def generate(self, user_prompt: str) -> str:
"""
Génère une réponse à partir du prompt utilisateur via l'API Mistral
Args:
user_prompt: Texte du prompt utilisateur
Returns:
Réponse générée par le modèle
Raises:
ValueError: Si la clé API n'est pas définie
Exception: Si une erreur survient lors de l'appel à l'API
"""
prompt = self._format_prompt(user_prompt)
if not self.api_key:
raise ValueError("Clé API Mistral non définie. Définissez la variable d'environnement MISTRAL_API_KEY.")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
# Préparation des messages
messages = []
# Ajout du prompt système si défini
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
# Ajout du prompt utilisateur
messages.append({"role": "user", "content": user_prompt})
# Préparation du payload
payload = {
"model": self.model,
"messages": messages,
"temperature": self.params.get("temperature", 0.7),
"top_p": self.params.get("top_p", 0.9),
"max_tokens": self.params.get("max_tokens", 1024)
}
# Ajout des paramètres optionnels
if self.params.get("presence_penalty") is not None:
payload["presence_penalty"] = self.params.get("presence_penalty")
if self.params.get("frequency_penalty") is not None:
payload["frequency_penalty"] = self.params.get("frequency_penalty")
# Vérifier que stop est une liste non vide
stop_sequences = self.params.get("stop")
if isinstance(stop_sequences, list) and stop_sequences:
payload["stop"] = stop_sequences
if self.params.get("random_seed") is not None:
payload["random_seed"] = self.params.get("random_seed")
# Envoi de la requête
response = requests.post(self.api_url, headers=headers, json=payload)
if not response.ok:
raise Exception(f"Erreur API Mistral: {response.status_code} - {response.text}")
# Traitement de la réponse
result_data = response.json()
result_text = result_data.get("choices", [{}])[0].get("message", {}).get("content", "")
# Logging du résultat
filename = self._log_result(user_prompt, result_text)
return result_text
def obtenir_liste_modeles(self) -> List[str]:
"""
Récupère la liste des modèles disponibles via l'API Mistral
Returns:
Liste des identifiants de modèles disponibles
Raises:
ValueError: Si la clé API n'est pas définie
Exception: Si une erreur survient lors de l'appel à l'API
"""
if not self.api_key:
raise ValueError("Clé API Mistral non définie. Définissez la variable d'environnement MISTRAL_API_KEY.")
headers = {
"Authorization": f"Bearer {self.api_key}"
}
response = requests.get("https://api.mistral.ai/v1/models", headers=headers)
if not response.ok:
raise Exception(f"Erreur API Mistral: {response.status_code} - {response.text}")
result_data = response.json()
models = [model.get("id") for model in result_data.get("data", [])]
return models