Spaces:
Running
Running
| from __future__ import annotations | |
| import re | |
| from abc import ABC, abstractmethod | |
| from collections import defaultdict | |
| from collections.abc import Hashable | |
| from pathlib import Path | |
| from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union | |
| from PIL import Image | |
| from pydantic import Field, field_validator | |
| from tenacity import retry, stop_after_attempt, stop_after_delay | |
| from ...base import BotBase | |
| from ...utils.env import EnvVar | |
| from ...utils.general import LRUCache | |
| from ...utils.registry import registry | |
| from .prompt.base import _OUTPUT_PARSER, StrParser | |
| from .prompt.parser import BaseOutputParser | |
| from .prompt.prompt import PromptTemplate | |
| from .schemas import Message | |
| import copy | |
| from collections.abc import Iterator | |
| T = TypeVar("T", str, dict, list) | |
| class BaseLLM(BotBase, ABC): | |
| cache: bool = False | |
| lru_cache: LRUCache = Field(default=LRUCache(EnvVar.LLM_CACHE_NUM)) | |
| def workflow_instance_id(self) -> str: | |
| if hasattr(self, "_parent"): | |
| return self._parent.workflow_instance_id | |
| return None | |
| def workflow_instance_id(self, value: str): | |
| if hasattr(self, "_parent"): | |
| self._parent.workflow_instance_id = value | |
| def _call(self, records: List[Message], **kwargs) -> str: | |
| """Run the LLM on the given prompt and input.""" | |
| async def _acall(self, records: List[Message], **kwargs) -> str: | |
| """Run the LLM on the given prompt and input.""" | |
| raise NotImplementedError("Async generation not implemented for this LLM.") | |
| def generate(self, records: List[Message], **kwargs) -> str: # TODO: use python native lru cache | |
| """Run the LLM on the given prompt and input.""" | |
| if self.cache: | |
| key = self._cache_key(records) | |
| cached_res = self.lru_cache.get(key) | |
| if cached_res: | |
| return cached_res | |
| else: | |
| gen = self._call(records, **kwargs) | |
| self.lru_cache.put(key, gen) | |
| return gen | |
| else: | |
| return self._call(records, **kwargs) | |
| async def agenerate(self, records: List[str], **kwargs) -> str: | |
| """Run the LLM on the given prompt and input.""" | |
| if self.cache: | |
| key = self._cache_key(records) | |
| cached_res = self.lru_cache.get(key) | |
| if cached_res: | |
| return cached_res | |
| else: | |
| gen = await self._acall(records, **kwargs) | |
| self.lru_cache.put(key, gen) | |
| return gen | |
| else: | |
| return await self._acall(records, **kwargs) | |
| def _cache_key(self, records: List[Message]) -> int: | |
| return str([item.model_dump() for item in records]) | |
| def dict(self, *args, **kwargs): | |
| kwargs["exclude"] = {"lru_cache"} | |
| return super().model_dump(*args, **kwargs) | |
| def json(self, *args, **kwargs): | |
| kwargs["exclude"] = {"lru_cache"} | |
| return super().model_dump_json(*args, **kwargs) | |
| T = TypeVar("T", str, dict, list) | |
| class BaseLLMBackend(BotBase, ABC): | |
| """Prompts prepare and LLM infer""" | |
| output_parser: Optional[BaseOutputParser] = None | |
| prompts: List[PromptTemplate] = [] | |
| llm: BaseLLM | |
| def token_usage(self): | |
| if not hasattr(self, 'workflow_instance_id'): | |
| raise AttributeError("workflow_instance_id not set") | |
| return dict(self.stm(self.workflow_instance_id).get('token_usage', defaultdict(int))) | |
| def set_output_parser(cls, output_parser: Union[BaseOutputParser, Dict, None]): | |
| if output_parser is None: | |
| return StrParser() | |
| elif isinstance(output_parser, BaseOutputParser): | |
| return output_parser | |
| elif isinstance(output_parser, dict): | |
| return _OUTPUT_PARSER[output_parser["name"]](**output_parser) | |
| else: | |
| raise ValueError | |
| def set_prompts( | |
| cls, prompts: List[Union[PromptTemplate, Dict, str]] | |
| ) -> List[PromptTemplate]: | |
| init_prompts = [] | |
| for prompt in prompts: | |
| prompt = copy.deepcopy(prompt) | |
| if isinstance(prompt, Path): | |
| if prompt.suffix == ".prompt": | |
| init_prompts.append(PromptTemplate.from_file(prompt)) | |
| elif isinstance(prompt, str): | |
| if prompt.endswith(".prompt"): | |
| init_prompts.append(PromptTemplate.from_file(prompt)) | |
| init_prompts.append(PromptTemplate.from_template(prompt)) | |
| elif isinstance(prompt, dict): | |
| init_prompts.append(PromptTemplate.from_config(prompt)) | |
| elif isinstance(prompt, PromptTemplate): | |
| init_prompts.append(prompt) | |
| else: | |
| raise ValueError( | |
| "Prompt only support str, dict and PromptTemplate object" | |
| ) | |
| return init_prompts | |
| def set_llm(cls, llm: Union[BaseLLM, Dict]): | |
| if isinstance(llm, dict): | |
| return registry.get_llm(llm["name"])(**llm) | |
| elif isinstance(llm, BaseLLM): | |
| return llm | |
| else: | |
| raise ValueError("LLM only support dict and BaseLLM object") | |
| def prep_prompt( | |
| self, input_list: List[Dict[str, Any]], prompts=None, **kwargs | |
| ) -> List[List[Message]]: | |
| """Prepare prompts from inputs.""" | |
| if prompts is None: | |
| prompts = self.prompts | |
| images = [] | |
| if len(kwargs_images := kwargs.get("images", [])): | |
| images = kwargs_images | |
| processed_prompts = [] | |
| for inputs in input_list: | |
| records = [] | |
| for prompt in prompts: | |
| selected_inputs = {k: inputs.get(k, "") for k in prompt.input_variables} | |
| prompt_str = prompt.template | |
| parts = re.split(r"(\{\{.*?\}\})", prompt_str) | |
| formatted_parts = [] | |
| for part in parts: | |
| if part.startswith("{{") and part.endswith("}}"): | |
| part = part[2:-2].strip() | |
| value = selected_inputs[part] | |
| if isinstance(value, (Image.Image, list)): | |
| formatted_parts.extend( | |
| [value] if isinstance(value, Image.Image) else value | |
| ) | |
| else: | |
| formatted_parts.append(str(value)) | |
| else: | |
| formatted_parts.append(str(part)) | |
| formatted_parts = ( | |
| formatted_parts[0] if len(formatted_parts) == 1 else formatted_parts | |
| ) | |
| if prompt.role == "system": | |
| records.append(Message.system(formatted_parts)) | |
| elif prompt.role == "user": | |
| records.append(Message.user(formatted_parts)) | |
| if len(images): | |
| records.append(Message.user(images)) | |
| processed_prompts.append(records) | |
| return processed_prompts | |
| def infer(self, input_list: List[Dict[str, Any]], **kwargs) -> List[T]: | |
| prompts = self.prep_prompt(input_list, **kwargs) | |
| res = [] | |
| stm_token_usage = self.stm(self.workflow_instance_id).get('token_usage', defaultdict(int)) | |
| def process_stream(self, stream_output): | |
| for chunk in stream_output: | |
| if chunk.usage is not None: | |
| for key, value in chunk.usage.dict().items(): | |
| if key in ["prompt_tokens", "completion_tokens", 'total_tokens']: | |
| if value is not None: | |
| stm_token_usage[key] += value | |
| self.stm(self.workflow_instance_id)['token_usage'] = stm_token_usage | |
| yield chunk | |
| for prompt in prompts: | |
| output = self.llm.generate(prompt, **kwargs) | |
| if not isinstance(output, Iterator): | |
| for key, value in output.get("usage", {}).items(): | |
| if key in ["prompt_tokens", "completion_tokens", 'total_tokens']: | |
| if value is not None: | |
| stm_token_usage[key] += value | |
| if not self.llm.stream: | |
| for choice in output["choices"]: | |
| if choice.get("message"): | |
| choice["message"]["content"] = self.output_parser.parse( | |
| choice["message"]["content"] | |
| ) | |
| res.append(output) | |
| else: | |
| res.append(process_stream(self, output)) | |
| self.stm(self.workflow_instance_id)['token_usage'] = stm_token_usage | |
| return res | |
| async def ainfer(self, input_list: List[Dict[str, Any]], **kwargs) -> List[T]: | |
| prompts = self.prep_prompt(input_list) | |
| res = [] | |
| for prompt in prompts: | |
| output = await self.llm.agenerate(prompt, **kwargs) | |
| for key, value in output["usage"].items(): | |
| self.token_usage[key] += value | |
| for choice in output["choices"]: | |
| if choice.get("message"): | |
| choice["message"]["content"] = self.output_parser.parse( | |
| choice["message"]["content"] | |
| ) | |
| res.append(output) | |
| return res | |
| def simple_infer(self, **kwargs: Any) -> T: | |
| return self.infer([kwargs])[0] | |
| async def simple_ainfer(self, **kwargs: Any) -> T: | |
| return await self.ainfer([kwargs])[0] | |