2025-04-02 09:01:55 +02:00

150 lines
5.3 KiB
Python

"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
import weakref
from typing import Any, Callable, Dict, Optional, Union, cast
import httpx
from mistralai_azure import models, utils
from mistralai_azure._hooks import SDKHooks
from mistralai_azure.chat import Chat
from mistralai_azure.types import UNSET, OptionalNullable
from .basesdk import BaseSDK
from .httpclient import AsyncHttpClient, ClientOwner, HttpClient, close_clients
from .sdkconfiguration import SDKConfiguration
from .utils.logger import Logger, get_default_logger
from .utils.retries import RetryConfig
class MistralAzure(BaseSDK):
r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""
chat: Chat
r"""Chat Completion API."""
def __init__(
self,
azure_api_key: Union[str, Callable[[], str]],
azure_endpoint: str,
url_params: Optional[Dict[str, str]] = None,
client: Optional[HttpClient] = None,
async_client: Optional[AsyncHttpClient] = None,
retry_config: OptionalNullable[RetryConfig] = UNSET,
timeout_ms: Optional[int] = None,
debug_logger: Optional[Logger] = None,
) -> None:
r"""Instantiates the SDK configuring it with the provided parameters.
:param azure_api_key: The azure_api_key required for authentication
:param azure_endpoint: The Azure AI endpoint URL to use for all methods
:param url_params: Parameters to optionally template the server URL with
:param client: The HTTP client to use for all synchronous methods
:param async_client: The Async HTTP client to use for all asynchronous methods
:param retry_config: The retry configuration to use for all supported methods
:param timeout_ms: Optional request timeout applied to each operation in milliseconds
"""
# if azure_endpoint doesn't end with `/v1` add it
if not azure_endpoint.endswith("/"):
azure_endpoint += "/"
if not azure_endpoint.endswith("v1/"):
azure_endpoint += "v1/"
server_url = azure_endpoint
client_supplied = True
if client is None:
client = httpx.Client()
client_supplied = False
assert issubclass(
type(client), HttpClient
), "The provided client must implement the HttpClient protocol."
async_client_supplied = True
if async_client is None:
async_client = httpx.AsyncClient()
async_client_supplied = False
if debug_logger is None:
debug_logger = get_default_logger()
assert issubclass(
type(async_client), AsyncHttpClient
), "The provided async_client must implement the AsyncHttpClient protocol."
security: Any = None
if callable(azure_api_key):
security = lambda: models.Security(api_key=azure_api_key()) # pylint: disable=unnecessary-lambda-assignment
else:
security = models.Security(api_key=azure_api_key)
if server_url is not None:
if url_params is not None:
server_url = utils.template_url(server_url, url_params)
BaseSDK.__init__(
self,
SDKConfiguration(
client=client,
client_supplied=client_supplied,
async_client=async_client,
async_client_supplied=async_client_supplied,
security=security,
server_url=server_url,
server=None,
retry_config=retry_config,
timeout_ms=timeout_ms,
debug_logger=debug_logger,
),
)
hooks = SDKHooks()
current_server_url, *_ = self.sdk_configuration.get_server_details()
server_url, self.sdk_configuration.client = hooks.sdk_init(
current_server_url, client
)
if current_server_url != server_url:
self.sdk_configuration.server_url = server_url
# pylint: disable=protected-access
self.sdk_configuration.__dict__["_hooks"] = hooks
weakref.finalize(
self,
close_clients,
cast(ClientOwner, self.sdk_configuration),
self.sdk_configuration.client,
self.sdk_configuration.client_supplied,
self.sdk_configuration.async_client,
self.sdk_configuration.async_client_supplied,
)
self._init_sdks()
def _init_sdks(self):
self.chat = Chat(self.sdk_configuration)
def __enter__(self):
return self
async def __aenter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if (
self.sdk_configuration.client is not None
and not self.sdk_configuration.client_supplied
):
self.sdk_configuration.client.close()
self.sdk_configuration.client = None
async def __aexit__(self, exc_type, exc_val, exc_tb):
if (
self.sdk_configuration.async_client is not None
and not self.sdk_configuration.async_client_supplied
):
await self.sdk_configuration.async_client.aclose()
self.sdk_configuration.async_client = None