llm_ticket3/agents/llama_vision/agent_vision_ocr.py
2025-04-29 10:04:52 +02:00

164 lines
6.1 KiB
Python

import os
import logging
from datetime import datetime
from PIL import Image
from ..base_agent import BaseAgent
from ..utils.pipeline_logger import sauvegarder_donnees
from typing import Optional
logger = logging.getLogger("AgentVisionOCR")
class AgentVisionOCR(BaseAgent):
"""
Agent LlamaVision qui extrait du texte (OCR avancé) depuis une image.
Version corrigée pour maximiser la qualité du texte extrait.
"""
def __init__(self, llm):
super().__init__("AgentVisionOCR", llm)
self.params = {
"temperature": 0.1,
"top_p": 0.85,
"max_tokens": 1500
}
# 🧠 Prompt beaucoup plus léger et performant
self.system_prompt = (
"Extract all visible text from this image as accurately as possible.\n"
"If there is little or no text, briefly describe the visual content instead.\n"
"Preserve the original language and formatting if possible."
)
self._configurer_llm()
self.resultats = []
self.images_traitees = set()
logger.info("AgentVisionOCR initialisé avec prompt allégé.")
def _configurer_llm(self):
if hasattr(self.llm, "prompt_system"):
self.llm.prompt_system = self.system_prompt
if hasattr(self.llm, "configurer"):
self.llm.configurer(**self.params)
def _extraire_ticket_id(self, image_path):
if not image_path:
return "UNKNOWN"
segments = image_path.replace('\\', '/').split('/')
for segment in segments:
if segment.startswith('T') and segment[1:].isdigit():
return segment
if segment.startswith('ticket_T') and segment[8:].isdigit():
return 'T' + segment[8:]
return "UNKNOWN"
def executer(self, image_path: str, ocr_baseline: str = "", ticket_id: Optional[str] = None) -> dict:
image_path_abs = os.path.abspath(image_path)
image_name = os.path.basename(image_path)
if image_path_abs in self.images_traitees:
logger.warning(f"[OCR-LLM] Image déjà traitée, ignorée: {image_name}")
print(f" AgentVisionOCR: Image déjà traitée, ignorée: {image_name}")
return {
"extracted_text": "DUPLICATE - Already processed",
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id or self._extraire_ticket_id(image_path),
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"is_duplicate": True
}
self.images_traitees.add(image_path_abs)
logger.info(f"[OCR-LLM] Extraction OCR sur {image_name}")
print(f" AgentVisionOCR: Extraction OCR sur {image_name}")
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image introuvable: {image_path}")
if not hasattr(self.llm, "interroger_avec_image"):
raise RuntimeError("Le modèle ne supporte pas l'analyse d'images.")
ticket_id = ticket_id or self._extraire_ticket_id(image_path)
response = self.llm.interroger_avec_image(image_path, self.system_prompt)
if not response or "i cannot" in response.lower():
raise ValueError("Réponse vide ou invalide du modèle")
cleaned_text = response.strip()
model_name = getattr(self.llm, "pipeline_normalized_name",
getattr(self.llm, "modele", "llama3-vision-90b-instruct"))
model_name = model_name.replace(".", "-").replace(":", "-").replace("_", "-")
result = {
"extracted_text": cleaned_text,
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"model_info": {
"model": model_name,
**self.params
}
}
self.resultats.append(result)
logger.info(f"[OCR-LLM] OCR réussi ({len(cleaned_text)} caractères) pour {image_name}")
return result
except Exception as e:
error_result = {
"extracted_text": "",
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id or "UNKNOWN",
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"error": str(e),
"model_info": {
"model": getattr(self.llm, "pipeline_normalized_name", "llama3-vision-90b-instruct"),
**self.params
}
}
self.resultats.append(error_result)
logger.error(f"[OCR-LLM] Erreur OCR pour {image_name}: {e}")
return error_result
def sauvegarder_resultats(self, ticket_id: str = "T11143") -> None:
if not self.resultats:
logger.warning("[OCR-LLM] Aucun résultat à sauvegarder")
return
resultats_dedupliques = {}
for resultat in self.resultats:
image_path = resultat.get("image_path")
if not image_path:
continue
if image_path not in resultats_dedupliques or \
resultat.get("timestamp", "") > resultats_dedupliques[image_path].get("timestamp", ""):
resultats_dedupliques[image_path] = resultat
resultats_finaux = list(resultats_dedupliques.values())
try:
logger.info(f"[OCR-LLM] Sauvegarde de {len(resultats_finaux)} résultats")
sauvegarder_donnees(
ticket_id=ticket_id,
step_name="ocr_llm",
data=resultats_finaux,
base_dir=None,
is_resultat=True
)
self.resultats = []
self.images_traitees = set()
except Exception as e:
logger.error(f"[OCR-LLM] Erreur sauvegarde résultats: {e}")