| | import argparse, os, sys, glob, datetime, yaml |
| | import torch |
| | import time |
| | import numpy as np |
| | from tqdm import trange |
| |
|
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| |
|
| | from ldm.models.diffusion.ddim import DDIMSampler |
| | from ldm.util import instantiate_from_config |
| |
|
| | rescale = lambda x: (x + 1.) / 2. |
| |
|
| | def custom_to_pil(x): |
| | x = x.detach().cpu() |
| | x = torch.clamp(x, -1., 1.) |
| | x = (x + 1.) / 2. |
| | x = x.permute(1, 2, 0).numpy() |
| | x = (255 * x).astype(np.uint8) |
| | x = Image.fromarray(x) |
| | if not x.mode == "RGB": |
| | x = x.convert("RGB") |
| | return x |
| |
|
| |
|
| | def custom_to_np(x): |
| | |
| | sample = x.detach().cpu() |
| | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) |
| | sample = sample.permute(0, 2, 3, 1) |
| | sample = sample.contiguous() |
| | return sample |
| |
|
| |
|
| | def logs2pil(logs, keys=["sample"]): |
| | imgs = dict() |
| | for k in logs: |
| | try: |
| | if len(logs[k].shape) == 4: |
| | img = custom_to_pil(logs[k][0, ...]) |
| | elif len(logs[k].shape) == 3: |
| | img = custom_to_pil(logs[k]) |
| | else: |
| | print(f"Unknown format for key {k}. ") |
| | img = None |
| | except: |
| | img = None |
| | imgs[k] = img |
| | return imgs |
| |
|
| |
|
| | @torch.no_grad() |
| | def convsample(model, shape, return_intermediates=True, |
| | verbose=True, |
| | make_prog_row=False): |
| |
|
| |
|
| | if not make_prog_row: |
| | return model.p_sample_loop(None, shape, |
| | return_intermediates=return_intermediates, verbose=verbose) |
| | else: |
| | return model.progressive_denoising( |
| | None, shape, verbose=True |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def convsample_ddim(model, steps, shape, eta=1.0 |
| | ): |
| | ddim = DDIMSampler(model) |
| | bs = shape[0] |
| | shape = shape[1:] |
| | samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) |
| | return samples, intermediates |
| |
|
| |
|
| | @torch.no_grad() |
| | def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): |
| |
|
| |
|
| | log = dict() |
| |
|
| | shape = [batch_size, |
| | model.model.diffusion_model.in_channels, |
| | model.model.diffusion_model.image_size, |
| | model.model.diffusion_model.image_size] |
| |
|
| | with model.ema_scope("Plotting"): |
| | t0 = time.time() |
| | if vanilla: |
| | sample, progrow = convsample(model, shape, |
| | make_prog_row=True) |
| | else: |
| | sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, |
| | eta=eta) |
| |
|
| | t1 = time.time() |
| |
|
| | x_sample = model.decode_first_stage(sample) |
| |
|
| | log["sample"] = x_sample |
| | log["time"] = t1 - t0 |
| | log['throughput'] = sample.shape[0] / (t1 - t0) |
| | print(f'Throughput for this batch: {log["throughput"]}') |
| | return log |
| |
|
| | def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): |
| | if vanilla: |
| | print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') |
| | else: |
| | print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') |
| |
|
| |
|
| | tstart = time.time() |
| | n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 |
| | |
| | if model.cond_stage_model is None: |
| | all_images = [] |
| |
|
| | print(f"Running unconditional sampling for {n_samples} samples") |
| | for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): |
| | logs = make_convolutional_sample(model, batch_size=batch_size, |
| | vanilla=vanilla, custom_steps=custom_steps, |
| | eta=eta) |
| | n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") |
| | all_images.extend([custom_to_np(logs["sample"])]) |
| | if n_saved >= n_samples: |
| | print(f'Finish after generating {n_saved} samples') |
| | break |
| | all_img = np.concatenate(all_images, axis=0) |
| | all_img = all_img[:n_samples] |
| | shape_str = "x".join([str(x) for x in all_img.shape]) |
| | nppath = os.path.join(nplog, f"{shape_str}-samples.npz") |
| | np.savez(nppath, all_img) |
| |
|
| | else: |
| | raise NotImplementedError('Currently only sampling for unconditional models supported.') |
| |
|
| | print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") |
| |
|
| |
|
| | def save_logs(logs, path, n_saved=0, key="sample", np_path=None): |
| | for k in logs: |
| | if k == key: |
| | batch = logs[key] |
| | if np_path is None: |
| | for x in batch: |
| | img = custom_to_pil(x) |
| | imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") |
| | img.save(imgpath) |
| | n_saved += 1 |
| | else: |
| | npbatch = custom_to_np(batch) |
| | shape_str = "x".join([str(x) for x in npbatch.shape]) |
| | nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") |
| | np.savez(nppath, npbatch) |
| | n_saved += npbatch.shape[0] |
| | return n_saved |
| |
|
| |
|
| | def get_parser(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "-r", |
| | "--resume", |
| | type=str, |
| | nargs="?", |
| | help="load from logdir or checkpoint in logdir", |
| | ) |
| | parser.add_argument( |
| | "-n", |
| | "--n_samples", |
| | type=int, |
| | nargs="?", |
| | help="number of samples to draw", |
| | default=50000 |
| | ) |
| | parser.add_argument( |
| | "-e", |
| | "--eta", |
| | type=float, |
| | nargs="?", |
| | help="eta for ddim sampling (0.0 yields deterministic sampling)", |
| | default=1.0 |
| | ) |
| | parser.add_argument( |
| | "-v", |
| | "--vanilla_sample", |
| | default=False, |
| | action='store_true', |
| | help="vanilla sampling (default option is DDIM sampling)?", |
| | ) |
| | parser.add_argument( |
| | "-l", |
| | "--logdir", |
| | type=str, |
| | nargs="?", |
| | help="extra logdir", |
| | default="none" |
| | ) |
| | parser.add_argument( |
| | "-c", |
| | "--custom_steps", |
| | type=int, |
| | nargs="?", |
| | help="number of steps for ddim and fastdpm sampling", |
| | default=50 |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | nargs="?", |
| | help="the bs", |
| | default=10 |
| | ) |
| | return parser |
| |
|
| |
|
| | def load_model_from_config(config, sd): |
| | model = instantiate_from_config(config) |
| | model.load_state_dict(sd,strict=False) |
| | model.cuda() |
| | model.eval() |
| | return model |
| |
|
| |
|
| | def load_model(config, ckpt, gpu, eval_mode): |
| | if ckpt: |
| | print(f"Loading model from {ckpt}") |
| | pl_sd = torch.load(ckpt, map_location="cpu") |
| | global_step = pl_sd["global_step"] |
| | else: |
| | pl_sd = {"state_dict": None} |
| | global_step = None |
| | model = load_model_from_config(config.model, |
| | pl_sd["state_dict"]) |
| |
|
| | return model, global_step |
| |
|
| |
|
| | if __name__ == "__main__": |
| | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
| | sys.path.append(os.getcwd()) |
| | command = " ".join(sys.argv) |
| |
|
| | parser = get_parser() |
| | opt, unknown = parser.parse_known_args() |
| | ckpt = None |
| |
|
| | if not os.path.exists(opt.resume): |
| | raise ValueError("Cannot find {}".format(opt.resume)) |
| | if os.path.isfile(opt.resume): |
| | |
| | try: |
| | logdir = '/'.join(opt.resume.split('/')[:-1]) |
| | |
| | print(f'Logdir is {logdir}') |
| | except ValueError: |
| | paths = opt.resume.split("/") |
| | idx = -2 |
| | logdir = "/".join(paths[:idx]) |
| | ckpt = opt.resume |
| | else: |
| | assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" |
| | logdir = opt.resume.rstrip("/") |
| | ckpt = os.path.join(logdir, "model.ckpt") |
| |
|
| | base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) |
| | opt.base = base_configs |
| |
|
| | configs = [OmegaConf.load(cfg) for cfg in opt.base] |
| | cli = OmegaConf.from_dotlist(unknown) |
| | config = OmegaConf.merge(*configs, cli) |
| |
|
| | gpu = True |
| | eval_mode = True |
| |
|
| | if opt.logdir != "none": |
| | locallog = logdir.split(os.sep)[-1] |
| | if locallog == "": locallog = logdir.split(os.sep)[-2] |
| | print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") |
| | logdir = os.path.join(opt.logdir, locallog) |
| |
|
| | print(config) |
| |
|
| | model, global_step = load_model(config, ckpt, gpu, eval_mode) |
| | print(f"global step: {global_step}") |
| | print(75 * "=") |
| | print("logging to:") |
| | logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) |
| | imglogdir = os.path.join(logdir, "img") |
| | numpylogdir = os.path.join(logdir, "numpy") |
| |
|
| | os.makedirs(imglogdir) |
| | os.makedirs(numpylogdir) |
| | print(logdir) |
| | print(75 * "=") |
| |
|
| | |
| | sampling_file = os.path.join(logdir, "sampling_config.yaml") |
| | sampling_conf = vars(opt) |
| |
|
| | with open(sampling_file, 'w') as f: |
| | yaml.dump(sampling_conf, f, default_flow_style=False) |
| | print(sampling_conf) |
| |
|
| |
|
| | run(model, imglogdir, eta=opt.eta, |
| | vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, |
| | batch_size=opt.batch_size, nplog=numpylogdir) |
| |
|
| | print("done.") |
| |
|