Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import tensorflow as tf | |
| import numpy as np | |
| from source.facelib.facer import FaceAna | |
| import source.utils as utils | |
| from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points | |
| if tf.__version__ >= '2.0': | |
| tf = tf.compat.v1 | |
| tf.disable_eager_execution() | |
| class Cartoonizer(): | |
| def __init__(self, dataroot): | |
| self.facer = FaceAna(dataroot) | |
| self.sess_head = self.load_sess( | |
| os.path.join(dataroot, 'cartoon_anime_h.pb'), 'model_head') | |
| self.sess_bg = self.load_sess( | |
| os.path.join(dataroot, 'cartoon_anime_bg.pb'), 'model_bg') | |
| self.box_width = 288 | |
| global_mask = cv2.imread(os.path.join(dataroot, 'alpha.jpg')) | |
| global_mask = cv2.resize( | |
| global_mask, (self.box_width, self.box_width), | |
| interpolation=cv2.INTER_AREA) | |
| self.global_mask = cv2.cvtColor( | |
| global_mask, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0 | |
| def load_sess(self, model_path, name): | |
| config = tf.ConfigProto(allow_soft_placement=True) | |
| config.gpu_options.allow_growth = True | |
| sess = tf.Session(config=config) | |
| print(f'loading model from {model_path}') | |
| with tf.gfile.FastGFile(model_path, 'rb') as f: | |
| graph_def = tf.GraphDef() | |
| graph_def.ParseFromString(f.read()) | |
| sess.graph.as_default() | |
| tf.import_graph_def(graph_def, name=name) | |
| sess.run(tf.global_variables_initializer()) | |
| print(f'load model {model_path} done.') | |
| return sess | |
| def detect_face(self, img): | |
| src_h, src_w, _ = img.shape | |
| src_x = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| boxes, landmarks, _ = self.facer.run(src_x) | |
| if boxes.shape[0] == 0: | |
| return None | |
| else: | |
| return landmarks | |
| def cartoonize(self, img): | |
| # img: RGB input | |
| ori_h, ori_w, _ = img.shape | |
| img = utils.resize_size(img, size=720) | |
| img_brg = img[:, :, ::-1] | |
| # background process | |
| pad_bg, pad_h, pad_w = utils.padTo16x(img_brg) | |
| bg_res = self.sess_bg.run( | |
| self.sess_bg.graph.get_tensor_by_name( | |
| 'model_bg/output_image:0'), | |
| feed_dict={'model_bg/input_image:0': pad_bg}) | |
| res = bg_res[:pad_h, :pad_w, :] | |
| landmarks = self.detect_face(img_brg) | |
| if landmarks is None: | |
| print('No face detected!') | |
| return res | |
| print('%d faces detected!'%len(landmarks)) | |
| for landmark in landmarks: | |
| # get facial 5 points | |
| f5p = utils.get_f5p(landmark, img_brg) | |
| # face alignment | |
| head_img, trans_inv = warp_and_crop_face( | |
| img, | |
| f5p, | |
| ratio=0.75, | |
| reference_pts=get_reference_facial_points(default_square=True), | |
| crop_size=(self.box_width, self.box_width), | |
| return_trans_inv=True) | |
| # head process | |
| head_res = self.sess_head.run( | |
| self.sess_head.graph.get_tensor_by_name( | |
| 'model_head/output_image:0'), | |
| feed_dict={ | |
| 'model_head/input_image:0': head_img[:, :, ::-1] | |
| }) | |
| # merge head and background | |
| head_trans_inv = cv2.warpAffine( | |
| head_res, | |
| trans_inv, (np.size(img, 1), np.size(img, 0)), | |
| borderValue=(0, 0, 0)) | |
| mask = self.global_mask | |
| mask_trans_inv = cv2.warpAffine( | |
| mask, | |
| trans_inv, (np.size(img, 1), np.size(img, 0)), | |
| borderValue=(0, 0, 0)) | |
| mask_trans_inv = np.expand_dims(mask_trans_inv, 2) | |
| res = mask_trans_inv * head_trans_inv + (1 - mask_trans_inv) * res | |
| res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA) | |
| return res | |