mirror of
https://github.com/Ladebeze66/llm_lab.git
synced 2025-12-13 10:46:50 +01:00
134 lines
4.7 KiB
Python
134 lines
4.7 KiB
Python
from core.base_llm import BaseLLM
|
|
import requests
|
|
import json
|
|
import os
|
|
from typing import Dict, List, Any, Optional
|
|
|
|
class MistralAPI(BaseLLM):
|
|
"""Intégration avec l'API Mistral (similaire à la classe Mistral de l'ancien projet)"""
|
|
|
|
def __init__(self):
|
|
model_name = "mistral-large-latest"
|
|
engine = "MistralAPI"
|
|
|
|
self.api_url = "https://api.mistral.ai/v1/chat/completions"
|
|
# À remplacer par la clé réelle ou une variable d'environnement
|
|
self.api_key = os.environ.get("MISTRAL_API_KEY", "")
|
|
|
|
default_params = {
|
|
# Paramètres de génération
|
|
"temperature": 0.7,
|
|
"top_p": 0.9,
|
|
"max_tokens": 1024,
|
|
|
|
# Paramètres de contrôle
|
|
"presence_penalty": 0,
|
|
"frequency_penalty": 0,
|
|
"stop": [],
|
|
|
|
# Paramètres divers
|
|
"random_seed": None
|
|
}
|
|
|
|
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
|
|
|
|
def generate(self, user_prompt: str) -> str:
|
|
"""
|
|
Génère une réponse à partir du prompt utilisateur via l'API Mistral
|
|
|
|
Args:
|
|
user_prompt: Texte du prompt utilisateur
|
|
|
|
Returns:
|
|
Réponse générée par le modèle
|
|
|
|
Raises:
|
|
ValueError: Si la clé API n'est pas définie
|
|
Exception: Si une erreur survient lors de l'appel à l'API
|
|
"""
|
|
prompt = self._format_prompt(user_prompt)
|
|
|
|
if not self.api_key:
|
|
raise ValueError("Clé API Mistral non définie. Définissez la variable d'environnement MISTRAL_API_KEY.")
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}"
|
|
}
|
|
|
|
# Préparation des messages
|
|
messages = []
|
|
|
|
# Ajout du prompt système si défini
|
|
if self.system_prompt:
|
|
messages.append({"role": "system", "content": self.system_prompt})
|
|
|
|
# Ajout du prompt utilisateur
|
|
messages.append({"role": "user", "content": user_prompt})
|
|
|
|
# Préparation du payload
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"temperature": self.params.get("temperature", 0.7),
|
|
"top_p": self.params.get("top_p", 0.9),
|
|
"max_tokens": self.params.get("max_tokens", 1024)
|
|
}
|
|
|
|
# Ajout des paramètres optionnels
|
|
if self.params.get("presence_penalty") is not None:
|
|
payload["presence_penalty"] = self.params.get("presence_penalty")
|
|
|
|
if self.params.get("frequency_penalty") is not None:
|
|
payload["frequency_penalty"] = self.params.get("frequency_penalty")
|
|
|
|
# Vérifier que stop est une liste non vide
|
|
stop_sequences = self.params.get("stop")
|
|
if isinstance(stop_sequences, list) and stop_sequences:
|
|
payload["stop"] = stop_sequences
|
|
|
|
if self.params.get("random_seed") is not None:
|
|
payload["random_seed"] = self.params.get("random_seed")
|
|
|
|
# Envoi de la requête
|
|
response = requests.post(self.api_url, headers=headers, json=payload)
|
|
|
|
if not response.ok:
|
|
raise Exception(f"Erreur API Mistral: {response.status_code} - {response.text}")
|
|
|
|
# Traitement de la réponse
|
|
result_data = response.json()
|
|
result_text = result_data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|
|
|
# Logging du résultat
|
|
filename = self._log_result(user_prompt, result_text)
|
|
|
|
return result_text
|
|
|
|
def obtenir_liste_modeles(self) -> List[str]:
|
|
"""
|
|
Récupère la liste des modèles disponibles via l'API Mistral
|
|
|
|
Returns:
|
|
Liste des identifiants de modèles disponibles
|
|
|
|
Raises:
|
|
ValueError: Si la clé API n'est pas définie
|
|
Exception: Si une erreur survient lors de l'appel à l'API
|
|
"""
|
|
if not self.api_key:
|
|
raise ValueError("Clé API Mistral non définie. Définissez la variable d'environnement MISTRAL_API_KEY.")
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}"
|
|
}
|
|
|
|
response = requests.get("https://api.mistral.ai/v1/models", headers=headers)
|
|
|
|
if not response.ok:
|
|
raise Exception(f"Erreur API Mistral: {response.status_code} - {response.text}")
|
|
|
|
result_data = response.json()
|
|
models = [model.get("id") for model in result_data.get("data", [])]
|
|
|
|
return models |