llm_ticket3/agents/llama_vision/agent_vision_ocr.py
2025-05-07 17:12:50 +02:00

254 lines
9.7 KiB
Python

import os
import json
import logging
from datetime import datetime
from typing import Optional
from pathlib import Path
from ..base_agent import BaseAgent
from ..utils.pipeline_logger import sauvegarder_donnees
from utils.ocr_avance.ocr_cleaner import clean_text_with_profiles
from utils.ocr_brut.ocr_utils import extraire_texte
from utils.ocr_avance.image_preparer import prepare_image_for_llama_vision
logger = logging.getLogger("AgentVisionOCR")
class AgentVisionOCR(BaseAgent):
"""
Agent LlamaVision qui extrait du texte (OCR avancé) depuis une image.
"""
def __init__(self, llm):
super().__init__("AgentVisionOCR", llm)
# Configuration des paramètres du modèle
self.params = {
"stream": False,
"seed": 0,
#"stop_sequence": [],
"temperature": 1.3,
#"reasoning_effort": 0.5,
#"logit_bias": {},
"mirostat": 0,
"mirostat_eta": 0.1,
"mirostat_tau": 5.0,
"top_k": 35,
"top_p": 0.85,
"min_p": 0.06,
"frequency_penalty": 0.15,
"presence_penalty": 0.1,
"repeat_penalty": 1.15,
"repeat_last_n": 128,
"tfs_z": 1.0,
"num_keep": 0,
"num_predict": 2048,
"num_ctx": 16384,
#"repeat_penalty": 1.1,
"num_batch": 2048,
#"mmap": True,
#"mlock": False,
#"num_thread": 4,
#"num_gpu": 1
}
# Prompt OCR optimisé
self.system_prompt = ("""You are tasked with performing an exhaustive OCR extraction on a technical or administrative web interface screenshot.
GOAL: Extract **every legible piece of text**, even partially visible, faded, or cropped. Structure your output for clarity. Do not guess, but always report what is visible.
FORMAT USING THESE CATEGORIES:
1. PAGE STRUCTURE
- Page titles
- Interface headers or section labels
- Navigation bars or visible URLs
2. IDENTIFIERS & DATA
- Operator or user names
- Sample IDs, test references
- Materials, dates, batch numbers
3. INTERFACE ELEMENTS (MANDATORY SCAN)
- Button labels (e.g., RAZ, SAVE)
- Tabs (e.g., MATERIAL, OBSERVATIONS)
- Sidebars, form field labels
4. SYSTEM MESSAGES
- Connection or server errors
- Domains, IP addresses, server notices
5. METADATA
- Standard references (e.g., "NF EN ####-#")
- Version numbers, document codes, timestamps
6. UNCLEAR / CROPPED TEXT
- Logos, partial lines (use "..." for truncated)
- Background/faded elements, labels not fully legible
RULES:
- Preserve punctuation, case, accents exactly.
- Include duplicates if text appears more than once.
- Never skip faint or partial text; use "..." if incomplete.
- Even if cropped, report as much as possible from any UI region.
This prompt is designed to generalize across all web portals, technical forms, or reports. Prioritize completeness over certainty. Do not ignore UI components or system messages.
""")
self._configurer_llm()
self.resultats = []
self.images_traitees = set()
logger.info("AgentVisionOCR initialisé avec prompt amélioré.")
def _configurer_llm(self):
if hasattr(self.llm, "prompt_system"):
self.llm.prompt_system = self.system_prompt
if hasattr(self.llm, "configurer"):
self.llm.configurer(**self.params)
def _extraire_ticket_id(self, image_path):
if not image_path:
return "UNKNOWN"
segments = image_path.replace('\\', '/').split('/')
for segment in segments:
if segment.startswith('T') and segment[1:].isdigit():
return segment
if segment.startswith('ticket_T') and segment[8:].isdigit():
return 'T' + segment[8:]
return "UNKNOWN"
def executer(self, image_path: str, ocr_baseline: str = "", ticket_id: Optional[str] = None) -> dict:
image_path_abs = os.path.abspath(image_path)
image_name = os.path.basename(image_path)
if image_path_abs in self.images_traitees:
logger.warning(f"[OCR-LLM] Image déjà traitée, ignorée: {image_name}")
print(f" AgentVisionOCR: Image déjà traitée, ignorée: {image_name}")
return {
"extracted_text": "DUPLICATE - Already processed",
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id or self._extraire_ticket_id(image_path),
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"is_duplicate": True
}
self.images_traitees.add(image_path_abs)
logger.info(f"[OCR-LLM] Extraction OCR sur {image_name}")
print(f" AgentVisionOCR: Extraction OCR sur {image_name}")
ticket_id = ticket_id or self._extraire_ticket_id(image_path)
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image introuvable: {image_path}")
if not hasattr(self.llm, "interroger_avec_image"):
raise RuntimeError("Le modèle ne supporte pas l'analyse d'images.")
# Étape 1 : OCR brut avec Tesseract
ocr_brut, _ = extraire_texte(image_path, lang="auto")
ocr_brut = ocr_brut.strip()
# Étape 2 : Enrichissement du prompt
if ocr_brut and len(ocr_brut) > 30:
contexte_ocr_brut = f"\n\nNote: the following raw text was detected with a traditional OCR tool. It may be incomplete or inaccurate. Use it to guide or validate your structured output:\n---\n{ocr_brut[:1000]}\n---\n"
else:
contexte_ocr_brut = ""
prompt_avec_contexte = self.system_prompt + contexte_ocr_brut
# ✅ Étape 3 : Préparation de l'image pour le modèle Vision
image_stem = Path(image_path).stem
# Utiliser le nouveau chemin pour les résultats OCR avancé
os.makedirs("results/ocr_avance", exist_ok=True)
vision_ready_path = os.path.join("results/ocr_avance", f"vision_ready_{image_stem}.png")
prepare_image_for_llama_vision(image_path, vision_ready_path)
# Étape 4 : Appel au modèle avec image traitée
response = self.llm.interroger_avec_image(vision_ready_path, prompt_avec_contexte)
if not response or "i cannot" in response.lower():
raise ValueError("Réponse vide ou invalide du modèle")
cleaned_text = clean_text_with_profiles(response.strip(), active_profiles=("ocr",))
model_name = getattr(self.llm, "pipeline_normalized_name",
getattr(self.llm, "modele", "llama3-vision-90b-instruct"))
model_name = model_name.replace(".", "-").replace(":", "-").replace("_", "-")
# Sauvegarde du résultat dans results/ocr_avance
try:
result_dir = "results/ocr_avance"
os.makedirs(result_dir, exist_ok=True)
with open(f"{result_dir}/ocr_{image_stem}.txt", "w", encoding="utf-8") as f:
f.write(cleaned_text)
except Exception as e:
logger.error(f"[OCR-LLM] Erreur sauvegarde texte: {e}")
result = {
"extracted_text": cleaned_text,
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"model_info": {
"model": model_name,
**self.params
}
}
self.resultats.append(result)
logger.info(f"[OCR-LLM] OCR réussi ({len(cleaned_text)} caractères) pour {image_name}")
return result
except Exception as e:
error_result = {
"extracted_text": "",
"image_name": image_name,
"image_path": image_path_abs,
"ticket_id": ticket_id or "UNKNOWN",
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"source_agent": self.nom,
"error": str(e),
"model_info": {
"model": getattr(self.llm, "pipeline_normalized_name", "llama3-vision-90b-instruct"),
**self.params
}
}
self.resultats.append(error_result)
logger.error(f"[OCR-LLM] Erreur OCR pour {image_name}: {e}")
return error_result
def sauvegarder_resultats(self, ticket_id: str = "T11143") -> None:
if not self.resultats:
logger.warning("[OCR-LLM] Aucun résultat à sauvegarder")
return
resultats_dedupliques = {}
for resultat in self.resultats:
image_path = resultat.get("image_path")
if not image_path:
continue
if image_path not in resultats_dedupliques or \
resultat.get("timestamp", "") > resultats_dedupliques[image_path].get("timestamp", ""):
resultats_dedupliques[image_path] = resultat
resultats_finaux = list(resultats_dedupliques.values())
try:
logger.info(f"[OCR-LLM] Sauvegarde de {len(resultats_finaux)} résultats")
sauvegarder_donnees(
ticket_id=ticket_id,
step_name="ocr_llm",
data=resultats_finaux,
base_dir=None,
is_resultat=True
)
self.resultats = []
self.images_traitees = set()
except Exception as e:
logger.error(f"[OCR-LLM] Erreur sauvegarde résultats: {e}")