| | from typing import Mapping, Any, Tuple, Callable |
| | import importlib |
| | import os |
| | from urllib.parse import urlparse |
| |
|
| | import torch |
| | from torch import Tensor |
| | from torch.nn import functional as F |
| | import numpy as np |
| |
|
| | from torch.hub import download_url_to_file, get_dir |
| |
|
| |
|
| | def get_obj_from_str(string: str, reload: bool=False) -> Any: |
| | module, cls = string.rsplit(".", 1) |
| | if reload: |
| | module_imp = importlib.import_module(module) |
| | importlib.reload(module_imp) |
| | return getattr(importlib.import_module(module, package=None), cls) |
| |
|
| |
|
| | def instantiate_from_config(config: Mapping[str, Any]) -> Any: |
| | if not "target" in config: |
| | raise KeyError("Expected key `target` to instantiate.") |
| | |
| | return get_obj_from_str(config["target"])(**config.get("params", dict())) |
| |
|
| |
|
| | def wavelet_blur(image: Tensor, radius: int): |
| | """ |
| | Apply wavelet blur to the input tensor. |
| | """ |
| | |
| | |
| | kernel_vals = [ |
| | [0.0625, 0.125, 0.0625], |
| | [0.125, 0.25, 0.125], |
| | [0.0625, 0.125, 0.0625], |
| | ] |
| | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) |
| | |
| | kernel = kernel[None, None] |
| | |
| | kernel = kernel.repeat(3, 1, 1, 1) |
| | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') |
| | |
| | output = F.conv2d(image, kernel, groups=3, dilation=radius) |
| | return output |
| |
|
| |
|
| | def wavelet_decomposition(image: Tensor, levels=5): |
| | """ |
| | Apply wavelet decomposition to the input tensor. |
| | This function only returns the low frequency & the high frequency. |
| | """ |
| | high_freq = torch.zeros_like(image) |
| | for i in range(levels): |
| | radius = 2 ** i |
| | low_freq = wavelet_blur(image, radius) |
| | high_freq += (image - low_freq) |
| | image = low_freq |
| |
|
| | return high_freq, low_freq |
| |
|
| |
|
| | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): |
| | """ |
| | Apply wavelet decomposition, so that the content will have the same color as the style. |
| | """ |
| | |
| | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) |
| | del content_low_freq |
| | |
| | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) |
| | del style_high_freq |
| | |
| | return content_high_freq + style_low_freq |
| |
|
| |
|
| | |
| | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): |
| | """Load file form http url, will download models if necessary. |
| | |
| | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py |
| | |
| | Args: |
| | url (str): URL to be downloaded. |
| | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. |
| | Default: None. |
| | progress (bool): Whether to show the download progress. Default: True. |
| | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. |
| | |
| | Returns: |
| | str: The path to the downloaded file. |
| | """ |
| | if model_dir is None: |
| | hub_dir = get_dir() |
| | model_dir = os.path.join(hub_dir, 'checkpoints') |
| |
|
| | os.makedirs(model_dir, exist_ok=True) |
| |
|
| | parts = urlparse(url) |
| | filename = os.path.basename(parts.path) |
| | if file_name is not None: |
| | filename = file_name |
| | cached_file = os.path.abspath(os.path.join(model_dir, filename)) |
| | if not os.path.exists(cached_file): |
| | print(f'Downloading: "{url}" to {cached_file}\n') |
| | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) |
| | return cached_file |
| |
|
| |
|
| | def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]: |
| | hi_list = list(range(0, h - tile_size + 1, tile_stride)) |
| | if (h - tile_size) % tile_stride != 0: |
| | hi_list.append(h - tile_size) |
| | |
| | wi_list = list(range(0, w - tile_size + 1, tile_stride)) |
| | if (w - tile_size) % tile_stride != 0: |
| | wi_list.append(w - tile_size) |
| | |
| | coords = [] |
| | for hi in hi_list: |
| | for wi in wi_list: |
| | coords.append((hi, hi + tile_size, wi, wi + tile_size)) |
| | return coords |
| |
|
| |
|
| | |
| | def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray: |
| | """Generates a gaussian mask of weights for tile contributions""" |
| | latent_width = tile_width |
| | latent_height = tile_height |
| | var = 0.01 |
| | midpoint = (latent_width - 1) / 2 |
| | x_probs = [ |
| | np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var) |
| | for x in range(latent_width)] |
| | midpoint = latent_height / 2 |
| | y_probs = [ |
| | np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var) |
| | for y in range(latent_height)] |
| | weights = np.outer(y_probs, x_probs) |
| | return weights |
| |
|
| |
|
| | COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False)) |
| |
|
| | def count_vram_usage(func: Callable) -> Callable: |
| | if not COUNT_VRAM: |
| | return func |
| | |
| | def wrapper(*args, **kwargs): |
| | peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3) |
| | ret = func(*args, **kwargs) |
| | torch.cuda.synchronize() |
| | peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3) |
| | print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB") |
| | return ret |
| | return wrapper |