mirror of
https://github.com/Ladebeze66/llm_lab.git
synced 2025-12-13 10:46:50 +01:00
save
This commit is contained in:
parent
1cc834f9ab
commit
82c933bd5f
@ -1,157 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from agents.base_agent import Agent
|
||||
from core.factory import LLMFactory
|
||||
|
||||
class AgentAnalyseJSON(Agent):
|
||||
"""
|
||||
Agent pour analyser des données JSON
|
||||
"""
|
||||
def __init__(self, nom: str = "AgentAnalyseJSON", modele: str = "mistral7b"):
|
||||
super().__init__(nom)
|
||||
|
||||
# Choix du modèle
|
||||
self.llm = LLMFactory.create(modele)
|
||||
|
||||
# Configuration du modèle
|
||||
self.llm.set_role("formateur", {
|
||||
"system_prompt": "Tu es un expert en analyse de données JSON. Tu dois extraire des informations pertinentes, identifier des tendances et répondre à des questions sur les données.",
|
||||
"params": {
|
||||
"temperature": 0.4,
|
||||
"top_p": 0.9
|
||||
}
|
||||
})
|
||||
|
||||
def executer(self, json_data: Dict[str, Any],
|
||||
question: str = "Analyse ces données et extrait les informations principales.") -> Dict[str, Any]:
|
||||
"""
|
||||
Analyse des données JSON
|
||||
|
||||
Args:
|
||||
json_data: Données JSON à analyser
|
||||
question: Question à poser au modèle
|
||||
|
||||
Returns:
|
||||
Résultats de l'analyse sous forme de dictionnaire
|
||||
"""
|
||||
# Conversion du JSON en chaîne formatée
|
||||
json_str = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||
|
||||
# Construction du prompt avec le JSON et la question
|
||||
prompt = f"{question}\n\nDonnées JSON à analyser:\n```json\n{json_str}\n```"
|
||||
|
||||
# Interrogation du modèle
|
||||
reponse = self.llm.generate(prompt)
|
||||
|
||||
# Construction du résultat
|
||||
resultats = {
|
||||
"question": question,
|
||||
"reponse": reponse,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"taille_json": len(json_str)
|
||||
}
|
||||
|
||||
# Ajout à l'historique
|
||||
self.ajouter_historique("analyse_json", prompt[:200] + "...", reponse[:200] + "...")
|
||||
|
||||
return resultats
|
||||
|
||||
def extraire_structure(self, json_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extrait la structure d'un JSON (clés, types, profondeur)
|
||||
"""
|
||||
resultat = {
|
||||
"structure": {},
|
||||
"statistiques": {
|
||||
"nb_cles": 0,
|
||||
"profondeur_max": 0,
|
||||
"types": {}
|
||||
}
|
||||
}
|
||||
|
||||
def explorer_structure(data, chemin="", profondeur=0):
|
||||
nonlocal resultat
|
||||
|
||||
# Mise à jour de la profondeur max
|
||||
resultat["statistiques"]["profondeur_max"] = max(resultat["statistiques"]["profondeur_max"], profondeur)
|
||||
|
||||
if isinstance(data, dict):
|
||||
structure = {}
|
||||
for cle, valeur in data.items():
|
||||
nouveau_chemin = f"{chemin}.{cle}" if chemin else cle
|
||||
resultat["statistiques"]["nb_cles"] += 1
|
||||
|
||||
type_valeur = type(valeur).__name__
|
||||
if type_valeur not in resultat["statistiques"]["types"]:
|
||||
resultat["statistiques"]["types"][type_valeur] = 0
|
||||
resultat["statistiques"]["types"][type_valeur] += 1
|
||||
|
||||
if isinstance(valeur, (dict, list)):
|
||||
structure[cle] = explorer_structure(valeur, nouveau_chemin, profondeur + 1)
|
||||
else:
|
||||
structure[cle] = type_valeur
|
||||
return structure
|
||||
|
||||
elif isinstance(data, list):
|
||||
if data and isinstance(data[0], (dict, list)):
|
||||
# Pour les listes de structures complexes, on analyse le premier élément
|
||||
return [explorer_structure(data[0], f"{chemin}[0]", profondeur + 1)]
|
||||
else:
|
||||
# Pour les listes de valeurs simples
|
||||
type_elements = "vide" if not data else type(data[0]).__name__
|
||||
resultat["statistiques"]["nb_cles"] += 1
|
||||
|
||||
if "list" not in resultat["statistiques"]["types"]:
|
||||
resultat["statistiques"]["types"]["list"] = 0
|
||||
resultat["statistiques"]["types"]["list"] += 1
|
||||
|
||||
return f"list[{type_elements}]"
|
||||
else:
|
||||
return type(data).__name__
|
||||
|
||||
resultat["structure"] = explorer_structure(json_data)
|
||||
|
||||
self.ajouter_historique("extraire_structure", "JSON", resultat)
|
||||
return resultat
|
||||
|
||||
def fusionner_jsons(self, json1: Dict[str, Any], json2: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Fusionne deux structures JSON en conservant les informations des deux
|
||||
"""
|
||||
if not json1:
|
||||
return json2
|
||||
|
||||
if not json2:
|
||||
return json1
|
||||
|
||||
resultat = json1.copy()
|
||||
|
||||
# Fonction récursive pour fusionner
|
||||
def fusionner(dict1, dict2):
|
||||
for cle, valeur in dict2.items():
|
||||
if cle in dict1:
|
||||
# Si les deux sont des dictionnaires, fusion récursive
|
||||
if isinstance(dict1[cle], dict) and isinstance(valeur, dict):
|
||||
fusionner(dict1[cle], valeur)
|
||||
# Si les deux sont des listes, concaténation
|
||||
elif isinstance(dict1[cle], list) and isinstance(valeur, list):
|
||||
dict1[cle].extend(valeur)
|
||||
# Sinon, on garde les deux valeurs dans une liste
|
||||
else:
|
||||
if not isinstance(dict1[cle], list):
|
||||
dict1[cle] = [dict1[cle]]
|
||||
if isinstance(valeur, list):
|
||||
dict1[cle].extend(valeur)
|
||||
else:
|
||||
dict1[cle].append(valeur)
|
||||
else:
|
||||
# Si la clé n'existe pas dans dict1, on l'ajoute simplement
|
||||
dict1[cle] = valeur
|
||||
|
||||
fusionner(resultat, json2)
|
||||
self.ajouter_historique("fusionner_jsons", "Fusion de deux JSON", "Fusion réussie")
|
||||
|
||||
return resultat
|
||||
@ -1,94 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from agents.base_agent import Agent
|
||||
from core.factory import LLMFactory
|
||||
|
||||
class AgentAnalyseImage(Agent):
|
||||
"""
|
||||
Agent pour analyser des images avec LlamaVision
|
||||
"""
|
||||
def __init__(self, nom: str = "AgentAnalyseImage"):
|
||||
super().__init__(nom)
|
||||
self.llm = LLMFactory.create("llamavision")
|
||||
self.questions_standard = [
|
||||
"Décris en détail ce que tu vois sur cette image.",
|
||||
"Quels sont les éléments principaux visibles sur cette image?",
|
||||
"Y a-t-il du texte visible sur cette image? Si oui, peux-tu le transcrire?",
|
||||
"Quelle est l'ambiance générale de cette image?"
|
||||
]
|
||||
|
||||
def executer(self, image_path: str, json_data: Optional[Dict[str, Any]] = None,
|
||||
question: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyse une image avec Llama Vision
|
||||
|
||||
Args:
|
||||
image_path: Chemin vers l'image à analyser
|
||||
json_data: Données JSON optionnelles associées à l'image
|
||||
question: Question à poser au modèle (si None, utilise les questions standard)
|
||||
|
||||
Returns:
|
||||
Résultats de l'analyse sous forme de dictionnaire
|
||||
"""
|
||||
# Vérification de l'existence de l'image
|
||||
if not os.path.exists(image_path):
|
||||
resultat = {"erreur": f"L'image {image_path} n'existe pas"}
|
||||
self.ajouter_historique("analyse_image_erreur", image_path, resultat)
|
||||
return resultat
|
||||
|
||||
# Préparation de la liste des questions à poser
|
||||
questions_a_poser = [question] if question else self.questions_standard
|
||||
|
||||
# Exécution de l'analyse pour chaque question
|
||||
resultats = {
|
||||
"image": image_path,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"analyses": []
|
||||
}
|
||||
|
||||
# Encodage de l'image en base64 pour l'API
|
||||
with open(image_path, "rb") as image_file:
|
||||
image_data = image_file.read()
|
||||
|
||||
# Traitement pour chaque question
|
||||
for q in questions_a_poser:
|
||||
try:
|
||||
# Appel du modèle via la factory
|
||||
reponse = self.llm.generate(q, images=[image_data])
|
||||
|
||||
# Ajout du résultat
|
||||
resultats["analyses"].append({
|
||||
"question": q,
|
||||
"reponse": reponse
|
||||
})
|
||||
|
||||
# Ajout à l'historique
|
||||
self.ajouter_historique("analyse_image", q, reponse[:200] + "...")
|
||||
except Exception as e:
|
||||
resultats["analyses"].append({
|
||||
"question": q,
|
||||
"erreur": str(e)
|
||||
})
|
||||
self.ajouter_historique("analyse_image_erreur", q, str(e))
|
||||
|
||||
return resultats
|
||||
|
||||
def sauvegarder_resultats(self, chemin_fichier: str, resultats: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Sauvegarde les résultats d'analyse dans un fichier JSON
|
||||
"""
|
||||
try:
|
||||
# Création du dossier parent si nécessaire
|
||||
os.makedirs(os.path.dirname(os.path.abspath(chemin_fichier)), exist_ok=True)
|
||||
|
||||
with open(chemin_fichier, 'w', encoding='utf-8') as f:
|
||||
json.dump(resultats, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.ajouter_historique("sauvegarder_resultats", chemin_fichier, "Résultats sauvegardés")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.ajouter_historique("sauvegarder_resultats_erreur", chemin_fichier, str(e))
|
||||
return False
|
||||
@ -1,26 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
class Agent:
|
||||
"""Classe de base pour tous les agents d'analyse"""
|
||||
|
||||
def __init__(self, nom: str = "Agent"):
|
||||
self.nom = nom
|
||||
self.historique = []
|
||||
|
||||
def ajouter_historique(self, action: str, input_data: Any, output_data: Any) -> None:
|
||||
"""Ajoute une entrée dans l'historique de l'agent"""
|
||||
self.historique.append({
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"action": action,
|
||||
"input": str(input_data)[:500], # Limite pour éviter des historiques trop grands
|
||||
"output": str(output_data)[:500] # Limite pour éviter des historiques trop grands
|
||||
})
|
||||
|
||||
def obtenir_historique(self) -> List[Dict[str, Any]]:
|
||||
"""Retourne l'historique complet de l'agent"""
|
||||
return self.historique
|
||||
|
||||
def executer(self, *args, **kwargs) -> Any:
|
||||
"""Méthode abstraite à implémenter dans les classes dérivées"""
|
||||
raise NotImplementedError("Chaque agent doit implémenter sa propre méthode executer()")
|
||||
@ -1,6 +1,5 @@
|
||||
from core.mistral7b import Mistral7B
|
||||
from core.llama_vision90b import LlamaVision90B
|
||||
from core.mistral_api import MistralAPI
|
||||
|
||||
class LLMFactory:
|
||||
"""
|
||||
@ -9,7 +8,6 @@ class LLMFactory:
|
||||
_registry = {
|
||||
"mistral7b": Mistral7B,
|
||||
"llamavision": LlamaVision90B,
|
||||
"mistralapi": MistralAPI,
|
||||
# Ajouter d'autres modèles LLM ici
|
||||
}
|
||||
|
||||
|
||||
@ -2,8 +2,6 @@ from core.base_llm import BaseLLM
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from typing import Dict, List, Any, Optional
|
||||
from deep_translator import GoogleTranslator
|
||||
|
||||
class LlamaVision90B(BaseLLM):
|
||||
@ -14,105 +12,29 @@ class LlamaVision90B(BaseLLM):
|
||||
self.api_url = "http://217.182.105.173:11434/api/chat"
|
||||
|
||||
default_params = {
|
||||
# Paramètres de créativité
|
||||
"temperature": 0.3, # Créativité basse pour analyse technique
|
||||
"top_p": 1.0, # Conserve toute la distribution
|
||||
"top_k": 40, # Limite vocabulaire
|
||||
|
||||
# Paramètres de qualité
|
||||
"repeat_penalty": 1.1, # Réduction des répétitions
|
||||
"min_p": 0.0, # Seuil minimal pour la probabilité des tokens
|
||||
|
||||
# Paramètres de contrôle avancé
|
||||
"mirostat": 0, # 0=désactivé, 1=v1, 2=v2
|
||||
"mirostat_eta": 0.1, # Taux d'apprentissage pour mirostat
|
||||
"mirostat_tau": 5.0, # Cible pour mirostat
|
||||
|
||||
# Paramètres de taille
|
||||
"num_predict": 512, # Longueur max sortie
|
||||
"num_ctx": 4096, # Contexte étendu
|
||||
|
||||
# Paramètres de contrôle
|
||||
"seed": 0, # Graine pour reproductibilité (0=aléatoire)
|
||||
"stop": [], # Séquences d'arrêt
|
||||
|
||||
# Paramètres de format
|
||||
"format": "json", # Réponse structurée JSON (optionnel)
|
||||
"stream": False, # Réponse d'un seul bloc
|
||||
"raw": False, # Laisse le formatage systèmes
|
||||
"keep_alive": "5m" # Durée de vie de la connexion
|
||||
"temperature": 0.3, #Créativité basse pour analyse technique
|
||||
"top_p": 1.0, # Conserve toute la distribution
|
||||
"top_k": 40, #Limite vocabulaire
|
||||
"repeat_penalty": 1.1, #Réduction des répétitions
|
||||
"num_predict": 512, #longueur max sortie
|
||||
"num-ctx": 4096, #Contexte étendu
|
||||
"format": "json", #Réponse structurée JSON (optionnel)
|
||||
"stream": False, #Réponse d'un seul bloc
|
||||
"raw": False, #laisse le formatage systèmes
|
||||
"keep_alive": "5m" #Durée de vie de la connexion
|
||||
}
|
||||
|
||||
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
|
||||
|
||||
# Attributs spécifiques pour Llama Vision
|
||||
self.image_data = None
|
||||
self.json_data = {}
|
||||
|
||||
def set_image(self, image_path: str) -> bool:
|
||||
"""
|
||||
Définit l'image à analyser à partir d'un chemin de fichier
|
||||
Retourne True si l'image a été chargée avec succès
|
||||
"""
|
||||
if not os.path.exists(image_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
self.image_data = f.read()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Erreur lors du chargement de l'image: {e}")
|
||||
return False
|
||||
|
||||
def set_image_data(self, image_data: bytes) -> bool:
|
||||
"""
|
||||
Définit directement les données de l'image à analyser
|
||||
"""
|
||||
try:
|
||||
self.image_data = image_data
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Erreur lors de la définition des données d'image: {e}")
|
||||
return False
|
||||
|
||||
def set_json_data(self, json_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Définit les données JSON à associer à l'image
|
||||
"""
|
||||
try:
|
||||
self.json_data = json_data
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Erreur lors de la définition des données JSON: {e}")
|
||||
return False
|
||||
|
||||
def generate(self, user_prompt: str, images: list = None, translate: bool = False):
|
||||
prompt = self._format_prompt(user_prompt)
|
||||
|
||||
# Si des images sont fournies directement, utilisez-les
|
||||
images_to_use = images if images else []
|
||||
|
||||
# Si image_data est défini et aucune image n'est fournie explicitement
|
||||
if self.image_data is not None and not images:
|
||||
# Encodage en base64 si ce n'est pas déjà fait
|
||||
if isinstance(self.image_data, bytes):
|
||||
encoded_image = base64.b64encode(self.image_data).decode('utf-8')
|
||||
images_to_use = [encoded_image]
|
||||
|
||||
# Ajout des données JSON dans le prompt si disponibles
|
||||
if self.json_data:
|
||||
json_str = json.dumps(self.json_data, ensure_ascii=False, indent=2)
|
||||
# On ajoute le JSON au prompt pour qu'il soit traité avec l'image
|
||||
prompt = f"{prompt}\n\nVoici des données JSON associées à cette image:\n```json\n{json_str}\n```"
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": images_to_use
|
||||
"images": images if images else []
|
||||
}
|
||||
],
|
||||
"options": self.params,
|
||||
@ -129,58 +51,8 @@ class LlamaVision90B(BaseLLM):
|
||||
|
||||
self._log_result(user_prompt, result_text)
|
||||
|
||||
# Stockage du résultat pour fusion ultérieure
|
||||
self.dernier_resultat = result_data
|
||||
|
||||
if translate:
|
||||
result_fr = GoogleTranslator(source="auto", target="fr").translate(result_text)
|
||||
return result_text, result_fr
|
||||
|
||||
return result_text
|
||||
|
||||
def fusionner_json_avec_resultats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Fusionne les données JSON existantes avec les résultats de l'analyse d'image
|
||||
"""
|
||||
if not hasattr(self, 'dernier_resultat'):
|
||||
return self.json_data
|
||||
|
||||
# Créer une copie du JSON original
|
||||
resultat_fusionne = self.json_data.copy() if self.json_data else {}
|
||||
|
||||
# Ajouter le résultat de l'analyse d'image
|
||||
if "analyse_image" not in resultat_fusionne:
|
||||
resultat_fusionne["analyse_image"] = []
|
||||
|
||||
# Ajouter le résultat à la liste des analyses
|
||||
nouvelle_analyse = {
|
||||
"modele": self.model,
|
||||
"reponse": self.dernier_resultat.get("message", {}).get("content", ""),
|
||||
"parametres": {
|
||||
"temperature": self.params.get("temperature"),
|
||||
"top_p": self.params.get("top_p"),
|
||||
"num_ctx": self.params.get("num_ctx")
|
||||
}
|
||||
}
|
||||
|
||||
resultat_fusionne["analyse_image"].append(nouvelle_analyse)
|
||||
|
||||
return resultat_fusionne
|
||||
|
||||
def sauvegarder_resultats(self, chemin_fichier: str) -> bool:
|
||||
"""
|
||||
Sauvegarde les résultats fusionnés dans un fichier JSON
|
||||
"""
|
||||
try:
|
||||
resultats_fusionnes = self.fusionner_json_avec_resultats()
|
||||
|
||||
# Création du dossier parent si nécessaire
|
||||
os.makedirs(os.path.dirname(os.path.abspath(chemin_fichier)), exist_ok=True)
|
||||
|
||||
with open(chemin_fichier, 'w', encoding='utf-8') as f:
|
||||
json.dump(resultats_fusionnes, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Erreur lors de la sauvegarde des résultats: {e}")
|
||||
return False
|
||||
return result_text
|
||||
@ -1,113 +0,0 @@
|
||||
from core.base_llm import BaseLLM
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
class MistralAPI(BaseLLM):
|
||||
"""Intégration avec l'API Mistral (similaire à la classe Mistral de l'ancien projet)"""
|
||||
|
||||
def __init__(self):
|
||||
model_name = "mistral-large-latest"
|
||||
engine = "MistralAPI"
|
||||
|
||||
self.api_url = "https://api.mistral.ai/v1/chat/completions"
|
||||
# À remplacer par la clé réelle ou une variable d'environnement
|
||||
self.api_key = os.environ.get("MISTRAL_API_KEY", "")
|
||||
|
||||
default_params = {
|
||||
# Paramètres de génération
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 1024,
|
||||
|
||||
# Paramètres de contrôle
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"stop": [],
|
||||
|
||||
# Paramètres divers
|
||||
"random_seed": None
|
||||
}
|
||||
|
||||
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
|
||||
|
||||
def generate(self, user_prompt):
|
||||
"""Génère une réponse à partir du prompt utilisateur via l'API Mistral"""
|
||||
prompt = self._format_prompt(user_prompt)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("Clé API Mistral non définie. Définissez la variable d'environnement MISTRAL_API_KEY.")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
# Préparation des messages
|
||||
messages = []
|
||||
|
||||
# Ajout du prompt système si défini
|
||||
if self.system_prompt:
|
||||
messages.append({"role": "system", "content": self.system_prompt})
|
||||
|
||||
# Ajout du prompt utilisateur
|
||||
messages.append({"role": "user", "content": user_prompt})
|
||||
|
||||
# Préparation du payload
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": self.params.get("temperature", 0.7),
|
||||
"top_p": self.params.get("top_p", 0.9),
|
||||
"max_tokens": self.params.get("max_tokens", 1024)
|
||||
}
|
||||
|
||||
# Ajout des paramètres optionnels
|
||||
if self.params.get("presence_penalty") is not None:
|
||||
payload["presence_penalty"] = self.params.get("presence_penalty")
|
||||
|
||||
if self.params.get("frequency_penalty") is not None:
|
||||
payload["frequency_penalty"] = self.params.get("frequency_penalty")
|
||||
|
||||
# Vérifier que stop est une liste non vide
|
||||
stop_sequences = self.params.get("stop")
|
||||
if isinstance(stop_sequences, list) and stop_sequences:
|
||||
payload["stop"] = stop_sequences
|
||||
|
||||
if self.params.get("random_seed") is not None:
|
||||
payload["random_seed"] = self.params.get("random_seed")
|
||||
|
||||
# Envoi de la requête
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise Exception(f"Erreur API Mistral: {response.status_code} - {response.text}")
|
||||
|
||||
# Traitement de la réponse
|
||||
result_data = response.json()
|
||||
result_text = result_data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
# Logging du résultat
|
||||
filename = self._log_result(user_prompt, result_text)
|
||||
|
||||
return result_text
|
||||
|
||||
def obtenir_liste_modeles(self) -> List[str]:
|
||||
"""Récupère la liste des modèles disponibles via l'API Mistral"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Clé API Mistral non définie. Définissez la variable d'environnement MISTRAL_API_KEY.")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
response = requests.get("https://api.mistral.ai/v1/models", headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
raise Exception(f"Erreur API Mistral: {response.status_code} - {response.text}")
|
||||
|
||||
result_data = response.json()
|
||||
models = [model.get("id") for model in result_data.get("data", [])]
|
||||
|
||||
return models
|
||||
@ -1,65 +0,0 @@
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Ajouter le répertoire parent au sys.path pour importer les modules
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from agents.base_agent import Agent
|
||||
|
||||
class TestAgent(unittest.TestCase):
|
||||
"""Tests pour la classe Agent de base"""
|
||||
|
||||
def setUp(self):
|
||||
"""Initialisation pour chaque test"""
|
||||
self.agent = Agent(nom="TestAgent")
|
||||
|
||||
def test_initialisation(self):
|
||||
"""Test de l'initialisation correcte de l'agent"""
|
||||
self.assertEqual(self.agent.nom, "TestAgent")
|
||||
self.assertEqual(len(self.agent.historique), 0)
|
||||
|
||||
def test_ajouter_historique(self):
|
||||
"""Test de l'ajout d'entrées dans l'historique"""
|
||||
self.agent.ajouter_historique("test_action", "données d'entrée", "données de sortie")
|
||||
|
||||
self.assertEqual(len(self.agent.historique), 1)
|
||||
entry = self.agent.historique[0]
|
||||
|
||||
self.assertEqual(entry["action"], "test_action")
|
||||
self.assertEqual(entry["input"], "données d'entrée")
|
||||
self.assertEqual(entry["output"], "données de sortie")
|
||||
self.assertTrue("timestamp" in entry)
|
||||
|
||||
def test_obtenir_historique(self):
|
||||
"""Test de la récupération de l'historique"""
|
||||
# Ajouter plusieurs entrées
|
||||
self.agent.ajouter_historique("action1", "entrée1", "sortie1")
|
||||
self.agent.ajouter_historique("action2", "entrée2", "sortie2")
|
||||
|
||||
historique = self.agent.obtenir_historique()
|
||||
|
||||
self.assertEqual(len(historique), 2)
|
||||
self.assertEqual(historique[0]["action"], "action1")
|
||||
self.assertEqual(historique[1]["action"], "action2")
|
||||
|
||||
def test_executer_not_implemented(self):
|
||||
"""Test que la méthode executer lève une NotImplementedError"""
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.agent.executer()
|
||||
|
||||
def test_limite_taille_historique(self):
|
||||
"""Test de la limite de taille dans l'historique"""
|
||||
# Créer une entrée avec une chaîne très longue
|
||||
longue_chaine = "x" * 1000
|
||||
self.agent.ajouter_historique("test_limite", longue_chaine, longue_chaine)
|
||||
|
||||
entry = self.agent.historique[0]
|
||||
self.assertEqual(len(entry["input"]), 500) # Limite à 500 caractères
|
||||
self.assertEqual(len(entry["output"]), 500) # Limite à 500 caractères
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user