llm_lab_perso/core/mistral7b.py
2025-03-27 18:40:52 +01:00

38 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from core.base_llm import BaseLLM
import requests
class Mistral7B(BaseLLM):
def __init__(self):
# Nom du modèle spécifique
model_name = "mistral:latest"
# Moteur utilisé pour l'inférence
engine = "Ollama"
# Paramètres par défaut spécifiques à Mistral7B
default_params = {
"temperature": 0.7, # Contrôle la créativité : 0 = déterministe, 1 = plus créatif
"top_p": 0.9, # Nucleus sampling : sélectionne les tokens jusqu'à une probabilité cumulative de top_p
"top_k": 50, # Considère les top_k tokens les plus probables pour chaque étape de génération
"repeat_penalty": 1.1, # Pénalise les répétitions : >1 pour réduire les répétitions, 1 pour aucune pénalité
"num_predict": 512, # Nombre maximum de tokens à générer dans la réponse
"stop": [], # Liste de séquences qui arrêteront la génération si rencontrées
"seed": None, # Graine pour la reproductibilité : fixe la graine pour obtenir les mêmes résultats
"stream": False, # Si True, la réponse est envoyée en flux (streaming)
"raw": False # Si True, désactive le prompt système automatique
}
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
def generate(self, user_prompt):
prompt = self._format_prompt(user_prompt)
payload = self.params.copy()
payload["prompt"] = prompt
response = requests.post("http://localhost:11434/api/generate", json=payload)
if not response.ok:
raise Exception(f"Erreur API Ollama : {response.status_code} - {response.text}")
result = response.json().get("response", "")
self._log_result(user_prompt, result)
return result