|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
from typing import List, Optional, Union
|
|
|
import torch
|
|
|
from PIL import Image
|
|
|
|
|
|
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, DiffusionPipeline
|
|
|
from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
|
|
from diffusers.image_processor import VaeImageProcessor
|
|
|
from diffusers.utils import logging
|
|
|
from diffusers.pipelines.z_image.pipeline_z_image import calculate_shift
|
|
|
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
|
|
|
from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel
|
|
|
from transformers import AutoTokenizer, PreTrainedModel
|
|
|
from diffusers.utils.torch_utils import randn_tensor
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
def calculate_shift(
|
|
|
image_seq_len,
|
|
|
base_seq_len: int = 256,
|
|
|
max_seq_len: int = 4096,
|
|
|
base_shift: float = 0.5,
|
|
|
max_shift: float = 1.15,
|
|
|
):
|
|
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
|
|
b = base_shift - m * base_seq_len
|
|
|
mu = image_seq_len * m + b
|
|
|
return mu
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_timesteps(
|
|
|
scheduler,
|
|
|
num_inference_steps: Optional[int] = None,
|
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
|
timesteps: Optional[List[int]] = None,
|
|
|
sigmas: Optional[List[float]] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
r"""
|
|
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
|
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
|
|
|
|
|
Args:
|
|
|
scheduler (`SchedulerMixin`):
|
|
|
The scheduler to get timesteps from.
|
|
|
num_inference_steps (`int`):
|
|
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
|
|
must be `None`.
|
|
|
device (`str` or `torch.device`, *optional*):
|
|
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
|
|
timesteps (`List[int]`, *optional*):
|
|
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
|
|
`num_inference_steps` and `sigmas` must be `None`.
|
|
|
sigmas (`List[float]`, *optional*):
|
|
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
|
|
`num_inference_steps` and `timesteps` must be `None`.
|
|
|
|
|
|
Returns:
|
|
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
|
|
second element is the number of inference steps.
|
|
|
"""
|
|
|
if timesteps is not None and sigmas is not None:
|
|
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
|
|
if timesteps is not None:
|
|
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
|
if not accepts_timesteps:
|
|
|
raise ValueError(
|
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
|
|
)
|
|
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
|
|
timesteps = scheduler.timesteps
|
|
|
num_inference_steps = len(timesteps)
|
|
|
elif sigmas is not None:
|
|
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
|
if not accept_sigmas:
|
|
|
raise ValueError(
|
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
|
|
)
|
|
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
|
|
timesteps = scheduler.timesteps
|
|
|
num_inference_steps = len(timesteps)
|
|
|
else:
|
|
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
|
|
timesteps = scheduler.timesteps
|
|
|
return timesteps, num_inference_steps
|
|
|
|
|
|
|
|
|
class ZImageControlUnifiedPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
|
|
_model_cpu_offload_seq = "text_encoder->transformer->vae"
|
|
|
_optional_components = []
|
|
|
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
|
|
vae: AutoencoderKL,
|
|
|
text_encoder: PreTrainedModel,
|
|
|
tokenizer: AutoTokenizer,
|
|
|
transformer: ZImageControlTransformer2DModel,
|
|
|
):
|
|
|
self.register_modules(
|
|
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
|
|
|
transformer=transformer, scheduler=scheduler
|
|
|
)
|
|
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
|
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
|
|
|
|
|
def _encode_prompt(self, prompt: str, device: torch.device, max_sequence_length: int) -> torch.Tensor:
|
|
|
messages = [{"role": "user", "content": prompt}]
|
|
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
|
prompt_formatted = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
|
|
|
else:
|
|
|
prompt_formatted = prompt
|
|
|
|
|
|
text_inputs = self.tokenizer(prompt_formatted, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt").to(device)
|
|
|
prompt_masks = text_inputs.attention_mask.bool()
|
|
|
with torch.no_grad():
|
|
|
prompt_embeds = self.text_encoder(input_ids=text_inputs.input_ids, attention_mask=prompt_masks, output_hidden_states=True).hidden_states[-2]
|
|
|
return prompt_embeds[0][prompt_masks[0]]
|
|
|
|
|
|
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
|
|
shape = (batch_size, num_channels, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
|
if latents is None:
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
|
else:
|
|
|
latents = latents.to(device)
|
|
|
return latents * self.scheduler.init_noise_sigma if hasattr(self.scheduler, "init_noise_sigma") else latents
|
|
|
|
|
|
def prepare_control_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype):
|
|
|
image = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=dtype)
|
|
|
|
|
|
image_batch_size = image.shape[0]
|
|
|
if image_batch_size == 1:
|
|
|
repeat_by = batch_size
|
|
|
else:
|
|
|
repeat_by = num_images_per_prompt
|
|
|
image = image.repeat_interleave(repeat_by, dim=0)
|
|
|
return image
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def __call__(
|
|
|
self,
|
|
|
prompt: Union[str, List[str]],
|
|
|
image: Union[torch.Tensor, Image.Image],
|
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
|
height: Optional[int] = None,
|
|
|
width: Optional[int] = None,
|
|
|
num_inference_steps: int = 50,
|
|
|
guidance_scale: float = 0.0,
|
|
|
controlnet_conditioning_scale: float = 1.0,
|
|
|
num_images_per_prompt: int = 1,
|
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
|
output_type: str = "pil",
|
|
|
return_dict: bool = True,
|
|
|
**kwargs,
|
|
|
):
|
|
|
device = self._execution_device
|
|
|
height = height or image.height
|
|
|
width = width or image.width
|
|
|
|
|
|
|
|
|
if isinstance(prompt, str): prompt = [prompt]
|
|
|
if isinstance(negative_prompt, str): negative_prompt = [negative_prompt]
|
|
|
|
|
|
batch_size = len(prompt) * num_images_per_prompt
|
|
|
do_cfg = guidance_scale > 0.0
|
|
|
|
|
|
|
|
|
|
|
|
prompt_embeds_list = []
|
|
|
for p in prompt:
|
|
|
embed = self._encode_prompt(p, device, 512)
|
|
|
for _ in range(num_images_per_prompt):
|
|
|
prompt_embeds_list.append(embed)
|
|
|
|
|
|
if do_cfg:
|
|
|
if negative_prompt is None: negative_prompt = [""] * len(prompt)
|
|
|
neg_embeds_list = []
|
|
|
for np in negative_prompt:
|
|
|
embed = self._encode_prompt(np, device, 512)
|
|
|
for _ in range(num_images_per_prompt):
|
|
|
neg_embeds_list.append(embed)
|
|
|
|
|
|
prompt_input = neg_embeds_list + prompt_embeds_list
|
|
|
else:
|
|
|
prompt_input = prompt_embeds_list
|
|
|
|
|
|
|
|
|
|
|
|
control_tensor = self.prepare_control_image(
|
|
|
image, width, height, batch_size, num_images_per_prompt, device, self.vae.dtype
|
|
|
)
|
|
|
|
|
|
if len(control_tensor.shape) == 3:
|
|
|
control_tensor = control_tensor.unsqueeze(0)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
control_latents = self.vae.encode(control_tensor).latent_dist.mode()
|
|
|
control_latents = control_latents * self.vae.config.scaling_factor
|
|
|
|
|
|
|
|
|
if control_latents.shape[1] == 4 and self.transformer.in_channels == 16:
|
|
|
control_latents = control_latents.repeat(1, 4, 1, 1)
|
|
|
|
|
|
control_latents = control_latents.to(dtype=self.transformer.dtype)
|
|
|
|
|
|
|
|
|
control_latents = control_latents.unsqueeze(2)
|
|
|
control_context = list(control_latents.unbind(0))
|
|
|
|
|
|
|
|
|
if do_cfg:
|
|
|
control_context_input = control_context * 2
|
|
|
else:
|
|
|
control_context_input = control_context
|
|
|
|
|
|
|
|
|
latents = self.prepare_latents(
|
|
|
batch_size, self.transformer.in_channels, height, width,
|
|
|
prompt_embeds_list[0].dtype, device, generator
|
|
|
)
|
|
|
latents = latents.to(self.transformer.dtype)
|
|
|
|
|
|
|
|
|
image_seq_len = (height // (self.vae_scale_factor)) * (width // (self.vae_scale_factor))
|
|
|
mu = calculate_shift(image_seq_len)
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
|
|
|
|
|
|
for t in self.progress_bar(self.scheduler.timesteps):
|
|
|
t_input = t.expand(len(prompt_input))
|
|
|
timestep_norm = (1000.0 - t_input) / 1000.0
|
|
|
|
|
|
latents_input = torch.cat([latents] * 2) if do_cfg else latents
|
|
|
|
|
|
|
|
|
latent_list = list(latents_input.unsqueeze(2).unbind(dim=0))
|
|
|
|
|
|
model_out_list = self.transformer(
|
|
|
x=latent_list,
|
|
|
t=timestep_norm,
|
|
|
cap_feats=prompt_input,
|
|
|
control_context=control_context_input,
|
|
|
conditioning_scale=controlnet_conditioning_scale,
|
|
|
)[0]
|
|
|
|
|
|
model_out = torch.stack(model_out_list, dim=0).squeeze(2)
|
|
|
|
|
|
if do_cfg:
|
|
|
neg_out, pos_out = model_out.chunk(2)
|
|
|
noise_pred = neg_out + guidance_scale * (pos_out - neg_out)
|
|
|
else:
|
|
|
noise_pred = model_out
|
|
|
|
|
|
noise_pred = -noise_pred
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
|
|
|
|
|
|
|
|
if not output_type == "latent":
|
|
|
|
|
|
latents_for_vae = latents.to(self.vae.dtype)
|
|
|
latents_for_vae = (latents_for_vae / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
|
|
|
|
|
image = self.vae.decode(latents_for_vae, return_dict=False)[0]
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
|
|
else:
|
|
|
image = latents
|
|
|
|
|
|
self.maybe_free_model_hooks()
|
|
|
return ZImagePipelineOutput(images=image) |