mirror of
https://github.com/Ladebeze66/projetcbaollm.git
synced 2025-12-16 22:47:49 +01:00
127 lines
4.4 KiB
Python
127 lines
4.4 KiB
Python
from __future__ import annotations
|
|
|
|
import datetime
|
|
import os
|
|
import threading
|
|
from collections import OrderedDict
|
|
from collections.abc import Iterator
|
|
from copy import copy, deepcopy
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
if TYPE_CHECKING:
|
|
from gradio.blocks import Blocks
|
|
from gradio.components import State
|
|
|
|
|
|
class StateHolder:
|
|
def __init__(self):
|
|
self.capacity = 10000
|
|
self.session_data: OrderedDict[str, SessionState] = OrderedDict()
|
|
self.time_last_used: dict[str, datetime.datetime] = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def set_blocks(self, blocks: Blocks):
|
|
self.blocks = blocks
|
|
blocks.state_holder = self
|
|
self.capacity = blocks.state_session_capacity
|
|
|
|
def reset(self, blocks: Blocks):
|
|
"""Reset the state holder with new blocks. Used during reload mode."""
|
|
self.session_data = OrderedDict()
|
|
# Call set blocks again to set new ids
|
|
self.set_blocks(blocks)
|
|
|
|
def __getitem__(self, session_id: str) -> SessionState:
|
|
if session_id not in self.session_data:
|
|
self.session_data[session_id] = SessionState(self.blocks)
|
|
self.update(session_id)
|
|
self.time_last_used[session_id] = datetime.datetime.now()
|
|
return self.session_data[session_id]
|
|
|
|
def __contains__(self, session_id: str):
|
|
return session_id in self.session_data
|
|
|
|
def update(self, session_id: str):
|
|
with self.lock:
|
|
if session_id in self.session_data:
|
|
self.session_data.move_to_end(session_id)
|
|
if len(self.session_data) > self.capacity:
|
|
self.session_data.popitem(last=False)
|
|
|
|
def delete_all_expired_state(
|
|
self,
|
|
):
|
|
for session_id in self.session_data:
|
|
self.delete_state(session_id, expired_only=True)
|
|
|
|
def delete_state(self, session_id: str, expired_only: bool = False):
|
|
if session_id not in self.session_data:
|
|
return
|
|
to_delete = []
|
|
session_state = self.session_data[session_id]
|
|
for component, value, expired in session_state.state_components:
|
|
if not expired_only or expired:
|
|
component.delete_callback(value)
|
|
to_delete.append(component._id)
|
|
for component in to_delete:
|
|
del session_state.state_data[component]
|
|
|
|
|
|
class SessionState:
|
|
def __init__(self, blocks: Blocks):
|
|
self.blocks_config = copy(blocks.default_config)
|
|
self.state_data: dict[int, Any] = {}
|
|
self._state_ttl = {}
|
|
self.is_closed = False
|
|
# When a session is closed, the state is stored for an hour to give the user time to reopen the session.
|
|
# During testing we set to a lower value to be able to test
|
|
self.STATE_TTL_WHEN_CLOSED = (
|
|
1 if os.getenv("GRADIO_IS_E2E_TEST", None) else 3600
|
|
)
|
|
|
|
def __getitem__(self, key: int) -> Any:
|
|
block = self.blocks_config.blocks[key]
|
|
if block.stateful:
|
|
if key not in self.state_data:
|
|
self.state_data[key] = deepcopy(getattr(block, "value", None))
|
|
return self.state_data[key]
|
|
else:
|
|
return block
|
|
|
|
def __setitem__(self, key: int, value: Any):
|
|
from gradio.components import State
|
|
|
|
block = self.blocks_config.blocks[key]
|
|
if isinstance(block, State):
|
|
self._state_ttl[key] = (
|
|
block.time_to_live,
|
|
datetime.datetime.now(),
|
|
)
|
|
self.state_data[key] = value
|
|
else:
|
|
self.blocks_config.blocks[key] = value
|
|
|
|
def __contains__(self, key: int):
|
|
block = self.blocks_config.blocks[key]
|
|
if block.stateful:
|
|
return key in self.state_data
|
|
else:
|
|
return key in self.blocks_config.blocks
|
|
|
|
@property
|
|
def state_components(self) -> Iterator[tuple[State, Any, bool]]:
|
|
from gradio.components import State
|
|
|
|
for id in self.state_data:
|
|
block = self.blocks_config.blocks[id]
|
|
if isinstance(block, State) and id in self._state_ttl:
|
|
time_to_live, created_at = self._state_ttl[id]
|
|
if self.is_closed:
|
|
time_to_live = self.STATE_TTL_WHEN_CLOSED
|
|
value = self.state_data[id]
|
|
yield (
|
|
block,
|
|
value,
|
|
(datetime.datetime.now() - created_at).seconds > time_to_live,
|
|
)
|