2025-04-02 09:01:55 +02:00

232 lines
8.2 KiB
Python

"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
import json
import weakref
from typing import Any, Optional, cast
import google.auth
import google.auth.credentials
import google.auth.transport
import google.auth.transport.requests
import httpx
from mistralai_gcp import models
from mistralai_gcp._hooks import BeforeRequestHook, SDKHooks
from mistralai_gcp.chat import Chat
from mistralai_gcp.fim import Fim
from mistralai_gcp.types import UNSET, OptionalNullable
from .basesdk import BaseSDK
from .httpclient import AsyncHttpClient, ClientOwner, HttpClient, close_clients
from .sdkconfiguration import SDKConfiguration
from .utils.logger import Logger, get_default_logger
from .utils.retries import RetryConfig
LEGACY_MODEL_ID_FORMAT = {
"codestral-2405": "codestral@2405",
"mistral-large-2407": "mistral-large@2407",
"mistral-nemo-2407": "mistral-nemo@2407",
}
def get_model_info(model: str) -> tuple[str, str]:
# if the model requiers the legacy fomat, use it, else do nothing.
if model in LEGACY_MODEL_ID_FORMAT:
return "-".join(model.split("-")[:-1]), LEGACY_MODEL_ID_FORMAT[model]
return model, model
class MistralGoogleCloud(BaseSDK):
r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""
chat: Chat
r"""Chat Completion API."""
fim: Fim
r"""Fill-in-the-middle API."""
def __init__(
self,
region: str = "europe-west4",
project_id: Optional[str] = None,
access_token: Optional[str] = None,
client: Optional[HttpClient] = None,
async_client: Optional[AsyncHttpClient] = None,
retry_config: OptionalNullable[RetryConfig] = UNSET,
timeout_ms: Optional[int] = None,
debug_logger: Optional[Logger] = None,
) -> None:
r"""Instantiates the SDK configuring it with the provided parameters.
:param api_key: The api_key required for authentication
:param server: The server by name to use for all methods
:param server_url: The server URL to use for all methods
:param url_params: Parameters to optionally template the server URL with
:param client: The HTTP client to use for all synchronous methods
:param async_client: The Async HTTP client to use for all asynchronous methods
:param retry_config: The retry configuration to use for all supported methods
:param timeout_ms: Optional request timeout applied to each operation in milliseconds
"""
if not access_token:
credentials, loaded_project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
credentials.refresh(google.auth.transport.requests.Request())
if not isinstance(credentials, google.auth.credentials.Credentials):
raise models.SDKError(
"credentials must be an instance of google.auth.credentials.Credentials"
)
project_id = project_id or loaded_project_id
if project_id is None:
raise models.SDKError("project_id must be provided")
def auth_token() -> str:
if access_token:
return access_token
credentials.refresh(google.auth.transport.requests.Request())
token = credentials.token
if not token:
raise models.SDKError("Failed to get token from credentials")
return token
client_supplied = True
if client is None:
client = httpx.Client()
client_supplied = False
assert issubclass(
type(client), HttpClient
), "The provided client must implement the HttpClient protocol."
async_client_supplied = True
if async_client is None:
async_client = httpx.AsyncClient()
async_client_supplied = False
if debug_logger is None:
debug_logger = get_default_logger()
assert issubclass(
type(async_client), AsyncHttpClient
), "The provided async_client must implement the AsyncHttpClient protocol."
security: Any = None
if callable(auth_token):
security = lambda: models.Security( # pylint: disable=unnecessary-lambda-assignment
api_key=auth_token()
)
else:
security = models.Security(api_key=auth_token)
BaseSDK.__init__(
self,
SDKConfiguration(
client=client,
client_supplied=client_supplied,
async_client=async_client,
async_client_supplied=async_client_supplied,
security=security,
server_url=f"https://{region}-aiplatform.googleapis.com",
server=None,
retry_config=retry_config,
timeout_ms=timeout_ms,
debug_logger=debug_logger,
),
)
hooks = SDKHooks()
hook = GoogleCloudBeforeRequestHook(region, project_id)
hooks.register_before_request_hook(hook)
current_server_url, *_ = self.sdk_configuration.get_server_details()
server_url, self.sdk_configuration.client = hooks.sdk_init(
current_server_url, client
)
if current_server_url != server_url:
self.sdk_configuration.server_url = server_url
# pylint: disable=protected-access
self.sdk_configuration.__dict__["_hooks"] = hooks
weakref.finalize(
self,
close_clients,
cast(ClientOwner, self.sdk_configuration),
self.sdk_configuration.client,
self.sdk_configuration.client_supplied,
self.sdk_configuration.async_client,
self.sdk_configuration.async_client_supplied,
)
self._init_sdks()
def _init_sdks(self):
self.chat = Chat(self.sdk_configuration)
self.fim = Fim(self.sdk_configuration)
def __enter__(self):
return self
async def __aenter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if (
self.sdk_configuration.client is not None
and not self.sdk_configuration.client_supplied
):
self.sdk_configuration.client.close()
self.sdk_configuration.client = None
async def __aexit__(self, exc_type, exc_val, exc_tb):
if (
self.sdk_configuration.async_client is not None
and not self.sdk_configuration.async_client_supplied
):
await self.sdk_configuration.async_client.aclose()
self.sdk_configuration.async_client = None
class GoogleCloudBeforeRequestHook(BeforeRequestHook):
def __init__(self, region: str, project_id: str):
self.region = region
self.project_id = project_id
def before_request(
self, hook_ctx, request: httpx.Request
) -> httpx.Request | Exception:
# The goal of this function is to template in the region, project and model into the URL path
# We do this here so that the API remains more user-friendly
model_id = None
new_content = None
if request.content:
parsed = json.loads(request.content.decode("utf-8"))
model_raw = parsed.get("model")
model_name, model_id = get_model_info(model_raw)
parsed["model"] = model_name
new_content = json.dumps(parsed).encode("utf-8")
if model_id == "":
raise models.SDKError("model must be provided")
stream = "streamRawPredict" in request.url.path
specifier = "streamRawPredict" if stream else "rawPredict"
url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}"
headers = dict(request.headers)
# Delete content-length header as it will need to be recalculated
headers.pop("content-length", None)
next_request = httpx.Request(
method=request.method,
url=request.url.copy_with(path=url),
headers=headers,
content=new_content,
stream=None,
)
return next_request