mirror of
https://github.com/Ladebeze66/projetcbaollm.git
synced 2025-12-16 20:47:52 +01:00
291 lines
9.2 KiB
Python
291 lines
9.2 KiB
Python
"""Utility function for gradio/external.py, designed for internal use."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import math
|
|
import re
|
|
import warnings
|
|
|
|
import httpx
|
|
import yaml
|
|
from huggingface_hub import HfApi, ImageClassificationOutputElement, InferenceClient
|
|
|
|
from gradio import components
|
|
from gradio.exceptions import Error, TooManyRequestsError
|
|
|
|
|
|
def get_model_info(model_name, hf_token=None):
|
|
hf_api = HfApi(token=hf_token)
|
|
print(f"Fetching model from: https://huggingface.co/{model_name}")
|
|
|
|
model_info = hf_api.model_info(model_name)
|
|
pipeline = model_info.pipeline_tag
|
|
tags = model_info.tags
|
|
return pipeline, tags
|
|
|
|
|
|
##################
|
|
# Helper functions for processing tabular data
|
|
##################
|
|
|
|
|
|
def get_tabular_examples(model_name: str) -> dict[str, list[float]]:
|
|
readme = httpx.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
|
|
if readme.status_code != 200:
|
|
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
|
|
example_data = {}
|
|
else:
|
|
yaml_regex = re.search(
|
|
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
|
|
)
|
|
if yaml_regex is None:
|
|
example_data = {}
|
|
else:
|
|
example_yaml = next(
|
|
yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
|
|
)
|
|
example_data = example_yaml.get("widget", {}).get("structuredData", {})
|
|
if not example_data:
|
|
raise ValueError(
|
|
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
|
|
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
|
|
"for a reference on how to provide example data to your model."
|
|
)
|
|
# replace nan with string NaN for inference Endpoints
|
|
for data in example_data.values():
|
|
for i, val in enumerate(data):
|
|
if isinstance(val, float) and math.isnan(val):
|
|
data[i] = "NaN"
|
|
return example_data
|
|
|
|
|
|
def cols_to_rows(
|
|
example_data: dict[str, list[float | str] | None],
|
|
) -> tuple[list[str], list[list[float]]]:
|
|
headers = list(example_data.keys())
|
|
n_rows = max(len(example_data[header] or []) for header in headers)
|
|
data = []
|
|
for row_index in range(n_rows):
|
|
row_data = []
|
|
for header in headers:
|
|
col = example_data[header] or []
|
|
if row_index >= len(col):
|
|
row_data.append("NaN")
|
|
else:
|
|
row_data.append(col[row_index])
|
|
data.append(row_data)
|
|
return headers, data
|
|
|
|
|
|
def rows_to_cols(incoming_data: dict) -> dict[str, dict[str, dict[str, list[str]]]]:
|
|
data_column_wise = {}
|
|
for i, header in enumerate(incoming_data["headers"]):
|
|
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
|
|
return {"inputs": {"data": data_column_wise}}
|
|
|
|
|
|
##################
|
|
# Helper functions for processing other kinds of data
|
|
##################
|
|
|
|
|
|
def postprocess_label(scores: list[ImageClassificationOutputElement]) -> dict:
|
|
return {c.label: c.score for c in scores}
|
|
|
|
|
|
def postprocess_mask_tokens(scores: list[dict[str, str | float]]) -> dict:
|
|
return {c["token_str"]: c["score"] for c in scores}
|
|
|
|
|
|
def postprocess_question_answering(answer: dict) -> tuple[str, dict]:
|
|
return answer["answer"], {answer["answer"]: answer["score"]}
|
|
|
|
|
|
def postprocess_visual_question_answering(scores: list[dict[str, str | float]]) -> dict:
|
|
return {c["answer"]: c["score"] for c in scores}
|
|
|
|
|
|
def zero_shot_classification_wrapper(client: InferenceClient):
|
|
def zero_shot_classification_inner(input: str, labels: str, multi_label: bool):
|
|
return client.zero_shot_classification(
|
|
input, labels.split(","), multi_label=multi_label
|
|
)
|
|
|
|
return zero_shot_classification_inner
|
|
|
|
|
|
def sentence_similarity_wrapper(client: InferenceClient):
|
|
def sentence_similarity_inner(input: str, sentences: str):
|
|
return client.sentence_similarity(input, sentences.split("\n"))
|
|
|
|
return sentence_similarity_inner
|
|
|
|
|
|
def text_generation_wrapper(client: InferenceClient):
|
|
def text_generation_inner(input: str):
|
|
return input + client.text_generation(input)
|
|
|
|
return text_generation_inner
|
|
|
|
|
|
def conversational_wrapper(client: InferenceClient):
|
|
def chat_fn(message, history):
|
|
if not history:
|
|
history = []
|
|
history.append({"role": "user", "content": message})
|
|
try:
|
|
out = ""
|
|
for chunk in client.chat_completion(messages=history, stream=True):
|
|
out += chunk.choices[0].delta.content or ""
|
|
yield out
|
|
except Exception as e:
|
|
handle_hf_error(e)
|
|
|
|
return chat_fn
|
|
|
|
|
|
def encode_to_base64(r: httpx.Response) -> str:
|
|
# Handles the different ways HF API returns the prediction
|
|
base64_repr = base64.b64encode(r.content).decode("utf-8")
|
|
data_prefix = ";base64,"
|
|
# Case 1: base64 representation already includes data prefix
|
|
if data_prefix in base64_repr:
|
|
return base64_repr
|
|
else:
|
|
content_type = r.headers.get("content-type")
|
|
# Case 2: the data prefix is a key in the response
|
|
if content_type == "application/json":
|
|
try:
|
|
data = r.json()[0]
|
|
content_type = data["content-type"]
|
|
base64_repr = data["blob"]
|
|
except KeyError as ke:
|
|
raise ValueError(
|
|
"Cannot determine content type returned by external API."
|
|
) from ke
|
|
# Case 3: the data prefix is included in the response headers
|
|
else:
|
|
pass
|
|
new_base64 = f"data:{content_type};base64,{base64_repr}"
|
|
return new_base64
|
|
|
|
|
|
def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]):
|
|
if len(ner_groups) == 0:
|
|
return [(input_string, None)]
|
|
|
|
output = []
|
|
end = 0
|
|
prev_end = 0
|
|
|
|
for group in ner_groups:
|
|
entity, start, end = group["entity_group"], group["start"], group["end"]
|
|
output.append((input_string[prev_end:start], None))
|
|
output.append((input_string[start:end], entity))
|
|
prev_end = end
|
|
|
|
output.append((input_string[end:], None))
|
|
return output
|
|
|
|
|
|
def token_classification_wrapper(client: InferenceClient):
|
|
def token_classification_inner(input: str):
|
|
ner_list = client.token_classification(input)
|
|
return format_ner_list(input, ner_list) # type: ignore
|
|
|
|
return token_classification_inner
|
|
|
|
|
|
def object_detection_wrapper(client: InferenceClient):
|
|
def object_detection_inner(input: str):
|
|
annotations = client.object_detection(input)
|
|
formatted_annotations = [
|
|
(
|
|
(
|
|
a["box"]["xmin"],
|
|
a["box"]["ymin"],
|
|
a["box"]["xmax"],
|
|
a["box"]["ymax"],
|
|
),
|
|
a["label"],
|
|
)
|
|
for a in annotations
|
|
]
|
|
return (input, formatted_annotations)
|
|
|
|
return object_detection_inner
|
|
|
|
|
|
def chatbot_preprocess(text, state):
|
|
if not state:
|
|
return text, [], []
|
|
return (
|
|
text,
|
|
state["conversation"]["generated_responses"],
|
|
state["conversation"]["past_user_inputs"],
|
|
)
|
|
|
|
|
|
def chatbot_postprocess(response):
|
|
chatbot_history = list(
|
|
zip(
|
|
response["conversation"]["past_user_inputs"],
|
|
response["conversation"]["generated_responses"],
|
|
strict=False,
|
|
)
|
|
)
|
|
return chatbot_history, response
|
|
|
|
|
|
def tabular_wrapper(client: InferenceClient, pipeline: str):
|
|
# This wrapper is needed to handle an issue in the InfereneClient where the model name is not
|
|
# automatically loaded when using the tabular_classification and tabular_regression methods.
|
|
# See: https://github.com/huggingface/huggingface_hub/issues/2015
|
|
def tabular_inner(data):
|
|
if pipeline not in ("tabular_classification", "tabular_regression"):
|
|
raise TypeError(f"pipeline type {pipeline!r} not supported")
|
|
assert client.model # noqa: S101
|
|
if pipeline == "tabular_classification":
|
|
return client.tabular_classification(data, model=client.model)
|
|
else:
|
|
return client.tabular_regression(data, model=client.model)
|
|
|
|
return tabular_inner
|
|
|
|
|
|
##################
|
|
# Helper function for cleaning up an Interface loaded from HF Spaces
|
|
##################
|
|
|
|
|
|
def streamline_spaces_interface(config: dict) -> dict:
|
|
"""Streamlines the interface config dictionary to remove unnecessary keys."""
|
|
config["inputs"] = [
|
|
components.get_component_instance(component)
|
|
for component in config["input_components"]
|
|
]
|
|
config["outputs"] = [
|
|
components.get_component_instance(component)
|
|
for component in config["output_components"]
|
|
]
|
|
parameters = {
|
|
"article",
|
|
"description",
|
|
"flagging_options",
|
|
"inputs",
|
|
"outputs",
|
|
"title",
|
|
}
|
|
config = {k: config[k] for k in parameters}
|
|
return config
|
|
|
|
|
|
def handle_hf_error(e: Exception):
|
|
if "429" in str(e):
|
|
raise TooManyRequestsError() from e
|
|
elif "401" in str(e) or "You must provide an api_key" in str(e):
|
|
raise Error("Unauthorized, please make sure you are signed in.") from e
|
|
else:
|
|
raise Error(str(e)) from e
|