llm_ticket3/agents/llama_vision/agent_vision_ocr.py
2025-04-25 17:18:33 +02:00

187 lines
7.7 KiB
Python

import os
import logging
from datetime import datetime
from PIL import Image
from ..base_agent import BaseAgent
from ..utils.pipeline_logger import sauvegarder_donnees
logger = logging.getLogger("AgentVisionOCR")
class AgentVisionOCR(BaseAgent):
"""
Agent LlamaVision qui extrait du texte (OCR avancé) depuis une image.
Permet une lecture plus fine pour les images conservées après tri.
"""
def __init__(self, llm):
super().__init__("AgentVisionOCR", llm)
self.params = {
"temperature": 0.1,
"top_p": 0.85,
"max_tokens": 1500
}
self.system_prompt = """You are a multilingual OCR visual assistant.
Your task is to extract all visible text from image, even if it is in French, English, or both.
Guidelines:
1. Include partial, blurry, or stylized characters
2. Group the result by type: labels, titles, buttons, errors, URLs, etc.
3. Do NOT translate any text - just extract what is visible
4. Mention if the image contains unreadable or missing parts
Respond in English."""
self._configurer_llm()
# Collecter les résultats pour la sauvegarde groupée
self.resultats = []
logger.info("AgentVisionOCR initialisé")
def _configurer_llm(self):
if hasattr(self.llm, "prompt_system"):
self.llm.prompt_system = self.system_prompt
if hasattr(self.llm, "configurer"):
self.llm.configurer(**self.params)
def _extraire_ticket_id(self, image_path):
parts = image_path.split(os.sep)
for part in parts:
if part.startswith("T") and part[1:].isdigit():
return part
return "UNKNOWN"
def executer(self, image_path: str, ocr_baseline: str = "") -> dict:
"""" Effectue un OCR visuel via LlamaVision sur l'imga spécifiée.
Args:
image_path: Chemin vers l'image ç analyser
ocr_baseline: Texte OCRé précédemment (pour comparaison)
Returns:
Dictionnaire contenant le texte extrait et les métadonnées
"""
image_name = os.path.basename(image_path)
print(f" AgentVisionOCR: Extraction sur {image_name}")
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image introuvable: {image_path}")
if not hasattr(self.llm, "interroger_avec_image"):
raise RuntimeError("Le modèle ne supporte pas l'analyse d'images")
response = self.llm.interroger_avec_image(image_path, self.system_prompt)
if not response or "i cannot" in response.lower():
raise ValueError("Réponse vide invalide du modèle")
# Normaliser le nom du modèle pour cohérence
model_name = getattr(self.llm, "pipeline_normalized_name",
getattr(self.llm, "modele", "llama3-vision-90b-instruct"))
# Nettoyer le nom pour éviter les problèmes de fichiers
model_name = model_name.replace(".", "-").replace(":", "-").replace("_", "-")
result = {
"extracted_text": response.strip(),
"image_name": image_name,
"image_path": image_path,
"ocr_script_text": ocr_baseline.strip(),
"ticket_id": self._extraire_ticket_id(image_path),
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"model_info": {
"model": model_name,
**self.params
}
}
# Ajouter le résultat à la liste pour sauvegarde groupée
self.resultats.append(result)
# Sauvegarder individuellement pour traçabilité
sauvegarder_donnees(
ticket_id=result["ticket_id"],
step_name="ocr_llm",
data=result,
base_dir=None,
is_resultat=True
)
logger.info(f"OCR LLM réussi pour {image_name}")
return result
except Exception as e:
error_result = {
"extracted_text": "",
"image_name": image_name,
"image_path": image_path,
"ticket_id": self._extraire_ticket_id(image_path),
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"error": str(e),
"model_info": {
"model": getattr(self.llm, "pipeline_normalized_name", "llama3-vision-90b-instruct"),
**self.params
}
}
# Ajouter l'erreur à la liste des résultats
self.resultats.append(error_result)
logger.error(f"Erreur lors de l'extraction OCR pour {image_name}: {e}")
return error_result
def sauvegarder_resultats(self) -> None:
"""
Sauvegarde tous les résultats collectés en garantissant leur accumulation.
Utilise un format de liste pour maintenir les multiples résultats.
"""
logger.info(f"Sauvegarde de {len(self.resultats)} résultats d'OCR avancé")
if not self.resultats:
logger.warning("Aucun résultat à sauvegarder")
return
# Récupérer le ticket_id du premier résultat
ticket_id = self.resultats[0].get("ticket_id", "UNKNOWN")
try:
# Obtenir directement le nom normalisé du modèle depuis l'instance LLM
if not self.llm:
logger.warning("LLM est None, utilisation du nom de modèle par défaut")
normalized_model_name = "llama3-vision-90b-instruct"
else:
# Vérifier d'abord pipeline_normalized_name puis modele
normalized_model_name = getattr(self.llm, "pipeline_normalized_name", None)
if not normalized_model_name:
normalized_model_name = getattr(self.llm, "modele", "llama3-vision-90b-instruct")
# Normaliser manuellement (dans tous les cas)
normalized_model_name = normalized_model_name.replace(".", "-").replace(":", "-").replace("_", "-")
logger.info(f"Nom de modèle normalisé pour la sauvegarde OCR: {normalized_model_name}")
# Normaliser les noms de modèles dans tous les résultats
for result in self.resultats:
if "model_info" not in result:
result["model_info"] = {}
# Utiliser le nom de modèle normalisé pour tous les résultats
result["model_info"]["model"] = normalized_model_name
# Sauvegarder en mode liste pour accumuler les résultats
sauvegarder_donnees(
ticket_id=ticket_id,
step_name="ocr_llm",
data=self.resultats,
base_dir=None,
is_resultat=True
)
logger.info(f"Sauvegarde groupée de {len(self.resultats)} résultats d'OCR avancé")
print(f"Sauvegarde de {len(self.resultats)} résultats d'OCR avancé terminée")
# Réinitialiser la liste après la sauvegarde
self.resultats = []
except Exception as e:
logger.error(f"Erreur lors de la sauvegarde des résultats d'OCR avancé: {e}")
logger.exception("Détails de l'erreur:")
print(f"Erreur lors de la sauvegarde des résultats: {e}")