mirror of
https://github.com/Ladebeze66/projetcbaollm.git
synced 2025-12-16 11:47:53 +01:00
427 lines
12 KiB
Python
427 lines
12 KiB
Python
"""Pydantic data models and other dataclasses. This is the only file that uses Optional[]
|
|
typing syntax instead of | None syntax to work with pydantic"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pathlib
|
|
import secrets
|
|
import shutil
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Iterator
|
|
from enum import Enum, auto
|
|
from typing import (
|
|
Annotated,
|
|
Any,
|
|
Literal,
|
|
NewType,
|
|
Optional,
|
|
TypedDict,
|
|
Union,
|
|
)
|
|
|
|
from fastapi import Request
|
|
from gradio_client.documentation import document
|
|
from gradio_client.utils import is_file_obj_with_meta, traverse
|
|
from pydantic import (
|
|
BaseModel,
|
|
ConfigDict,
|
|
Field,
|
|
GetCoreSchemaHandler,
|
|
GetJsonSchemaHandler,
|
|
RootModel,
|
|
ValidationError,
|
|
ValidationInfo,
|
|
model_validator,
|
|
)
|
|
from pydantic.json_schema import JsonSchemaValue
|
|
from pydantic_core import core_schema
|
|
from typing_extensions import NotRequired
|
|
|
|
try:
|
|
from pydantic import JsonValue
|
|
except ImportError:
|
|
JsonValue = Any
|
|
|
|
DeveloperPath = NewType("DeveloperPath", str)
|
|
UserProvidedPath = NewType("UserProvidedPath", str)
|
|
|
|
|
|
class CancelBody(BaseModel):
|
|
session_hash: str
|
|
fn_index: int
|
|
event_id: str
|
|
|
|
|
|
class SimplePredictBody(BaseModel):
|
|
data: list[Any]
|
|
session_hash: Optional[str] = None
|
|
|
|
|
|
class _StarletteRequestPydanticAnnotation:
|
|
@classmethod
|
|
def __get_pydantic_core_schema__(
|
|
cls,
|
|
_source_type: Any,
|
|
_handler: GetCoreSchemaHandler,
|
|
) -> core_schema.CoreSchema:
|
|
def validate_request(value: Any) -> Request:
|
|
if isinstance(value, Request):
|
|
return value
|
|
raise ValueError("Input must be a Starlette Request object")
|
|
|
|
return core_schema.no_info_plain_validator_function(validate_request)
|
|
|
|
@classmethod
|
|
def __get_pydantic_json_schema__(
|
|
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
|
) -> JsonSchemaValue:
|
|
return {"type": "object", "title": "StarletteRequest"}
|
|
|
|
|
|
PydanticStarletteRequest = Annotated[Request, _StarletteRequestPydanticAnnotation]
|
|
|
|
|
|
class PredictBody(BaseModel):
|
|
session_hash: Optional[str] = None
|
|
event_id: Optional[str] = None
|
|
data: list[Any]
|
|
event_data: Optional[Any] = None
|
|
fn_index: Optional[int] = None
|
|
trigger_id: Optional[int] = None
|
|
simple_format: bool = False
|
|
batched: Optional[bool] = (
|
|
False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
|
|
)
|
|
|
|
@classmethod
|
|
def __get_pydantic_json_schema__(cls, core_schema, handler):
|
|
return {
|
|
"title": "PredictBody",
|
|
"type": "object",
|
|
"properties": {
|
|
"session_hash": {"type": "string"},
|
|
"event_id": {"type": "string"},
|
|
"data": {"type": "array", "items": {"type": "object"}},
|
|
"event_data": {"type": "object"},
|
|
"fn_index": {"type": "integer"},
|
|
"trigger_id": {"type": "integer"},
|
|
"simple_format": {"type": "boolean"},
|
|
"batched": {"type": "boolean"},
|
|
},
|
|
"required": ["data"],
|
|
}
|
|
|
|
|
|
class PredictBodyInternal(PredictBody):
|
|
"Separate class to avoid exposing PydanticStarletteRequest in the API validation"
|
|
|
|
request: Optional[PydanticStarletteRequest] = (
|
|
None # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing)
|
|
)
|
|
|
|
|
|
class ResetBody(BaseModel):
|
|
event_id: str
|
|
|
|
|
|
class ComponentServerJSONBody(BaseModel):
|
|
session_hash: str
|
|
component_id: int
|
|
fn_name: str
|
|
data: Any
|
|
|
|
|
|
class DataWithFiles(BaseModel):
|
|
data: Any
|
|
files: list[tuple[str, bytes]]
|
|
|
|
|
|
class ComponentServerBlobBody(BaseModel):
|
|
session_hash: str
|
|
component_id: int
|
|
fn_name: str
|
|
data: DataWithFiles
|
|
|
|
|
|
class InterfaceTypes(Enum):
|
|
STANDARD = auto()
|
|
INPUT_ONLY = auto()
|
|
OUTPUT_ONLY = auto()
|
|
UNIFIED = auto()
|
|
|
|
|
|
class GradioBaseModel(ABC):
|
|
def copy_to_dir(self, dir: str | pathlib.Path) -> GradioDataModel:
|
|
if not isinstance(self, (BaseModel, RootModel)):
|
|
raise TypeError("must be used in a Pydantic model")
|
|
dir = pathlib.Path(dir)
|
|
|
|
# TODO: Making sure path is unique should be done in caller
|
|
def unique_copy(obj: dict):
|
|
data = FileData(**obj)
|
|
return data._copy_to_dir(
|
|
str(pathlib.Path(dir / secrets.token_hex(10)))
|
|
).model_dump()
|
|
|
|
return self.__class__.from_json(
|
|
x=traverse(
|
|
self.model_dump(),
|
|
unique_copy,
|
|
FileData.is_file_data,
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def from_json(cls, x) -> GradioDataModel:
|
|
pass
|
|
|
|
|
|
class JsonData(RootModel):
|
|
"""JSON data returned from a component that should not be modified further."""
|
|
|
|
root: JsonValue
|
|
|
|
|
|
class GradioModel(GradioBaseModel, BaseModel):
|
|
@classmethod
|
|
def from_json(cls, x) -> GradioModel:
|
|
return cls(**x)
|
|
|
|
|
|
class GradioRootModel(GradioBaseModel, RootModel):
|
|
@classmethod
|
|
def from_json(cls, x) -> GradioRootModel:
|
|
return cls(root=x)
|
|
|
|
|
|
GradioDataModel = Union[GradioModel, GradioRootModel]
|
|
|
|
|
|
class FileDataDict(TypedDict):
|
|
path: str # server filepath
|
|
url: NotRequired[Optional[str]] # normalised server url
|
|
size: NotRequired[Optional[int]] # size in bytes
|
|
orig_name: NotRequired[Optional[str]] # original filename
|
|
mime_type: NotRequired[Optional[str]]
|
|
is_stream: bool
|
|
meta: NotRequired[dict]
|
|
|
|
|
|
@document()
|
|
class FileData(GradioModel):
|
|
"""
|
|
The FileData class is a subclass of the GradioModel class that represents a file object within a Gradio interface. It is used to store file data and metadata when a file is uploaded.
|
|
|
|
Attributes:
|
|
path: The server file path where the file is stored.
|
|
url: The normalized server URL pointing to the file.
|
|
size: The size of the file in bytes.
|
|
orig_name: The original filename before upload.
|
|
mime_type: The MIME type of the file.
|
|
is_stream: Indicates whether the file is a stream.
|
|
meta: Additional metadata used internally (should not be changed).
|
|
"""
|
|
|
|
path: str # server filepath
|
|
url: Optional[str] = None # normalised server url
|
|
size: Optional[int] = None # size in bytes
|
|
orig_name: Optional[str] = None # original filename
|
|
mime_type: Optional[str] = None
|
|
is_stream: bool = False
|
|
meta: dict = {"_type": "gradio.FileData"}
|
|
|
|
@model_validator(mode="before")
|
|
@classmethod
|
|
def validate_model(cls, v, info: ValidationInfo):
|
|
if (
|
|
info.context
|
|
and info.context.get("validate_meta")
|
|
and not is_file_obj_with_meta(v)
|
|
):
|
|
raise ValueError(
|
|
"The 'meta' field must be explicitly provided in the input data and be equal to {'_type': 'gradio.FileData'}."
|
|
)
|
|
return v
|
|
|
|
@property
|
|
def is_none(self) -> bool:
|
|
"""
|
|
Checks if the FileData object is empty, i.e., all attributes are None.
|
|
|
|
Returns:
|
|
bool: True if all attributes (except 'is_stream' and 'meta') are None, False otherwise.
|
|
"""
|
|
return all(
|
|
f is None
|
|
for f in [
|
|
self.path,
|
|
self.url,
|
|
self.size,
|
|
self.orig_name,
|
|
self.mime_type,
|
|
]
|
|
)
|
|
|
|
@classmethod
|
|
def from_path(cls, path: str) -> FileData:
|
|
"""
|
|
Creates a FileData object from a given file path.
|
|
|
|
Args:
|
|
path: The file path.
|
|
|
|
Returns:
|
|
FileData: An instance of FileData representing the file at the specified path.
|
|
"""
|
|
return cls(path=path)
|
|
|
|
def _copy_to_dir(self, dir: str) -> FileData:
|
|
"""
|
|
Copies the file to a specified directory and returns a new FileData object representing the copied file.
|
|
|
|
Args:
|
|
dir: The destination directory.
|
|
|
|
Returns:
|
|
FileData: A new FileData object representing the copied file.
|
|
|
|
Raises:
|
|
ValueError: If the source file path is not set.
|
|
"""
|
|
pathlib.Path(dir).mkdir(exist_ok=True)
|
|
new_obj = dict(self)
|
|
|
|
if not self.path:
|
|
raise ValueError("Source file path is not set")
|
|
new_name = shutil.copy(self.path, dir)
|
|
new_obj["path"] = new_name
|
|
return self.__class__(**new_obj)
|
|
|
|
@classmethod
|
|
def is_file_data(cls, obj: Any) -> bool:
|
|
"""
|
|
Checks if an object is a valid FileData instance.
|
|
|
|
Args:
|
|
obj: The object to check.
|
|
|
|
Returns:
|
|
bool: True if the object is a valid FileData instance, False otherwise.
|
|
"""
|
|
if isinstance(obj, dict):
|
|
try:
|
|
return not FileData(**obj).is_none
|
|
except (TypeError, ValidationError):
|
|
return False
|
|
return False
|
|
|
|
|
|
class ListFiles(GradioRootModel):
|
|
root: list[FileData]
|
|
|
|
def __getitem__(self, index):
|
|
return self.root[index]
|
|
|
|
def __iter__(self) -> Iterator[FileData]: # type: ignore[override]
|
|
return iter(self.root)
|
|
|
|
|
|
class _StaticFiles:
|
|
"""
|
|
Class to hold all static files for an app
|
|
"""
|
|
|
|
all_paths = []
|
|
|
|
def __init__(self, paths: list[str | pathlib.Path]) -> None:
|
|
self.paths = paths
|
|
self.all_paths = [pathlib.Path(p).resolve() for p in paths]
|
|
|
|
@classmethod
|
|
def clear(cls):
|
|
cls.all_paths = []
|
|
|
|
|
|
class BodyCSS(TypedDict):
|
|
body_background_fill: str
|
|
body_text_color: str
|
|
body_background_fill_dark: str
|
|
body_text_color_dark: str
|
|
|
|
|
|
class Layout(TypedDict):
|
|
id: int
|
|
children: NotRequired[list[int | Layout]]
|
|
|
|
|
|
class Page(TypedDict):
|
|
components: list[int]
|
|
dependencies: list[int]
|
|
layout: Layout
|
|
|
|
|
|
class BlocksConfigDict(TypedDict):
|
|
version: str
|
|
mode: str
|
|
app_id: int
|
|
dev_mode: bool
|
|
analytics_enabled: bool
|
|
components: list[dict[str, Any]]
|
|
css: str | None
|
|
connect_heartbeat: bool
|
|
js: str | Literal[True] | None
|
|
head: str | None
|
|
title: str
|
|
space_id: str | None
|
|
enable_queue: bool
|
|
show_error: bool
|
|
show_api: bool
|
|
is_colab: bool
|
|
max_file_size: int | None
|
|
stylesheets: list[str]
|
|
theme: str | None
|
|
protocol: Literal["ws", "sse", "sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
|
|
body_css: BodyCSS
|
|
fill_height: bool
|
|
fill_width: bool
|
|
theme_hash: str
|
|
layout: NotRequired[Layout]
|
|
dependencies: NotRequired[list[dict[str, Any]]]
|
|
root: NotRequired[str | None]
|
|
username: NotRequired[str | None]
|
|
api_prefix: str
|
|
pwa: NotRequired[bool]
|
|
page: dict[str, Page]
|
|
pages: list[tuple[str, str]]
|
|
current_page: NotRequired[str]
|
|
|
|
|
|
class MediaStreamChunk(TypedDict):
|
|
data: bytes
|
|
duration: float
|
|
extension: str
|
|
id: NotRequired[str]
|
|
|
|
|
|
class ImageData(GradioModel):
|
|
path: Optional[str] = Field(default=None, description="Path to a local file")
|
|
url: Optional[str] = Field(
|
|
default=None, description="Publicly available url or base64 encoded image"
|
|
)
|
|
size: Optional[int] = Field(default=None, description="Size of image in bytes")
|
|
orig_name: Optional[str] = Field(default=None, description="Original filename")
|
|
mime_type: Optional[str] = Field(default=None, description="mime type of image")
|
|
is_stream: bool = Field(default=False, description="Can always be set to False")
|
|
meta: dict = {"_type": "gradio.FileData"}
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra={
|
|
"description": "For input, either path or url must be provided. For output, path is always provided."
|
|
}
|
|
)
|
|
|
|
|
|
class Base64ImageData(GradioModel):
|
|
url: str = Field(description="base64 encoded image")
|