| import gradio as gr | |
| import os | |
| import shutil | |
| import torch | |
| from PIL import Image | |
| import argparse | |
| import pathlib | |
| os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model") | |
| os.chdir("Thin-Plate-Spline-Motion-Model") | |
| os.system("mkdir checkpoints") | |
| os.system("wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar") | |
| title = "# 图片动画" | |
| DESCRIPTION = '''### 图片动画的Gradio实现</b>, CVPR 2022. <a href='https://arxiv.org/abs/2203.14367'>[Paper]</a><a href='https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model'>[Github Code]</a> | |
| <img id="overview" alt="overview" src="https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model/raw/main/assets/vox.gif" /> | |
| ''' | |
| FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.Image-Animation-using-Thin-Plate-Spline-Motion-Model" />' | |
| ARTICLE = r""" | |
| --- | |
| <h2 style="font-weight: 900; margin-bottom: 7px;">点击<a href='https://www.toolchest.cn' target='_blank'>返回智能工具箱</a>查看更多好玩的人工智能项目</h2> | |
| ``` | |
| """ | |
| def get_style_image_path(style_name: str) -> str: | |
| base_path = 'assets' | |
| filenames = { | |
| 'source': 'source.png', | |
| 'driving': 'driving.mp4', | |
| } | |
| return f'{base_path}/{filenames[style_name]}' | |
| def get_style_image_markdown_text(style_name: str) -> str: | |
| url = get_style_image_path(style_name) | |
| return f'<img id="style-image" src="{url}" alt="style image">' | |
| def update_style_image(style_name: str) -> dict: | |
| text = get_style_image_markdown_text(style_name) | |
| return gr.Markdown.update(value=text) | |
| def set_example_image(example: list) -> dict: | |
| return gr.Image.update(value=example[0]) | |
| def set_example_video(example: list) -> dict: | |
| return gr.Video.update(value=example[0]) | |
| def inference(img,vid): | |
| if not os.path.exists('temp'): | |
| os.system('mkdir temp') | |
| img.save("temp/image.jpg", "JPEG") | |
| os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu") | |
| return './temp/result.mp4' | |
| def main(): | |
| with gr.Blocks(theme="huggingface", css='style.css') as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Box(): | |
| gr.Markdown('''## 第1步 (上传人脸图片) | |
| - 拖一张含人脸的图片到 **输入图片**. | |
| - 如果图片中有多张人脸, 使用右上角的编辑按钮裁剪图片. | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_image = gr.Image(label='输入图片', | |
| type="pil") | |
| with gr.Row(): | |
| paths = sorted(pathlib.Path('assets').glob('*.png')) | |
| example_images = gr.Dataset(components=[input_image], | |
| samples=[[path.as_posix()] | |
| for path in paths]) | |
| with gr.Box(): | |
| gr.Markdown('''## 第2步 (选择动态视频) | |
| - **为人脸图片选择目标视频**. | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| driving_video = gr.Video(label='目标视频', | |
| format="mp4") | |
| with gr.Row(): | |
| paths = sorted(pathlib.Path('assets').glob('*.mp4')) | |
| example_video = gr.Dataset(components=[driving_video], | |
| samples=[[path.as_posix()] | |
| for path in paths]) | |
| with gr.Box(): | |
| gr.Markdown('''## 第3步 (基于视频生成动态图片) | |
| - 点击 **开始** 按钮. (注意: 由于是在CPU上运行, 生成最终结果需要花费大约3分钟.) | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| generate_button = gr.Button('开始') | |
| with gr.Column(): | |
| result = gr.Video(type="file", label="输出") | |
| gr.Markdown(FOOTER) | |
| generate_button.click(fn=inference, | |
| inputs=[ | |
| input_image, | |
| driving_video | |
| ], | |
| outputs=result) | |
| example_images.click(fn=set_example_image, | |
| inputs=example_images, | |
| outputs=example_images.components) | |
| example_video.click(fn=set_example_video, | |
| inputs=example_video, | |
| outputs=example_video.components) | |
| demo.launch( | |
| enable_queue=True, | |
| debug=True | |
| ) | |
| if __name__ == '__main__': | |
| main() |