| from pprint import pprint
|
| from time import perf_counter
|
| from traceback import print_exc
|
| from typing import Any
|
|
|
| from app_settings import Settings
|
| from backend.image_saver import ImageSaver
|
| from backend.lcm_text_to_image import LCMTextToImage
|
| from backend.models.lcmdiffusion_setting import DiffusionTask
|
| from backend.utils import get_blank_image
|
| from models.interface_types import InterfaceType
|
|
|
|
|
| class Context:
|
| def __init__(
|
| self,
|
| interface_type: InterfaceType,
|
| device="cpu",
|
| ):
|
| self.interface_type = interface_type.value
|
| self.lcm_text_to_image = LCMTextToImage(device)
|
| self._latency = 0
|
| self._error = ""
|
|
|
| @property
|
| def latency(self):
|
| return self._latency
|
|
|
| @property
|
| def error(self):
|
| return self._error
|
|
|
| def generate_text_to_image(
|
| self,
|
| settings: Settings,
|
| reshape: bool = False,
|
| device: str = "cpu",
|
| save_config=True,
|
| ) -> Any:
|
| try:
|
| self._error = ""
|
| tick = perf_counter()
|
| from state import get_settings
|
|
|
| if (
|
| settings.lcm_diffusion_setting.diffusion_task
|
| == DiffusionTask.text_to_image.value
|
| ):
|
| settings.lcm_diffusion_setting.init_image = None
|
|
|
| if save_config:
|
| get_settings().save()
|
|
|
| pprint(settings.lcm_diffusion_setting.model_dump())
|
| if not settings.lcm_diffusion_setting.lcm_lora:
|
| return None
|
| self.lcm_text_to_image.init(
|
| device,
|
| settings.lcm_diffusion_setting,
|
| )
|
|
|
| images = self.lcm_text_to_image.generate(
|
| settings.lcm_diffusion_setting,
|
| reshape,
|
| )
|
|
|
| elapsed = perf_counter() - tick
|
| self._latency = elapsed
|
| print(f"Latency : {elapsed:.2f} seconds")
|
| if settings.lcm_diffusion_setting.controlnet:
|
| if settings.lcm_diffusion_setting.controlnet.enabled:
|
| images.append(
|
| settings.lcm_diffusion_setting.controlnet._control_image
|
| )
|
|
|
| if settings.lcm_diffusion_setting.use_safety_checker:
|
| print("Safety Checker is enabled")
|
| from state import get_safety_checker
|
|
|
| safety_checker = get_safety_checker()
|
| blank_image = get_blank_image(
|
| settings.lcm_diffusion_setting.image_width,
|
| settings.lcm_diffusion_setting.image_height,
|
| )
|
| for idx, image in enumerate(images):
|
| if not safety_checker.is_safe(image):
|
| images[idx] = blank_image
|
| except Exception as exception:
|
| print(f"Error in generating images: {exception}")
|
| self._error = str(exception)
|
| print_exc()
|
| return None
|
| return images
|
|
|
| def save_images(
|
| self,
|
| images: Any,
|
| settings: Settings,
|
| ) -> list[str]:
|
| saved_images = []
|
| if images and settings.generated_images.save_image:
|
| saved_images = ImageSaver.save_images(
|
| settings.generated_images.path,
|
| images=images,
|
| lcm_diffusion_setting=settings.lcm_diffusion_setting,
|
| format=settings.generated_images.format,
|
| jpeg_quality=settings.generated_images.save_image_quality,
|
| )
|
| return saved_images
|
|
|