"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT.""" import json import weakref from typing import Any, Optional, cast import google.auth import google.auth.credentials import google.auth.transport import google.auth.transport.requests import httpx from mistralai_gcp import models from mistralai_gcp._hooks import BeforeRequestHook, SDKHooks from mistralai_gcp.chat import Chat from mistralai_gcp.fim import Fim from mistralai_gcp.types import UNSET, OptionalNullable from .basesdk import BaseSDK from .httpclient import AsyncHttpClient, ClientOwner, HttpClient, close_clients from .sdkconfiguration import SDKConfiguration from .utils.logger import Logger, get_default_logger from .utils.retries import RetryConfig LEGACY_MODEL_ID_FORMAT = { "codestral-2405": "codestral@2405", "mistral-large-2407": "mistral-large@2407", "mistral-nemo-2407": "mistral-nemo@2407", } def get_model_info(model: str) -> tuple[str, str]: # if the model requiers the legacy fomat, use it, else do nothing. if model in LEGACY_MODEL_ID_FORMAT: return "-".join(model.split("-")[:-1]), LEGACY_MODEL_ID_FORMAT[model] return model, model class MistralGoogleCloud(BaseSDK): r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it.""" chat: Chat r"""Chat Completion API.""" fim: Fim r"""Fill-in-the-middle API.""" def __init__( self, region: str = "europe-west4", project_id: Optional[str] = None, access_token: Optional[str] = None, client: Optional[HttpClient] = None, async_client: Optional[AsyncHttpClient] = None, retry_config: OptionalNullable[RetryConfig] = UNSET, timeout_ms: Optional[int] = None, debug_logger: Optional[Logger] = None, ) -> None: r"""Instantiates the SDK configuring it with the provided parameters. :param api_key: The api_key required for authentication :param server: The server by name to use for all methods :param server_url: The server URL to use for all methods :param url_params: Parameters to optionally template the server URL with :param client: The HTTP client to use for all synchronous methods :param async_client: The Async HTTP client to use for all asynchronous methods :param retry_config: The retry configuration to use for all supported methods :param timeout_ms: Optional request timeout applied to each operation in milliseconds """ if not access_token: credentials, loaded_project_id = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"], ) credentials.refresh(google.auth.transport.requests.Request()) if not isinstance(credentials, google.auth.credentials.Credentials): raise models.SDKError( "credentials must be an instance of google.auth.credentials.Credentials" ) project_id = project_id or loaded_project_id if project_id is None: raise models.SDKError("project_id must be provided") def auth_token() -> str: if access_token: return access_token credentials.refresh(google.auth.transport.requests.Request()) token = credentials.token if not token: raise models.SDKError("Failed to get token from credentials") return token client_supplied = True if client is None: client = httpx.Client() client_supplied = False assert issubclass( type(client), HttpClient ), "The provided client must implement the HttpClient protocol." async_client_supplied = True if async_client is None: async_client = httpx.AsyncClient() async_client_supplied = False if debug_logger is None: debug_logger = get_default_logger() assert issubclass( type(async_client), AsyncHttpClient ), "The provided async_client must implement the AsyncHttpClient protocol." security: Any = None if callable(auth_token): security = lambda: models.Security( # pylint: disable=unnecessary-lambda-assignment api_key=auth_token() ) else: security = models.Security(api_key=auth_token) BaseSDK.__init__( self, SDKConfiguration( client=client, client_supplied=client_supplied, async_client=async_client, async_client_supplied=async_client_supplied, security=security, server_url=f"https://{region}-aiplatform.googleapis.com", server=None, retry_config=retry_config, timeout_ms=timeout_ms, debug_logger=debug_logger, ), ) hooks = SDKHooks() hook = GoogleCloudBeforeRequestHook(region, project_id) hooks.register_before_request_hook(hook) current_server_url, *_ = self.sdk_configuration.get_server_details() server_url, self.sdk_configuration.client = hooks.sdk_init( current_server_url, client ) if current_server_url != server_url: self.sdk_configuration.server_url = server_url # pylint: disable=protected-access self.sdk_configuration.__dict__["_hooks"] = hooks weakref.finalize( self, close_clients, cast(ClientOwner, self.sdk_configuration), self.sdk_configuration.client, self.sdk_configuration.client_supplied, self.sdk_configuration.async_client, self.sdk_configuration.async_client_supplied, ) self._init_sdks() def _init_sdks(self): self.chat = Chat(self.sdk_configuration) self.fim = Fim(self.sdk_configuration) def __enter__(self): return self async def __aenter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if ( self.sdk_configuration.client is not None and not self.sdk_configuration.client_supplied ): self.sdk_configuration.client.close() self.sdk_configuration.client = None async def __aexit__(self, exc_type, exc_val, exc_tb): if ( self.sdk_configuration.async_client is not None and not self.sdk_configuration.async_client_supplied ): await self.sdk_configuration.async_client.aclose() self.sdk_configuration.async_client = None class GoogleCloudBeforeRequestHook(BeforeRequestHook): def __init__(self, region: str, project_id: str): self.region = region self.project_id = project_id def before_request( self, hook_ctx, request: httpx.Request ) -> httpx.Request | Exception: # The goal of this function is to template in the region, project and model into the URL path # We do this here so that the API remains more user-friendly model_id = None new_content = None if request.content: parsed = json.loads(request.content.decode("utf-8")) model_raw = parsed.get("model") model_name, model_id = get_model_info(model_raw) parsed["model"] = model_name new_content = json.dumps(parsed).encode("utf-8") if model_id == "": raise models.SDKError("model must be provided") stream = "streamRawPredict" in request.url.path specifier = "streamRawPredict" if stream else "rawPredict" url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}" headers = dict(request.headers) # Delete content-length header as it will need to be recalculated headers.pop("content-length", None) next_request = httpx.Request( method=request.method, url=request.url.copy_with(path=url), headers=headers, content=new_content, stream=None, ) return next_request