mirror of
https://github.com/Ladebeze66/llm_lab.git
synced 2025-12-15 20:06:50 +01:00
save
This commit is contained in:
parent
1cc834f9ab
commit
82c933bd5f
@ -1,157 +0,0 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from agents.base_agent import Agent
|
|
||||||
from core.factory import LLMFactory
|
|
||||||
|
|
||||||
class AgentAnalyseJSON(Agent):
|
|
||||||
"""
|
|
||||||
Agent pour analyser des données JSON
|
|
||||||
"""
|
|
||||||
def __init__(self, nom: str = "AgentAnalyseJSON", modele: str = "mistral7b"):
|
|
||||||
super().__init__(nom)
|
|
||||||
|
|
||||||
# Choix du modèle
|
|
||||||
self.llm = LLMFactory.create(modele)
|
|
||||||
|
|
||||||
# Configuration du modèle
|
|
||||||
self.llm.set_role("formateur", {
|
|
||||||
"system_prompt": "Tu es un expert en analyse de données JSON. Tu dois extraire des informations pertinentes, identifier des tendances et répondre à des questions sur les données.",
|
|
||||||
"params": {
|
|
||||||
"temperature": 0.4,
|
|
||||||
"top_p": 0.9
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
def executer(self, json_data: Dict[str, Any],
|
|
||||||
question: str = "Analyse ces données et extrait les informations principales.") -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Analyse des données JSON
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_data: Données JSON à analyser
|
|
||||||
question: Question à poser au modèle
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Résultats de l'analyse sous forme de dictionnaire
|
|
||||||
"""
|
|
||||||
# Conversion du JSON en chaîne formatée
|
|
||||||
json_str = json.dumps(json_data, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
# Construction du prompt avec le JSON et la question
|
|
||||||
prompt = f"{question}\n\nDonnées JSON à analyser:\n```json\n{json_str}\n```"
|
|
||||||
|
|
||||||
# Interrogation du modèle
|
|
||||||
reponse = self.llm.generate(prompt)
|
|
||||||
|
|
||||||
# Construction du résultat
|
|
||||||
resultats = {
|
|
||||||
"question": question,
|
|
||||||
"reponse": reponse,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"taille_json": len(json_str)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Ajout à l'historique
|
|
||||||
self.ajouter_historique("analyse_json", prompt[:200] + "...", reponse[:200] + "...")
|
|
||||||
|
|
||||||
return resultats
|
|
||||||
|
|
||||||
def extraire_structure(self, json_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Extrait la structure d'un JSON (clés, types, profondeur)
|
|
||||||
"""
|
|
||||||
resultat = {
|
|
||||||
"structure": {},
|
|
||||||
"statistiques": {
|
|
||||||
"nb_cles": 0,
|
|
||||||
"profondeur_max": 0,
|
|
||||||
"types": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def explorer_structure(data, chemin="", profondeur=0):
|
|
||||||
nonlocal resultat
|
|
||||||
|
|
||||||
# Mise à jour de la profondeur max
|
|
||||||
resultat["statistiques"]["profondeur_max"] = max(resultat["statistiques"]["profondeur_max"], profondeur)
|
|
||||||
|
|
||||||
if isinstance(data, dict):
|
|
||||||
structure = {}
|
|
||||||
for cle, valeur in data.items():
|
|
||||||
nouveau_chemin = f"{chemin}.{cle}" if chemin else cle
|
|
||||||
resultat["statistiques"]["nb_cles"] += 1
|
|
||||||
|
|
||||||
type_valeur = type(valeur).__name__
|
|
||||||
if type_valeur not in resultat["statistiques"]["types"]:
|
|
||||||
resultat["statistiques"]["types"][type_valeur] = 0
|
|
||||||
resultat["statistiques"]["types"][type_valeur] += 1
|
|
||||||
|
|
||||||
if isinstance(valeur, (dict, list)):
|
|
||||||
structure[cle] = explorer_structure(valeur, nouveau_chemin, profondeur + 1)
|
|
||||||
else:
|
|
||||||
structure[cle] = type_valeur
|
|
||||||
return structure
|
|
||||||
|
|
||||||
elif isinstance(data, list):
|
|
||||||
if data and isinstance(data[0], (dict, list)):
|
|
||||||
# Pour les listes de structures complexes, on analyse le premier élément
|
|
||||||
return [explorer_structure(data[0], f"{chemin}[0]", profondeur + 1)]
|
|
||||||
else:
|
|
||||||
# Pour les listes de valeurs simples
|
|
||||||
type_elements = "vide" if not data else type(data[0]).__name__
|
|
||||||
resultat["statistiques"]["nb_cles"] += 1
|
|
||||||
|
|
||||||
if "list" not in resultat["statistiques"]["types"]:
|
|
||||||
resultat["statistiques"]["types"]["list"] = 0
|
|
||||||
resultat["statistiques"]["types"]["list"] += 1
|
|
||||||
|
|
||||||
return f"list[{type_elements}]"
|
|
||||||
else:
|
|
||||||
return type(data).__name__
|
|
||||||
|
|
||||||
resultat["structure"] = explorer_structure(json_data)
|
|
||||||
|
|
||||||
self.ajouter_historique("extraire_structure", "JSON", resultat)
|
|
||||||
return resultat
|
|
||||||
|
|
||||||
def fusionner_jsons(self, json1: Dict[str, Any], json2: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Fusionne deux structures JSON en conservant les informations des deux
|
|
||||||
"""
|
|
||||||
if not json1:
|
|
||||||
return json2
|
|
||||||
|
|
||||||
if not json2:
|
|
||||||
return json1
|
|
||||||
|
|
||||||
resultat = json1.copy()
|
|
||||||
|
|
||||||
# Fonction récursive pour fusionner
|
|
||||||
def fusionner(dict1, dict2):
|
|
||||||
for cle, valeur in dict2.items():
|
|
||||||
if cle in dict1:
|
|
||||||
# Si les deux sont des dictionnaires, fusion récursive
|
|
||||||
if isinstance(dict1[cle], dict) and isinstance(valeur, dict):
|
|
||||||
fusionner(dict1[cle], valeur)
|
|
||||||
# Si les deux sont des listes, concaténation
|
|
||||||
elif isinstance(dict1[cle], list) and isinstance(valeur, list):
|
|
||||||
dict1[cle].extend(valeur)
|
|
||||||
# Sinon, on garde les deux valeurs dans une liste
|
|
||||||
else:
|
|
||||||
if not isinstance(dict1[cle], list):
|
|
||||||
dict1[cle] = [dict1[cle]]
|
|
||||||
if isinstance(valeur, list):
|
|
||||||
dict1[cle].extend(valeur)
|
|
||||||
else:
|
|
||||||
dict1[cle].append(valeur)
|
|
||||||
else:
|
|
||||||
# Si la clé n'existe pas dans dict1, on l'ajoute simplement
|
|
||||||
dict1[cle] = valeur
|
|
||||||
|
|
||||||
fusionner(resultat, json2)
|
|
||||||
self.ajouter_historique("fusionner_jsons", "Fusion de deux JSON", "Fusion réussie")
|
|
||||||
|
|
||||||
return resultat
|
|
||||||
@ -1,94 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from agents.base_agent import Agent
|
|
||||||
from core.factory import LLMFactory
|
|
||||||
|
|
||||||
class AgentAnalyseImage(Agent):
|
|
||||||
"""
|
|
||||||
Agent pour analyser des images avec LlamaVision
|
|
||||||
"""
|
|
||||||
def __init__(self, nom: str = "AgentAnalyseImage"):
|
|
||||||
super().__init__(nom)
|
|
||||||
self.llm = LLMFactory.create("llamavision")
|
|
||||||
self.questions_standard = [
|
|
||||||
"Décris en détail ce que tu vois sur cette image.",
|
|
||||||
"Quels sont les éléments principaux visibles sur cette image?",
|
|
||||||
"Y a-t-il du texte visible sur cette image? Si oui, peux-tu le transcrire?",
|
|
||||||
"Quelle est l'ambiance générale de cette image?"
|
|
||||||
]
|
|
||||||
|
|
||||||
def executer(self, image_path: str, json_data: Optional[Dict[str, Any]] = None,
|
|
||||||
question: Optional[str] = None) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Analyse une image avec Llama Vision
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_path: Chemin vers l'image à analyser
|
|
||||||
json_data: Données JSON optionnelles associées à l'image
|
|
||||||
question: Question à poser au modèle (si None, utilise les questions standard)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Résultats de l'analyse sous forme de dictionnaire
|
|
||||||
"""
|
|
||||||
# Vérification de l'existence de l'image
|
|
||||||
if not os.path.exists(image_path):
|
|
||||||
resultat = {"erreur": f"L'image {image_path} n'existe pas"}
|
|
||||||
self.ajouter_historique("analyse_image_erreur", image_path, resultat)
|
|
||||||
return resultat
|
|
||||||
|
|
||||||
# Préparation de la liste des questions à poser
|
|
||||||
questions_a_poser = [question] if question else self.questions_standard
|
|
||||||
|
|
||||||
# Exécution de l'analyse pour chaque question
|
|
||||||
resultats = {
|
|
||||||
"image": image_path,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"analyses": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Encodage de l'image en base64 pour l'API
|
|
||||||
with open(image_path, "rb") as image_file:
|
|
||||||
image_data = image_file.read()
|
|
||||||
|
|
||||||
# Traitement pour chaque question
|
|
||||||
for q in questions_a_poser:
|
|
||||||
try:
|
|
||||||
# Appel du modèle via la factory
|
|
||||||
reponse = self.llm.generate(q, images=[image_data])
|
|
||||||
|
|
||||||
# Ajout du résultat
|
|
||||||
resultats["analyses"].append({
|
|
||||||
"question": q,
|
|
||||||
"reponse": reponse
|
|
||||||
})
|
|
||||||
|
|
||||||
# Ajout à l'historique
|
|
||||||
self.ajouter_historique("analyse_image", q, reponse[:200] + "...")
|
|
||||||
except Exception as e:
|
|
||||||
resultats["analyses"].append({
|
|
||||||
"question": q,
|
|
||||||
"erreur": str(e)
|
|
||||||
})
|
|
||||||
self.ajouter_historique("analyse_image_erreur", q, str(e))
|
|
||||||
|
|
||||||
return resultats
|
|
||||||
|
|
||||||
def sauvegarder_resultats(self, chemin_fichier: str, resultats: Dict[str, Any]) -> bool:
|
|
||||||
"""
|
|
||||||
Sauvegarde les résultats d'analyse dans un fichier JSON
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Création du dossier parent si nécessaire
|
|
||||||
os.makedirs(os.path.dirname(os.path.abspath(chemin_fichier)), exist_ok=True)
|
|
||||||
|
|
||||||
with open(chemin_fichier, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(resultats, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
self.ajouter_historique("sauvegarder_resultats", chemin_fichier, "Résultats sauvegardés")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
self.ajouter_historique("sauvegarder_resultats_erreur", chemin_fichier, str(e))
|
|
||||||
return False
|
|
||||||
@ -1,26 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
|
|
||||||
class Agent:
|
|
||||||
"""Classe de base pour tous les agents d'analyse"""
|
|
||||||
|
|
||||||
def __init__(self, nom: str = "Agent"):
|
|
||||||
self.nom = nom
|
|
||||||
self.historique = []
|
|
||||||
|
|
||||||
def ajouter_historique(self, action: str, input_data: Any, output_data: Any) -> None:
|
|
||||||
"""Ajoute une entrée dans l'historique de l'agent"""
|
|
||||||
self.historique.append({
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"action": action,
|
|
||||||
"input": str(input_data)[:500], # Limite pour éviter des historiques trop grands
|
|
||||||
"output": str(output_data)[:500] # Limite pour éviter des historiques trop grands
|
|
||||||
})
|
|
||||||
|
|
||||||
def obtenir_historique(self) -> List[Dict[str, Any]]:
|
|
||||||
"""Retourne l'historique complet de l'agent"""
|
|
||||||
return self.historique
|
|
||||||
|
|
||||||
def executer(self, *args, **kwargs) -> Any:
|
|
||||||
"""Méthode abstraite à implémenter dans les classes dérivées"""
|
|
||||||
raise NotImplementedError("Chaque agent doit implémenter sa propre méthode executer()")
|
|
||||||
@ -1,6 +1,5 @@
|
|||||||
from core.mistral7b import Mistral7B
|
from core.mistral7b import Mistral7B
|
||||||
from core.llama_vision90b import LlamaVision90B
|
from core.llama_vision90b import LlamaVision90B
|
||||||
from core.mistral_api import MistralAPI
|
|
||||||
|
|
||||||
class LLMFactory:
|
class LLMFactory:
|
||||||
"""
|
"""
|
||||||
@ -9,7 +8,6 @@ class LLMFactory:
|
|||||||
_registry = {
|
_registry = {
|
||||||
"mistral7b": Mistral7B,
|
"mistral7b": Mistral7B,
|
||||||
"llamavision": LlamaVision90B,
|
"llamavision": LlamaVision90B,
|
||||||
"mistralapi": MistralAPI,
|
|
||||||
# Ajouter d'autres modèles LLM ici
|
# Ajouter d'autres modèles LLM ici
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,6 @@ from core.base_llm import BaseLLM
|
|||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import base64
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
from deep_translator import GoogleTranslator
|
from deep_translator import GoogleTranslator
|
||||||
|
|
||||||
class LlamaVision90B(BaseLLM):
|
class LlamaVision90B(BaseLLM):
|
||||||
@ -14,105 +12,29 @@ class LlamaVision90B(BaseLLM):
|
|||||||
self.api_url = "http://217.182.105.173:11434/api/chat"
|
self.api_url = "http://217.182.105.173:11434/api/chat"
|
||||||
|
|
||||||
default_params = {
|
default_params = {
|
||||||
# Paramètres de créativité
|
"temperature": 0.3, #Créativité basse pour analyse technique
|
||||||
"temperature": 0.3, # Créativité basse pour analyse technique
|
"top_p": 1.0, # Conserve toute la distribution
|
||||||
"top_p": 1.0, # Conserve toute la distribution
|
"top_k": 40, #Limite vocabulaire
|
||||||
"top_k": 40, # Limite vocabulaire
|
"repeat_penalty": 1.1, #Réduction des répétitions
|
||||||
|
"num_predict": 512, #longueur max sortie
|
||||||
# Paramètres de qualité
|
"num-ctx": 4096, #Contexte étendu
|
||||||
"repeat_penalty": 1.1, # Réduction des répétitions
|
"format": "json", #Réponse structurée JSON (optionnel)
|
||||||
"min_p": 0.0, # Seuil minimal pour la probabilité des tokens
|
"stream": False, #Réponse d'un seul bloc
|
||||||
|
"raw": False, #laisse le formatage systèmes
|
||||||
# Paramètres de contrôle avancé
|
"keep_alive": "5m" #Durée de vie de la connexion
|
||||||
"mirostat": 0, # 0=désactivé, 1=v1, 2=v2
|
|
||||||
"mirostat_eta": 0.1, # Taux d'apprentissage pour mirostat
|
|
||||||
"mirostat_tau": 5.0, # Cible pour mirostat
|
|
||||||
|
|
||||||
# Paramètres de taille
|
|
||||||
"num_predict": 512, # Longueur max sortie
|
|
||||||
"num_ctx": 4096, # Contexte étendu
|
|
||||||
|
|
||||||
# Paramètres de contrôle
|
|
||||||
"seed": 0, # Graine pour reproductibilité (0=aléatoire)
|
|
||||||
"stop": [], # Séquences d'arrêt
|
|
||||||
|
|
||||||
# Paramètres de format
|
|
||||||
"format": "json", # Réponse structurée JSON (optionnel)
|
|
||||||
"stream": False, # Réponse d'un seul bloc
|
|
||||||
"raw": False, # Laisse le formatage systèmes
|
|
||||||
"keep_alive": "5m" # Durée de vie de la connexion
|
|
||||||
}
|
}
|
||||||
|
|
||||||
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
|
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
|
||||||
|
|
||||||
# Attributs spécifiques pour Llama Vision
|
|
||||||
self.image_data = None
|
|
||||||
self.json_data = {}
|
|
||||||
|
|
||||||
def set_image(self, image_path: str) -> bool:
|
|
||||||
"""
|
|
||||||
Définit l'image à analyser à partir d'un chemin de fichier
|
|
||||||
Retourne True si l'image a été chargée avec succès
|
|
||||||
"""
|
|
||||||
if not os.path.exists(image_path):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(image_path, "rb") as f:
|
|
||||||
self.image_data = f.read()
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Erreur lors du chargement de l'image: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def set_image_data(self, image_data: bytes) -> bool:
|
|
||||||
"""
|
|
||||||
Définit directement les données de l'image à analyser
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.image_data = image_data
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Erreur lors de la définition des données d'image: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def set_json_data(self, json_data: Dict[str, Any]) -> bool:
|
|
||||||
"""
|
|
||||||
Définit les données JSON à associer à l'image
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.json_data = json_data
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Erreur lors de la définition des données JSON: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def generate(self, user_prompt: str, images: list = None, translate: bool = False):
|
def generate(self, user_prompt: str, images: list = None, translate: bool = False):
|
||||||
prompt = self._format_prompt(user_prompt)
|
prompt = self._format_prompt(user_prompt)
|
||||||
|
|
||||||
# Si des images sont fournies directement, utilisez-les
|
|
||||||
images_to_use = images if images else []
|
|
||||||
|
|
||||||
# Si image_data est défini et aucune image n'est fournie explicitement
|
|
||||||
if self.image_data is not None and not images:
|
|
||||||
# Encodage en base64 si ce n'est pas déjà fait
|
|
||||||
if isinstance(self.image_data, bytes):
|
|
||||||
encoded_image = base64.b64encode(self.image_data).decode('utf-8')
|
|
||||||
images_to_use = [encoded_image]
|
|
||||||
|
|
||||||
# Ajout des données JSON dans le prompt si disponibles
|
|
||||||
if self.json_data:
|
|
||||||
json_str = json.dumps(self.json_data, ensure_ascii=False, indent=2)
|
|
||||||
# On ajoute le JSON au prompt pour qu'il soit traité avec l'image
|
|
||||||
prompt = f"{prompt}\n\nVoici des données JSON associées à cette image:\n```json\n{json_str}\n```"
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": prompt,
|
"content": prompt,
|
||||||
"images": images_to_use
|
"images": images if images else []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"options": self.params,
|
"options": self.params,
|
||||||
@ -129,58 +51,8 @@ class LlamaVision90B(BaseLLM):
|
|||||||
|
|
||||||
self._log_result(user_prompt, result_text)
|
self._log_result(user_prompt, result_text)
|
||||||
|
|
||||||
# Stockage du résultat pour fusion ultérieure
|
|
||||||
self.dernier_resultat = result_data
|
|
||||||
|
|
||||||
if translate:
|
if translate:
|
||||||
result_fr = GoogleTranslator(source="auto", target="fr").translate(result_text)
|
result_fr = GoogleTranslator(source="auto", target="fr").translate(result_text)
|
||||||
return result_text, result_fr
|
return result_text, result_fr
|
||||||
|
|
||||||
return result_text
|
return result_text
|
||||||
|
|
||||||
def fusionner_json_avec_resultats(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Fusionne les données JSON existantes avec les résultats de l'analyse d'image
|
|
||||||
"""
|
|
||||||
if not hasattr(self, 'dernier_resultat'):
|
|
||||||
return self.json_data
|
|
||||||
|
|
||||||
# Créer une copie du JSON original
|
|
||||||
resultat_fusionne = self.json_data.copy() if self.json_data else {}
|
|
||||||
|
|
||||||
# Ajouter le résultat de l'analyse d'image
|
|
||||||
if "analyse_image" not in resultat_fusionne:
|
|
||||||
resultat_fusionne["analyse_image"] = []
|
|
||||||
|
|
||||||
# Ajouter le résultat à la liste des analyses
|
|
||||||
nouvelle_analyse = {
|
|
||||||
"modele": self.model,
|
|
||||||
"reponse": self.dernier_resultat.get("message", {}).get("content", ""),
|
|
||||||
"parametres": {
|
|
||||||
"temperature": self.params.get("temperature"),
|
|
||||||
"top_p": self.params.get("top_p"),
|
|
||||||
"num_ctx": self.params.get("num_ctx")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resultat_fusionne["analyse_image"].append(nouvelle_analyse)
|
|
||||||
|
|
||||||
return resultat_fusionne
|
|
||||||
|
|
||||||
def sauvegarder_resultats(self, chemin_fichier: str) -> bool:
|
|
||||||
"""
|
|
||||||
Sauvegarde les résultats fusionnés dans un fichier JSON
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
resultats_fusionnes = self.fusionner_json_avec_resultats()
|
|
||||||
|
|
||||||
# Création du dossier parent si nécessaire
|
|
||||||
os.makedirs(os.path.dirname(os.path.abspath(chemin_fichier)), exist_ok=True)
|
|
||||||
|
|
||||||
with open(chemin_fichier, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(resultats_fusionnes, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Erreur lors de la sauvegarde des résultats: {e}")
|
|
||||||
return False
|
|
||||||
@ -1,113 +0,0 @@
|
|||||||
from core.base_llm import BaseLLM
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
|
|
||||||
class MistralAPI(BaseLLM):
|
|
||||||
"""Intégration avec l'API Mistral (similaire à la classe Mistral de l'ancien projet)"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
model_name = "mistral-large-latest"
|
|
||||||
engine = "MistralAPI"
|
|
||||||
|
|
||||||
self.api_url = "https://api.mistral.ai/v1/chat/completions"
|
|
||||||
# À remplacer par la clé réelle ou une variable d'environnement
|
|
||||||
self.api_key = os.environ.get("MISTRAL_API_KEY", "")
|
|
||||||
|
|
||||||
default_params = {
|
|
||||||
# Paramètres de génération
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"max_tokens": 1024,
|
|
||||||
|
|
||||||
# Paramètres de contrôle
|
|
||||||
"presence_penalty": 0,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"stop": [],
|
|
||||||
|
|
||||||
# Paramètres divers
|
|
||||||
"random_seed": None
|
|
||||||
}
|
|
||||||
|
|
||||||
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
|
|
||||||
|
|
||||||
def generate(self, user_prompt):
|
|
||||||
"""Génère une réponse à partir du prompt utilisateur via l'API Mistral"""
|
|
||||||
prompt = self._format_prompt(user_prompt)
|
|
||||||
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError("Clé API Mistral non définie. Définissez la variable d'environnement MISTRAL_API_KEY.")
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Préparation des messages
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# Ajout du prompt système si défini
|
|
||||||
if self.system_prompt:
|
|
||||||
messages.append({"role": "system", "content": self.system_prompt})
|
|
||||||
|
|
||||||
# Ajout du prompt utilisateur
|
|
||||||
messages.append({"role": "user", "content": user_prompt})
|
|
||||||
|
|
||||||
# Préparation du payload
|
|
||||||
payload = {
|
|
||||||
"model": self.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": self.params.get("temperature", 0.7),
|
|
||||||
"top_p": self.params.get("top_p", 0.9),
|
|
||||||
"max_tokens": self.params.get("max_tokens", 1024)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Ajout des paramètres optionnels
|
|
||||||
if self.params.get("presence_penalty") is not None:
|
|
||||||
payload["presence_penalty"] = self.params.get("presence_penalty")
|
|
||||||
|
|
||||||
if self.params.get("frequency_penalty") is not None:
|
|
||||||
payload["frequency_penalty"] = self.params.get("frequency_penalty")
|
|
||||||
|
|
||||||
# Vérifier que stop est une liste non vide
|
|
||||||
stop_sequences = self.params.get("stop")
|
|
||||||
if isinstance(stop_sequences, list) and stop_sequences:
|
|
||||||
payload["stop"] = stop_sequences
|
|
||||||
|
|
||||||
if self.params.get("random_seed") is not None:
|
|
||||||
payload["random_seed"] = self.params.get("random_seed")
|
|
||||||
|
|
||||||
# Envoi de la requête
|
|
||||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
|
||||||
|
|
||||||
if not response.ok:
|
|
||||||
raise Exception(f"Erreur API Mistral: {response.status_code} - {response.text}")
|
|
||||||
|
|
||||||
# Traitement de la réponse
|
|
||||||
result_data = response.json()
|
|
||||||
result_text = result_data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|
||||||
|
|
||||||
# Logging du résultat
|
|
||||||
filename = self._log_result(user_prompt, result_text)
|
|
||||||
|
|
||||||
return result_text
|
|
||||||
|
|
||||||
def obtenir_liste_modeles(self) -> List[str]:
|
|
||||||
"""Récupère la liste des modèles disponibles via l'API Mistral"""
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError("Clé API Mistral non définie. Définissez la variable d'environnement MISTRAL_API_KEY.")
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.get("https://api.mistral.ai/v1/models", headers=headers)
|
|
||||||
|
|
||||||
if not response.ok:
|
|
||||||
raise Exception(f"Erreur API Mistral: {response.status_code} - {response.text}")
|
|
||||||
|
|
||||||
result_data = response.json()
|
|
||||||
models = [model.get("id") for model in result_data.get("data", [])]
|
|
||||||
|
|
||||||
return models
|
|
||||||
@ -1,65 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Ajouter le répertoire parent au sys.path pour importer les modules
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from agents.base_agent import Agent
|
|
||||||
|
|
||||||
class TestAgent(unittest.TestCase):
|
|
||||||
"""Tests pour la classe Agent de base"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""Initialisation pour chaque test"""
|
|
||||||
self.agent = Agent(nom="TestAgent")
|
|
||||||
|
|
||||||
def test_initialisation(self):
|
|
||||||
"""Test de l'initialisation correcte de l'agent"""
|
|
||||||
self.assertEqual(self.agent.nom, "TestAgent")
|
|
||||||
self.assertEqual(len(self.agent.historique), 0)
|
|
||||||
|
|
||||||
def test_ajouter_historique(self):
|
|
||||||
"""Test de l'ajout d'entrées dans l'historique"""
|
|
||||||
self.agent.ajouter_historique("test_action", "données d'entrée", "données de sortie")
|
|
||||||
|
|
||||||
self.assertEqual(len(self.agent.historique), 1)
|
|
||||||
entry = self.agent.historique[0]
|
|
||||||
|
|
||||||
self.assertEqual(entry["action"], "test_action")
|
|
||||||
self.assertEqual(entry["input"], "données d'entrée")
|
|
||||||
self.assertEqual(entry["output"], "données de sortie")
|
|
||||||
self.assertTrue("timestamp" in entry)
|
|
||||||
|
|
||||||
def test_obtenir_historique(self):
|
|
||||||
"""Test de la récupération de l'historique"""
|
|
||||||
# Ajouter plusieurs entrées
|
|
||||||
self.agent.ajouter_historique("action1", "entrée1", "sortie1")
|
|
||||||
self.agent.ajouter_historique("action2", "entrée2", "sortie2")
|
|
||||||
|
|
||||||
historique = self.agent.obtenir_historique()
|
|
||||||
|
|
||||||
self.assertEqual(len(historique), 2)
|
|
||||||
self.assertEqual(historique[0]["action"], "action1")
|
|
||||||
self.assertEqual(historique[1]["action"], "action2")
|
|
||||||
|
|
||||||
def test_executer_not_implemented(self):
|
|
||||||
"""Test que la méthode executer lève une NotImplementedError"""
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
self.agent.executer()
|
|
||||||
|
|
||||||
def test_limite_taille_historique(self):
|
|
||||||
"""Test de la limite de taille dans l'historique"""
|
|
||||||
# Créer une entrée avec une chaîne très longue
|
|
||||||
longue_chaine = "x" * 1000
|
|
||||||
self.agent.ajouter_historique("test_limite", longue_chaine, longue_chaine)
|
|
||||||
|
|
||||||
entry = self.agent.historique[0]
|
|
||||||
self.assertEqual(len(entry["input"]), 500) # Limite à 500 caractères
|
|
||||||
self.assertEqual(len(entry["output"]), 500) # Limite à 500 caractères
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Loading…
x
Reference in New Issue
Block a user