mirror of
https://github.com/Ladebeze66/llm_ticket3.git
synced 2025-12-16 20:47:49 +01:00
232 lines
8.2 KiB
Python
232 lines
8.2 KiB
Python
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
|
|
import json
|
|
import weakref
|
|
from typing import Any, Optional, cast
|
|
|
|
import google.auth
|
|
import google.auth.credentials
|
|
import google.auth.transport
|
|
import google.auth.transport.requests
|
|
import httpx
|
|
|
|
from mistralai_gcp import models
|
|
from mistralai_gcp._hooks import BeforeRequestHook, SDKHooks
|
|
from mistralai_gcp.chat import Chat
|
|
from mistralai_gcp.fim import Fim
|
|
from mistralai_gcp.types import UNSET, OptionalNullable
|
|
|
|
from .basesdk import BaseSDK
|
|
from .httpclient import AsyncHttpClient, ClientOwner, HttpClient, close_clients
|
|
from .sdkconfiguration import SDKConfiguration
|
|
from .utils.logger import Logger, get_default_logger
|
|
from .utils.retries import RetryConfig
|
|
|
|
LEGACY_MODEL_ID_FORMAT = {
|
|
"codestral-2405": "codestral@2405",
|
|
"mistral-large-2407": "mistral-large@2407",
|
|
"mistral-nemo-2407": "mistral-nemo@2407",
|
|
}
|
|
|
|
|
|
def get_model_info(model: str) -> tuple[str, str]:
|
|
# if the model requiers the legacy fomat, use it, else do nothing.
|
|
if model in LEGACY_MODEL_ID_FORMAT:
|
|
return "-".join(model.split("-")[:-1]), LEGACY_MODEL_ID_FORMAT[model]
|
|
return model, model
|
|
|
|
|
|
class MistralGoogleCloud(BaseSDK):
|
|
r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""
|
|
|
|
chat: Chat
|
|
r"""Chat Completion API."""
|
|
fim: Fim
|
|
r"""Fill-in-the-middle API."""
|
|
|
|
def __init__(
|
|
self,
|
|
region: str = "europe-west4",
|
|
project_id: Optional[str] = None,
|
|
access_token: Optional[str] = None,
|
|
client: Optional[HttpClient] = None,
|
|
async_client: Optional[AsyncHttpClient] = None,
|
|
retry_config: OptionalNullable[RetryConfig] = UNSET,
|
|
timeout_ms: Optional[int] = None,
|
|
debug_logger: Optional[Logger] = None,
|
|
) -> None:
|
|
r"""Instantiates the SDK configuring it with the provided parameters.
|
|
|
|
:param api_key: The api_key required for authentication
|
|
:param server: The server by name to use for all methods
|
|
:param server_url: The server URL to use for all methods
|
|
:param url_params: Parameters to optionally template the server URL with
|
|
:param client: The HTTP client to use for all synchronous methods
|
|
:param async_client: The Async HTTP client to use for all asynchronous methods
|
|
:param retry_config: The retry configuration to use for all supported methods
|
|
:param timeout_ms: Optional request timeout applied to each operation in milliseconds
|
|
"""
|
|
|
|
if not access_token:
|
|
credentials, loaded_project_id = google.auth.default(
|
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
)
|
|
credentials.refresh(google.auth.transport.requests.Request())
|
|
|
|
if not isinstance(credentials, google.auth.credentials.Credentials):
|
|
raise models.SDKError(
|
|
"credentials must be an instance of google.auth.credentials.Credentials"
|
|
)
|
|
|
|
project_id = project_id or loaded_project_id
|
|
|
|
if project_id is None:
|
|
raise models.SDKError("project_id must be provided")
|
|
|
|
def auth_token() -> str:
|
|
if access_token:
|
|
return access_token
|
|
|
|
credentials.refresh(google.auth.transport.requests.Request())
|
|
token = credentials.token
|
|
if not token:
|
|
raise models.SDKError("Failed to get token from credentials")
|
|
return token
|
|
|
|
client_supplied = True
|
|
if client is None:
|
|
client = httpx.Client()
|
|
client_supplied = False
|
|
|
|
assert issubclass(
|
|
type(client), HttpClient
|
|
), "The provided client must implement the HttpClient protocol."
|
|
|
|
async_client_supplied = True
|
|
if async_client is None:
|
|
async_client = httpx.AsyncClient()
|
|
async_client_supplied = False
|
|
|
|
if debug_logger is None:
|
|
debug_logger = get_default_logger()
|
|
|
|
assert issubclass(
|
|
type(async_client), AsyncHttpClient
|
|
), "The provided async_client must implement the AsyncHttpClient protocol."
|
|
|
|
security: Any = None
|
|
if callable(auth_token):
|
|
security = lambda: models.Security( # pylint: disable=unnecessary-lambda-assignment
|
|
api_key=auth_token()
|
|
)
|
|
else:
|
|
security = models.Security(api_key=auth_token)
|
|
|
|
BaseSDK.__init__(
|
|
self,
|
|
SDKConfiguration(
|
|
client=client,
|
|
client_supplied=client_supplied,
|
|
async_client=async_client,
|
|
async_client_supplied=async_client_supplied,
|
|
security=security,
|
|
server_url=f"https://{region}-aiplatform.googleapis.com",
|
|
server=None,
|
|
retry_config=retry_config,
|
|
timeout_ms=timeout_ms,
|
|
debug_logger=debug_logger,
|
|
),
|
|
)
|
|
|
|
hooks = SDKHooks()
|
|
hook = GoogleCloudBeforeRequestHook(region, project_id)
|
|
hooks.register_before_request_hook(hook)
|
|
current_server_url, *_ = self.sdk_configuration.get_server_details()
|
|
server_url, self.sdk_configuration.client = hooks.sdk_init(
|
|
current_server_url, client
|
|
)
|
|
if current_server_url != server_url:
|
|
self.sdk_configuration.server_url = server_url
|
|
|
|
# pylint: disable=protected-access
|
|
self.sdk_configuration.__dict__["_hooks"] = hooks
|
|
|
|
weakref.finalize(
|
|
self,
|
|
close_clients,
|
|
cast(ClientOwner, self.sdk_configuration),
|
|
self.sdk_configuration.client,
|
|
self.sdk_configuration.client_supplied,
|
|
self.sdk_configuration.async_client,
|
|
self.sdk_configuration.async_client_supplied,
|
|
)
|
|
|
|
self._init_sdks()
|
|
|
|
def _init_sdks(self):
|
|
self.chat = Chat(self.sdk_configuration)
|
|
self.fim = Fim(self.sdk_configuration)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if (
|
|
self.sdk_configuration.client is not None
|
|
and not self.sdk_configuration.client_supplied
|
|
):
|
|
self.sdk_configuration.client.close()
|
|
self.sdk_configuration.client = None
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if (
|
|
self.sdk_configuration.async_client is not None
|
|
and not self.sdk_configuration.async_client_supplied
|
|
):
|
|
await self.sdk_configuration.async_client.aclose()
|
|
self.sdk_configuration.async_client = None
|
|
|
|
|
|
class GoogleCloudBeforeRequestHook(BeforeRequestHook):
|
|
def __init__(self, region: str, project_id: str):
|
|
self.region = region
|
|
self.project_id = project_id
|
|
|
|
def before_request(
|
|
self, hook_ctx, request: httpx.Request
|
|
) -> httpx.Request | Exception:
|
|
# The goal of this function is to template in the region, project and model into the URL path
|
|
# We do this here so that the API remains more user-friendly
|
|
model_id = None
|
|
new_content = None
|
|
if request.content:
|
|
parsed = json.loads(request.content.decode("utf-8"))
|
|
model_raw = parsed.get("model")
|
|
model_name, model_id = get_model_info(model_raw)
|
|
parsed["model"] = model_name
|
|
new_content = json.dumps(parsed).encode("utf-8")
|
|
|
|
if model_id == "":
|
|
raise models.SDKError("model must be provided")
|
|
|
|
stream = "streamRawPredict" in request.url.path
|
|
specifier = "streamRawPredict" if stream else "rawPredict"
|
|
url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}"
|
|
|
|
headers = dict(request.headers)
|
|
# Delete content-length header as it will need to be recalculated
|
|
headers.pop("content-length", None)
|
|
|
|
next_request = httpx.Request(
|
|
method=request.method,
|
|
url=request.url.copy_with(path=url),
|
|
headers=headers,
|
|
content=new_content,
|
|
stream=None,
|
|
)
|
|
|
|
return next_request
|