llm_ticket3/agents/llama_vision/agent_vision_ocr.py
2025-04-30 17:45:17 +02:00

219 lines
7.5 KiB
Python

import os
import json
import logging
from datetime import datetime
from typing import Optional
from ..base_agent import BaseAgent
from ..utils.pipeline_logger import sauvegarder_donnees
logger = logging.getLogger("AgentVisionOCR")
class AgentVisionOCR(BaseAgent):
"""
Agent LlamaVision qui extrait du texte (OCR avancé) depuis une image.
"""
def __init__(self, llm):
super().__init__("AgentVisionOCR", llm)
# Configuration des paramètres du modèle
self.params = {
"stream": False,
"seed": 0,
#"stop_sequence": [],
"temperature": 1.5,
#"reasoning_effort": 0.5,
#"logit_bias": {},
"mirostat": 0,
"mirostat_eta": 0.1,
"mirostat_tau": 5.0,
"top_k": 40,
"top_p": 0.85,
"min_p": 0.05,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"repeat_penalty": 1.1,
"repeat_last_n": 128,
"tfs_z": 1.0,
"num_keep": 0,
"num_predict": 4096,
"num_ctx": 16384,
#"repeat_penalty": 1.1,
"num_batch": 2048,
#"mmap": True,
#"mlock": False,
#"num_thread": 4,
#"num_gpu": 1
}
# Prompt OCR optimisé
self.system_prompt = ("""
Extract all text from this technical document with laboratory-grade precision:
DOCUMENT STRUCTURE:
1. HEADER
* Title/Document name
* Reference numbers
* Date/Time stamps
* Laboratory identifiers
2. MAIN CONTENT
* Test names/methods
* Technical parameters
* Measurement values
* Units and scales
* Standard references
3. METADATA
* Protocol numbers
* Batch/Sample IDs
* Equipment references
* Operator information
4. SUPPLEMENTARY
* Notes/Remarks
* Warning messages
* System notifications
* Status indicators
Rules:
- Extract EVERY number, symbol, and abbreviation
- Maintain exact formatting of technical values
- Include all reference codes and standards
- Report partial or truncated information
- Capture system messages and alerts
- Note any calibration or verification data
Format: Use bullet points (*) for each text element, grouped by section
""")
self._configurer_llm()
self.resultats = []
self.images_traitees = set()
logger.info("AgentVisionOCR initialisé avec prompt amélioré.")
def _configurer_llm(self):
if hasattr(self.llm, "prompt_system"):
self.llm.prompt_system = self.system_prompt
if hasattr(self.llm, "configurer"):
self.llm.configurer(**self.params)
def _extraire_ticket_id(self, image_path):
if not image_path:
return "UNKNOWN"
segments = image_path.replace('\\', '/').split('/')
for segment in segments:
if segment.startswith('T') and segment[1:].isdigit():
return segment
if segment.startswith('ticket_T') and segment[8:].isdigit():
return 'T' + segment[8:]
return "UNKNOWN"
def executer(self, image_path: str, ocr_baseline: str = "", ticket_id: Optional[str] = None) -> dict:
image_path_abs = os.path.abspath(image_path)
image_name = os.path.basename(image_path)
if image_path_abs in self.images_traitees:
logger.warning(f"[OCR-LLM] Image déjà traitée, ignorée: {image_name}")
print(f" AgentVisionOCR: Image déjà traitée, ignorée: {image_name}")
return {
"extracted_text": "DUPLICATE - Already processed",
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id or self._extraire_ticket_id(image_path),
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"is_duplicate": True
}
self.images_traitees.add(image_path_abs)
logger.info(f"[OCR-LLM] Extraction OCR sur {image_name}")
print(f" AgentVisionOCR: Extraction OCR sur {image_name}")
ticket_id = ticket_id or self._extraire_ticket_id(image_path)
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image introuvable: {image_path}")
if not hasattr(self.llm, "interroger_avec_image"):
raise RuntimeError("Le modèle ne supporte pas l'analyse d'images.")
# Interroger le modèle
response = self.llm.interroger_avec_image(image_path, self.system_prompt)
if not response or "i cannot" in response.lower():
raise ValueError("Réponse vide ou invalide du modèle")
cleaned_text = response.strip()
model_name = getattr(self.llm, "pipeline_normalized_name",
getattr(self.llm, "modele", "llama3-vision-90b-instruct"))
model_name = model_name.replace(".", "-").replace(":", "-").replace("_", "-")
result = {
"extracted_text": cleaned_text,
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"model_info": {
"model": model_name,
**self.params
}
}
self.resultats.append(result)
logger.info(f"[OCR-LLM] OCR réussi ({len(cleaned_text)} caractères) pour {image_name}")
return result
except Exception as e:
error_result = {
"extracted_text": "",
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id or "UNKNOWN",
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"error": str(e),
"model_info": {
"model": getattr(self.llm, "pipeline_normalized_name", "llama3-vision-90b-instruct"),
**self.params
}
}
self.resultats.append(error_result)
logger.error(f"[OCR-LLM] Erreur OCR pour {image_name}: {e}")
return error_result
def sauvegarder_resultats(self, ticket_id: str = "T11143") -> None:
if not self.resultats:
logger.warning("[OCR-LLM] Aucun résultat à sauvegarder")
return
resultats_dedupliques = {}
for resultat in self.resultats:
image_path = resultat.get("image_path")
if not image_path:
continue
if image_path not in resultats_dedupliques or \
resultat.get("timestamp", "") > resultats_dedupliques[image_path].get("timestamp", ""):
resultats_dedupliques[image_path] = resultat
resultats_finaux = list(resultats_dedupliques.values())
try:
logger.info(f"[OCR-LLM] Sauvegarde de {len(resultats_finaux)} résultats")
sauvegarder_donnees(
ticket_id=ticket_id,
step_name="ocr_llm",
data=resultats_finaux,
base_dir=None,
is_resultat=True
)
self.resultats = []
self.images_traitees = set()
except Exception as e:
logger.error(f"[OCR-LLM] Erreur sauvegarde résultats: {e}")