mirror of
https://github.com/Ladebeze66/llm_lab_perso.git
synced 2025-12-13 09:06:50 +01:00
603 lines
22 KiB
Python
603 lines
22 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
API Server to integrate LLM Lab with Cursor and Obsidian
|
|
"""
|
|
from flask import Flask, request, jsonify, Response
|
|
from flask_cors import CORS
|
|
import json
|
|
import os
|
|
import logging
|
|
import time
|
|
import sys
|
|
import subprocess
|
|
import psutil
|
|
import requests
|
|
import argparse
|
|
|
|
# Add current directory to Python search path
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# Import LLM Lab modules
|
|
from utils.agent_manager import AgentManager
|
|
from utils.ollama_manager import ollama_manager
|
|
|
|
# Logging configuration
|
|
os.makedirs("logs", exist_ok=True)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler("logs/api_server.log"),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger("api_server")
|
|
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser(description="LLM Lab API Server")
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
|
args = parser.parse_args()
|
|
|
|
# Flask app initialization
|
|
app = Flask(__name__)
|
|
CORS(app) # Allow cross-origin requests
|
|
|
|
# Custom model override based on environment variables
|
|
cursor_model = os.environ.get("CURSOR_MODEL")
|
|
obsidian_model = os.environ.get("OBSIDIAN_MODEL")
|
|
|
|
# Log which models are being used for this instance
|
|
if cursor_model:
|
|
logger.info(f"Using custom model for Cursor: {cursor_model}")
|
|
if obsidian_model:
|
|
logger.info(f"Using custom model for Obsidian: {obsidian_model}")
|
|
|
|
# Initialize all required agents
|
|
# Extract default models from environment or use defaults
|
|
default_cursor_model = cursor_model or "codellama:13b-python"
|
|
default_obsidian_model = obsidian_model or "llama2:13b"
|
|
|
|
# Préparation au démarrage - précharger les modèles appropriés
|
|
logger.info("Initialisation du serveur API unifié...")
|
|
|
|
# Précharger les modèles si Ollama est disponible
|
|
if ollama_manager.is_ollama_available():
|
|
# Déterminer les modèles à précharger
|
|
models_to_preload = []
|
|
|
|
# Toujours inclure les modèles spécifiés dans les variables d'environnement
|
|
if cursor_model:
|
|
models_to_preload.append(cursor_model)
|
|
logger.info(f"Modèle Cursor (depuis variable d'env): {cursor_model}")
|
|
else:
|
|
models_to_preload.append("codellama:13b-python")
|
|
logger.info("Modèle Cursor (défaut): codellama:13b-python")
|
|
|
|
if obsidian_model:
|
|
models_to_preload.append(obsidian_model)
|
|
logger.info(f"Modèle Obsidian (depuis variable d'env): {obsidian_model}")
|
|
else:
|
|
models_to_preload.append("llama2:13b")
|
|
logger.info("Modèle Obsidian (défaut): llama2:13b")
|
|
|
|
# Précharger les modèles
|
|
logger.info(f"Préchargement des modèles: {', '.join(models_to_preload)}")
|
|
ollama_manager.preload_models(models_to_preload)
|
|
|
|
# Attendre quelques secondes pour laisser le temps au premier modèle de commencer à charger
|
|
logger.info("Attente de 10 secondes pour l'initialisation des modèles...")
|
|
time.sleep(10)
|
|
else:
|
|
logger.warning("Ollama n'est pas disponible. Le préchargement des modèles est ignoré.")
|
|
|
|
# Détecter le type de requête et choisir le modèle approprié
|
|
def detect_request_type(prompt, endpoint_type=None):
|
|
"""
|
|
Détermine le type de requête (code ou texte) et le modèle approprié
|
|
|
|
Args:
|
|
prompt: Le texte de la requête
|
|
endpoint_type: Le type d'endpoint appelé ('cursor', 'obsidian', ou None pour auto-détection)
|
|
|
|
Returns:
|
|
tuple: (type_requete, modele_recommandé)
|
|
"""
|
|
# Si l'endpoint est explicitement défini, utiliser le modèle correspondant
|
|
if endpoint_type == "cursor":
|
|
return "code", cursor_model or "codellama:13b-python"
|
|
elif endpoint_type == "obsidian":
|
|
return "text", obsidian_model or "llama2:13b"
|
|
|
|
# Indicateurs pour du code
|
|
code_indicators = [
|
|
"```", "function", "class", "def ", "import ", "sudo ", "npm ", "pip ",
|
|
"python", "javascript", "typescript", "html", "css", "ruby", "php", "java",
|
|
"json", "xml", "yaml", "bash", "shell", "powershell", "sql",
|
|
"for(", "if(", "while(", "switch(", "{", "}", "==", "=>", "!=", "||", "&&"
|
|
]
|
|
|
|
# Indicateurs pour du texte
|
|
text_indicators = [
|
|
"résumé", "résume", "explique", "explique-moi", "summarize", "explain",
|
|
"rédige", "écris", "write", "create a", "crée", "génère", "generate",
|
|
"markdown", "obsidian", "note", "article", "blog", "histoire", "story",
|
|
"essai", "dissertation", "rapport", "report", "livre", "book"
|
|
]
|
|
|
|
# Compter les occurrences
|
|
code_score = sum(1 for indicator in code_indicators if indicator.lower() in prompt.lower())
|
|
text_score = sum(1 for indicator in text_indicators if indicator.lower() in prompt.lower())
|
|
|
|
# Normaliser les scores en fonction du nombre d'indicateurs
|
|
code_score = code_score / len(code_indicators)
|
|
text_score = text_score / len(text_indicators)
|
|
|
|
# Décision basée sur les scores
|
|
if code_score > text_score:
|
|
return "code", cursor_model or "codellama:13b-python"
|
|
else:
|
|
return "text", obsidian_model or "llama2:13b"
|
|
|
|
# Fonction pour basculer le modèle en fonction du type de requête
|
|
def ensure_appropriate_model(prompt, endpoint_type=None):
|
|
"""
|
|
Assure que le modèle approprié est chargé en fonction de la requête
|
|
|
|
Args:
|
|
prompt: Le texte de la requête
|
|
endpoint_type: Le type d'endpoint appelé ('cursor', 'obsidian', ou None)
|
|
|
|
Returns:
|
|
str: Le modèle qui sera utilisé
|
|
"""
|
|
request_type, recommended_model = detect_request_type(prompt, endpoint_type)
|
|
|
|
# Vérifier si un changement de modèle est nécessaire
|
|
if ollama_manager.is_model_switch_needed(recommended_model):
|
|
logger.info(f"Détecté demande de type '{request_type}', basculement vers {recommended_model}")
|
|
ollama_manager.switch_model(recommended_model, max_wait=120)
|
|
else:
|
|
current_model = ollama_manager.get_running_model() or "inconnu"
|
|
logger.info(f"Requête de type '{request_type}', utilisation du modèle actuel: {current_model}")
|
|
|
|
return recommended_model
|
|
|
|
@app.route('/v1/chat/completions', methods=['POST'])
|
|
def chat_completion():
|
|
"""
|
|
OpenAI-compatible Chat API endpoint for Cursor
|
|
"""
|
|
try:
|
|
# Check for valid JSON request
|
|
if not request.is_json:
|
|
return jsonify({"error": "Request must contain valid JSON"}), 400
|
|
|
|
data = request.json or {} # Use empty dict as default if None
|
|
logger.info(f"Request received: {json.dumps(data)}")
|
|
|
|
# Extract messages and parameters
|
|
messages = data.get('messages', [])
|
|
model = data.get('model', 'codellama:13b-python')
|
|
temperature = data.get('temperature', 0.7)
|
|
|
|
# Build prompt from messages
|
|
system_message = next((msg['content'] for msg in messages if msg['role'] == 'system'), None)
|
|
user_messages = [msg['content'] for msg in messages if msg['role'] == 'user']
|
|
|
|
# Use last user message as prompt
|
|
prompt = user_messages[-1] if user_messages else ""
|
|
|
|
# Detect request type and ensure appropriate model is loaded
|
|
# This is the Cursor endpoint, so we force 'cursor' as endpoint type
|
|
ensure_appropriate_model(prompt, endpoint_type="cursor")
|
|
|
|
# Detect task type to choose appropriate agent
|
|
agent_name = "cursor" # Default
|
|
|
|
# Agent selection logic based on content
|
|
if "obsidian" in prompt.lower() or "markdown" in prompt.lower() or "note" in prompt.lower():
|
|
agent_name = "obsidian"
|
|
elif "javascript" in prompt.lower() or "js" in prompt.lower() or "html" in prompt.lower() or "css" in prompt.lower():
|
|
agent_name = "webdev"
|
|
elif "python" in prompt.lower():
|
|
agent_name = "python"
|
|
|
|
logger.info(f"Selected agent: {agent_name}")
|
|
|
|
# Create and configure agent
|
|
agent = AgentManager.create(agent_name)
|
|
|
|
# Apply model override from environment if available
|
|
# This allows specific instances to use specific models
|
|
if agent_name == "cursor" and cursor_model:
|
|
from core.factory import LLMFactory
|
|
from agents.roles import AGENTS # Importation pour éviter les erreurs
|
|
logger.info(f"Overriding model for cursor agent: {cursor_model}")
|
|
|
|
agent = LLMFactory.create(cursor_model)
|
|
agent.set_role(agent_name, AGENTS[agent_name])
|
|
elif agent_name == "obsidian" and obsidian_model:
|
|
from core.factory import LLMFactory
|
|
from agents.roles import AGENTS # Importation pour éviter les erreurs
|
|
logger.info(f"Overriding model for obsidian agent: {obsidian_model}")
|
|
|
|
agent = LLMFactory.create(obsidian_model)
|
|
agent.set_role(agent_name, AGENTS[agent_name])
|
|
|
|
# Replace system prompt if provided
|
|
if system_message:
|
|
agent.system_prompt = system_message
|
|
|
|
# Adjust parameters
|
|
agent.params["temperature"] = temperature
|
|
|
|
# Generate response
|
|
start_time = time.time()
|
|
response = agent.generate(prompt)
|
|
end_time = time.time()
|
|
|
|
generation_time = end_time - start_time
|
|
logger.info(f"Response generated for agent {agent_name} in {generation_time:.2f} seconds")
|
|
|
|
# OpenAI API compatible formatting
|
|
return jsonify({
|
|
"id": f"llmlab-{agent_name}-{hash(prompt) % 10000}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": agent.model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": response
|
|
},
|
|
"finish_reason": "stop"
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": len(prompt.split()),
|
|
"completion_tokens": len(response.split()),
|
|
"total_tokens": len(prompt.split()) + len(response.split())
|
|
}
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error: {str(e)}", exc_info=True)
|
|
return jsonify({
|
|
"error": {
|
|
"message": str(e),
|
|
"type": "server_error",
|
|
"code": 500
|
|
}
|
|
}), 500
|
|
|
|
@app.route('/v1/models', methods=['GET'])
|
|
def list_models():
|
|
"""
|
|
List available models (OpenAI compatible)
|
|
"""
|
|
agents = AgentManager.list_agents()
|
|
models = []
|
|
|
|
for agent_name, info in agents.items():
|
|
# Apply model overrides from environment variables
|
|
model_name = info['model']
|
|
if agent_name == "cursor" and cursor_model:
|
|
model_name = cursor_model
|
|
elif agent_name == "obsidian" and obsidian_model:
|
|
model_name = obsidian_model
|
|
|
|
models.append({
|
|
"id": model_name,
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "llmlab",
|
|
"permission": [{"id": agent_name, "object": "model_permission"}],
|
|
"root": model_name,
|
|
"parent": None
|
|
})
|
|
|
|
return jsonify({
|
|
"object": "list",
|
|
"data": models
|
|
})
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health_check():
|
|
"""
|
|
Server health check endpoint
|
|
"""
|
|
# Get current Ollama state
|
|
current_model = "none"
|
|
ollama_status = "unavailable"
|
|
|
|
if ollama_manager.is_ollama_available():
|
|
ollama_status = "online"
|
|
current_model = ollama_manager.get_running_model() or "unknown"
|
|
|
|
return jsonify({
|
|
"status": "healthy",
|
|
"version": "1.0.0",
|
|
"timestamp": int(time.time()),
|
|
"port": args.port,
|
|
"cursor_model": cursor_model,
|
|
"obsidian_model": obsidian_model,
|
|
"ollama_status": ollama_status,
|
|
"current_model": current_model
|
|
})
|
|
|
|
@app.route('/agents', methods=['GET'])
|
|
def list_agents():
|
|
"""
|
|
List available agents (custom endpoint)
|
|
"""
|
|
agents = AgentManager.list_agents()
|
|
return jsonify({
|
|
"agents": [
|
|
{
|
|
"name": name,
|
|
"model": cursor_model if name == "cursor" and cursor_model else
|
|
obsidian_model if name == "obsidian" and obsidian_model else
|
|
info['model'],
|
|
"description": info['description']
|
|
}
|
|
for name, info in agents.items()
|
|
]
|
|
})
|
|
|
|
@app.route('/running', methods=['GET'])
|
|
def running_models():
|
|
"""
|
|
Endpoint to check currently running models
|
|
"""
|
|
try:
|
|
# Try to get list of available models via Ollama API
|
|
ollama_available = ollama_manager.is_ollama_available()
|
|
available_models = ollama_manager.available_models
|
|
running_model = ollama_manager.get_running_model()
|
|
|
|
# Compatibility with previous implementation
|
|
running_models = []
|
|
if running_model:
|
|
running_models.append({
|
|
"name": running_model,
|
|
"status": "active",
|
|
"memory": "unknown"
|
|
})
|
|
|
|
return jsonify({
|
|
"ollama_available": ollama_available,
|
|
"available_models": available_models,
|
|
"running_models": running_models,
|
|
"current_model": running_model,
|
|
"cursor_model": cursor_model,
|
|
"obsidian_model": obsidian_model,
|
|
"timestamp": int(time.time())
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error: {str(e)}", exc_info=True)
|
|
return jsonify({
|
|
"error": str(e)
|
|
}), 500
|
|
|
|
@app.route('/generate', methods=['POST'])
|
|
def generate():
|
|
"""
|
|
Simplified endpoint for custom applications
|
|
"""
|
|
try:
|
|
# Check for valid JSON request
|
|
if not request.is_json:
|
|
return jsonify({"error": "Request must contain valid JSON"}), 400
|
|
|
|
data = request.json or {} # Use empty dict as default if None
|
|
prompt = data.get('prompt', '')
|
|
agent_name = data.get('agent', 'auto') # Par défaut, auto-détection
|
|
|
|
# Optional parameters
|
|
system_prompt = data.get('system_prompt', None)
|
|
temperature = data.get('temperature', None)
|
|
|
|
# Détection d'application - si l'URL contient un port spécifique
|
|
endpoint_type = None
|
|
if request.host.endswith(':8001'):
|
|
endpoint_type = "cursor"
|
|
elif request.host.endswith(':5001'):
|
|
endpoint_type = "obsidian"
|
|
|
|
# Si l'agent est spécifié explicitement
|
|
if agent_name == "cursor":
|
|
endpoint_type = "cursor"
|
|
elif agent_name == "obsidian":
|
|
endpoint_type = "obsidian"
|
|
elif agent_name == "auto":
|
|
# Auto-détection basée sur le contenu
|
|
endpoint_type = None
|
|
|
|
# Détecter le type et s'assurer que le bon modèle est chargé
|
|
logger.info(f"Analyse de la requête... Agent: {agent_name}, Endpoint: {endpoint_type}")
|
|
ensure_appropriate_model(prompt, endpoint_type)
|
|
|
|
# Déterminer l'agent optimal si 'auto' est spécifié
|
|
if agent_name == "auto":
|
|
request_type, _ = detect_request_type(prompt)
|
|
if request_type == "code":
|
|
agent_name = "cursor"
|
|
else:
|
|
agent_name = "obsidian"
|
|
logger.info(f"Agent auto-sélectionné en fonction du contenu: {agent_name}")
|
|
|
|
# Create agent
|
|
agent = AgentManager.create(agent_name)
|
|
|
|
# Apply model override from environment if available
|
|
if agent_name == "cursor" and cursor_model:
|
|
from core.factory import LLMFactory
|
|
from agents.roles import AGENTS
|
|
logger.info(f"Overriding model for cursor agent: {cursor_model}")
|
|
agent = LLMFactory.create(cursor_model)
|
|
agent.set_role(agent_name, AGENTS[agent_name])
|
|
elif agent_name == "obsidian" and obsidian_model:
|
|
from core.factory import LLMFactory
|
|
from agents.roles import AGENTS
|
|
logger.info(f"Overriding model for obsidian agent: {obsidian_model}")
|
|
agent = LLMFactory.create(obsidian_model)
|
|
agent.set_role(agent_name, AGENTS[agent_name])
|
|
|
|
# Apply custom parameters if provided
|
|
if system_prompt:
|
|
agent.system_prompt = system_prompt
|
|
|
|
if temperature is not None:
|
|
agent.params["temperature"] = temperature
|
|
|
|
# Generate response
|
|
start_time = time.time()
|
|
response = agent.generate(prompt)
|
|
generation_time = time.time() - start_time
|
|
|
|
return jsonify({
|
|
"response": response,
|
|
"agent": agent_name,
|
|
"model": agent.model,
|
|
"generation_time": generation_time
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error: {str(e)}", exc_info=True)
|
|
return jsonify({
|
|
"error": str(e)
|
|
}), 500
|
|
|
|
@app.route('/switch-model', methods=['POST'])
|
|
def switch_model():
|
|
"""
|
|
Endpoint to manually switch Ollama to a specific model
|
|
"""
|
|
try:
|
|
if not request.is_json:
|
|
return jsonify({"error": "Request must contain valid JSON"}), 400
|
|
|
|
data = request.json or {} # Utiliser un dictionnaire vide si json est None
|
|
model_name = data.get('model')
|
|
|
|
if not model_name:
|
|
return jsonify({"error": "Model name is required"}), 400
|
|
|
|
success = ollama_manager.switch_model(model_name)
|
|
|
|
if success:
|
|
return jsonify({
|
|
"status": "switching",
|
|
"model": model_name,
|
|
"message": f"Switching to model {model_name} in background"
|
|
})
|
|
else:
|
|
return jsonify({
|
|
"status": "error",
|
|
"message": f"Failed to switch to model {model_name}"
|
|
}), 400
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error switching model: {str(e)}", exc_info=True)
|
|
return jsonify({
|
|
"error": str(e)
|
|
}), 500
|
|
|
|
if __name__ == '__main__':
|
|
port = args.port
|
|
|
|
# Log which models are being used
|
|
model_info = ""
|
|
if cursor_model:
|
|
model_info += f"\n - Cursor override model: {cursor_model}"
|
|
if obsidian_model:
|
|
model_info += f"\n - Obsidian override model: {obsidian_model}"
|
|
|
|
print(f"=== LLM Lab API Server for Cursor and Obsidian ===")
|
|
print(f"Server started on http://localhost:{port}")
|
|
if model_info:
|
|
print(f"\nUsing custom models:{model_info}")
|
|
print()
|
|
|
|
# Show Ollama status
|
|
if ollama_manager.is_ollama_available():
|
|
print("Ollama status: Online")
|
|
current_model = ollama_manager.get_running_model()
|
|
if current_model:
|
|
print(f"Currently loaded model: {current_model}")
|
|
|
|
# Print list of available models
|
|
if ollama_manager.available_models:
|
|
print("\nAvailable models:")
|
|
for model in ollama_manager.available_models:
|
|
print(f" - {model}")
|
|
else:
|
|
print("Ollama status: Offline")
|
|
|
|
print("\nAvailable endpoints:")
|
|
print(f" - http://localhost:{port}/v1/chat/completions (OpenAI compatible)")
|
|
print(f" - http://localhost:{port}/v1/models (OpenAI compatible)")
|
|
print(f" - http://localhost:{port}/generate (Simplified API)")
|
|
print(f" - http://localhost:{port}/agents (agent list)")
|
|
print(f" - http://localhost:{port}/running (running models)")
|
|
print(f" - http://localhost:{port}/switch-model (manual model control)")
|
|
print(f" - http://localhost:{port}/health (status)")
|
|
print()
|
|
|
|
# Show specific usage based on port for clearer user guidance
|
|
if port == 8001:
|
|
print("For Cursor:")
|
|
print(" 1. Open Cursor")
|
|
print(" 2. Go to Settings > AI")
|
|
print(" 3. Select 'Custom endpoint'")
|
|
print(f" 4. Enter URL: http://localhost:{port}/v1")
|
|
elif port == 5001:
|
|
print("For Obsidian Text Generator plugin:")
|
|
print(" 1. In Obsidian, install the 'Text Generator' plugin")
|
|
print(" 2. Go to Text Generator settings")
|
|
print(" 3. Select 'Custom' endpoint")
|
|
print(f" 4. Enter URL: http://localhost:{port}/generate")
|
|
print(" 5. Set request method to POST")
|
|
print(" 6. Set completion endpoint to /generate")
|
|
else:
|
|
print("For Cursor:")
|
|
print(" 1. Open Cursor")
|
|
print(" 2. Go to Settings > AI")
|
|
print(" 3. Select 'Custom endpoint'")
|
|
print(f" 4. Enter URL: http://localhost:{port}/v1")
|
|
print()
|
|
print("For Obsidian Text Generator plugin:")
|
|
print(" 1. In Obsidian, install the 'Text Generator' plugin")
|
|
print(" 2. Go to Text Generator settings")
|
|
print(" 3. Select 'Custom' endpoint")
|
|
print(f" 4. Enter URL: http://localhost:{port}/generate")
|
|
print(" 5. Set request method to POST")
|
|
print(" 6. Set completion endpoint to /generate")
|
|
|
|
print()
|
|
print("Available agents:")
|
|
try:
|
|
for agent_name, info in AgentManager.list_agents().items():
|
|
# Show customized model for agents with override
|
|
model_display = cursor_model if agent_name == "cursor" and cursor_model else \
|
|
obsidian_model if agent_name == "obsidian" and obsidian_model else \
|
|
info['model']
|
|
print(f" - {agent_name}: {info['description']} ({model_display})")
|
|
except Exception as e:
|
|
print(f"Error listing agents: {str(e)}")
|
|
print("Make sure LLM Lab modules are correctly installed.")
|
|
print()
|
|
print("Logs: logs/api_server.log")
|
|
print("Press Ctrl+C to stop the server")
|
|
|
|
try:
|
|
# Import agents here to avoid circular imports
|
|
from agents.roles import AGENTS
|
|
except Exception as e:
|
|
logger.error(f"Error importing AGENTS: {str(e)}")
|
|
|
|
# Start server
|
|
app.run(host='0.0.0.0', port=port, debug=False) |