mirror of
https://github.com/Ladebeze66/ragflow_preprocess.git
synced 2026-02-04 05:50:26 +01:00
185 lines
6.8 KiB
Python
185 lines
6.8 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Agent for optical character recognition (OCR) in images
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import uuid
|
|
import re
|
|
from typing import Dict, Optional, List, Any, Union
|
|
import pytesseract
|
|
from PIL import Image
|
|
import io
|
|
import platform
|
|
|
|
from .base import LLMBaseAgent
|
|
|
|
class OCRAgent(LLMBaseAgent):
|
|
"""
|
|
Agent for optical character recognition (OCR)
|
|
"""
|
|
|
|
def __init__(self, model_name: str = "ocr", endpoint: str = "", **config):
|
|
"""
|
|
Initialize the OCR agent
|
|
|
|
Args:
|
|
model_name (str): Model name (default "ocr" as OCR doesn't use LLM models)
|
|
endpoint (str): API endpoint (not used for OCR)
|
|
**config: Additional configuration like language, etc.
|
|
"""
|
|
# Appeler le constructeur parent avec les paramètres requis
|
|
super().__init__(model_name, endpoint, **config)
|
|
|
|
# Default configuration for OCR
|
|
default_config = {
|
|
"language": "fra", # Default language: French
|
|
"tesseract_config": "--psm 1 --oem 3", # Default Tesseract config
|
|
}
|
|
|
|
# Merge with provided configuration
|
|
self.config.update(default_config)
|
|
for key, value in default_config.items():
|
|
if key not in self.config:
|
|
self.config[key] = value
|
|
|
|
# Windows-specific configuration
|
|
if platform.system() == "Windows":
|
|
# Possible paths for Tesseract on Windows
|
|
possible_paths = [
|
|
r"C:\Program Files\Tesseract-OCR\tesseract.exe",
|
|
r"C:\Program Files (x86)\Tesseract-OCR\tesseract.exe",
|
|
r"C:\Tesseract-OCR\tesseract.exe",
|
|
r"C:\Users\PCDEV\AppData\Local\Programs\Tesseract-OCR\tesseract.exe",
|
|
r"C:\Users\PCDEV\Tesseract-OCR\tesseract.exe"
|
|
]
|
|
|
|
# Look for Tesseract in possible paths
|
|
tesseract_path = None
|
|
for path in possible_paths:
|
|
if os.path.exists(path):
|
|
tesseract_path = path
|
|
break
|
|
|
|
# Configure pytesseract with the found path
|
|
if tesseract_path:
|
|
self.config["tesseract_path"] = tesseract_path
|
|
pytesseract.pytesseract.tesseract_cmd = tesseract_path
|
|
print(f"Tesseract found at: {tesseract_path}")
|
|
else:
|
|
print("WARNING: Tesseract was not found in standard paths.")
|
|
print("Please install Tesseract OCR from: https://github.com/UB-Mannheim/tesseract/wiki")
|
|
print("Or manually specify the path with the tesseract_path parameter")
|
|
|
|
# If a path is provided in the configuration, use it anyway
|
|
if "tesseract_path" in self.config:
|
|
pytesseract.pytesseract.tesseract_cmd = self.config["tesseract_path"]
|
|
|
|
# Create directory for OCR logs
|
|
self.log_dir = os.path.join("data", "ocr_logs")
|
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
|
|
def generate(self, prompt: str = "", images: Optional[List[bytes]] = None) -> str:
|
|
"""
|
|
Perform optical character recognition on provided images
|
|
|
|
Args:
|
|
prompt (str, optional): Not used for OCR
|
|
images (List[bytes], optional): List of images to process in bytes
|
|
|
|
Returns:
|
|
str: Text extracted from images
|
|
"""
|
|
if not images:
|
|
return "Error: No images provided for OCR"
|
|
|
|
results = []
|
|
image_count = len(images)
|
|
|
|
# Generate unique ID for this OCR session
|
|
ocr_id = str(uuid.uuid4())[:8]
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
|
|
for i, img_bytes in enumerate(images):
|
|
try:
|
|
# Open image from bytes
|
|
img = Image.open(io.BytesIO(img_bytes))
|
|
|
|
# Perform OCR with Tesseract
|
|
lang = self.config.get("language", "fra")
|
|
config = self.config.get("tesseract_config", "--psm 1 --oem 3")
|
|
|
|
text = pytesseract.image_to_string(img, lang=lang, config=config)
|
|
|
|
# Basic text cleaning
|
|
text = self._clean_text(text)
|
|
|
|
if text:
|
|
results.append(text)
|
|
|
|
# Save image and OCR result
|
|
image_path = os.path.join(self.log_dir, f"{timestamp}_{ocr_id}_img{i+1}.png")
|
|
img.save(image_path, "PNG")
|
|
|
|
# Save extracted text
|
|
text_path = os.path.join(self.log_dir, f"{timestamp}_{ocr_id}_img{i+1}_ocr.txt")
|
|
with open(text_path, "w", encoding="utf-8") as f:
|
|
f.write(f"OCR Language: {lang}\n")
|
|
f.write(f"Tesseract config: {config}\n\n")
|
|
f.write(text)
|
|
|
|
print(f"OCR performed on image {i+1}/{image_count}, saved to: {text_path}")
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error processing image {i+1}: {str(e)}"
|
|
print(error_msg)
|
|
|
|
# Log the error
|
|
error_path = os.path.join(self.log_dir, f"{timestamp}_{ocr_id}_img{i+1}_error.txt")
|
|
with open(error_path, "w", encoding="utf-8") as f:
|
|
f.write(f"Error processing image {i+1}:\n{str(e)}")
|
|
|
|
# Add error message to results
|
|
results.append(f"[OCR Error on image {i+1}: {str(e)}]")
|
|
|
|
# Combine all extracted texts
|
|
if not results:
|
|
return "No text could be extracted from the provided images."
|
|
|
|
combined_result = "\n\n".join(results)
|
|
|
|
# Save combined result
|
|
combined_path = os.path.join(self.log_dir, f"{timestamp}_{ocr_id}_combined.txt")
|
|
with open(combined_path, "w", encoding="utf-8") as f:
|
|
f.write(f"OCR Language: {self.config.get('language', 'fra')}\n")
|
|
f.write(f"Number of images: {image_count}\n\n")
|
|
f.write(combined_result)
|
|
|
|
return combined_result
|
|
|
|
def _clean_text(self, text: str) -> str:
|
|
"""
|
|
Clean the text extracted by OCR
|
|
|
|
Args:
|
|
text (str): Raw text to clean
|
|
|
|
Returns:
|
|
str: Cleaned text
|
|
"""
|
|
if not text:
|
|
return ""
|
|
|
|
# Remove spaces at beginning and end
|
|
text = text.strip()
|
|
|
|
# Remove multiple empty lines
|
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
|
|
|
# Remove non-printable characters
|
|
text = ''.join(c for c in text if c.isprintable() or c == '\n')
|
|
|
|
return text |