mirror of
https://github.com/Ladebeze66/ragflow_preprocess.git
synced 2026-02-04 05:30:26 +01:00
193 lines
8.4 KiB
Python
193 lines
8.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Agent pour l'analyse d'images et de schémas
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import uuid
|
|
from PIL import Image
|
|
import io
|
|
from typing import List, Optional, Dict, Any
|
|
|
|
from .base import LLMBaseAgent
|
|
from utils.api_ollama import OllamaAPI
|
|
|
|
class VisionAgent(LLMBaseAgent):
|
|
"""
|
|
Agent pour l'analyse d'images avec des modèles multimodaux
|
|
"""
|
|
|
|
def __init__(self, model_name: str, endpoint: str = "http://217.182.105.173:11434", **config):
|
|
"""
|
|
Initialise l'agent de vision
|
|
|
|
Args:
|
|
model_name (str): Nom du modèle à utiliser
|
|
endpoint (str): URL de l'API Ollama
|
|
**config: Configuration supplémentaire
|
|
"""
|
|
super().__init__(model_name, endpoint, **config)
|
|
|
|
# Configuration par défaut
|
|
default_config = {
|
|
"save_images": True # Enregistrer les images par défaut
|
|
}
|
|
|
|
# Mettre à jour la configuration avec les valeurs par défaut si non spécifiées
|
|
for key, value in default_config.items():
|
|
if key not in self.config:
|
|
self.config[key] = value
|
|
|
|
# Création du répertoire pour sauvegarder les images analysées
|
|
self.image_dir = os.path.join("data", "images")
|
|
os.makedirs(self.image_dir, exist_ok=True)
|
|
|
|
def generate(self, prompt: Optional[str] = "", images: Optional[List[bytes]] = None,
|
|
selection_type: str = "autre", context: Optional[str] = "") -> str:
|
|
"""
|
|
Génère une description ou une analyse d'une image
|
|
|
|
Args:
|
|
prompt (str, optional): Prompt supplémentaire (non utilisé)
|
|
images (List[bytes], optional): Liste d'images à analyser
|
|
selection_type (str): Type de la sélection (schéma, tableau, formule...)
|
|
context (str, optional): Contexte textuel
|
|
|
|
Returns:
|
|
str: Description générée par le modèle
|
|
"""
|
|
if not images or len(images) == 0:
|
|
return "Erreur: Aucune image fournie pour l'analyse"
|
|
|
|
image_data = images[0]
|
|
|
|
# Sauvegarder l'image pour référence future seulement si l'option est activée
|
|
image_id = str(uuid.uuid4())[:8]
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
image_filename = f"{timestamp}_{image_id}.png"
|
|
image_path = os.path.join(self.image_dir, image_filename)
|
|
|
|
if self.config.get("save_images", True):
|
|
try:
|
|
# Sauvegarder l'image
|
|
img = Image.open(io.BytesIO(image_data))
|
|
img.save(image_path)
|
|
print(f"Image sauvegardée: {image_path}")
|
|
except Exception as e:
|
|
print(f"Erreur lors de la sauvegarde de l'image: {str(e)}")
|
|
|
|
# Construction du prompt en anglais pour le modèle
|
|
system_prompt = "Analyze the following image"
|
|
|
|
# Mapper les types de sélection en français vers l'anglais
|
|
content_type_mapping = {
|
|
"schéma": "diagram",
|
|
"tableau": "table",
|
|
"formule": "formula",
|
|
"graphique": "chart",
|
|
"autre": "content"
|
|
}
|
|
|
|
# Obtenir le type en anglais ou utiliser le type original
|
|
content_type_en = content_type_mapping.get(selection_type.lower(), selection_type)
|
|
system_prompt += f" which contains {content_type_en}"
|
|
|
|
# Ajout d'instructions spécifiques selon le type de contenu
|
|
if content_type_en == "diagram":
|
|
system_prompt += ". Please describe in detail what this diagram shows, including all components, connections, and what it represents."
|
|
elif content_type_en == "table":
|
|
system_prompt += ". Please extract and format the table content, describing its structure, headers, and data. If possible, recreate the table structure."
|
|
elif content_type_en == "formula" or content_type_en == "equation":
|
|
system_prompt += ". Please transcribe this mathematical formula/equation and explain what it represents and its components."
|
|
elif content_type_en == "chart" or content_type_en == "graph":
|
|
system_prompt += ". Please describe this chart/graph in detail, including the axes, data points, trends, and what information it conveys."
|
|
else:
|
|
system_prompt += ". Please provide a detailed description of what you see in this image."
|
|
|
|
# Ajouter des instructions générales
|
|
system_prompt += "\n\nPlease be detailed and precise in your analysis."
|
|
|
|
# Préparer le prompt avec le contexte
|
|
user_prompt = ""
|
|
if context and context.strip():
|
|
# Le contexte est déjà en français, pas besoin de le traduire
|
|
# mais préciser explicitement que c'est en français pour le modèle
|
|
user_prompt = f"Here is additional context that may help with your analysis (may be in French):\n{context}"
|
|
|
|
# Créer l'API Ollama pour l'appel direct
|
|
api = OllamaAPI(base_url=self.endpoint)
|
|
|
|
# Journaliser le prompt complet
|
|
full_prompt = f"System: {system_prompt}\n\nUser: {user_prompt}"
|
|
print(f"Envoi du prompt au modèle {self.model_name}:\n{full_prompt}")
|
|
|
|
try:
|
|
# Pour les modèles qui supportent le format de chat
|
|
if "llama" in self.model_name.lower() or "llava" in self.model_name.lower():
|
|
# Formater en tant que messages de chat
|
|
messages = [
|
|
{"role": "system", "content": system_prompt}
|
|
]
|
|
|
|
if user_prompt:
|
|
messages.append({"role": "user", "content": user_prompt})
|
|
|
|
response = api.chat(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
images=[image_data],
|
|
options={
|
|
"temperature": self.config.get("temperature", 0.2),
|
|
"top_p": self.config.get("top_p", 0.95),
|
|
"top_k": self.config.get("top_k", 40),
|
|
"num_predict": self.config.get("max_tokens", 1024)
|
|
}
|
|
)
|
|
|
|
if "message" in response and "content" in response["message"]:
|
|
result = response["message"]["content"]
|
|
else:
|
|
result = response.get("response", "Erreur: Format de réponse inattendu")
|
|
else:
|
|
# Format de génération standard pour les autres modèles
|
|
prompt_text = system_prompt
|
|
if user_prompt:
|
|
prompt_text += f"\n\n{user_prompt}"
|
|
|
|
response = api.generate(
|
|
model=self.model_name,
|
|
prompt=prompt_text,
|
|
images=[image_data],
|
|
options={
|
|
"temperature": self.config.get("temperature", 0.2),
|
|
"top_p": self.config.get("top_p", 0.95),
|
|
"top_k": self.config.get("top_k", 40),
|
|
"num_predict": self.config.get("max_tokens", 1024)
|
|
}
|
|
)
|
|
|
|
result = response.get("response", "Erreur: Pas de réponse")
|
|
|
|
# Enregistrer la réponse dans un fichier si l'option d'enregistrement est activée
|
|
if self.config.get("save_images", True):
|
|
response_path = os.path.join(self.image_dir, f"{timestamp}_{image_id}_response.txt")
|
|
with open(response_path, "w", encoding="utf-8") as f:
|
|
f.write(f"Prompt:\n{full_prompt}\n\nResponse:\n{result}")
|
|
|
|
print(f"Réponse enregistrée dans: {response_path}")
|
|
return result
|
|
|
|
except Exception as e:
|
|
error_msg = f"Erreur lors de l'analyse de l'image: {str(e)}"
|
|
print(error_msg)
|
|
|
|
# Enregistrer l'erreur si l'option d'enregistrement est activée
|
|
if self.config.get("save_images", True):
|
|
error_path = os.path.join(self.image_dir, f"{timestamp}_{image_id}_error.txt")
|
|
with open(error_path, "w", encoding="utf-8") as f:
|
|
f.write(f"Prompt:\n{full_prompt}\n\nError:\n{str(e)}")
|
|
|
|
return error_msg |