mirror of
https://github.com/Ladebeze66/llm_lab_perso.git
synced 2025-12-15 19:06:50 +01:00
37 lines
1.8 KiB
Python
37 lines
1.8 KiB
Python
from core.base_llm import BaseLLM
|
|
import requests
|
|
|
|
class Llama2_13b(BaseLLM):
|
|
def __init__(self):
|
|
# Nom du modèle spécifique
|
|
model_name = "llama2:13b"
|
|
# Moteur utilisé pour l'inférence
|
|
engine = "Ollama"
|
|
|
|
# Paramètres par défaut spécifiques à Llama2 13B
|
|
default_params = {
|
|
"temperature": 0.75, # Contrôle la créativité : 0 = déterministe, 1 = plus créatif
|
|
"top_p": 0.9, # Nucleus sampling : sélectionne les tokens jusqu'à une probabilité cumulative de top_p
|
|
"top_k": 40, # Considère les top_k tokens les plus probables pour chaque étape de génération
|
|
"repeat_penalty": 1.1, # Pénalise les répétitions : >1 pour réduire les répétitions, 1 pour aucune pénalité
|
|
"num_predict": 768, # Nombre maximum de tokens à générer dans la réponse
|
|
"stop": [], # Liste de séquences qui arrêteront la génération si rencontrées
|
|
"seed": None, # Graine pour la reproductibilité : fixe la graine pour obtenir les mêmes résultats
|
|
"stream": False, # Si True, la réponse est envoyée en flux (streaming)
|
|
"raw": False # Si True, désactive le prompt système automatique
|
|
}
|
|
|
|
super().__init__(model_name=model_name, engine=engine, base_params=default_params)
|
|
|
|
def generate(self, user_prompt):
|
|
prompt = self._format_prompt(user_prompt)
|
|
payload = self.params.copy()
|
|
payload["prompt"] = prompt
|
|
|
|
response = requests.post("http://localhost:11434/api/generate", json=payload)
|
|
if not response.ok:
|
|
raise Exception(f"Erreur API Ollama : {response.status_code} - {response.text}")
|
|
|
|
result = response.json().get("response", "")
|
|
self._log_result(user_prompt, result)
|
|
return result |