u0feff commited on
Commit
7f09d23
·
1 Parent(s): 786e257
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +159 -0
  2. app.py +570 -0
  3. assets/BBOX_SHIFT.md +26 -0
  4. assets/demo/sit/sit.jpeg +0 -0
  5. assets/demo/yongen/yongen.jpeg +0 -0
  6. configs/inference/realtime.yaml +10 -0
  7. configs/inference/test.yaml +10 -0
  8. configs/training/gpu.yaml +21 -0
  9. configs/training/preprocess.yaml +31 -0
  10. configs/training/stage1.yaml +89 -0
  11. configs/training/stage2.yaml +89 -0
  12. configs/training/syncnet.yaml +19 -0
  13. download_weights.bat +45 -0
  14. download_weights.sh +51 -0
  15. entrypoint.sh +9 -0
  16. inference.sh +72 -0
  17. musetalk/data/audio.py +168 -0
  18. musetalk/data/dataset.py +607 -0
  19. musetalk/data/sample_method.py +233 -0
  20. musetalk/loss/basic_loss.py +81 -0
  21. musetalk/loss/conv.py +44 -0
  22. musetalk/loss/discriminator.py +145 -0
  23. musetalk/loss/resnet.py +152 -0
  24. musetalk/loss/syncnet.py +95 -0
  25. musetalk/loss/vgg_face.py +237 -0
  26. musetalk/models/syncnet.py +240 -0
  27. musetalk/models/unet.py +51 -0
  28. musetalk/models/vae.py +148 -0
  29. musetalk/utils/__init__.py +5 -0
  30. musetalk/utils/audio_processor.py +102 -0
  31. musetalk/utils/blending.py +136 -0
  32. musetalk/utils/dwpose/default_runtime.py +54 -0
  33. musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py +257 -0
  34. musetalk/utils/face_detection/README.md +1 -0
  35. musetalk/utils/face_detection/__init__.py +7 -0
  36. musetalk/utils/face_detection/api.py +240 -0
  37. musetalk/utils/face_detection/detection/__init__.py +1 -0
  38. musetalk/utils/face_detection/detection/core.py +130 -0
  39. musetalk/utils/face_detection/detection/sfd/__init__.py +1 -0
  40. musetalk/utils/face_detection/detection/sfd/bbox.py +129 -0
  41. musetalk/utils/face_detection/detection/sfd/detect.py +114 -0
  42. musetalk/utils/face_detection/detection/sfd/net_s3fd.py +129 -0
  43. musetalk/utils/face_detection/detection/sfd/sfd_detector.py +59 -0
  44. musetalk/utils/face_detection/models.py +261 -0
  45. musetalk/utils/face_detection/utils.py +313 -0
  46. musetalk/utils/face_parsing/__init__.py +117 -0
  47. musetalk/utils/face_parsing/model.py +283 -0
  48. musetalk/utils/face_parsing/resnet.py +109 -0
  49. musetalk/utils/preprocessing.py +155 -0
  50. musetalk/utils/training_utils.py +337 -0
LICENSE ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ MIT License
3
+
4
+ Copyright (c) 2024 Tencent Music Entertainment Group
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+
24
+
25
+ Other dependencies and licenses:
26
+
27
+
28
+ Open Source Software Licensed under the MIT License:
29
+ --------------------------------------------------------------------
30
+ 1. sd-vae-ft-mse
31
+ Files:https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main
32
+ License:MIT license
33
+ For details:https://choosealicense.com/licenses/mit/
34
+
35
+ 2. whisper
36
+ Files:https://github.com/openai/whisper
37
+ License:MIT license
38
+ Copyright (c) 2022 OpenAI
39
+ For details:https://github.com/openai/whisper/blob/main/LICENSE
40
+
41
+ 3. face-parsing.PyTorch
42
+ Files:https://github.com/zllrunning/face-parsing.PyTorch
43
+ License:MIT License
44
+ Copyright (c) 2019 zll
45
+ For details:https://github.com/zllrunning/face-parsing.PyTorch/blob/master/LICENSE
46
+
47
+
48
+
49
+ Open Source Software Licensed under the Apache License Version 2.0:
50
+ --------------------------------------------------------------------
51
+ 1. DWpose
52
+ Files:https://huggingface.co/yzd-v/DWPose/tree/main
53
+ License:Apache-2.0
54
+ For details:https://choosealicense.com/licenses/apache-2.0/
55
+
56
+
57
+ Terms of the Apache License Version 2.0:
58
+ --------------------------------------------------------------------
59
+ Apache License
60
+
61
+ Version 2.0, January 2004
62
+
63
+ http://www.apache.org/licenses/
64
+
65
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
66
+ 1. Definitions.
67
+
68
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
69
+
70
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
71
+
72
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
73
+
74
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
75
+
76
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
77
+
78
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
79
+
80
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
81
+
82
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
83
+
84
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
85
+
86
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
87
+
88
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
89
+
90
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
91
+
92
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
93
+
94
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
95
+
96
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
97
+
98
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
99
+
100
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
101
+
102
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
103
+
104
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
105
+
106
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
107
+
108
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
109
+
110
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
111
+
112
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
113
+
114
+ END OF TERMS AND CONDITIONS
115
+
116
+
117
+
118
+ Open Source Software Licensed under the BSD 3-Clause License:
119
+ --------------------------------------------------------------------
120
+ 1. face-alignment
121
+ Files:https://github.com/1adrianb/face-alignment/tree/master
122
+ License:BSD 3-Clause License
123
+ Copyright (c) 2017, Adrian Bulat
124
+ All rights reserved.
125
+ For details:https://github.com/1adrianb/face-alignment/blob/master/LICENSE
126
+
127
+
128
+ Terms of the BSD 3-Clause License:
129
+ --------------------------------------------------------------------
130
+ Redistribution and use in source and binary forms, with or without
131
+ modification, are permitted provided that the following conditions are met:
132
+
133
+ * Redistributions of source code must retain the above copyright notice, this
134
+ list of conditions and the following disclaimer.
135
+
136
+ * Redistributions in binary form must reproduce the above copyright notice,
137
+ this list of conditions and the following disclaimer in the documentation
138
+ and/or other materials provided with the distribution.
139
+
140
+ * Neither the name of the copyright holder nor the names of its
141
+ contributors may be used to endorse or promote products derived from
142
+ this software without specific prior written permission.
143
+
144
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
145
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
146
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
147
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
148
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
149
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
150
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
151
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
152
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
153
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
154
+
155
+
156
+ Open Source Software:
157
+ --------------------------------------------------------------------
158
+ 1.s3FD
159
+ Files:https://github.com/yxlijun/S3FD.pytorch
app.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pdb
4
+ import re
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import sys
9
+ import subprocess
10
+
11
+ from huggingface_hub import snapshot_download
12
+ import requests
13
+
14
+ import argparse
15
+ import os
16
+ from omegaconf import OmegaConf
17
+ import numpy as np
18
+ import cv2
19
+ import torch
20
+ import glob
21
+ import pickle
22
+ from tqdm import tqdm
23
+ import copy
24
+ from argparse import Namespace
25
+ import shutil
26
+ import gdown
27
+ import imageio
28
+ import ffmpeg
29
+ from moviepy.editor import *
30
+ from transformers import WhisperModel
31
+
32
+ ProjectDir = os.path.abspath(os.path.dirname(__file__))
33
+ CheckpointsDir = os.path.join(ProjectDir, "models")
34
+
35
+ @torch.no_grad()
36
+ def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
37
+ left_cheek_width=90, right_cheek_width=90):
38
+ """Debug inpainting parameters, only process the first frame"""
39
+ # Set default parameters
40
+ args_dict = {
41
+ "result_dir": './results/debug',
42
+ "fps": 25,
43
+ "batch_size": 1,
44
+ "output_vid_name": '',
45
+ "use_saved_coord": False,
46
+ "audio_padding_length_left": 2,
47
+ "audio_padding_length_right": 2,
48
+ "version": "v15",
49
+ "extra_margin": extra_margin,
50
+ "parsing_mode": parsing_mode,
51
+ "left_cheek_width": left_cheek_width,
52
+ "right_cheek_width": right_cheek_width
53
+ }
54
+ args = Namespace(**args_dict)
55
+
56
+ # Create debug directory
57
+ os.makedirs(args.result_dir, exist_ok=True)
58
+
59
+ # Read first frame
60
+ if get_file_type(video_path) == "video":
61
+ reader = imageio.get_reader(video_path)
62
+ first_frame = reader.get_data(0)
63
+ reader.close()
64
+ else:
65
+ first_frame = cv2.imread(video_path)
66
+ first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
67
+
68
+ # Save first frame
69
+ debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
70
+ cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
71
+
72
+ # Get face coordinates
73
+ coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
74
+ bbox = coord_list[0]
75
+ frame = frame_list[0]
76
+
77
+ if bbox == coord_placeholder:
78
+ return None, "No face detected, please adjust bbox_shift parameter"
79
+
80
+ # Initialize face parser
81
+ fp = FaceParsing(
82
+ left_cheek_width=args.left_cheek_width,
83
+ right_cheek_width=args.right_cheek_width
84
+ )
85
+
86
+ # Process first frame
87
+ x1, y1, x2, y2 = bbox
88
+ y2 = y2 + args.extra_margin
89
+ y2 = min(y2, frame.shape[0])
90
+ crop_frame = frame[y1:y2, x1:x2]
91
+ crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
92
+
93
+ # Generate random audio features
94
+ random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
95
+ audio_feature = pe(random_audio)
96
+
97
+ # Get latents
98
+ latents = vae.get_latents_for_unet(crop_frame)
99
+ latents = latents.to(dtype=weight_dtype)
100
+
101
+ # Generate prediction results
102
+ pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
103
+ recon = vae.decode_latents(pred_latents)
104
+
105
+ # Inpaint back to original image
106
+ res_frame = recon[0]
107
+ res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
108
+ combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
109
+
110
+ # Save results (no need to convert color space again since get_image already returns RGB format)
111
+ debug_result_path = os.path.join(args.result_dir, "debug_result.png")
112
+ cv2.imwrite(debug_result_path, combine_frame)
113
+
114
+ # Create information text
115
+ info_text = f"Parameter information:\n" + \
116
+ f"bbox_shift: {bbox_shift}\n" + \
117
+ f"extra_margin: {extra_margin}\n" + \
118
+ f"parsing_mode: {parsing_mode}\n" + \
119
+ f"left_cheek_width: {left_cheek_width}\n" + \
120
+ f"right_cheek_width: {right_cheek_width}\n" + \
121
+ f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
122
+
123
+ return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
124
+
125
+ def print_directory_contents(path):
126
+ for child in os.listdir(path):
127
+ child_path = os.path.join(path, child)
128
+ if os.path.isdir(child_path):
129
+ print(child_path)
130
+
131
+ def download_model():
132
+ # 检查必需的模型文件是否存在
133
+ required_models = {
134
+ "MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
135
+ "MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
136
+ "SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
137
+ "Whisper": f"{CheckpointsDir}/whisper/config.json",
138
+ "DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
139
+ "SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
140
+ "Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
141
+ "ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
142
+ }
143
+
144
+ missing_models = []
145
+ for model_name, model_path in required_models.items():
146
+ if not os.path.exists(model_path):
147
+ missing_models.append(model_name)
148
+
149
+ if missing_models:
150
+ # 全用英文
151
+ print("The following required model files are missing:")
152
+ for model in missing_models:
153
+ print(f"- {model}")
154
+ print("\nPlease run the download script to download the missing models:")
155
+ if sys.platform == "win32":
156
+ print("Windows: Run download_weights.bat")
157
+ else:
158
+ print("Linux/Mac: Run ./download_weights.sh")
159
+ sys.exit(1)
160
+ else:
161
+ print("All required model files exist.")
162
+
163
+
164
+
165
+
166
+ download_model() # for huggingface deployment.
167
+
168
+ from musetalk.utils.blending import get_image
169
+ from musetalk.utils.face_parsing import FaceParsing
170
+ from musetalk.utils.audio_processor import AudioProcessor
171
+ from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
172
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
173
+
174
+
175
+ def fast_check_ffmpeg():
176
+ try:
177
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
178
+ return True
179
+ except:
180
+ return False
181
+
182
+
183
+ @torch.no_grad()
184
+ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
185
+ left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
186
+ # Set default parameters, aligned with inference.py
187
+ args_dict = {
188
+ "result_dir": './results/output',
189
+ "fps": 25,
190
+ "batch_size": 8,
191
+ "output_vid_name": '',
192
+ "use_saved_coord": False,
193
+ "audio_padding_length_left": 2,
194
+ "audio_padding_length_right": 2,
195
+ "version": "v15", # Fixed use v15 version
196
+ "extra_margin": extra_margin,
197
+ "parsing_mode": parsing_mode,
198
+ "left_cheek_width": left_cheek_width,
199
+ "right_cheek_width": right_cheek_width
200
+ }
201
+ args = Namespace(**args_dict)
202
+
203
+ # Check ffmpeg
204
+ if not fast_check_ffmpeg():
205
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
206
+
207
+ input_basename = os.path.basename(video_path).split('.')[0]
208
+ audio_basename = os.path.basename(audio_path).split('.')[0]
209
+ output_basename = f"{input_basename}_{audio_basename}"
210
+
211
+ # Create temporary directory
212
+ temp_dir = os.path.join(args.result_dir, f"{args.version}")
213
+ os.makedirs(temp_dir, exist_ok=True)
214
+
215
+ # Set result save path
216
+ result_img_save_path = os.path.join(temp_dir, output_basename)
217
+ crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
218
+ os.makedirs(result_img_save_path, exist_ok=True)
219
+
220
+ if args.output_vid_name == "":
221
+ output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
222
+ else:
223
+ output_vid_name = os.path.join(temp_dir, args.output_vid_name)
224
+
225
+ ############################################## extract frames from source video ##############################################
226
+ if get_file_type(video_path) == "video":
227
+ save_dir_full = os.path.join(temp_dir, input_basename)
228
+ os.makedirs(save_dir_full, exist_ok=True)
229
+ # Read video
230
+ reader = imageio.get_reader(video_path)
231
+
232
+ # Save images
233
+ for i, im in enumerate(reader):
234
+ imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
235
+ input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
236
+ fps = get_video_fps(video_path)
237
+ else: # input img folder
238
+ input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
239
+ input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
240
+ fps = args.fps
241
+
242
+ ############################################## extract audio feature ##############################################
243
+ # Extract audio features
244
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
245
+ whisper_chunks = audio_processor.get_whisper_chunk(
246
+ whisper_input_features,
247
+ device,
248
+ weight_dtype,
249
+ whisper,
250
+ librosa_length,
251
+ fps=fps,
252
+ audio_padding_length_left=args.audio_padding_length_left,
253
+ audio_padding_length_right=args.audio_padding_length_right,
254
+ )
255
+
256
+ ############################################## preprocess input image ##############################################
257
+ if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
258
+ print("using extracted coordinates")
259
+ with open(crop_coord_save_path,'rb') as f:
260
+ coord_list = pickle.load(f)
261
+ frame_list = read_imgs(input_img_list)
262
+ else:
263
+ print("extracting landmarks...time consuming")
264
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
265
+ with open(crop_coord_save_path, 'wb') as f:
266
+ pickle.dump(coord_list, f)
267
+ bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
268
+
269
+ # Initialize face parser
270
+ fp = FaceParsing(
271
+ left_cheek_width=args.left_cheek_width,
272
+ right_cheek_width=args.right_cheek_width
273
+ )
274
+
275
+ i = 0
276
+ input_latent_list = []
277
+ for bbox, frame in zip(coord_list, frame_list):
278
+ if bbox == coord_placeholder:
279
+ continue
280
+ x1, y1, x2, y2 = bbox
281
+ y2 = y2 + args.extra_margin
282
+ y2 = min(y2, frame.shape[0])
283
+ crop_frame = frame[y1:y2, x1:x2]
284
+ crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
285
+ latents = vae.get_latents_for_unet(crop_frame)
286
+ input_latent_list.append(latents)
287
+
288
+ # to smooth the first and the last frame
289
+ frame_list_cycle = frame_list + frame_list[::-1]
290
+ coord_list_cycle = coord_list + coord_list[::-1]
291
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
292
+
293
+ ############################################## inference batch by batch ##############################################
294
+ print("start inference")
295
+ video_num = len(whisper_chunks)
296
+ batch_size = args.batch_size
297
+ gen = datagen(
298
+ whisper_chunks=whisper_chunks,
299
+ vae_encode_latents=input_latent_list_cycle,
300
+ batch_size=batch_size,
301
+ delay_frame=0,
302
+ device=device,
303
+ )
304
+ res_frame_list = []
305
+ for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
306
+ audio_feature_batch = pe(whisper_batch)
307
+ # Ensure latent_batch is consistent with model weight type
308
+ latent_batch = latent_batch.to(dtype=weight_dtype)
309
+
310
+ pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
311
+ recon = vae.decode_latents(pred_latents)
312
+ for res_frame in recon:
313
+ res_frame_list.append(res_frame)
314
+
315
+ ############################################## pad to full image ##############################################
316
+ print("pad talking image to original video")
317
+ for i, res_frame in enumerate(tqdm(res_frame_list)):
318
+ bbox = coord_list_cycle[i%(len(coord_list_cycle))]
319
+ ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
320
+ x1, y1, x2, y2 = bbox
321
+ y2 = y2 + args.extra_margin
322
+ y2 = min(y2, frame.shape[0])
323
+ try:
324
+ res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
325
+ except:
326
+ continue
327
+
328
+ # Use v15 version blending
329
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
330
+
331
+ cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
332
+
333
+ # Frame rate
334
+ fps = 25
335
+ # Output video path
336
+ output_video = 'temp.mp4'
337
+
338
+ # Read images
339
+ def is_valid_image(file):
340
+ pattern = re.compile(r'\d{8}\.png')
341
+ return pattern.match(file)
342
+
343
+ images = []
344
+ files = [file for file in os.listdir(result_img_save_path) if is_valid_image(file)]
345
+ files.sort(key=lambda x: int(x.split('.')[0]))
346
+
347
+ for file in files:
348
+ filename = os.path.join(result_img_save_path, file)
349
+ images.append(imageio.imread(filename))
350
+
351
+
352
+ # Save video
353
+ imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
354
+
355
+ input_video = './temp.mp4'
356
+ # Check if the input_video and audio_path exist
357
+ if not os.path.exists(input_video):
358
+ raise FileNotFoundError(f"Input video file not found: {input_video}")
359
+ if not os.path.exists(audio_path):
360
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
361
+
362
+ # Read video
363
+ reader = imageio.get_reader(input_video)
364
+ fps = reader.get_meta_data()['fps'] # Get original video frame rate
365
+ reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
366
+ # Store frames in list
367
+ frames = images
368
+
369
+ print(len(frames))
370
+
371
+ # Load the video
372
+ video_clip = VideoFileClip(input_video)
373
+
374
+ # Load the audio
375
+ audio_clip = AudioFileClip(audio_path)
376
+
377
+ # Set the audio to the video
378
+ video_clip = video_clip.set_audio(audio_clip)
379
+
380
+ # Write the output video
381
+ video_clip.write_videofile(output_vid_name, codec='libx264', audio_codec='aac',fps=25)
382
+
383
+ os.remove("temp.mp4")
384
+ #shutil.rmtree(result_img_save_path)
385
+ print(f"result is save to {output_vid_name}")
386
+ return output_vid_name,bbox_shift_text
387
+
388
+
389
+
390
+ # load model weights
391
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
392
+ vae, unet, pe = load_all_model(
393
+ unet_model_path="./models/musetalkV15/unet.pth",
394
+ vae_type="sd-vae",
395
+ unet_config="./models/musetalkV15/musetalk.json",
396
+ device=device
397
+ )
398
+
399
+ # Parse command line arguments
400
+ parser = argparse.ArgumentParser()
401
+ parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
402
+ parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
403
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
404
+ parser.add_argument("--share", action="store_true", help="Create a public link")
405
+ parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
406
+ args = parser.parse_args()
407
+
408
+ # Set data type
409
+ if args.use_float16:
410
+ # Convert models to half precision for better performance
411
+ pe = pe.half()
412
+ vae.vae = vae.vae.half()
413
+ unet.model = unet.model.half()
414
+ weight_dtype = torch.float16
415
+ else:
416
+ weight_dtype = torch.float32
417
+
418
+ # Move models to specified device
419
+ pe = pe.to(device)
420
+ vae.vae = vae.vae.to(device)
421
+ unet.model = unet.model.to(device)
422
+
423
+ timesteps = torch.tensor([0], device=device)
424
+
425
+ # Initialize audio processor and Whisper model
426
+ audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
427
+ whisper = WhisperModel.from_pretrained("./models/whisper")
428
+ whisper = whisper.to(device=device, dtype=weight_dtype).eval()
429
+ whisper.requires_grad_(False)
430
+
431
+
432
+ def check_video(video):
433
+ if not isinstance(video, str):
434
+ return video # in case of none type
435
+ # Define the output video file name
436
+ dir_path, file_name = os.path.split(video)
437
+ if file_name.startswith("outputxxx_"):
438
+ return video
439
+ # Add the output prefix to the file name
440
+ output_file_name = "outputxxx_" + file_name
441
+
442
+ os.makedirs('./results',exist_ok=True)
443
+ os.makedirs('./results/output',exist_ok=True)
444
+ os.makedirs('./results/input',exist_ok=True)
445
+
446
+ # Combine the directory path and the new file name
447
+ output_video = os.path.join('./results/input', output_file_name)
448
+
449
+
450
+ # read video
451
+ reader = imageio.get_reader(video)
452
+ fps = reader.get_meta_data()['fps'] # get fps from original video
453
+
454
+ # conver fps to 25
455
+ frames = [im for im in reader]
456
+ target_fps = 25
457
+
458
+ L = len(frames)
459
+ L_target = int(L / fps * target_fps)
460
+ original_t = [x / fps for x in range(1, L+1)]
461
+ t_idx = 0
462
+ target_frames = []
463
+ for target_t in range(1, L_target+1):
464
+ while target_t / target_fps > original_t[t_idx]:
465
+ t_idx += 1 # find the first t_idx so that target_t / target_fps <= original_t[t_idx]
466
+ if t_idx >= L:
467
+ break
468
+ target_frames.append(frames[t_idx])
469
+
470
+ # save video
471
+ imageio.mimwrite(output_video, target_frames, 'FFMPEG', fps=25, codec='libx264', quality=9, pixelformat='yuv420p')
472
+ return output_video
473
+
474
+
475
+
476
+
477
+ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
478
+
479
+ with gr.Blocks(css=css) as demo:
480
+ gr.Markdown(
481
+ """<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
482
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
483
+ </br>\
484
+ Yue Zhang <sup>*</sup>,\
485
+ Zhizhou Zhong <sup>*</sup>,\
486
+ Minhao Liu<sup>*</sup>,\
487
+ Zhaokang Chen,\
488
+ Bin Wu<sup>†</sup>,\
489
+ Yubin Zeng,\
490
+ Chao Zhang,\
491
+ Yingjie He,\
492
+ Junxin Huang,\
493
+ Wenjiang Zhou <br>\
494
+ (<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, [email protected])\
495
+ Lyra Lab, Tencent Music Entertainment\
496
+ </h2> \
497
+ <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
498
+ <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
499
+ <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
500
+ )
501
+
502
+ with gr.Row():
503
+ with gr.Column():
504
+ audio = gr.Audio(label="Drving Audio",type="filepath")
505
+ video = gr.Video(label="Reference Video",sources=['upload'])
506
+ bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
507
+ extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
508
+ parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
509
+ left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
510
+ right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
511
+ bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
512
+
513
+ with gr.Row():
514
+ debug_btn = gr.Button("1. Test Inpainting ")
515
+ btn = gr.Button("2. Generate")
516
+ with gr.Column():
517
+ debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
518
+ debug_info = gr.Textbox(label="Parameter Information", lines=5)
519
+ out1 = gr.Video()
520
+
521
+ video.change(
522
+ fn=check_video, inputs=[video], outputs=[video]
523
+ )
524
+ btn.click(
525
+ fn=inference,
526
+ inputs=[
527
+ audio,
528
+ video,
529
+ bbox_shift,
530
+ extra_margin,
531
+ parsing_mode,
532
+ left_cheek_width,
533
+ right_cheek_width
534
+ ],
535
+ outputs=[out1,bbox_shift_scale]
536
+ )
537
+ debug_btn.click(
538
+ fn=debug_inpainting,
539
+ inputs=[
540
+ video,
541
+ bbox_shift,
542
+ extra_margin,
543
+ parsing_mode,
544
+ left_cheek_width,
545
+ right_cheek_width
546
+ ],
547
+ outputs=[debug_image, debug_info]
548
+ )
549
+
550
+ # Check ffmpeg and add to PATH
551
+ if not fast_check_ffmpeg():
552
+ print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
553
+ # According to operating system, choose path separator
554
+ path_separator = ';' if sys.platform == 'win32' else ':'
555
+ os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
556
+ if not fast_check_ffmpeg():
557
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
558
+
559
+ # Solve asynchronous IO issues on Windows
560
+ if sys.platform == 'win32':
561
+ import asyncio
562
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
563
+
564
+ # Start Gradio application
565
+ demo.queue().launch(
566
+ share=args.share,
567
+ debug=True,
568
+ server_name=args.ip,
569
+ server_port=args.port
570
+ )
assets/BBOX_SHIFT.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Why is there a "bbox_shift" parameter?
2
+ When processing training data, we utilize the combination of face detection results (bbox) and facial landmarks to determine the region of the head segmentation box. Specifically, we use the upper bound of the bbox as the upper boundary of the segmentation box, the maximum y value of the facial landmarks coordinates as the lower boundary of the segmentation box, and the minimum and maximum x values of the landmarks coordinates as the left and right boundaries of the segmentation box. By processing the dataset in this way, we can ensure the integrity of the face.
3
+
4
+ However, we have observed that the masked ratio on the face varies across different images due to the varying face shapes of subjects. Furthermore, we found that the upper-bound of the mask mainly lies close to the landmark28, landmark29 and landmark30 landmark points (as shown in Fig.1), which correspond to proportions of 15%, 63%, and 22% in the dataset, respectively.
5
+
6
+ During the inference process, we discover that as the upper-bound of the mask gets closer to the mouth (near landmark30), the audio features contribute more to lip movements. Conversely, as the upper-bound of the mask moves away from the mouth (near landmark28), the audio features contribute more to generating details of facial appearance. Hence, we define this characteristic as a parameter that can adjust the contribution of audio features to generating lip movements, which users can modify according to their specific needs in practical scenarios.
7
+
8
+ ![landmark](figs/landmark_ref.png)
9
+
10
+ Fig.1. Facial landmarks
11
+ ### Step 0.
12
+ Running with the default configuration to obtain the adjustable value range.
13
+ ```
14
+ python -m scripts.inference --inference_config configs/inference/test.yaml
15
+ ```
16
+ ```
17
+ ********************************************bbox_shift parameter adjustment**********************************************************
18
+ Total frame:「838」 Manually adjust range : [ -9~9 ] , the current value: 0
19
+ *************************************************************************************************************************************
20
+ ```
21
+ ### Step 1.
22
+ Re-run the script within the above range.
23
+ ```
24
+ python -m scripts.inference --inference_config configs/inference/test.yaml --bbox_shift xx # where xx is in [-9, 9].
25
+ ```
26
+ In our experimental observations, we found that positive values (moving towards the lower half) generally increase mouth openness, while negative values (moving towards the upper half) generally decrease mouth openness. However, it's important to note that this is not an absolute rule, and users may need to adjust the parameter according to their specific needs and the desired effect.
assets/demo/sit/sit.jpeg ADDED
assets/demo/yongen/yongen.jpeg ADDED
configs/inference/realtime.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ avator_1:
2
+ preparation: True # your can set it to False if you want to use the existing avator, it will save time
3
+ bbox_shift: 5
4
+ video_path: "data/video/yongen.mp4"
5
+ audio_clips:
6
+ audio_0: "data/audio/yongen.wav"
7
+ audio_1: "data/audio/eng.wav"
8
+
9
+
10
+
configs/inference/test.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ task_0:
2
+ video_path: "data/video/yongen.mp4"
3
+ audio_path: "data/audio/yongen.wav"
4
+
5
+ task_1:
6
+ video_path: "data/video/yongen.mp4"
7
+ audio_path: "data/audio/eng.wav"
8
+ bbox_shift: -7
9
+
10
+
configs/training/gpu.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: True
3
+ deepspeed_config:
4
+ offload_optimizer_device: none
5
+ offload_param_device: none
6
+ zero3_init_flag: False
7
+ zero_stage: 2
8
+
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ gpu_ids: "5, 7" # modify this according to your GPU number
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ num_machines: 1
15
+ num_processes: 2 # it should be the same as the number of GPUs
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
configs/training/preprocess.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip_len_second: 30 # the length of the video clip
2
+ video_root_raw: "./dataset/HDTF/source/" # the path of the original video
3
+ val_list_hdtf:
4
+ - RD_Radio7_000
5
+ - RD_Radio8_000
6
+ - RD_Radio9_000
7
+ - WDA_TinaSmith_000
8
+ - WDA_TomCarper_000
9
+ - WDA_TomPerez_000
10
+ - WDA_TomUdall_000
11
+ - WDA_VeronicaEscobar0_000
12
+ - WDA_VeronicaEscobar1_000
13
+ - WDA_WhipJimClyburn_000
14
+ - WDA_XavierBecerra_000
15
+ - WDA_XavierBecerra_001
16
+ - WDA_XavierBecerra_002
17
+ - WDA_ZoeLofgren_000
18
+ - WRA_SteveScalise1_000
19
+ - WRA_TimScott_000
20
+ - WRA_ToddYoung_000
21
+ - WRA_TomCotton_000
22
+ - WRA_TomPrice_000
23
+ - WRA_VickyHartzler_000
24
+
25
+ # following dir will be automatically generated
26
+ video_root_25fps: "./dataset/HDTF/video_root_25fps/"
27
+ video_file_list: "./dataset/HDTF/video_file_list.txt"
28
+ video_audio_clip_root: "./dataset/HDTF/video_audio_clip_root/"
29
+ meta_root: "./dataset/HDTF/meta/"
30
+ video_clip_file_list_train: "./dataset/HDTF/train.txt"
31
+ video_clip_file_list_val: "./dataset/HDTF/val.txt"
configs/training/stage1.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: 'test' # Name of the experiment
2
+ output_dir: './exp_out/stage1/' # Directory to save experiment outputs
3
+ unet_sub_folder: musetalk # Subfolder name for UNet model
4
+ random_init_unet: True # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
5
+ whisper_path: "./models/whisper" # Path to the Whisper model
6
+ pretrained_model_name_or_path: "./models" # Path to pretrained models
7
+ resume_from_checkpoint: True # Whether to resume training from a checkpoint
8
+ padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
9
+ vae_type: "sd-vae" # Type of VAE model to use
10
+ # Validation parameters
11
+ num_images_to_keep: 8 # Number of validation images to keep
12
+ ref_dropout_rate: 0 # Dropout rate for reference images
13
+ syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
14
+ use_adapted_weight: False # Whether to use adapted weights for loss calculation
15
+ cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
16
+ cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
17
+ crop_type: "crop_resize" # Type of cropping method
18
+ random_margin_method: "normal" # Method for random margin generation
19
+ num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
20
+
21
+ data:
22
+ dataset_key: "HDTF" # Dataset to use for training
23
+ train_bs: 32 # Training batch size (actual batch size is train_bs*n_sample_frames)
24
+ image_size: 256 # Size of input images
25
+ n_sample_frames: 1 # Number of frames to sample per batch
26
+ num_workers: 8 # Number of data loading workers
27
+ audio_padding_length_left: 2 # Left padding length for audio features
28
+ audio_padding_length_right: 2 # Right padding length for audio features
29
+ sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
30
+ top_k_ratio: 0.51 # Ratio for top-k sampling
31
+ contorl_face_min_size: True # Whether to control minimum face size
32
+ min_face_size: 150 # Minimum face size in pixels
33
+
34
+ loss_params:
35
+ l1_loss: 1.0 # Weight for L1 loss
36
+ vgg_loss: 0.01 # Weight for VGG perceptual loss
37
+ vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
38
+ pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
39
+ gan_loss: 0 # Weight for GAN loss
40
+ fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
41
+ sync_loss: 0 # Weight for sync loss
42
+ mouth_gan_loss: 0 # Weight for mouth-specific GAN loss
43
+
44
+ model_params:
45
+ discriminator_params:
46
+ scales: [1] # Scales for discriminator
47
+ block_expansion: 32 # Expansion factor for discriminator blocks
48
+ max_features: 512 # Maximum number of features in discriminator
49
+ num_blocks: 4 # Number of blocks in discriminator
50
+ sn: True # Whether to use spectral normalization
51
+ image_channel: 3 # Number of image channels
52
+ estimate_jacobian: False # Whether to estimate Jacobian
53
+
54
+ discriminator_train_params:
55
+ lr: 0.000005 # Learning rate for discriminator
56
+ eps: 0.00000001 # Epsilon for optimizer
57
+ weight_decay: 0.01 # Weight decay for optimizer
58
+ patch_size: 1 # Size of patches for discriminator
59
+ betas: [0.5, 0.999] # Beta parameters for Adam optimizer
60
+ epochs: 10000 # Number of training epochs
61
+ start_gan: 1000 # Step to start GAN training
62
+
63
+ solver:
64
+ gradient_accumulation_steps: 1 # Number of steps for gradient accumulation
65
+ uncond_steps: 10 # Number of unconditional steps
66
+ mixed_precision: 'fp32' # Precision mode for training
67
+ enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
68
+ gradient_checkpointing: True # Whether to use gradient checkpointing
69
+ max_train_steps: 250000 # Maximum number of training steps
70
+ max_grad_norm: 1.0 # Maximum gradient norm for clipping
71
+ # Learning rate parameters
72
+ learning_rate: 2.0e-5 # Base learning rate
73
+ scale_lr: False # Whether to scale learning rate
74
+ lr_warmup_steps: 1000 # Number of warmup steps for learning rate
75
+ lr_scheduler: "linear" # Type of learning rate scheduler
76
+ # Optimizer parameters
77
+ use_8bit_adam: False # Whether to use 8-bit Adam optimizer
78
+ adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
79
+ adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
80
+ adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
81
+ adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
82
+
83
+ total_limit: 10 # Maximum number of checkpoints to keep
84
+ save_model_epoch_interval: 250000 # Interval between model saves
85
+ checkpointing_steps: 10000 # Number of steps between checkpoints
86
+ val_freq: 2000 # Frequency of validation
87
+
88
+ seed: 41 # Random seed for reproducibility
89
+
configs/training/stage2.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: 'test' # Name of the experiment
2
+ output_dir: './exp_out/stage2/' # Directory to save experiment outputs
3
+ unet_sub_folder: musetalk # Subfolder name for UNet model
4
+ random_init_unet: False # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
5
+ whisper_path: "./models/whisper" # Path to the Whisper model
6
+ pretrained_model_name_or_path: "./models" # Path to pretrained models
7
+ resume_from_checkpoint: True # Whether to resume training from a checkpoint
8
+ padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
9
+ vae_type: "sd-vae" # Type of VAE model to use
10
+ # Validation parameters
11
+ num_images_to_keep: 8 # Number of validation images to keep
12
+ ref_dropout_rate: 0 # Dropout rate for reference images
13
+ syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
14
+ use_adapted_weight: False # Whether to use adapted weights for loss calculation
15
+ cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
16
+ cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
17
+ crop_type: "dynamic_margin_crop_resize" # Type of cropping method
18
+ random_margin_method: "normal" # Method for random margin generation
19
+ num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
20
+
21
+ data:
22
+ dataset_key: "HDTF" # Dataset to use for training
23
+ train_bs: 2 # Training batch size (actual batch size is train_bs*n_sample_frames)
24
+ image_size: 256 # Size of input images
25
+ n_sample_frames: 16 # Number of frames to sample per batch
26
+ num_workers: 8 # Number of data loading workers
27
+ audio_padding_length_left: 2 # Left padding length for audio features
28
+ audio_padding_length_right: 2 # Right padding length for audio features
29
+ sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
30
+ top_k_ratio: 0.51 # Ratio for top-k sampling
31
+ contorl_face_min_size: True # Whether to control minimum face size
32
+ min_face_size: 200 # Minimum face size in pixels
33
+
34
+ loss_params:
35
+ l1_loss: 1.0 # Weight for L1 loss
36
+ vgg_loss: 0.01 # Weight for VGG perceptual loss
37
+ vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
38
+ pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
39
+ gan_loss: 0.01 # Weight for GAN loss
40
+ fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
41
+ sync_loss: 0.05 # Weight for sync loss
42
+ mouth_gan_loss: 0.01 # Weight for mouth-specific GAN loss
43
+
44
+ model_params:
45
+ discriminator_params:
46
+ scales: [1] # Scales for discriminator
47
+ block_expansion: 32 # Expansion factor for discriminator blocks
48
+ max_features: 512 # Maximum number of features in discriminator
49
+ num_blocks: 4 # Number of blocks in discriminator
50
+ sn: True # Whether to use spectral normalization
51
+ image_channel: 3 # Number of image channels
52
+ estimate_jacobian: False # Whether to estimate Jacobian
53
+
54
+ discriminator_train_params:
55
+ lr: 0.000005 # Learning rate for discriminator
56
+ eps: 0.00000001 # Epsilon for optimizer
57
+ weight_decay: 0.01 # Weight decay for optimizer
58
+ patch_size: 1 # Size of patches for discriminator
59
+ betas: [0.5, 0.999] # Beta parameters for Adam optimizer
60
+ epochs: 10000 # Number of training epochs
61
+ start_gan: 1000 # Step to start GAN training
62
+
63
+ solver:
64
+ gradient_accumulation_steps: 8 # Number of steps for gradient accumulation
65
+ uncond_steps: 10 # Number of unconditional steps
66
+ mixed_precision: 'fp32' # Precision mode for training
67
+ enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
68
+ gradient_checkpointing: True # Whether to use gradient checkpointing
69
+ max_train_steps: 250000 # Maximum number of training steps
70
+ max_grad_norm: 1.0 # Maximum gradient norm for clipping
71
+ # Learning rate parameters
72
+ learning_rate: 5.0e-6 # Base learning rate
73
+ scale_lr: False # Whether to scale learning rate
74
+ lr_warmup_steps: 1000 # Number of warmup steps for learning rate
75
+ lr_scheduler: "linear" # Type of learning rate scheduler
76
+ # Optimizer parameters
77
+ use_8bit_adam: False # Whether to use 8-bit Adam optimizer
78
+ adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
79
+ adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
80
+ adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
81
+ adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
82
+
83
+ total_limit: 10 # Maximum number of checkpoints to keep
84
+ save_model_epoch_interval: 250000 # Interval between model saves
85
+ checkpointing_steps: 2000 # Number of steps between checkpoints
86
+ val_freq: 2000 # Frequency of validation
87
+
88
+ seed: 41 # Random seed for reproducibility
89
+
configs/training/syncnet.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/configs/training/syncnet_16_pixel.yaml).
2
+ model:
3
+ audio_encoder: # input (1, 80, 52)
4
+ in_channels: 1
5
+ block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
6
+ downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
7
+ attn_blocks: [0, 0, 0, 0, 0, 0, 0]
8
+ dropout: 0.0
9
+ visual_encoder: # input (48, 128, 256)
10
+ in_channels: 48
11
+ block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
12
+ downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
13
+ attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
14
+ dropout: 0.0
15
+
16
+ ckpt:
17
+ resume_ckpt_path: ""
18
+ inference_ckpt_path: ./models/syncnet/latentsync_syncnet.pt # this pretrained model is from LatentSync (https://huggingface.co/ByteDance/LatentSync/tree/main)
19
+ save_ckpt_steps: 2500
download_weights.bat ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal
3
+
4
+ :: Set the checkpoints directory
5
+ set CheckpointsDir=models
6
+
7
+ :: Create necessary directories
8
+ mkdir %CheckpointsDir%\musetalk
9
+ mkdir %CheckpointsDir%\musetalkV15
10
+ mkdir %CheckpointsDir%\syncnet
11
+ mkdir %CheckpointsDir%\dwpose
12
+ mkdir %CheckpointsDir%\face-parse-bisent
13
+ mkdir %CheckpointsDir%\sd-vae-ft-mse
14
+ mkdir %CheckpointsDir%\whisper
15
+
16
+ :: Install required packages
17
+ pip install -U "huggingface_hub[cli]"
18
+ pip install gdown
19
+
20
+ :: Set HuggingFace endpoint
21
+ set HF_ENDPOINT=https://hf-mirror.com
22
+
23
+ :: Download MuseTalk weights
24
+ huggingface-cli download TMElyralab/MuseTalk --local-dir %CheckpointsDir%
25
+
26
+ :: Download SD VAE weights
27
+ huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir %CheckpointsDir%\sd-vae --include "config.json" "diffusion_pytorch_model.bin"
28
+
29
+ :: Download Whisper weights
30
+ huggingface-cli download openai/whisper-tiny --local-dir %CheckpointsDir%\whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
31
+
32
+ :: Download DWPose weights
33
+ huggingface-cli download yzd-v/DWPose --local-dir %CheckpointsDir%\dwpose --include "dw-ll_ucoco_384.pth"
34
+
35
+ :: Download SyncNet weights
36
+ huggingface-cli download ByteDance/LatentSync --local-dir %CheckpointsDir%\syncnet --include "latentsync_syncnet.pt"
37
+
38
+ :: Download Face Parse Bisent weights (using gdown)
39
+ gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O %CheckpointsDir%\face-parse-bisent\79999_iter.pth
40
+
41
+ :: Download ResNet weights
42
+ curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o %CheckpointsDir%\face-parse-bisent\resnet18-5c106cde.pth
43
+
44
+ echo All weights have been downloaded successfully!
45
+ endlocal
download_weights.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set the checkpoints directory
4
+ CheckpointsDir="models"
5
+
6
+ # Create necessary directories
7
+ mkdir -p models/musetalk models/musetalkV15 models/syncnet models/dwpose models/face-parse-bisent models/sd-vae models/whisper
8
+
9
+ # Install required packages
10
+ pip install -U "huggingface_hub[cli]"
11
+ pip install gdown
12
+
13
+ # Set HuggingFace mirror endpoint
14
+ export HF_ENDPOINT=https://hf-mirror.com
15
+
16
+ # Download MuseTalk V1.0 weights
17
+ huggingface-cli download TMElyralab/MuseTalk \
18
+ --local-dir $CheckpointsDir \
19
+ --include "musetalk/musetalk.json" "musetalk/pytorch_model.bin"
20
+
21
+ # Download MuseTalk V1.5 weights (unet.pth)
22
+ huggingface-cli download TMElyralab/MuseTalk \
23
+ --local-dir $CheckpointsDir \
24
+ --include "musetalkV15/musetalk.json" "musetalkV15/unet.pth"
25
+
26
+ # Download SD VAE weights
27
+ huggingface-cli download stabilityai/sd-vae-ft-mse \
28
+ --local-dir $CheckpointsDir/sd-vae \
29
+ --include "config.json" "diffusion_pytorch_model.bin"
30
+
31
+ # Download Whisper weights
32
+ huggingface-cli download openai/whisper-tiny \
33
+ --local-dir $CheckpointsDir/whisper \
34
+ --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
35
+
36
+ # Download DWPose weights
37
+ huggingface-cli download yzd-v/DWPose \
38
+ --local-dir $CheckpointsDir/dwpose \
39
+ --include "dw-ll_ucoco_384.pth"
40
+
41
+ # Download SyncNet weights
42
+ huggingface-cli download ByteDance/LatentSync \
43
+ --local-dir $CheckpointsDir/syncnet \
44
+ --include "latentsync_syncnet.pt"
45
+
46
+ # Download Face Parse Bisent weights
47
+ gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
48
+ curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth \
49
+ -o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
50
+
51
+ echo "✅ All weights have been downloaded successfully!"
entrypoint.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "entrypoint.sh"
4
+ whoami
5
+ which python
6
+ source /opt/conda/etc/profile.d/conda.sh
7
+ conda activate musev
8
+ which python
9
+ python app.py
inference.sh ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script runs inference based on the version and mode specified by the user.
4
+ # Usage:
5
+ # To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
6
+ # To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
7
+
8
+ # Check if the correct number of arguments is provided
9
+ if [ "$#" -ne 2 ]; then
10
+ echo "Usage: $0 <version> <mode>"
11
+ echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
12
+ exit 1
13
+ fi
14
+
15
+ # Get the version and mode from the user input
16
+ version=$1
17
+ mode=$2
18
+
19
+ # Validate mode
20
+ if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
21
+ echo "Invalid mode specified. Please use 'normal' or 'realtime'."
22
+ exit 1
23
+ fi
24
+
25
+ # Set config path based on mode
26
+ if [ "$mode" = "normal" ]; then
27
+ config_path="./configs/inference/test.yaml"
28
+ result_dir="./results/test"
29
+ else
30
+ config_path="./configs/inference/realtime.yaml"
31
+ result_dir="./results/realtime"
32
+ fi
33
+
34
+ # Define the model paths based on the version
35
+ if [ "$version" = "v1.0" ]; then
36
+ model_dir="./models/musetalk"
37
+ unet_model_path="$model_dir/pytorch_model.bin"
38
+ unet_config="$model_dir/musetalk.json"
39
+ version_arg="v1"
40
+ elif [ "$version" = "v1.5" ]; then
41
+ model_dir="./models/musetalkV15"
42
+ unet_model_path="$model_dir/unet.pth"
43
+ unet_config="$model_dir/musetalk.json"
44
+ version_arg="v15"
45
+ else
46
+ echo "Invalid version specified. Please use v1.0 or v1.5."
47
+ exit 1
48
+ fi
49
+
50
+ # Set script name based on mode
51
+ if [ "$mode" = "normal" ]; then
52
+ script_name="scripts.inference"
53
+ else
54
+ script_name="scripts.realtime_inference"
55
+ fi
56
+
57
+ # Base command arguments
58
+ cmd_args="--inference_config $config_path \
59
+ --result_dir $result_dir \
60
+ --unet_model_path $unet_model_path \
61
+ --unet_config $unet_config \
62
+ --version $version_arg"
63
+
64
+ # Add realtime-specific arguments if in realtime mode
65
+ if [ "$mode" = "realtime" ]; then
66
+ cmd_args="$cmd_args \
67
+ --fps 25 \
68
+ --version $version_arg"
69
+ fi
70
+
71
+ # Run inference
72
+ python3 -m $script_name $cmd_args
musetalk/data/audio.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+
7
+ class HParams:
8
+ # copy from wav2lip
9
+ def __init__(self):
10
+ self.n_fft = 800
11
+ self.hop_size = 200
12
+ self.win_size = 800
13
+ self.sample_rate = 16000
14
+ self.frame_shift_ms = None
15
+ self.signal_normalization = True
16
+
17
+ self.allow_clipping_in_normalization = True
18
+ self.symmetric_mels = True
19
+ self.max_abs_value = 4.0
20
+ self.preemphasize = True
21
+ self.preemphasis = 0.97
22
+ self.min_level_db = -100
23
+ self.ref_level_db = 20
24
+ self.fmin = 55
25
+ self.fmax=7600
26
+
27
+ self.use_lws=False
28
+ self.num_mels=80 # Number of mel-spectrogram channels and local conditioning dimensionality
29
+ self.rescale=True # Whether to rescale audio prior to preprocessing
30
+ self.rescaling_max=0.9 # Rescaling value
31
+ self.use_lws=False
32
+
33
+
34
+ hp = HParams()
35
+
36
+ def load_wav(path, sr):
37
+ return librosa.core.load(path, sr=sr)[0]
38
+ #def load_wav(path, sr):
39
+ # audio, sr_native = sf.read(path)
40
+ # if sr != sr_native:
41
+ # audio = librosa.resample(audio.T, sr_native, sr).T
42
+ # return audio
43
+
44
+ def save_wav(wav, path, sr):
45
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
46
+ #proposed by @dsmiller
47
+ wavfile.write(path, sr, wav.astype(np.int16))
48
+
49
+ def save_wavenet_wav(wav, path, sr):
50
+ librosa.output.write_wav(path, wav, sr=sr)
51
+
52
+ def preemphasis(wav, k, preemphasize=True):
53
+ if preemphasize:
54
+ return signal.lfilter([1, -k], [1], wav)
55
+ return wav
56
+
57
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
58
+ if inv_preemphasize:
59
+ return signal.lfilter([1], [1, -k], wav)
60
+ return wav
61
+
62
+ def get_hop_size():
63
+ hop_size = hp.hop_size
64
+ if hop_size is None:
65
+ assert hp.frame_shift_ms is not None
66
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
67
+ return hop_size
68
+
69
+ def linearspectrogram(wav):
70
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
71
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
72
+
73
+ if hp.signal_normalization:
74
+ return _normalize(S)
75
+ return S
76
+
77
+ def melspectrogram(wav):
78
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
79
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
80
+
81
+ if hp.signal_normalization:
82
+ return _normalize(S)
83
+ return S
84
+
85
+ def _lws_processor():
86
+ import lws
87
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
88
+
89
+ def _stft(y):
90
+ if hp.use_lws:
91
+ return _lws_processor(hp).stft(y).T
92
+ else:
93
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
94
+
95
+ ##########################################################
96
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
97
+ def num_frames(length, fsize, fshift):
98
+ """Compute number of time frames of spectrogram
99
+ """
100
+ pad = (fsize - fshift)
101
+ if length % fshift == 0:
102
+ M = (length + pad * 2 - fsize) // fshift + 1
103
+ else:
104
+ M = (length + pad * 2 - fsize) // fshift + 2
105
+ return M
106
+
107
+
108
+ def pad_lr(x, fsize, fshift):
109
+ """Compute left and right padding
110
+ """
111
+ M = num_frames(len(x), fsize, fshift)
112
+ pad = (fsize - fshift)
113
+ T = len(x) + 2 * pad
114
+ r = (M - 1) * fshift + fsize - T
115
+ return pad, pad + r
116
+ ##########################################################
117
+ #Librosa correct padding
118
+ def librosa_pad_lr(x, fsize, fshift):
119
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
120
+
121
+ # Conversions
122
+ _mel_basis = None
123
+
124
+ def _linear_to_mel(spectogram):
125
+ global _mel_basis
126
+ if _mel_basis is None:
127
+ _mel_basis = _build_mel_basis()
128
+ return np.dot(_mel_basis, spectogram)
129
+
130
+ def _build_mel_basis():
131
+ assert hp.fmax <= hp.sample_rate // 2
132
+ return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
133
+ fmin=hp.fmin, fmax=hp.fmax)
134
+
135
+ def _amp_to_db(x):
136
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
137
+ return 20 * np.log10(np.maximum(min_level, x))
138
+
139
+ def _db_to_amp(x):
140
+ return np.power(10.0, (x) * 0.05)
141
+
142
+ def _normalize(S):
143
+ if hp.allow_clipping_in_normalization:
144
+ if hp.symmetric_mels:
145
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
146
+ -hp.max_abs_value, hp.max_abs_value)
147
+ else:
148
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
149
+
150
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
151
+ if hp.symmetric_mels:
152
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
153
+ else:
154
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
155
+
156
+ def _denormalize(D):
157
+ if hp.allow_clipping_in_normalization:
158
+ if hp.symmetric_mels:
159
+ return (((np.clip(D, -hp.max_abs_value,
160
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
161
+ + hp.min_level_db)
162
+ else:
163
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
164
+
165
+ if hp.symmetric_mels:
166
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
167
+ else:
168
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
musetalk/data/dataset.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+ from PIL import Image
5
+ import torch
6
+ from torch.utils.data import Dataset, ConcatDataset
7
+ import torchvision.transforms as transforms
8
+ from transformers import AutoFeatureExtractor
9
+ import librosa
10
+ import time
11
+ import json
12
+ import math
13
+ from decord import AudioReader, VideoReader
14
+ from decord.ndarray import cpu
15
+
16
+ from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark
17
+ from musetalk.data import audio
18
+
19
+ syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync
20
+
21
+
22
+ class FaceDataset(Dataset):
23
+ """Dataset class for loading and processing video data
24
+
25
+ Each video can be represented as:
26
+ - Concatenated frame images
27
+ - '.mp4' or '.gif' files
28
+ - Folder containing all frames
29
+ """
30
+ def __init__(self,
31
+ cfg,
32
+ list_paths,
33
+ root_path='./dataset/',
34
+ repeats=None):
35
+ # Initialize dataset paths
36
+ meta_paths = []
37
+ if repeats is None:
38
+ repeats = [1] * len(list_paths)
39
+ assert len(repeats) == len(list_paths)
40
+
41
+ # Load data list
42
+ for list_path, repeat_time in zip(list_paths, repeats):
43
+ with open(list_path, 'r') as f:
44
+ num = 0
45
+ f.readline() # Skip header line
46
+ for line in f.readlines():
47
+ line_info = line.strip()
48
+ meta = line_info.split()
49
+ meta = meta[0]
50
+ meta_paths.extend([os.path.join(root_path, meta)] * repeat_time)
51
+ num += 1
52
+ print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples')
53
+
54
+ # Set basic attributes
55
+ self.meta_paths = meta_paths
56
+ self.root_path = root_path
57
+ self.image_size = cfg['image_size']
58
+ self.min_face_size = cfg['min_face_size']
59
+ self.T = cfg['T']
60
+ self.sample_method = cfg['sample_method']
61
+ self.top_k_ratio = cfg['top_k_ratio']
62
+ self.max_attempts = 200
63
+ self.padding_pixel_mouth = cfg['padding_pixel_mouth']
64
+
65
+ # Cropping related parameters
66
+ self.crop_type = cfg['crop_type']
67
+ self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean']
68
+ self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std']
69
+ self.random_margin_method = cfg['random_margin_method']
70
+
71
+ # Image transformations
72
+ self.to_tensor = transforms.Compose([
73
+ transforms.ToTensor(),
74
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
75
+ ])
76
+ self.pose_to_tensor = transforms.Compose([
77
+ transforms.ToTensor(),
78
+ ])
79
+
80
+ # Feature extractor
81
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path'])
82
+ self.contorl_face_min_size = cfg["contorl_face_min_size"]
83
+
84
+ print("The sample method is: ", self.sample_method)
85
+ print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size)
86
+
87
+ def generate_random_value(self):
88
+ """Generate random value
89
+
90
+ Returns:
91
+ float: Generated random value
92
+ """
93
+ if self.random_margin_method == "uniform":
94
+ random_value = np.random.uniform(
95
+ self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
96
+ self.jaw2edge_margin_mean + self.jaw2edge_margin_std
97
+ )
98
+ elif self.random_margin_method == "normal":
99
+ random_value = np.random.normal(
100
+ loc=self.jaw2edge_margin_mean,
101
+ scale=self.jaw2edge_margin_std
102
+ )
103
+ random_value = np.clip(
104
+ random_value,
105
+ self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
106
+ self.jaw2edge_margin_mean + self.jaw2edge_margin_std,
107
+ )
108
+ else:
109
+ raise ValueError(f"Invalid random margin method: {self.random_margin_method}")
110
+ return max(0, random_value)
111
+
112
+ def dynamic_margin_crop(self, img, original_bbox, extra_margin=None):
113
+ """Dynamically crop image with dynamic margin
114
+
115
+ Args:
116
+ img: Input image
117
+ original_bbox: Original bounding box
118
+ extra_margin: Extra margin
119
+
120
+ Returns:
121
+ tuple: (x1, y1, x2, y2, extra_margin)
122
+ """
123
+ if extra_margin is None:
124
+ extra_margin = self.generate_random_value()
125
+ w, h = img.size
126
+ x1, y1, x2, y2 = original_bbox
127
+ y2 = min(y2 + int(extra_margin), h)
128
+ return x1, y1, x2, y2, extra_margin
129
+
130
+ def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None):
131
+ """Crop and resize image
132
+
133
+ Args:
134
+ img: Input image
135
+ bbox: Bounding box
136
+ crop_type: Type of cropping
137
+ extra_margin: Extra margin
138
+
139
+ Returns:
140
+ tuple: (Processed image, extra_margin, mask_scaled_factor)
141
+ """
142
+ mask_scaled_factor = 1.
143
+ if crop_type == 'crop_resize':
144
+ x1, y1, x2, y2 = bbox
145
+ img = img.crop((x1, y1, x2, y2))
146
+ img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
147
+ elif crop_type == 'dynamic_margin_crop_resize':
148
+ x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin)
149
+ w_original, _ = img.size
150
+ img = img.crop((x1, y1, x2, y2))
151
+ w_cropped, _ = img.size
152
+ mask_scaled_factor = w_cropped / w_original
153
+ img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
154
+ elif crop_type == 'resize':
155
+ w, h = img.size
156
+ scale = np.sqrt(self.image_size ** 2 / (h * w))
157
+ new_w = int(w * scale) / 64 * 64
158
+ new_h = int(h * scale) / 64 * 64
159
+ img = img.resize((new_w, new_h), Image.LANCZOS)
160
+ return img, extra_margin, mask_scaled_factor
161
+
162
+ def get_audio_file(self, wav_path, start_index):
163
+ """Get audio file features
164
+
165
+ Args:
166
+ wav_path: Audio file path
167
+ start_index: Starting index
168
+
169
+ Returns:
170
+ tuple: (Audio features, start index)
171
+ """
172
+ if not os.path.exists(wav_path):
173
+ return None
174
+ audio_input_librosa, sampling_rate = librosa.load(wav_path, sr=16000)
175
+ assert sampling_rate == 16000
176
+
177
+ while start_index >= 25 * 30:
178
+ audio_input = audio_input_librosa[16000*30:]
179
+ start_index -= 25 * 30
180
+ if start_index + 2 * 25 >= 25 * 30:
181
+ start_index -= 4 * 25
182
+ audio_input = audio_input_librosa[16000*4:16000*34]
183
+ else:
184
+ audio_input = audio_input_librosa[:16000*30]
185
+
186
+ assert 2 * (start_index) >= 0
187
+ assert 2 * (start_index + 2 * 25) <= 1500
188
+
189
+ audio_input = self.feature_extractor(
190
+ audio_input,
191
+ return_tensors="pt",
192
+ sampling_rate=sampling_rate
193
+ ).input_features
194
+ return audio_input, start_index
195
+
196
+ def get_audio_file_mel(self, wav_path, start_index):
197
+ """Get mel spectrogram of audio file
198
+
199
+ Args:
200
+ wav_path: Audio file path
201
+ start_index: Starting index
202
+
203
+ Returns:
204
+ tuple: (Mel spectrogram, start index)
205
+ """
206
+ if not os.path.exists(wav_path):
207
+ return None
208
+
209
+ audio_input, sampling_rate = librosa.load(wav_path, sr=16000)
210
+ assert sampling_rate == 16000
211
+
212
+ audio_input = self.mel_feature_extractor(audio_input)
213
+ return audio_input, start_index
214
+
215
+ def mel_feature_extractor(self, audio_input):
216
+ """Extract mel spectrogram features
217
+
218
+ Args:
219
+ audio_input: Input audio
220
+
221
+ Returns:
222
+ ndarray: Mel spectrogram features
223
+ """
224
+ orig_mel = audio.melspectrogram(audio_input)
225
+ return orig_mel.T
226
+
227
+ def crop_audio_window(self, spec, start_frame_num, fps=25):
228
+ """Crop audio window
229
+
230
+ Args:
231
+ spec: Spectrogram
232
+ start_frame_num: Starting frame number
233
+ fps: Frames per second
234
+
235
+ Returns:
236
+ ndarray: Cropped spectrogram
237
+ """
238
+ start_idx = int(80. * (start_frame_num / float(fps)))
239
+ end_idx = start_idx + syncnet_mel_step_size
240
+ return spec[start_idx: end_idx, :]
241
+
242
+ def get_syncnet_input(self, video_path):
243
+ """Get SyncNet input features
244
+
245
+ Args:
246
+ video_path: Video file path
247
+
248
+ Returns:
249
+ ndarray: SyncNet input features
250
+ """
251
+ ar = AudioReader(video_path, sample_rate=16000)
252
+ original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0))
253
+ return original_mel.T
254
+
255
+ def get_resized_mouth_mask(
256
+ self,
257
+ img_resized,
258
+ landmark_array,
259
+ face_shape,
260
+ padding_pixel_mouth=0,
261
+ image_size=256,
262
+ crop_margin=0
263
+ ):
264
+ landmark_array = np.array(landmark_array)
265
+ resized_landmark = resize_landmark(
266
+ landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size)
267
+
268
+ landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format
269
+ min_x, min_y = np.min(landmark_array, axis=0)
270
+ max_x, max_y = np.max(landmark_array, axis=0)
271
+ min_x = min_x - padding_pixel_mouth
272
+ max_x = max_x + padding_pixel_mouth
273
+
274
+ # Calculate x-axis length and use it for y-axis
275
+ width = max_x - min_x
276
+
277
+ # Calculate old center point
278
+ center_y = (max_y + min_y) / 2
279
+
280
+ # Determine new min_y and max_y based on width
281
+ min_y = center_y - width / 4
282
+ max_y = center_y + width / 4
283
+
284
+ # Adjust mask position for dynamic crop, shift y-axis
285
+ min_y = min_y - crop_margin
286
+ max_y = max_y - crop_margin
287
+
288
+ # Prevent out of bounds
289
+ min_x = max(min_x, 0)
290
+ min_y = max(min_y, 0)
291
+ max_x = min(max_x, face_shape[0])
292
+ max_y = min(max_y, face_shape[1])
293
+
294
+ mask = np.zeros_like(np.array(img_resized))
295
+ mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255
296
+ return Image.fromarray(mask)
297
+
298
+ def __len__(self):
299
+ return 100000
300
+
301
+ def __getitem__(self, idx):
302
+ attempts = 0
303
+ while attempts < self.max_attempts:
304
+ try:
305
+ meta_path = random.sample(self.meta_paths, k=1)[0]
306
+ with open(meta_path, 'r') as f:
307
+ meta_data = json.load(f)
308
+ except Exception as e:
309
+ print(f"meta file error:{meta_path}")
310
+ print(e)
311
+ attempts += 1
312
+ time.sleep(0.1)
313
+ continue
314
+
315
+ video_path = meta_data["mp4_path"]
316
+ wav_path = meta_data["wav_path"]
317
+ bbox_list = meta_data["face_list"]
318
+ landmark_list = meta_data["landmark_list"]
319
+ T = self.T
320
+
321
+ s = 0
322
+ e = meta_data["frames"]
323
+ len_valid_clip = e - s
324
+
325
+ if len_valid_clip < T * 10:
326
+ attempts += 1
327
+ print(f"video {video_path} has less than {T * 10} frames")
328
+ continue
329
+
330
+ try:
331
+ cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0))
332
+ total_frames = len(cap)
333
+ assert total_frames == len(landmark_list)
334
+ assert total_frames == len(bbox_list)
335
+ landmark_shape = np.array(landmark_list).shape
336
+ if landmark_shape != (total_frames, 68, 2):
337
+ attempts += 1
338
+ print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks
339
+ continue
340
+ except Exception as e:
341
+ print(f"video file error:{video_path}")
342
+ print(e)
343
+ attempts += 1
344
+ time.sleep(0.1)
345
+ continue
346
+
347
+ shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates(
348
+ landmark_list,
349
+ bbox_list
350
+ )
351
+ if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size:
352
+ print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}")
353
+ attempts += 1
354
+ continue
355
+
356
+ step = 1
357
+ drive_idx_start = random.randint(s, e - T * step)
358
+ drive_idx_list = list(
359
+ range(drive_idx_start, drive_idx_start + T * step, step))
360
+ assert len(drive_idx_list) == T
361
+
362
+ src_idx_list = []
363
+ list_index_out_of_range = False
364
+ for drive_idx in drive_idx_list:
365
+ src_idx = get_src_idx(
366
+ drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio)
367
+ if src_idx is None:
368
+ list_index_out_of_range = True
369
+ break
370
+ src_idx = min(src_idx, e - 1)
371
+ src_idx = max(src_idx, s)
372
+ src_idx_list.append(src_idx)
373
+
374
+ if list_index_out_of_range:
375
+ attempts += 1
376
+ print(f"video {video_path} has invalid source index for drive frames")
377
+ continue
378
+
379
+ ref_face_valid_flag = True
380
+ extra_margin = self.generate_random_value()
381
+
382
+ # Get reference images
383
+ ref_imgs = []
384
+ for src_idx in src_idx_list:
385
+ imSrc = Image.fromarray(cap[src_idx].asnumpy())
386
+ bbox_s = bbox_list_union[src_idx]
387
+ imSrc, _, _ = self.crop_resize_img(
388
+ imSrc,
389
+ bbox_s,
390
+ self.crop_type,
391
+ extra_margin=None
392
+ )
393
+ if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size:
394
+ ref_face_valid_flag = False
395
+ break
396
+ ref_imgs.append(imSrc)
397
+
398
+ if not ref_face_valid_flag:
399
+ attempts += 1
400
+ print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}")
401
+ continue
402
+
403
+ # Get target images and masks
404
+ imSameIDs = []
405
+ bboxes = []
406
+ face_masks = []
407
+ face_mask_valid = True
408
+ target_face_valid_flag = True
409
+
410
+ for drive_idx in drive_idx_list:
411
+ imSameID = Image.fromarray(cap[drive_idx].asnumpy())
412
+ bbox_s = bbox_list_union[drive_idx]
413
+ imSameID, _ , mask_scaled_factor = self.crop_resize_img(
414
+ imSameID,
415
+ bbox_s,
416
+ self.crop_type,
417
+ extra_margin=extra_margin
418
+ )
419
+ if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size:
420
+ target_face_valid_flag = False
421
+ break
422
+ crop_margin = extra_margin * mask_scaled_factor
423
+ face_mask = self.get_resized_mouth_mask(
424
+ imSameID,
425
+ shift_landmarks[drive_idx],
426
+ face_shapes[drive_idx],
427
+ self.padding_pixel_mouth,
428
+ self.image_size,
429
+ crop_margin=crop_margin
430
+ )
431
+ if np.count_nonzero(face_mask) == 0:
432
+ face_mask_valid = False
433
+ break
434
+
435
+ if face_mask.size[1] == 0 or face_mask.size[0] == 0:
436
+ print(f"video {video_path} has invalid face mask size at frame {drive_idx}")
437
+ face_mask_valid = False
438
+ break
439
+
440
+ imSameIDs.append(imSameID)
441
+ bboxes.append(bbox_s)
442
+ face_masks.append(face_mask)
443
+
444
+ if not face_mask_valid:
445
+ attempts += 1
446
+ print(f"video {video_path} has invalid face mask")
447
+ continue
448
+
449
+ if not target_face_valid_flag:
450
+ attempts += 1
451
+ print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}")
452
+ continue
453
+
454
+ # Process audio features
455
+ audio_offset = drive_idx_list[0]
456
+ audio_step = step
457
+ fps = 25.0 / step
458
+
459
+ try:
460
+ audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset)
461
+ _, audio_offset = self.get_audio_file_mel(wav_path, audio_offset)
462
+ audio_feature_mel = self.get_syncnet_input(video_path)
463
+ except Exception as e:
464
+ print(f"audio file error:{wav_path}")
465
+ print(e)
466
+ attempts += 1
467
+ time.sleep(0.1)
468
+ continue
469
+
470
+ mel = self.crop_audio_window(audio_feature_mel, audio_offset)
471
+ if mel.shape[0] != syncnet_mel_step_size:
472
+ attempts += 1
473
+ print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}")
474
+ continue
475
+
476
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
477
+
478
+ # Build sample dictionary
479
+ sample = dict(
480
+ pixel_values_vid=torch.stack(
481
+ [self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0),
482
+ pixel_values_ref_img=torch.stack(
483
+ [self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0),
484
+ pixel_values_face_mask=torch.stack(
485
+ [self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0),
486
+ audio_feature=audio_feature[0],
487
+ audio_offset=audio_offset,
488
+ audio_step=audio_step,
489
+ mel=mel,
490
+ wav_path=wav_path,
491
+ fps=fps,
492
+ )
493
+
494
+ return sample
495
+
496
+ raise ValueError("Unable to find a valid sample after maximum attempts.")
497
+
498
+ class HDTFDataset(FaceDataset):
499
+ """HDTF dataset class"""
500
+ def __init__(self, cfg):
501
+ root_path = './dataset/HDTF/meta'
502
+ list_paths = [
503
+ './dataset/HDTF/train.txt',
504
+ ]
505
+
506
+
507
+ repeats = [10]
508
+ super().__init__(cfg, list_paths, root_path, repeats)
509
+ print('HDTFDataset: ', len(self))
510
+
511
+ class VFHQDataset(FaceDataset):
512
+ """VFHQ dataset class"""
513
+ def __init__(self, cfg):
514
+ root_path = './dataset/VFHQ/meta'
515
+ list_paths = [
516
+ './dataset/VFHQ/train.txt',
517
+ ]
518
+ repeats = [1]
519
+ super().__init__(cfg, list_paths, root_path, repeats)
520
+ print('VFHQDataset: ', len(self))
521
+
522
+ def PortraitDataset(cfg=None):
523
+ """Return dataset based on configuration
524
+
525
+ Args:
526
+ cfg: Configuration dictionary
527
+
528
+ Returns:
529
+ Dataset: Combined dataset
530
+ """
531
+ if cfg["dataset_key"] == "HDTF":
532
+ return ConcatDataset([HDTFDataset(cfg)])
533
+ elif cfg["dataset_key"] == "VFHQ":
534
+ return ConcatDataset([VFHQDataset(cfg)])
535
+ else:
536
+ print("############ use all dataset ############ ")
537
+ return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)])
538
+
539
+
540
+ if __name__ == '__main__':
541
+ # Set random seeds for reproducibility
542
+ seed = 42
543
+ random.seed(seed)
544
+ np.random.seed(seed)
545
+ torch.manual_seed(seed)
546
+ torch.cuda.manual_seed(seed)
547
+ torch.cuda.manual_seed_all(seed)
548
+
549
+ # Create dataset with configuration parameters
550
+ dataset = PortraitDataset(cfg={
551
+ 'T': 1, # Number of frames to process at once
552
+ 'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform"
553
+ 'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both
554
+ 'image_size': 256, # Size of processed images (height and width)
555
+ 'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames
556
+ 'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling
557
+ 'contorl_face_min_size': True, # Whether to enforce minimum face size
558
+ 'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask
559
+ 'min_face_size': 200, # Minimum face size requirement for dataset
560
+ 'whisper_path': "./models/whisper", # Path to Whisper model
561
+ 'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping
562
+ 'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping
563
+ 'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize"
564
+ })
565
+ print(len(dataset))
566
+
567
+ import torchvision
568
+ os.makedirs('debug', exist_ok=True)
569
+ for i in range(10): # Check 10 samples
570
+ sample = dataset[0]
571
+ print(f"processing {i}")
572
+
573
+ # Get images and mask
574
+ ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w)
575
+ target_img = (sample['pixel_values_vid'] + 1.0) / 2
576
+ face_mask = sample['pixel_values_face_mask']
577
+
578
+ # Print dimension information
579
+ print(f"ref_img shape: {ref_img.shape}")
580
+ print(f"target_img shape: {target_img.shape}")
581
+ print(f"face_mask shape: {face_mask.shape}")
582
+
583
+ # Create visualization images
584
+ b, c, h, w = ref_img.shape
585
+
586
+ # Apply mask only to target image
587
+ target_mask = face_mask
588
+
589
+ # Keep reference image unchanged
590
+ ref_with_mask = ref_img.clone()
591
+
592
+ # Create mask overlay for target image
593
+ target_with_mask = target_img.clone()
594
+ target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target
595
+
596
+ # Save original images, mask, and overlay results
597
+ # First row: original images
598
+ # Second row: mask
599
+ # Third row: overlay effect
600
+ concatenated_img = torch.cat((
601
+ ref_img, target_img, # Original images
602
+ torch.zeros_like(ref_img), target_mask, # Mask (black for ref)
603
+ ref_with_mask, target_with_mask # Overlay effect
604
+ ), dim=3)
605
+
606
+ torchvision.utils.save_image(
607
+ concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2)
musetalk/data/sample_method.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+
4
+ def summarize_tensor(x):
5
+ return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
6
+
7
+ def calculate_mouth_open_similarity(landmarks_list, select_idx,top_k=50,ascending=True):
8
+ num_landmarks = len(landmarks_list)
9
+ mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
10
+ print(np.shape(landmarks_list))
11
+ ## Calculate mouth opening ratios
12
+ for i, landmarks in enumerate(landmarks_list):
13
+ # Assuming landmarks are in the format [x, y] and accessible by index
14
+ mouth_top = landmarks[165] # Adjust index according to your landmarks format
15
+ mouth_bottom = landmarks[147] # Adjust index according to your landmarks format
16
+ mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
17
+ mouth_open_ratios[i] = mouth_open_ratio
18
+
19
+ # Calculate differences matrix
20
+ differences_matrix = np.abs(mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx])
21
+ differences_matrix_with_signs = mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx]
22
+ print(differences_matrix.shape)
23
+ # Find top_k similar indices for each landmark set
24
+ if ascending:
25
+ top_indices = np.argsort(differences_matrix[i])[:top_k]
26
+ else:
27
+ top_indices = np.argsort(-differences_matrix[i])[:top_k]
28
+ similar_landmarks_indices = top_indices.tolist()
29
+ similar_landmarks_distances = differences_matrix_with_signs[i].tolist() #注意这里不要排序
30
+
31
+ return similar_landmarks_indices, similar_landmarks_distances
32
+ #############################################################################################
33
+ def get_closed_mouth(landmarks_list,ascending=True,top_k=50):
34
+ num_landmarks = len(landmarks_list)
35
+
36
+ mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
37
+ ## Calculate mouth opening ratios
38
+ #print("landmarks shape",np.shape(landmarks_list))
39
+ for i, landmarks in enumerate(landmarks_list):
40
+ # Assuming landmarks are in the format [x, y] and accessible by index
41
+ #print(landmarks[165])
42
+ mouth_top = np.array(landmarks[165])# Adjust index according to your landmarks format
43
+ mouth_bottom = np.array(landmarks[147]) # Adjust index according to your landmarks format
44
+ mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
45
+ mouth_open_ratios[i] = mouth_open_ratio
46
+
47
+ # Find top_k similar indices for each landmark set
48
+ if ascending:
49
+ top_indices = np.argsort(mouth_open_ratios)[:top_k]
50
+ else:
51
+ top_indices = np.argsort(-mouth_open_ratios)[:top_k]
52
+ return top_indices
53
+
54
+ def calculate_landmarks_similarity(selected_idx, landmarks_list,image_shapes, start_index, end_index, top_k=50,ascending=True):
55
+ """
56
+ Calculate the similarity between sets of facial landmarks and return the indices of the most similar faces.
57
+
58
+ Parameters:
59
+ landmarks_list (list): A list containing sets of facial landmarks, each element is a set of landmarks.
60
+ image_shapes (list): A list containing the shape of each image, each element is a (width, height) tuple.
61
+ start_index (int): The starting index of the facial landmarks.
62
+ end_index (int): The ending index of the facial landmarks.
63
+ top_k (int): The number of most similar landmark sets to return. Default is 50.
64
+ ascending (bool): Controls the sorting order. If True, sort in ascending order; If False, sort in descending order. Default is True.
65
+
66
+ Returns:
67
+ similar_landmarks_indices (list): A list containing the indices of the most similar facial landmarks for each face.
68
+ resized_landmarks (list): A list containing the resized facial landmarks.
69
+ """
70
+ num_landmarks = len(landmarks_list)
71
+ resized_landmarks = []
72
+
73
+ # Preprocess landmarks
74
+ for i in range(num_landmarks):
75
+ landmark_array = np.array(landmarks_list[i])
76
+ selected_landmarks = landmark_array[start_index:end_index]
77
+ resized_landmark = resize_landmark(selected_landmarks, w=image_shapes[i][0], h=image_shapes[i][1],new_w=256,new_h=256)
78
+ resized_landmarks.append(resized_landmark)
79
+
80
+ resized_landmarks_array = np.array(resized_landmarks) # Convert list to array for easier manipulation
81
+
82
+ # Calculate similarity
83
+ distances = np.linalg.norm(resized_landmarks_array - resized_landmarks_array[selected_idx][np.newaxis, :], axis=2)
84
+ overall_distances = np.mean(distances, axis=1) # Calculate mean distance for each set of landmarks
85
+
86
+ if ascending:
87
+ sorted_indices = np.argsort(overall_distances)
88
+ similar_landmarks_indices = sorted_indices[1:top_k+1].tolist() # Exclude self and take top_k
89
+ else:
90
+ sorted_indices = np.argsort(-overall_distances)
91
+ similar_landmarks_indices = sorted_indices[0:top_k].tolist()
92
+
93
+ return similar_landmarks_indices
94
+
95
+ def process_bbox_musetalk(face_array, landmark_array):
96
+ x_min_face, y_min_face, x_max_face, y_max_face = map(int, face_array)
97
+ x_min_lm = min([int(x) for x, y in landmark_array])
98
+ y_min_lm = min([int(y) for x, y in landmark_array])
99
+ x_max_lm = max([int(x) for x, y in landmark_array])
100
+ y_max_lm = max([int(y) for x, y in landmark_array])
101
+ x_min = min(x_min_face, x_min_lm)
102
+ y_min = min(y_min_face, y_min_lm)
103
+ x_max = max(x_max_face, x_max_lm)
104
+ y_max = max(y_max_face, y_max_lm)
105
+
106
+ x_min = max(x_min, 0)
107
+ y_min = max(y_min, 0)
108
+
109
+ return [x_min, y_min, x_max, y_max]
110
+
111
+ def shift_landmarks_to_face_coordinates(landmark_list, face_list):
112
+ """
113
+ Translates the data in landmark_list to the coordinates of the cropped larger face.
114
+
115
+ Parameters:
116
+ landmark_list (list): A list containing multiple sets of facial landmarks.
117
+ face_list (list): A list containing multiple facial images.
118
+
119
+ Returns:
120
+ landmark_list_shift (list): The list of translated landmarks.
121
+ bbox_union (list): The list of union bounding boxes.
122
+ face_shapes (list): The list of facial shapes.
123
+ """
124
+ landmark_list_shift = []
125
+ bbox_union = []
126
+ face_shapes = []
127
+
128
+ for i in range(len(face_list)):
129
+ landmark_array = np.array(landmark_list[i]) # 转换为numpy数组并创建副本
130
+ face_array = face_list[i]
131
+ f_landmark_bbox = process_bbox_musetalk(face_array, landmark_array)
132
+ x_min, y_min, x_max, y_max = f_landmark_bbox
133
+ landmark_array[:, 0] = landmark_array[:, 0] - f_landmark_bbox[0]
134
+ landmark_array[:, 1] = landmark_array[:, 1] - f_landmark_bbox[1]
135
+ landmark_list_shift.append(landmark_array)
136
+ bbox_union.append(f_landmark_bbox)
137
+ face_shapes.append((x_max - x_min, y_max - y_min))
138
+
139
+ return landmark_list_shift, bbox_union, face_shapes
140
+
141
+ def resize_landmark(landmark, w, h, new_w, new_h):
142
+ landmark_norm = landmark / [w, h]
143
+ landmark_resized = landmark_norm * [new_w, new_h]
144
+
145
+ return landmark_resized
146
+
147
+ def get_src_idx(drive_idx, T, sample_method,landmarks_list,image_shapes,top_k_ratio):
148
+ """
149
+ Calculate the source index (src_idx) based on the given drive index, T, s, e, and sampling method.
150
+
151
+ Parameters:
152
+ - drive_idx (int): The current drive index.
153
+ - T (int): Total number of frames or a specific range limit.
154
+ - sample_method (str): Sampling method, which can be "random" or other methods.
155
+ - landmarks_list (list): List of facial landmarks.
156
+ - image_shapes (list): List of image shapes.
157
+ - top_k_ratio (float): Ratio for selecting top k similar frames.
158
+
159
+ Returns:
160
+ - src_idx (int): The calculated source index.
161
+ """
162
+ if sample_method == "random":
163
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
164
+ elif sample_method == "pose_similarity":
165
+ top_k = int(top_k_ratio*len(landmarks_list))
166
+ try:
167
+ top_k = int(top_k_ratio*len(landmarks_list))
168
+ # facial contour
169
+ landmark_start_idx = 0
170
+ landmark_end_idx = 16
171
+ pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
172
+ src_idx = random.choice(pose_similarity_list)
173
+ while abs(src_idx-drive_idx)<5:
174
+ src_idx = random.choice(pose_similarity_list)
175
+ except Exception as e:
176
+ print(e)
177
+ return None
178
+ elif sample_method=="pose_similarity_and_closed_mouth":
179
+ # facial contour
180
+ landmark_start_idx = 0
181
+ landmark_end_idx = 16
182
+ try:
183
+ top_k = int(top_k_ratio*len(landmarks_list))
184
+ closed_mouth_list = get_closed_mouth(landmarks_list, ascending=True,top_k=top_k)
185
+ #print("closed_mouth_list",closed_mouth_list)
186
+ pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
187
+ #print("pose_similarity_list",pose_similarity_list)
188
+ common_list = list(set(closed_mouth_list).intersection(set(pose_similarity_list)))
189
+ if len(common_list) == 0:
190
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
191
+ else:
192
+ src_idx = random.choice(common_list)
193
+
194
+ while abs(src_idx-drive_idx) <5:
195
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
196
+
197
+ except Exception as e:
198
+ print(e)
199
+ return None
200
+
201
+ elif sample_method=="pose_similarity_and_mouth_dissimilarity":
202
+ top_k = int(top_k_ratio*len(landmarks_list))
203
+ try:
204
+ top_k = int(top_k_ratio*len(landmarks_list))
205
+
206
+ # facial contour for 68 landmarks format
207
+ landmark_start_idx = 0
208
+ landmark_end_idx = 16
209
+
210
+ pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
211
+
212
+ # Mouth inner coutour for 68 landmarks format
213
+ landmark_start_idx = 60
214
+ landmark_end_idx = 67
215
+
216
+ mouth_dissimilarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=False)
217
+
218
+ common_list = list(set(pose_similarity_list).intersection(set(mouth_dissimilarity_list)))
219
+ if len(common_list) == 0:
220
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
221
+ else:
222
+ src_idx = random.choice(common_list)
223
+
224
+ while abs(src_idx-drive_idx) <5:
225
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
226
+
227
+ except Exception as e:
228
+ print(e)
229
+ return None
230
+
231
+ else:
232
+ raise ValueError(f"Unknown sample_method: {sample_method}")
233
+ return src_idx
musetalk/loss/basic_loss.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, optim
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
9
+ import musetalk.loss.vgg_face as vgg_face
10
+
11
+ class Interpolate(nn.Module):
12
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
13
+ super(Interpolate, self).__init__()
14
+ self.size = size
15
+ self.scale_factor = scale_factor
16
+ self.mode = mode
17
+ self.align_corners = align_corners
18
+
19
+ def forward(self, input):
20
+ return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
21
+
22
+ def set_requires_grad(net, requires_grad=False):
23
+ if net is not None:
24
+ for param in net.parameters():
25
+ param.requires_grad = requires_grad
26
+
27
+ if __name__ == "__main__":
28
+ cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
29
+
30
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+ pyramid_scale = [1, 0.5, 0.25, 0.125]
32
+ vgg_IN = vgg_face.Vgg19().to(device)
33
+ pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
34
+ vgg_IN.eval()
35
+ downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
36
+
37
+ image = torch.rand(8, 3, 256, 256).to(device)
38
+ image_pred = torch.rand(8, 3, 256, 256).to(device)
39
+ pyramide_real = pyramid(downsampler(image))
40
+ pyramide_generated = pyramid(downsampler(image_pred))
41
+
42
+
43
+ loss_IN = 0
44
+ for scale in cfg.loss_params.pyramid_scale:
45
+ x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
46
+ y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
47
+ for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
48
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
49
+ loss_IN += weight * value
50
+ loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值,金字塔loss是每层叠
51
+ print(loss_IN)
52
+
53
+ #print(cfg.model_params.discriminator_params)
54
+
55
+ discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
56
+ discriminator_full = DiscriminatorFullModel(discriminator)
57
+ disc_scales = cfg.model_params.discriminator_params.scales
58
+ # Prepare optimizer and loss function
59
+ optimizer_D = optim.AdamW(discriminator.parameters(),
60
+ lr=cfg.discriminator_train_params.lr,
61
+ weight_decay=cfg.discriminator_train_params.weight_decay,
62
+ betas=cfg.discriminator_train_params.betas,
63
+ eps=cfg.discriminator_train_params.eps)
64
+ scheduler_D = CosineAnnealingLR(optimizer_D,
65
+ T_max=cfg.discriminator_train_params.epochs,
66
+ eta_min=1e-6)
67
+
68
+ discriminator.train()
69
+
70
+ set_requires_grad(discriminator, False)
71
+
72
+ loss_G = 0.
73
+ discriminator_maps_generated = discriminator(pyramide_generated)
74
+ discriminator_maps_real = discriminator(pyramide_real)
75
+
76
+ for scale in disc_scales:
77
+ key = 'prediction_map_%s' % scale
78
+ value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
79
+ loss_G += value
80
+
81
+ print(loss_G)
musetalk/loss/conv.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class nonorm_Conv2d(nn.Module):
22
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self.conv_block = nn.Sequential(
25
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
26
+ )
27
+ self.act = nn.LeakyReLU(0.01, inplace=True)
28
+
29
+ def forward(self, x):
30
+ out = self.conv_block(x)
31
+ return self.act(out)
32
+
33
+ class Conv2dTranspose(nn.Module):
34
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self.conv_block = nn.Sequential(
37
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
38
+ nn.BatchNorm2d(cout)
39
+ )
40
+ self.act = nn.ReLU()
41
+
42
+ def forward(self, x):
43
+ out = self.conv_block(x)
44
+ return self.act(out)
musetalk/loss/discriminator.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from musetalk.loss.vgg_face import ImagePyramide
5
+
6
+ class DownBlock2d(nn.Module):
7
+ """
8
+ Simple block for processing video (encoder).
9
+ """
10
+
11
+ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
12
+ super(DownBlock2d, self).__init__()
13
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
14
+
15
+ if sn:
16
+ self.conv = nn.utils.spectral_norm(self.conv)
17
+
18
+ if norm:
19
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
20
+ else:
21
+ self.norm = None
22
+ self.pool = pool
23
+
24
+ def forward(self, x):
25
+ out = x
26
+ out = self.conv(out)
27
+ if self.norm:
28
+ out = self.norm(out)
29
+ out = F.leaky_relu(out, 0.2)
30
+ if self.pool:
31
+ out = F.avg_pool2d(out, (2, 2))
32
+ return out
33
+
34
+
35
+ class Discriminator(nn.Module):
36
+ """
37
+ Discriminator similar to Pix2Pix
38
+ """
39
+
40
+ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
41
+ sn=False, **kwargs):
42
+ super(Discriminator, self).__init__()
43
+
44
+ down_blocks = []
45
+ for i in range(num_blocks):
46
+ down_blocks.append(
47
+ DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
48
+ min(max_features, block_expansion * (2 ** (i + 1))),
49
+ norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
50
+
51
+ self.down_blocks = nn.ModuleList(down_blocks)
52
+ self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
53
+ if sn:
54
+ self.conv = nn.utils.spectral_norm(self.conv)
55
+
56
+ def forward(self, x):
57
+ feature_maps = []
58
+ out = x
59
+
60
+ for down_block in self.down_blocks:
61
+ feature_maps.append(down_block(out))
62
+ out = feature_maps[-1]
63
+ prediction_map = self.conv(out)
64
+
65
+ return feature_maps, prediction_map
66
+
67
+
68
+ class MultiScaleDiscriminator(nn.Module):
69
+ """
70
+ Multi-scale (scale) discriminator
71
+ """
72
+
73
+ def __init__(self, scales=(), **kwargs):
74
+ super(MultiScaleDiscriminator, self).__init__()
75
+ self.scales = scales
76
+ discs = {}
77
+ for scale in scales:
78
+ discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
79
+ self.discs = nn.ModuleDict(discs)
80
+
81
+ def forward(self, x):
82
+ out_dict = {}
83
+ for scale, disc in self.discs.items():
84
+ scale = str(scale).replace('-', '.')
85
+ key = 'prediction_' + scale
86
+ #print(key)
87
+ #print(x)
88
+ feature_maps, prediction_map = disc(x[key])
89
+ out_dict['feature_maps_' + scale] = feature_maps
90
+ out_dict['prediction_map_' + scale] = prediction_map
91
+ return out_dict
92
+
93
+
94
+
95
+ class DiscriminatorFullModel(torch.nn.Module):
96
+ """
97
+ Merge all discriminator related updates into single model for better multi-gpu usage
98
+ """
99
+
100
+ def __init__(self, discriminator):
101
+ super(DiscriminatorFullModel, self).__init__()
102
+ self.discriminator = discriminator
103
+ self.scales = self.discriminator.scales
104
+ print("scales",self.scales)
105
+ self.pyramid = ImagePyramide(self.scales, 3)
106
+ if torch.cuda.is_available():
107
+ self.pyramid = self.pyramid.cuda()
108
+
109
+ self.zero_tensor = None
110
+
111
+ def get_zero_tensor(self, input):
112
+ if self.zero_tensor is None:
113
+ self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
114
+ self.zero_tensor.requires_grad_(False)
115
+ return self.zero_tensor.expand_as(input)
116
+
117
+ def forward(self, x, generated, gan_mode='ls'):
118
+ pyramide_real = self.pyramid(x)
119
+ pyramide_generated = self.pyramid(generated.detach())
120
+
121
+ discriminator_maps_generated = self.discriminator(pyramide_generated)
122
+ discriminator_maps_real = self.discriminator(pyramide_real)
123
+
124
+ value_total = 0
125
+ for scale in self.scales:
126
+ key = 'prediction_map_%s' % scale
127
+ if gan_mode == 'hinge':
128
+ value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
129
+ elif gan_mode == 'ls':
130
+ value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
131
+ else:
132
+ raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
133
+
134
+ value_total += value
135
+
136
+ return value_total
137
+
138
+ def main():
139
+ discriminator = MultiScaleDiscriminator(scales=[1],
140
+ block_expansion=32,
141
+ max_features=512,
142
+ num_blocks=4,
143
+ sn=True,
144
+ image_channel=3,
145
+ estimate_jacobian=False)
musetalk/loss/resnet.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+
4
+ __all__ = ['ResNet', 'resnet50']
5
+
6
+ def conv3x3(in_planes, out_planes, stride=1):
7
+ """3x3 convolution with padding"""
8
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9
+ padding=1, bias=False)
10
+
11
+
12
+ class BasicBlock(nn.Module):
13
+ expansion = 1
14
+
15
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
16
+ super(BasicBlock, self).__init__()
17
+ self.conv1 = conv3x3(inplanes, planes, stride)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu = nn.ReLU(inplace=True)
20
+ self.conv2 = conv3x3(planes, planes)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+ self.downsample = downsample
23
+ self.stride = stride
24
+
25
+ def forward(self, x):
26
+ residual = x
27
+
28
+ out = self.conv1(x)
29
+ out = self.bn1(out)
30
+ out = self.relu(out)
31
+
32
+ out = self.conv2(out)
33
+ out = self.bn2(out)
34
+
35
+ if self.downsample is not None:
36
+ residual = self.downsample(x)
37
+
38
+ out += residual
39
+ out = self.relu(out)
40
+
41
+ return out
42
+
43
+
44
+ class Bottleneck(nn.Module):
45
+ expansion = 4
46
+
47
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
48
+ super(Bottleneck, self).__init__()
49
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(planes)
51
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
52
+ self.bn2 = nn.BatchNorm2d(planes)
53
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
54
+ self.bn3 = nn.BatchNorm2d(planes * 4)
55
+ self.relu = nn.ReLU(inplace=True)
56
+ self.downsample = downsample
57
+ self.stride = stride
58
+
59
+ def forward(self, x):
60
+ residual = x
61
+
62
+ out = self.conv1(x)
63
+ out = self.bn1(out)
64
+ out = self.relu(out)
65
+
66
+ out = self.conv2(out)
67
+ out = self.bn2(out)
68
+ out = self.relu(out)
69
+
70
+ out = self.conv3(out)
71
+ out = self.bn3(out)
72
+
73
+ if self.downsample is not None:
74
+ residual = self.downsample(x)
75
+
76
+ out += residual
77
+ out = self.relu(out)
78
+
79
+ return out
80
+
81
+
82
+ class ResNet(nn.Module):
83
+
84
+ def __init__(self, block, layers, num_classes=1000, include_top=True):
85
+ self.inplanes = 64
86
+ super(ResNet, self).__init__()
87
+ self.include_top = include_top
88
+
89
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
90
+ self.bn1 = nn.BatchNorm2d(64)
91
+ self.relu = nn.ReLU(inplace=True)
92
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
93
+
94
+ self.layer1 = self._make_layer(block, 64, layers[0])
95
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
96
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
97
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
98
+ self.avgpool = nn.AvgPool2d(7, stride=1)
99
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
100
+
101
+ for m in self.modules():
102
+ if isinstance(m, nn.Conv2d):
103
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
104
+ m.weight.data.normal_(0, math.sqrt(2. / n))
105
+ elif isinstance(m, nn.BatchNorm2d):
106
+ m.weight.data.fill_(1)
107
+ m.bias.data.zero_()
108
+
109
+ def _make_layer(self, block, planes, blocks, stride=1):
110
+ downsample = None
111
+ if stride != 1 or self.inplanes != planes * block.expansion:
112
+ downsample = nn.Sequential(
113
+ nn.Conv2d(self.inplanes, planes * block.expansion,
114
+ kernel_size=1, stride=stride, bias=False),
115
+ nn.BatchNorm2d(planes * block.expansion),
116
+ )
117
+
118
+ layers = []
119
+ layers.append(block(self.inplanes, planes, stride, downsample))
120
+ self.inplanes = planes * block.expansion
121
+ for i in range(1, blocks):
122
+ layers.append(block(self.inplanes, planes))
123
+
124
+ return nn.Sequential(*layers)
125
+
126
+ def forward(self, x):
127
+ x = x * 255.
128
+ x = x.flip(1)
129
+ x = self.conv1(x)
130
+ x = self.bn1(x)
131
+ x = self.relu(x)
132
+ x = self.maxpool(x)
133
+
134
+ x = self.layer1(x)
135
+ x = self.layer2(x)
136
+ x = self.layer3(x)
137
+ x = self.layer4(x)
138
+
139
+ x = self.avgpool(x)
140
+
141
+ if not self.include_top:
142
+ return x
143
+
144
+ x = x.view(x.size(0), -1)
145
+ x = self.fc(x)
146
+ return x
147
+
148
+ def resnet50(**kwargs):
149
+ """Constructs a ResNet-50 model.
150
+ """
151
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
152
+ return model
musetalk/loss/syncnet.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from .conv import Conv2d
6
+
7
+ logloss = nn.BCELoss(reduction="none")
8
+ def cosine_loss(a, v, y):
9
+ d = nn.functional.cosine_similarity(a, v)
10
+ d = d.clamp(0,1) # cosine_similarity的取值范围是【-1,1】,BCE如果输入负数会报错RuntimeError: CUDA error: device-side assert triggered
11
+ loss = logloss(d.unsqueeze(1), y).squeeze()
12
+ loss = loss.mean()
13
+ return loss, d
14
+
15
+ def get_sync_loss(
16
+ audio_embed,
17
+ gt_frames,
18
+ pred_frames,
19
+ syncnet,
20
+ adapted_weight,
21
+ frames_left_index=0,
22
+ frames_right_index=16,
23
+ ):
24
+ # 跟gt_frames做随机的插入交换,节省显存开销
25
+ assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3
26
+ # 3通道图像
27
+ frames_sync_loss = torch.cat(
28
+ [gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]],
29
+ axis=1
30
+ )
31
+ vision_embed = syncnet.get_image_embed(frames_sync_loss)
32
+ y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device)
33
+ loss, score = cosine_loss(audio_embed, vision_embed, y)
34
+ return loss, score
35
+
36
+ class SyncNet_color(nn.Module):
37
+ def __init__(self):
38
+ super(SyncNet_color, self).__init__()
39
+
40
+ self.face_encoder = nn.Sequential(
41
+ Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
42
+
43
+ Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
44
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
45
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
46
+
47
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
48
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
49
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
50
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
51
+
52
+ Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
53
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
54
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
55
+
56
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
57
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
58
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
59
+
60
+ Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
61
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
62
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
63
+
64
+ self.audio_encoder = nn.Sequential(
65
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
66
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
67
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
68
+
69
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
70
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
71
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
72
+
73
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
74
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
75
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
76
+
77
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
78
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
79
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
80
+
81
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
82
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
83
+
84
+ def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
85
+ face_embedding = self.face_encoder(face_sequences)
86
+ audio_embedding = self.audio_encoder(audio_sequences)
87
+
88
+ audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
89
+ face_embedding = face_embedding.view(face_embedding.size(0), -1)
90
+
91
+ audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
92
+ face_embedding = F.normalize(face_embedding, p=2, dim=1)
93
+
94
+
95
+ return audio_embedding, face_embedding
musetalk/loss/vgg_face.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This part of code contains a pretrained vgg_face model.
3
+ ref link: https://github.com/prlz77/vgg-face.pytorch
4
+ '''
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo
8
+ import pickle
9
+ from musetalk.loss import resnet as ResNet
10
+
11
+
12
+ MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
13
+ VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
14
+
15
+ # It was 93.5940, 104.7624, 129.1863 before dividing by 255
16
+ MEAN_RGB = [
17
+ 0.367035294117647,
18
+ 0.41083294117647057,
19
+ 0.5066129411764705
20
+ ]
21
+ def load_state_dict(model, fname):
22
+ """
23
+ Set parameters converted from Caffe models authors of VGGFace2 provide.
24
+ See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
25
+
26
+ Arguments:
27
+ model: model
28
+ fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
29
+ """
30
+ with open(fname, 'rb') as f:
31
+ weights = pickle.load(f, encoding='latin1')
32
+
33
+ own_state = model.state_dict()
34
+ for name, param in weights.items():
35
+ if name in own_state:
36
+ try:
37
+ own_state[name].copy_(torch.from_numpy(param))
38
+ except Exception:
39
+ raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
40
+ 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
41
+ else:
42
+ raise KeyError('unexpected key "{}" in state_dict'.format(name))
43
+
44
+
45
+ def vggface2(pretrained=True):
46
+ vggface = ResNet.resnet50(num_classes=8631, include_top=True)
47
+ load_state_dict(vggface, VGG_FACE_PATH)
48
+ return vggface
49
+
50
+ def vggface(pretrained=False, **kwargs):
51
+ """VGGFace model.
52
+
53
+ Args:
54
+ pretrained (bool): If True, returns pre-trained model
55
+ """
56
+ model = VggFace(**kwargs)
57
+ if pretrained:
58
+ state = torch.utils.model_zoo.load_url(MODEL_URL)
59
+ model.load_state_dict(state)
60
+ return model
61
+
62
+
63
+ class VggFace(torch.nn.Module):
64
+ def __init__(self, classes=2622):
65
+ """VGGFace model.
66
+
67
+ Face recognition network. It takes as input a Bx3x224x224
68
+ batch of face images and gives as output a BxC score vector
69
+ (C is the number of identities).
70
+ Input images need to be scaled in the 0-1 range and then
71
+ normalized with respect to the mean RGB used during training.
72
+
73
+ Args:
74
+ classes (int): number of identities recognized by the
75
+ network
76
+
77
+ """
78
+ super().__init__()
79
+ self.conv1 = _ConvBlock(3, 64, 64)
80
+ self.conv2 = _ConvBlock(64, 128, 128)
81
+ self.conv3 = _ConvBlock(128, 256, 256, 256)
82
+ self.conv4 = _ConvBlock(256, 512, 512, 512)
83
+ self.conv5 = _ConvBlock(512, 512, 512, 512)
84
+ self.dropout = torch.nn.Dropout(0.5)
85
+ self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
86
+ self.fc2 = torch.nn.Linear(4096, 4096)
87
+ self.fc3 = torch.nn.Linear(4096, classes)
88
+
89
+ def forward(self, x):
90
+ x = self.conv1(x)
91
+ x = self.conv2(x)
92
+ x = self.conv3(x)
93
+ x = self.conv4(x)
94
+ x = self.conv5(x)
95
+ x = x.view(x.size(0), -1)
96
+ x = self.dropout(F.relu(self.fc1(x)))
97
+ x = self.dropout(F.relu(self.fc2(x)))
98
+ x = self.fc3(x)
99
+ return x
100
+
101
+
102
+ class _ConvBlock(torch.nn.Module):
103
+ """A Convolutional block."""
104
+
105
+ def __init__(self, *units):
106
+ """Create a block with len(units) - 1 convolutions.
107
+
108
+ convolution number i transforms the number of channels from
109
+ units[i - 1] to units[i] channels.
110
+
111
+ """
112
+ super().__init__()
113
+ self.convs = torch.nn.ModuleList([
114
+ torch.nn.Conv2d(in_, out, 3, 1, 1)
115
+ for in_, out in zip(units[:-1], units[1:])
116
+ ])
117
+
118
+ def forward(self, x):
119
+ # Each convolution is followed by a ReLU, then the block is
120
+ # concluded by a max pooling.
121
+ for c in self.convs:
122
+ x = F.relu(c(x))
123
+ return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
124
+
125
+
126
+
127
+ import numpy as np
128
+ from torchvision import models
129
+ class Vgg19(torch.nn.Module):
130
+ """
131
+ Vgg19 network for perceptual loss.
132
+ """
133
+ def __init__(self, requires_grad=False):
134
+ super(Vgg19, self).__init__()
135
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
136
+ self.slice1 = torch.nn.Sequential()
137
+ self.slice2 = torch.nn.Sequential()
138
+ self.slice3 = torch.nn.Sequential()
139
+ self.slice4 = torch.nn.Sequential()
140
+ self.slice5 = torch.nn.Sequential()
141
+ for x in range(2):
142
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
143
+ for x in range(2, 7):
144
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
145
+ for x in range(7, 12):
146
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
147
+ for x in range(12, 21):
148
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
149
+ for x in range(21, 30):
150
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
151
+
152
+ self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
153
+ requires_grad=False)
154
+ self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
155
+ requires_grad=False)
156
+
157
+ if not requires_grad:
158
+ for param in self.parameters():
159
+ param.requires_grad = False
160
+
161
+ def forward(self, X):
162
+ X = (X - self.mean) / self.std
163
+ h_relu1 = self.slice1(X)
164
+ h_relu2 = self.slice2(h_relu1)
165
+ h_relu3 = self.slice3(h_relu2)
166
+ h_relu4 = self.slice4(h_relu3)
167
+ h_relu5 = self.slice5(h_relu4)
168
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
169
+ return out
170
+
171
+
172
+ from torch import nn
173
+ class AntiAliasInterpolation2d(nn.Module):
174
+ """
175
+ Band-limited downsampling, for better preservation of the input signal.
176
+ """
177
+ def __init__(self, channels, scale):
178
+ super(AntiAliasInterpolation2d, self).__init__()
179
+ sigma = (1 / scale - 1) / 2
180
+ kernel_size = 2 * round(sigma * 4) + 1
181
+ self.ka = kernel_size // 2
182
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
183
+
184
+ kernel_size = [kernel_size, kernel_size]
185
+ sigma = [sigma, sigma]
186
+ # The gaussian kernel is the product of the
187
+ # gaussian function of each dimension.
188
+ kernel = 1
189
+ meshgrids = torch.meshgrid(
190
+ [
191
+ torch.arange(size, dtype=torch.float32)
192
+ for size in kernel_size
193
+ ]
194
+ )
195
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
196
+ mean = (size - 1) / 2
197
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
198
+
199
+ # Make sure sum of values in gaussian kernel equals 1.
200
+ kernel = kernel / torch.sum(kernel)
201
+ # Reshape to depthwise convolutional weight
202
+ kernel = kernel.view(1, 1, *kernel.size())
203
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
204
+
205
+ self.register_buffer('weight', kernel)
206
+ self.groups = channels
207
+ self.scale = scale
208
+ inv_scale = 1 / scale
209
+ self.int_inv_scale = int(inv_scale)
210
+
211
+ def forward(self, input):
212
+ if self.scale == 1.0:
213
+ return input
214
+
215
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
216
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
217
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
218
+
219
+ return out
220
+
221
+
222
+ class ImagePyramide(torch.nn.Module):
223
+ """
224
+ Create image pyramide for computing pyramide perceptual loss.
225
+ """
226
+ def __init__(self, scales, num_channels):
227
+ super(ImagePyramide, self).__init__()
228
+ downs = {}
229
+ for scale in scales:
230
+ downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
231
+ self.downs = nn.ModuleDict(downs)
232
+
233
+ def forward(self, x):
234
+ out_dict = {}
235
+ for scale, down_module in self.downs.items():
236
+ out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
237
+ return out_dict
musetalk/models/syncnet.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py).
3
+ """
4
+
5
+ import torch
6
+ from torch import nn
7
+ from einops import rearrange
8
+ from torch.nn import functional as F
9
+
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from diffusers.models.attention import Attention as CrossAttention, FeedForward
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from einops import rearrange
16
+
17
+
18
+ class SyncNet(nn.Module):
19
+ def __init__(self, config):
20
+ super().__init__()
21
+ self.audio_encoder = DownEncoder2D(
22
+ in_channels=config["audio_encoder"]["in_channels"],
23
+ block_out_channels=config["audio_encoder"]["block_out_channels"],
24
+ downsample_factors=config["audio_encoder"]["downsample_factors"],
25
+ dropout=config["audio_encoder"]["dropout"],
26
+ attn_blocks=config["audio_encoder"]["attn_blocks"],
27
+ )
28
+
29
+ self.visual_encoder = DownEncoder2D(
30
+ in_channels=config["visual_encoder"]["in_channels"],
31
+ block_out_channels=config["visual_encoder"]["block_out_channels"],
32
+ downsample_factors=config["visual_encoder"]["downsample_factors"],
33
+ dropout=config["visual_encoder"]["dropout"],
34
+ attn_blocks=config["visual_encoder"]["attn_blocks"],
35
+ )
36
+
37
+ self.eval()
38
+
39
+ def forward(self, image_sequences, audio_sequences):
40
+ vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
41
+ audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
42
+
43
+ vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
44
+ audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
45
+
46
+ # Make them unit vectors
47
+ vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
48
+ audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
49
+
50
+ return vision_embeds, audio_embeds
51
+
52
+ def get_image_embed(self, image_sequences):
53
+ vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
54
+
55
+ vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
56
+
57
+ # Make them unit vectors
58
+ vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
59
+
60
+ return vision_embeds
61
+
62
+ def get_audio_embed(self, audio_sequences):
63
+ audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
64
+
65
+ audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
66
+
67
+ audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
68
+
69
+ return audio_embeds
70
+
71
+ class ResnetBlock2D(nn.Module):
72
+ def __init__(
73
+ self,
74
+ in_channels: int,
75
+ out_channels: int,
76
+ dropout: float = 0.0,
77
+ norm_num_groups: int = 32,
78
+ eps: float = 1e-6,
79
+ act_fn: str = "silu",
80
+ downsample_factor=2,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
85
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
86
+
87
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
88
+ self.dropout = nn.Dropout(dropout)
89
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
90
+
91
+ if act_fn == "relu":
92
+ self.act_fn = nn.ReLU()
93
+ elif act_fn == "silu":
94
+ self.act_fn = nn.SiLU()
95
+
96
+ if in_channels != out_channels:
97
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
98
+ else:
99
+ self.conv_shortcut = None
100
+
101
+ if isinstance(downsample_factor, list):
102
+ downsample_factor = tuple(downsample_factor)
103
+
104
+ if downsample_factor == 1:
105
+ self.downsample_conv = None
106
+ else:
107
+ self.downsample_conv = nn.Conv2d(
108
+ out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
109
+ )
110
+ self.pad = (0, 1, 0, 1)
111
+ if isinstance(downsample_factor, tuple):
112
+ if downsample_factor[0] == 1:
113
+ self.pad = (0, 1, 1, 1) # The padding order is from back to front
114
+ elif downsample_factor[1] == 1:
115
+ self.pad = (1, 1, 0, 1)
116
+
117
+ def forward(self, input_tensor):
118
+ hidden_states = input_tensor
119
+
120
+ hidden_states = self.norm1(hidden_states)
121
+ hidden_states = self.act_fn(hidden_states)
122
+
123
+ hidden_states = self.conv1(hidden_states)
124
+ hidden_states = self.norm2(hidden_states)
125
+ hidden_states = self.act_fn(hidden_states)
126
+
127
+ hidden_states = self.dropout(hidden_states)
128
+ hidden_states = self.conv2(hidden_states)
129
+
130
+ if self.conv_shortcut is not None:
131
+ input_tensor = self.conv_shortcut(input_tensor)
132
+
133
+ hidden_states += input_tensor
134
+
135
+ if self.downsample_conv is not None:
136
+ hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
137
+ hidden_states = self.downsample_conv(hidden_states)
138
+
139
+ return hidden_states
140
+
141
+
142
+ class AttentionBlock2D(nn.Module):
143
+ def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
144
+ super().__init__()
145
+ if not is_xformers_available():
146
+ raise ModuleNotFoundError(
147
+ "You have to install xformers to enable memory efficient attetion", name="xformers"
148
+ )
149
+ # inner_dim = dim_head * heads
150
+ self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
151
+ self.norm2 = nn.LayerNorm(query_dim)
152
+ self.norm3 = nn.LayerNorm(query_dim)
153
+
154
+ self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
155
+
156
+ self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
157
+ self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
158
+
159
+ self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
160
+ self.attn._use_memory_efficient_attention_xformers = True
161
+
162
+ def forward(self, hidden_states):
163
+ assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
164
+
165
+ batch, channel, height, width = hidden_states.shape
166
+ residual = hidden_states
167
+
168
+ hidden_states = self.norm1(hidden_states)
169
+ hidden_states = self.conv_in(hidden_states)
170
+ hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
171
+
172
+ norm_hidden_states = self.norm2(hidden_states)
173
+ hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
174
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
175
+
176
+ hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
177
+ hidden_states = self.conv_out(hidden_states)
178
+
179
+ hidden_states = hidden_states + residual
180
+ return hidden_states
181
+
182
+
183
+ class DownEncoder2D(nn.Module):
184
+ def __init__(
185
+ self,
186
+ in_channels=4 * 16,
187
+ block_out_channels=[64, 128, 256, 256],
188
+ downsample_factors=[2, 2, 2, 2],
189
+ layers_per_block=2,
190
+ norm_num_groups=32,
191
+ attn_blocks=[1, 1, 1, 1],
192
+ dropout: float = 0.0,
193
+ act_fn="silu",
194
+ ):
195
+ super().__init__()
196
+ self.layers_per_block = layers_per_block
197
+
198
+ # in
199
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
200
+
201
+ # down
202
+ self.down_blocks = nn.ModuleList([])
203
+
204
+ output_channels = block_out_channels[0]
205
+ for i, block_out_channel in enumerate(block_out_channels):
206
+ input_channels = output_channels
207
+ output_channels = block_out_channel
208
+ # is_final_block = i == len(block_out_channels) - 1
209
+
210
+ down_block = ResnetBlock2D(
211
+ in_channels=input_channels,
212
+ out_channels=output_channels,
213
+ downsample_factor=downsample_factors[i],
214
+ norm_num_groups=norm_num_groups,
215
+ dropout=dropout,
216
+ act_fn=act_fn,
217
+ )
218
+
219
+ self.down_blocks.append(down_block)
220
+
221
+ if attn_blocks[i] == 1:
222
+ attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
223
+ self.down_blocks.append(attention_block)
224
+
225
+ # out
226
+ self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
227
+ self.act_fn_out = nn.ReLU()
228
+
229
+ def forward(self, hidden_states):
230
+ hidden_states = self.conv_in(hidden_states)
231
+
232
+ # down
233
+ for down_block in self.down_blocks:
234
+ hidden_states = down_block(hidden_states)
235
+
236
+ # post-process
237
+ hidden_states = self.norm_out(hidden_states)
238
+ hidden_states = self.act_fn_out(hidden_states)
239
+
240
+ return hidden_states
musetalk/models/unet.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import json
5
+
6
+ from diffusers import UNet2DConditionModel
7
+ import sys
8
+ import time
9
+ import numpy as np
10
+ import os
11
+
12
+ class PositionalEncoding(nn.Module):
13
+ def __init__(self, d_model=384, max_len=5000):
14
+ super(PositionalEncoding, self).__init__()
15
+ pe = torch.zeros(max_len, d_model)
16
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
17
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
18
+ pe[:, 0::2] = torch.sin(position * div_term)
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+ pe = pe.unsqueeze(0)
21
+ self.register_buffer('pe', pe)
22
+
23
+ def forward(self, x):
24
+ b, seq_len, d_model = x.size()
25
+ pe = self.pe[:, :seq_len, :]
26
+ x = x + pe.to(x.device)
27
+ return x
28
+
29
+ class UNet():
30
+ def __init__(self,
31
+ unet_config,
32
+ model_path,
33
+ use_float16=False,
34
+ device=None
35
+ ):
36
+ with open(unet_config, 'r') as f:
37
+ unet_config = json.load(f)
38
+ self.model = UNet2DConditionModel(**unet_config)
39
+ self.pe = PositionalEncoding(d_model=384)
40
+ if device != None:
41
+ self.device = device
42
+ else:
43
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
45
+ self.model.load_state_dict(weights)
46
+ if use_float16:
47
+ self.model = self.model.half()
48
+ self.model.to(self.device)
49
+
50
+ if __name__ == "__main__":
51
+ unet = UNet()
musetalk/models/vae.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderKL
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn.functional as F
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ import os
9
+
10
+ class VAE():
11
+ """
12
+ VAE (Variational Autoencoder) class for image processing.
13
+ """
14
+
15
+ def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
16
+ """
17
+ Initialize the VAE instance.
18
+
19
+ :param model_path: Path to the trained model.
20
+ :param resized_img: The size to which images are resized.
21
+ :param use_float16: Whether to use float16 precision.
22
+ """
23
+ self.model_path = model_path
24
+ self.vae = AutoencoderKL.from_pretrained(self.model_path)
25
+
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.vae.to(self.device)
28
+
29
+ if use_float16:
30
+ self.vae = self.vae.half()
31
+ self._use_float16 = True
32
+ else:
33
+ self._use_float16 = False
34
+
35
+ self.scaling_factor = self.vae.config.scaling_factor
36
+ self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
37
+ self._resized_img = resized_img
38
+ self._mask_tensor = self.get_mask_tensor()
39
+
40
+ def get_mask_tensor(self):
41
+ """
42
+ Creates a mask tensor for image processing.
43
+ :return: A mask tensor.
44
+ """
45
+ mask_tensor = torch.zeros((self._resized_img,self._resized_img))
46
+ mask_tensor[:self._resized_img//2,:] = 1
47
+ mask_tensor[mask_tensor< 0.5] = 0
48
+ mask_tensor[mask_tensor>= 0.5] = 1
49
+ return mask_tensor
50
+
51
+ def preprocess_img(self,img_name,half_mask=False):
52
+ """
53
+ Preprocess an image for the VAE.
54
+
55
+ :param img_name: The image file path or a list of image file paths.
56
+ :param half_mask: Whether to apply a half mask to the image.
57
+ :return: A preprocessed image tensor.
58
+ """
59
+ window = []
60
+ if isinstance(img_name, str):
61
+ window_fnames = [img_name]
62
+ for fname in window_fnames:
63
+ img = cv2.imread(fname)
64
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65
+ img = cv2.resize(img, (self._resized_img, self._resized_img),
66
+ interpolation=cv2.INTER_LANCZOS4)
67
+ window.append(img)
68
+ else:
69
+ img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
70
+ window.append(img)
71
+
72
+ x = np.asarray(window) / 255.
73
+ x = np.transpose(x, (3, 0, 1, 2))
74
+ x = torch.squeeze(torch.FloatTensor(x))
75
+ if half_mask:
76
+ x = x * (self._mask_tensor>0.5)
77
+ x = self.transform(x)
78
+
79
+ x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
80
+ x = x.to(self.vae.device)
81
+
82
+ return x
83
+
84
+ def encode_latents(self,image):
85
+ """
86
+ Encode an image into latent variables.
87
+
88
+ :param image: The image tensor to encode.
89
+ :return: The encoded latent variables.
90
+ """
91
+ with torch.no_grad():
92
+ init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
93
+ init_latents = self.scaling_factor * init_latent_dist.sample()
94
+ return init_latents
95
+
96
+ def decode_latents(self, latents):
97
+ """
98
+ Decode latent variables back into an image.
99
+ :param latents: The latent variables to decode.
100
+ :return: A NumPy array representing the decoded image.
101
+ """
102
+ latents = (1/ self.scaling_factor) * latents
103
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
104
+ image = (image / 2 + 0.5).clamp(0, 1)
105
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
106
+ image = (image * 255).round().astype("uint8")
107
+ image = image[...,::-1] # RGB to BGR
108
+ return image
109
+
110
+ def get_latents_for_unet(self,img):
111
+ """
112
+ Prepare latent variables for a U-Net model.
113
+ :param img: The image to process.
114
+ :return: A concatenated tensor of latents for U-Net input.
115
+ """
116
+
117
+ ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
118
+ masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
119
+ ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
120
+ ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
121
+ latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
122
+ return latent_model_input
123
+
124
+ if __name__ == "__main__":
125
+ vae_mode_path = "./models/sd-vae-ft-mse/"
126
+ vae = VAE(model_path = vae_mode_path,use_float16=False)
127
+ img_path = "./results/sun001_crop/00000.png"
128
+
129
+ crop_imgs_path = "./results/sun001_crop/"
130
+ latents_out_path = "./results/latents/"
131
+ if not os.path.exists(latents_out_path):
132
+ os.mkdir(latents_out_path)
133
+
134
+ files = os.listdir(crop_imgs_path)
135
+ files.sort()
136
+ files = [file for file in files if file.split(".")[-1] == "png"]
137
+
138
+ for file in files:
139
+ index = file.split(".")[0]
140
+ img_path = crop_imgs_path + file
141
+ latents = vae.get_latents_for_unet(img_path)
142
+ print(img_path,"latents",latents.size())
143
+ #torch.save(latents,os.path.join(latents_out_path,index+".pt"))
144
+ #reload_tensor = torch.load('tensor.pt')
145
+ #print(reload_tensor.size())
146
+
147
+
148
+
musetalk/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import sys
2
+ from os.path import abspath, dirname
3
+ current_dir = dirname(abspath(__file__))
4
+ parent_dir = dirname(current_dir)
5
+ sys.path.append(parent_dir+'/utils')
musetalk/utils/audio_processor.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from einops import rearrange
8
+ from transformers import AutoFeatureExtractor
9
+
10
+
11
+ class AudioProcessor:
12
+ def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
13
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
14
+
15
+ def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
16
+ if not os.path.exists(wav_path):
17
+ return None
18
+ librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
19
+ assert sampling_rate == 16000
20
+ # Split audio into 30s segments
21
+ segment_length = 30 * sampling_rate
22
+ segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
23
+
24
+ features = []
25
+ for segment in segments:
26
+ audio_feature = self.feature_extractor(
27
+ segment,
28
+ return_tensors="pt",
29
+ sampling_rate=sampling_rate
30
+ ).input_features
31
+ if weight_dtype is not None:
32
+ audio_feature = audio_feature.to(dtype=weight_dtype)
33
+ features.append(audio_feature)
34
+
35
+ return features, len(librosa_output)
36
+
37
+ def get_whisper_chunk(
38
+ self,
39
+ whisper_input_features,
40
+ device,
41
+ weight_dtype,
42
+ whisper,
43
+ librosa_length,
44
+ fps=25,
45
+ audio_padding_length_left=2,
46
+ audio_padding_length_right=2,
47
+ ):
48
+ audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
49
+ whisper_feature = []
50
+ # Process multiple 30s mel input features
51
+ for input_feature in whisper_input_features:
52
+ input_feature = input_feature.to(device).to(weight_dtype)
53
+ audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
54
+ audio_feats = torch.stack(audio_feats, dim=2)
55
+ whisper_feature.append(audio_feats)
56
+
57
+ whisper_feature = torch.cat(whisper_feature, dim=1)
58
+ # Trim the last segment to remove padding
59
+ sr = 16000
60
+ audio_fps = 50
61
+ fps = int(fps)
62
+ whisper_idx_multiplier = audio_fps / fps
63
+ num_frames = math.floor((librosa_length / sr) * fps)
64
+ actual_length = math.floor((librosa_length / sr) * audio_fps)
65
+ whisper_feature = whisper_feature[:,:actual_length,...]
66
+
67
+ # Calculate padding amount
68
+ padding_nums = math.ceil(whisper_idx_multiplier)
69
+ # Add padding at start and end
70
+ whisper_feature = torch.cat([
71
+ torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
72
+ whisper_feature,
73
+ # Add extra padding to prevent out of bounds
74
+ torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
75
+ ], 1)
76
+
77
+ audio_prompts = []
78
+ for frame_index in range(num_frames):
79
+ try:
80
+ audio_index = math.floor(frame_index * whisper_idx_multiplier)
81
+ audio_clip = whisper_feature[:, audio_index: audio_index + audio_feature_length_per_frame]
82
+ assert audio_clip.shape[1] == audio_feature_length_per_frame
83
+ audio_prompts.append(audio_clip)
84
+ except Exception as e:
85
+ print(f"Error occurred: {e}")
86
+ print(f"whisper_feature.shape: {whisper_feature.shape}")
87
+ print(f"audio_clip.shape: {audio_clip.shape}")
88
+ print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
89
+ print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
90
+ exit()
91
+
92
+ audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
93
+ audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
94
+ return audio_prompts
95
+
96
+ if __name__ == "__main__":
97
+ audio_processor = AudioProcessor()
98
+ wav_path = "./2.wav"
99
+ audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
100
+ print("Audio Feature shape:", audio_feature.shape)
101
+ print("librosa_feature_length:", librosa_feature_length)
102
+
musetalk/utils/blending.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import cv2
4
+ import copy
5
+
6
+
7
+ def get_crop_box(box, expand):
8
+ x, y, x1, y1 = box
9
+ x_c, y_c = (x+x1)//2, (y+y1)//2
10
+ w, h = x1-x, y1-y
11
+ s = int(max(w, h)//2*expand)
12
+ crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
13
+ return crop_box, s
14
+
15
+
16
+ def face_seg(image, mode="raw", fp=None):
17
+ """
18
+ 对图像进行面部解析,生成面部区域的掩码。
19
+
20
+ Args:
21
+ image (PIL.Image): 输入图像。
22
+
23
+ Returns:
24
+ PIL.Image: 面部区域的掩码图像。
25
+ """
26
+ seg_image = fp(image, mode=mode) # 使用 FaceParsing 模型解析面部
27
+ if seg_image is None:
28
+ print("error, no person_segment") # 如果没有检测到面部,返回错误
29
+ return None
30
+
31
+ seg_image = seg_image.resize(image.size) # 将掩码图像调整为输入图像的大小
32
+ return seg_image
33
+
34
+
35
+ def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
36
+ """
37
+ 将裁剪的面部图像粘贴回原始图像,并进行一些处理。
38
+
39
+ Args:
40
+ image (numpy.ndarray): 原始图像(身体部分)。
41
+ face (numpy.ndarray): 裁剪的面部图像。
42
+ face_box (tuple): 面部边界框的坐标 (x, y, x1, y1)。
43
+ upper_boundary_ratio (float): 用于控制面部区域的保留比例。
44
+ expand (float): 扩展因子,用于放大裁剪框。
45
+ mode: 融合mask构建方式
46
+
47
+ Returns:
48
+ numpy.ndarray: 处理后的图像。
49
+ """
50
+ # 将 numpy 数组转换为 PIL 图像
51
+ body = Image.fromarray(image[:, :, ::-1]) # 身体部分图像(整张图)
52
+ face = Image.fromarray(face[:, :, ::-1]) # 面部图像
53
+
54
+ x, y, x1, y1 = face_box # 获取面部边界框的坐标
55
+ crop_box, s = get_crop_box(face_box, expand) # 计算扩展后的裁剪框
56
+ x_s, y_s, x_e, y_e = crop_box # 裁剪框的坐标
57
+ face_position = (x, y) # 面部在原始图像中的位置
58
+
59
+ # 从身体图像中裁剪出扩展后的面部区域(下巴到边界有距离)
60
+ face_large = body.crop(crop_box)
61
+
62
+ ori_shape = face_large.size # 裁剪后图像的原始尺寸
63
+
64
+ # 对裁剪后的面部区域进行面部解析,生成掩码
65
+ mask_image = face_seg(face_large, mode=mode, fp=fp)
66
+
67
+ mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 裁剪出面部区域的掩码
68
+
69
+ mask_image = Image.new('L', ori_shape, 0) # 创建一个全黑的掩码图像
70
+ mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 将面部掩码粘贴到全黑图像上
71
+
72
+
73
+ # 保留面部区域的上半部分(用于控制说话区域)
74
+ width, height = mask_image.size
75
+ top_boundary = int(height * upper_boundary_ratio) # 计算上半部分的边界
76
+ modified_mask_image = Image.new('L', ori_shape, 0) # 创建一个新的全黑掩码图像
77
+ modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) # 粘贴上半部分掩码
78
+
79
+
80
+ # 对掩码进行高斯模糊,使边缘更平滑
81
+ blur_kernel_size = int(0.05 * ori_shape[0] // 2 * 2) + 1 # 计算模糊核大小
82
+ mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) # 高斯模糊
83
+ #mask_array = np.array(modified_mask_image)
84
+ mask_image = Image.fromarray(mask_array) # 将模糊后的掩码转换回 PIL 图像
85
+
86
+ # 将裁剪的面部图像粘贴回扩展后的面部区域
87
+ face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
88
+
89
+ body.paste(face_large, crop_box[:2], mask_image)
90
+
91
+ body = np.array(body) # 将 PIL 图像转换回 numpy 数组
92
+
93
+ return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
94
+
95
+
96
+ def get_image_blending(image, face, face_box, mask_array, crop_box):
97
+ body = Image.fromarray(image[:,:,::-1])
98
+ face = Image.fromarray(face[:,:,::-1])
99
+
100
+ x, y, x1, y1 = face_box
101
+ x_s, y_s, x_e, y_e = crop_box
102
+ face_large = body.crop(crop_box)
103
+
104
+ mask_image = Image.fromarray(mask_array)
105
+ mask_image = mask_image.convert("L")
106
+ face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
107
+ body.paste(face_large, crop_box[:2], mask_image)
108
+ body = np.array(body)
109
+ return body[:,:,::-1]
110
+
111
+
112
+ def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
113
+ body = Image.fromarray(image[:,:,::-1])
114
+
115
+ x, y, x1, y1 = face_box
116
+ #print(x1-x,y1-y)
117
+ crop_box, s = get_crop_box(face_box, expand)
118
+ x_s, y_s, x_e, y_e = crop_box
119
+
120
+ face_large = body.crop(crop_box)
121
+ ori_shape = face_large.size
122
+
123
+ mask_image = face_seg(face_large, mode=mode, fp=fp)
124
+ mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
125
+ mask_image = Image.new('L', ori_shape, 0)
126
+ mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
127
+
128
+ # keep upper_boundary_ratio of talking area
129
+ width, height = mask_image.size
130
+ top_boundary = int(height * upper_boundary_ratio)
131
+ modified_mask_image = Image.new('L', ori_shape, 0)
132
+ modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
133
+
134
+ blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
135
+ mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
136
+ return mask_array, crop_box
musetalk/utils/dwpose/default_runtime.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_scope = 'mmpose'
2
+
3
+ # hooks
4
+ default_hooks = dict(
5
+ timer=dict(type='IterTimerHook'),
6
+ logger=dict(type='LoggerHook', interval=50),
7
+ param_scheduler=dict(type='ParamSchedulerHook'),
8
+ checkpoint=dict(type='CheckpointHook', interval=10),
9
+ sampler_seed=dict(type='DistSamplerSeedHook'),
10
+ visualization=dict(type='PoseVisualizationHook', enable=False),
11
+ badcase=dict(
12
+ type='BadCaseAnalysisHook',
13
+ enable=False,
14
+ out_dir='badcase',
15
+ metric_type='loss',
16
+ badcase_thr=5))
17
+
18
+ # custom hooks
19
+ custom_hooks = [
20
+ # Synchronize model buffers such as running_mean and running_var in BN
21
+ # at the end of each epoch
22
+ dict(type='SyncBuffersHook')
23
+ ]
24
+
25
+ # multi-processing backend
26
+ env_cfg = dict(
27
+ cudnn_benchmark=False,
28
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
29
+ dist_cfg=dict(backend='nccl'),
30
+ )
31
+
32
+ # visualizer
33
+ vis_backends = [
34
+ dict(type='LocalVisBackend'),
35
+ # dict(type='TensorboardVisBackend'),
36
+ # dict(type='WandbVisBackend'),
37
+ ]
38
+ visualizer = dict(
39
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
40
+
41
+ # logger
42
+ log_processor = dict(
43
+ type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
44
+ log_level = 'INFO'
45
+ load_from = None
46
+ resume = False
47
+
48
+ # file I/O backend
49
+ backend_args = dict(backend='local')
50
+
51
+ # training/validation/testing progress
52
+ train_cfg = dict(by_epoch=True)
53
+ val_cfg = dict()
54
+ test_cfg = dict()
musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #_base_ = ['../../../_base_/default_runtime.py']
2
+ _base_ = ['default_runtime.py']
3
+
4
+ # runtime
5
+ max_epochs = 270
6
+ stage2_num_epochs = 30
7
+ base_lr = 4e-3
8
+ train_batch_size = 32
9
+ val_batch_size = 32
10
+
11
+ train_cfg = dict(max_epochs=max_epochs, val_interval=10)
12
+ randomness = dict(seed=21)
13
+
14
+ # optimizer
15
+ optim_wrapper = dict(
16
+ type='OptimWrapper',
17
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
18
+ paramwise_cfg=dict(
19
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
20
+
21
+ # learning rate
22
+ param_scheduler = [
23
+ dict(
24
+ type='LinearLR',
25
+ start_factor=1.0e-5,
26
+ by_epoch=False,
27
+ begin=0,
28
+ end=1000),
29
+ dict(
30
+ # use cosine lr from 150 to 300 epoch
31
+ type='CosineAnnealingLR',
32
+ eta_min=base_lr * 0.05,
33
+ begin=max_epochs // 2,
34
+ end=max_epochs,
35
+ T_max=max_epochs // 2,
36
+ by_epoch=True,
37
+ convert_to_iter_based=True),
38
+ ]
39
+
40
+ # automatically scaling LR based on the actual training batch size
41
+ auto_scale_lr = dict(base_batch_size=512)
42
+
43
+ # codec settings
44
+ codec = dict(
45
+ type='SimCCLabel',
46
+ input_size=(288, 384),
47
+ sigma=(6., 6.93),
48
+ simcc_split_ratio=2.0,
49
+ normalize=False,
50
+ use_dark=False)
51
+
52
+ # model settings
53
+ model = dict(
54
+ type='TopdownPoseEstimator',
55
+ data_preprocessor=dict(
56
+ type='PoseDataPreprocessor',
57
+ mean=[123.675, 116.28, 103.53],
58
+ std=[58.395, 57.12, 57.375],
59
+ bgr_to_rgb=True),
60
+ backbone=dict(
61
+ _scope_='mmdet',
62
+ type='CSPNeXt',
63
+ arch='P5',
64
+ expand_ratio=0.5,
65
+ deepen_factor=1.,
66
+ widen_factor=1.,
67
+ out_indices=(4, ),
68
+ channel_attention=True,
69
+ norm_cfg=dict(type='SyncBN'),
70
+ act_cfg=dict(type='SiLU'),
71
+ init_cfg=dict(
72
+ type='Pretrained',
73
+ prefix='backbone.',
74
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
75
+ 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
76
+ )),
77
+ head=dict(
78
+ type='RTMCCHead',
79
+ in_channels=1024,
80
+ out_channels=133,
81
+ input_size=codec['input_size'],
82
+ in_featuremap_size=(9, 12),
83
+ simcc_split_ratio=codec['simcc_split_ratio'],
84
+ final_layer_kernel_size=7,
85
+ gau_cfg=dict(
86
+ hidden_dims=256,
87
+ s=128,
88
+ expansion_factor=2,
89
+ dropout_rate=0.,
90
+ drop_path=0.,
91
+ act_fn='SiLU',
92
+ use_rel_bias=False,
93
+ pos_enc=False),
94
+ loss=dict(
95
+ type='KLDiscretLoss',
96
+ use_target_weight=True,
97
+ beta=10.,
98
+ label_softmax=True),
99
+ decoder=codec),
100
+ test_cfg=dict(flip_test=True, ))
101
+
102
+ # base dataset settings
103
+ dataset_type = 'UBody2dDataset'
104
+ data_mode = 'topdown'
105
+ data_root = 'data/UBody/'
106
+
107
+ backend_args = dict(backend='local')
108
+
109
+ scenes = [
110
+ 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
111
+ 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
112
+ 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
113
+ ]
114
+
115
+ train_datasets = [
116
+ dict(
117
+ type='CocoWholeBodyDataset',
118
+ data_root='data/coco/',
119
+ data_mode=data_mode,
120
+ ann_file='annotations/coco_wholebody_train_v1.0.json',
121
+ data_prefix=dict(img='train2017/'),
122
+ pipeline=[])
123
+ ]
124
+
125
+ for scene in scenes:
126
+ train_dataset = dict(
127
+ type=dataset_type,
128
+ data_root=data_root,
129
+ data_mode=data_mode,
130
+ ann_file=f'annotations/{scene}/train_annotations.json',
131
+ data_prefix=dict(img='images/'),
132
+ pipeline=[],
133
+ sample_interval=10)
134
+ train_datasets.append(train_dataset)
135
+
136
+ # pipelines
137
+ train_pipeline = [
138
+ dict(type='LoadImage', backend_args=backend_args),
139
+ dict(type='GetBBoxCenterScale'),
140
+ dict(type='RandomFlip', direction='horizontal'),
141
+ dict(type='RandomHalfBody'),
142
+ dict(
143
+ type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
144
+ dict(type='TopdownAffine', input_size=codec['input_size']),
145
+ dict(type='mmdet.YOLOXHSVRandomAug'),
146
+ dict(
147
+ type='Albumentation',
148
+ transforms=[
149
+ dict(type='Blur', p=0.1),
150
+ dict(type='MedianBlur', p=0.1),
151
+ dict(
152
+ type='CoarseDropout',
153
+ max_holes=1,
154
+ max_height=0.4,
155
+ max_width=0.4,
156
+ min_holes=1,
157
+ min_height=0.2,
158
+ min_width=0.2,
159
+ p=1.0),
160
+ ]),
161
+ dict(type='GenerateTarget', encoder=codec),
162
+ dict(type='PackPoseInputs')
163
+ ]
164
+ val_pipeline = [
165
+ dict(type='LoadImage', backend_args=backend_args),
166
+ dict(type='GetBBoxCenterScale'),
167
+ dict(type='TopdownAffine', input_size=codec['input_size']),
168
+ dict(type='PackPoseInputs')
169
+ ]
170
+
171
+ train_pipeline_stage2 = [
172
+ dict(type='LoadImage', backend_args=backend_args),
173
+ dict(type='GetBBoxCenterScale'),
174
+ dict(type='RandomFlip', direction='horizontal'),
175
+ dict(type='RandomHalfBody'),
176
+ dict(
177
+ type='RandomBBoxTransform',
178
+ shift_factor=0.,
179
+ scale_factor=[0.5, 1.5],
180
+ rotate_factor=90),
181
+ dict(type='TopdownAffine', input_size=codec['input_size']),
182
+ dict(type='mmdet.YOLOXHSVRandomAug'),
183
+ dict(
184
+ type='Albumentation',
185
+ transforms=[
186
+ dict(type='Blur', p=0.1),
187
+ dict(type='MedianBlur', p=0.1),
188
+ dict(
189
+ type='CoarseDropout',
190
+ max_holes=1,
191
+ max_height=0.4,
192
+ max_width=0.4,
193
+ min_holes=1,
194
+ min_height=0.2,
195
+ min_width=0.2,
196
+ p=0.5),
197
+ ]),
198
+ dict(type='GenerateTarget', encoder=codec),
199
+ dict(type='PackPoseInputs')
200
+ ]
201
+
202
+ # data loaders
203
+ train_dataloader = dict(
204
+ batch_size=train_batch_size,
205
+ num_workers=10,
206
+ persistent_workers=True,
207
+ sampler=dict(type='DefaultSampler', shuffle=True),
208
+ dataset=dict(
209
+ type='CombinedDataset',
210
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
211
+ datasets=train_datasets,
212
+ pipeline=train_pipeline,
213
+ test_mode=False,
214
+ ))
215
+
216
+ val_dataloader = dict(
217
+ batch_size=val_batch_size,
218
+ num_workers=10,
219
+ persistent_workers=True,
220
+ drop_last=False,
221
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222
+ dataset=dict(
223
+ type='CocoWholeBodyDataset',
224
+ data_root=data_root,
225
+ data_mode=data_mode,
226
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
227
+ bbox_file='data/coco/person_detection_results/'
228
+ 'COCO_val2017_detections_AP_H_56_person.json',
229
+ data_prefix=dict(img='coco/val2017/'),
230
+ test_mode=True,
231
+ pipeline=val_pipeline,
232
+ ))
233
+ test_dataloader = val_dataloader
234
+
235
+ # hooks
236
+ default_hooks = dict(
237
+ checkpoint=dict(
238
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239
+
240
+ custom_hooks = [
241
+ dict(
242
+ type='EMAHook',
243
+ ema_type='ExpMomentumEMA',
244
+ momentum=0.0002,
245
+ update_buffers=True,
246
+ priority=49),
247
+ dict(
248
+ type='mmdet.PipelineSwitchHook',
249
+ switch_epoch=max_epochs - stage2_num_epochs,
250
+ switch_pipeline=train_pipeline_stage2)
251
+ ]
252
+
253
+ # evaluators
254
+ val_evaluator = dict(
255
+ type='CocoWholeBodyMetric',
256
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
257
+ test_evaluator = val_evaluator
musetalk/utils/face_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
musetalk/utils/face_detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __author__ = """Adrian Bulat"""
4
+ __email__ = '[email protected]'
5
+ __version__ = '1.0.1'
6
+
7
+ from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face
musetalk/utils/face_detection/api.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import torch
4
+ from torch.utils.model_zoo import load_url
5
+ from enum import Enum
6
+ import numpy as np
7
+ import cv2
8
+ try:
9
+ import urllib.request as request_file
10
+ except BaseException:
11
+ import urllib as request_file
12
+
13
+ from .models import FAN, ResNetDepth
14
+ from .utils import *
15
+
16
+
17
+ class LandmarksType(Enum):
18
+ """Enum class defining the type of landmarks to detect.
19
+
20
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
21
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
22
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
23
+
24
+ """
25
+ _2D = 1
26
+ _2halfD = 2
27
+ _3D = 3
28
+
29
+
30
+ class NetworkSize(Enum):
31
+ # TINY = 1
32
+ # SMALL = 2
33
+ # MEDIUM = 3
34
+ LARGE = 4
35
+
36
+ def __new__(cls, value):
37
+ member = object.__new__(cls)
38
+ member._value_ = value
39
+ return member
40
+
41
+ def __int__(self):
42
+ return self.value
43
+
44
+
45
+
46
+ class FaceAlignment:
47
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
49
+ self.device = device
50
+ self.flip_input = flip_input
51
+ self.landmarks_type = landmarks_type
52
+ self.verbose = verbose
53
+
54
+ network_size = int(network_size)
55
+
56
+ if 'cuda' in device:
57
+ torch.backends.cudnn.benchmark = True
58
+ # torch.backends.cuda.matmul.allow_tf32 = False
59
+ # torch.backends.cudnn.benchmark = True
60
+ # torch.backends.cudnn.deterministic = False
61
+ # torch.backends.cudnn.allow_tf32 = True
62
+ print('cuda start')
63
+
64
+
65
+ # Get the face detector
66
+ face_detector_module = __import__('face_detection.detection.' + face_detector,
67
+ globals(), locals(), [face_detector], 0)
68
+
69
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
70
+
71
+ def get_detections_for_batch(self, images):
72
+ images = images[..., ::-1]
73
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
74
+ results = []
75
+
76
+ for i, d in enumerate(detected_faces):
77
+ if len(d) == 0:
78
+ results.append(None)
79
+ continue
80
+ d = d[0]
81
+ d = np.clip(d, 0, None)
82
+
83
+ x1, y1, x2, y2 = map(int, d[:-1])
84
+ results.append((x1, y1, x2, y2))
85
+
86
+ return results
87
+
88
+
89
+ class YOLOv8_face:
90
+ def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5):
91
+ self.conf_threshold = conf_thres
92
+ self.iou_threshold = iou_thres
93
+ self.class_names = ['face']
94
+ self.num_classes = len(self.class_names)
95
+ # Initialize model
96
+ self.net = cv2.dnn.readNet(path)
97
+ self.input_height = 640
98
+ self.input_width = 640
99
+ self.reg_max = 16
100
+
101
+ self.project = np.arange(self.reg_max)
102
+ self.strides = (8, 16, 32)
103
+ self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))]
104
+ self.anchors = self.make_anchors(self.feats_hw)
105
+
106
+ def make_anchors(self, feats_hw, grid_cell_offset=0.5):
107
+ """Generate anchors from features."""
108
+ anchor_points = {}
109
+ for i, stride in enumerate(self.strides):
110
+ h,w = feats_hw[i]
111
+ x = np.arange(0, w) + grid_cell_offset # shift x
112
+ y = np.arange(0, h) + grid_cell_offset # shift y
113
+ sx, sy = np.meshgrid(x, y)
114
+ # sy, sx = np.meshgrid(y, x)
115
+ anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2)
116
+ return anchor_points
117
+
118
+ def softmax(self, x, axis=1):
119
+ x_exp = np.exp(x)
120
+ # 如果是列向量,则axis=0
121
+ x_sum = np.sum(x_exp, axis=axis, keepdims=True)
122
+ s = x_exp / x_sum
123
+ return s
124
+
125
+ def resize_image(self, srcimg, keep_ratio=True):
126
+ top, left, newh, neww = 0, 0, self.input_width, self.input_height
127
+ if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
128
+ hw_scale = srcimg.shape[0] / srcimg.shape[1]
129
+ if hw_scale > 1:
130
+ newh, neww = self.input_height, int(self.input_width / hw_scale)
131
+ img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
132
+ left = int((self.input_width - neww) * 0.5)
133
+ img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT,
134
+ value=(0, 0, 0)) # add border
135
+ else:
136
+ newh, neww = int(self.input_height * hw_scale), self.input_width
137
+ img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
138
+ top = int((self.input_height - newh) * 0.5)
139
+ img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT,
140
+ value=(0, 0, 0))
141
+ else:
142
+ img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA)
143
+ return img, newh, neww, top, left
144
+
145
+ def detect(self, srcimg):
146
+ input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB))
147
+ scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww
148
+ input_img = input_img.astype(np.float32) / 255.0
149
+
150
+ blob = cv2.dnn.blobFromImage(input_img)
151
+ self.net.setInput(blob)
152
+ outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
153
+ # if isinstance(outputs, tuple):
154
+ # outputs = list(outputs)
155
+ # if float(cv2.__version__[:3])>=4.7:
156
+ # outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步,opencv4.5不需要
157
+ # Perform inference on the image
158
+ det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw)
159
+ return det_bboxes, det_conf, det_classid, landmarks
160
+
161
+ def post_process(self, preds, scale_h, scale_w, padh, padw):
162
+ bboxes, scores, landmarks = [], [], []
163
+ for i, pred in enumerate(preds):
164
+ stride = int(self.input_height/pred.shape[2])
165
+ pred = pred.transpose((0, 2, 3, 1))
166
+
167
+ box = pred[..., :self.reg_max * 4]
168
+ cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1))
169
+ kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5
170
+
171
+ # tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max)
172
+ tmp = box.reshape(-1, 4, self.reg_max)
173
+ bbox_pred = self.softmax(tmp, axis=-1)
174
+ bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4))
175
+
176
+ bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride
177
+ kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride
178
+ kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride
179
+ kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3]))
180
+
181
+ bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则
182
+ bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]])
183
+ kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15))
184
+ kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15))
185
+
186
+ bboxes.append(bbox)
187
+ scores.append(cls)
188
+ landmarks.append(kpts)
189
+
190
+ bboxes = np.concatenate(bboxes, axis=0)
191
+ scores = np.concatenate(scores, axis=0)
192
+ landmarks = np.concatenate(landmarks, axis=0)
193
+
194
+ bboxes_wh = bboxes.copy()
195
+ bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh
196
+ classIds = np.argmax(scores, axis=1)
197
+ confidences = np.max(scores, axis=1) ####max_class_confidence
198
+
199
+ mask = confidences>self.conf_threshold
200
+ bboxes_wh = bboxes_wh[mask] ###合理使用广播法则
201
+ confidences = confidences[mask]
202
+ classIds = classIds[mask]
203
+ landmarks = landmarks[mask]
204
+
205
+ indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold,
206
+ self.iou_threshold).flatten()
207
+ if len(indices) > 0:
208
+ mlvl_bboxes = bboxes_wh[indices]
209
+ confidences = confidences[indices]
210
+ classIds = classIds[indices]
211
+ landmarks = landmarks[indices]
212
+ return mlvl_bboxes, confidences, classIds, landmarks
213
+ else:
214
+ print('nothing detect')
215
+ return np.array([]), np.array([]), np.array([]), np.array([])
216
+
217
+ def distance2bbox(self, points, distance, max_shape=None):
218
+ x1 = points[:, 0] - distance[:, 0]
219
+ y1 = points[:, 1] - distance[:, 1]
220
+ x2 = points[:, 0] + distance[:, 2]
221
+ y2 = points[:, 1] + distance[:, 3]
222
+ if max_shape is not None:
223
+ x1 = np.clip(x1, 0, max_shape[1])
224
+ y1 = np.clip(y1, 0, max_shape[0])
225
+ x2 = np.clip(x2, 0, max_shape[1])
226
+ y2 = np.clip(y2, 0, max_shape[0])
227
+ return np.stack([x1, y1, x2, y2], axis=-1)
228
+
229
+ def draw_detections(self, image, boxes, scores, kpts):
230
+ for box, score, kp in zip(boxes, scores, kpts):
231
+ x, y, w, h = box.astype(int)
232
+ # Draw rectangle
233
+ cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3)
234
+ cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2)
235
+ for i in range(5):
236
+ cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1)
237
+ # cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1)
238
+ return image
239
+
240
+ ROOT = os.path.dirname(os.path.abspath(__file__))
musetalk/utils/face_detection/detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FaceDetector
musetalk/utils/face_detection/detection/core.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+
9
+ class FaceDetector(object):
10
+ """An abstract class representing a face detector.
11
+
12
+ Any other face detection implementation must subclass it. All subclasses
13
+ must implement ``detect_from_image``, that return a list of detected
14
+ bounding boxes. Optionally, for speed considerations detect from path is
15
+ recommended.
16
+ """
17
+
18
+ def __init__(self, device, verbose):
19
+ self.device = device
20
+ self.verbose = verbose
21
+
22
+ if verbose:
23
+ if 'cpu' in device:
24
+ logger = logging.getLogger(__name__)
25
+ logger.warning("Detection running on CPU, this may be potentially slow.")
26
+
27
+ if 'cpu' not in device and 'cuda' not in device:
28
+ if verbose:
29
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30
+ raise ValueError
31
+
32
+ def detect_from_image(self, tensor_or_path):
33
+ """Detects faces in a given image.
34
+
35
+ This function detects the faces present in a provided BGR(usually)
36
+ image. The input can be either the image itself or the path to it.
37
+
38
+ Arguments:
39
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40
+ to an image or the image itself.
41
+
42
+ Example::
43
+
44
+ >>> path_to_image = 'data/image_01.jpg'
45
+ ... detected_faces = detect_from_image(path_to_image)
46
+ [A list of bounding boxes (x1, y1, x2, y2)]
47
+ >>> image = cv2.imread(path_to_image)
48
+ ... detected_faces = detect_from_image(image)
49
+ [A list of bounding boxes (x1, y1, x2, y2)]
50
+
51
+ """
52
+ raise NotImplementedError
53
+
54
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
55
+ """Detects faces from all the images present in a given directory.
56
+
57
+ Arguments:
58
+ path {string} -- a string containing a path that points to the folder containing the images
59
+
60
+ Keyword Arguments:
61
+ extensions {list} -- list of string containing the extensions to be
62
+ consider in the following format: ``.extension_name`` (default:
63
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
64
+ folder recursively (default: {False}) show_progress_bar {bool} --
65
+ display a progressbar (default: {True})
66
+
67
+ Example:
68
+ >>> directory = 'data'
69
+ ... detected_faces = detect_from_directory(directory)
70
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
71
+
72
+ """
73
+ if self.verbose:
74
+ logger = logging.getLogger(__name__)
75
+
76
+ if len(extensions) == 0:
77
+ if self.verbose:
78
+ logger.error("Expected at list one extension, but none was received.")
79
+ raise ValueError
80
+
81
+ if self.verbose:
82
+ logger.info("Constructing the list of images.")
83
+ additional_pattern = '/**/*' if recursive else '/*'
84
+ files = []
85
+ for extension in extensions:
86
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
87
+
88
+ if self.verbose:
89
+ logger.info("Finished searching for images. %s images found", len(files))
90
+ logger.info("Preparing to run the detection.")
91
+
92
+ predictions = {}
93
+ for image_path in tqdm(files, disable=not show_progress_bar):
94
+ if self.verbose:
95
+ logger.info("Running the face detector on image: %s", image_path)
96
+ predictions[image_path] = self.detect_from_image(image_path)
97
+
98
+ if self.verbose:
99
+ logger.info("The detector was successfully run on all %s images", len(files))
100
+
101
+ return predictions
102
+
103
+ @property
104
+ def reference_scale(self):
105
+ raise NotImplementedError
106
+
107
+ @property
108
+ def reference_x_shift(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def reference_y_shift(self):
113
+ raise NotImplementedError
114
+
115
+ @staticmethod
116
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
117
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
118
+
119
+ Arguments:
120
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
121
+ """
122
+ if isinstance(tensor_or_path, str):
123
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
124
+ elif torch.is_tensor(tensor_or_path):
125
+ # Call cpu in case its coming from cuda
126
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
127
+ elif isinstance(tensor_or_path, np.ndarray):
128
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
129
+ else:
130
+ raise TypeError
musetalk/utils/face_detection/detection/sfd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sfd_detector import SFDDetector as FaceDetector
musetalk/utils/face_detection/detection/sfd/bbox.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import random
6
+ import datetime
7
+ import time
8
+ import math
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+
13
+ try:
14
+ from iou import IOU
15
+ except BaseException:
16
+ # IOU cython speedup 10x
17
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
19
+ sb = abs((bx2 - bx1) * (by2 - by1))
20
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
21
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ if w < 0 or h < 0:
25
+ return 0.0
26
+ else:
27
+ return 1.0 * w * h / (sa + sb - w * h)
28
+
29
+
30
+ def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
34
+ return dx, dy, dw, dh
35
+
36
+
37
+ def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38
+ xc, yc = dx * aww + axc, dy * ahh + ayc
39
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def nms(dets, thresh):
45
+ if 0 == len(dets):
46
+ return []
47
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ order = scores.argsort()[::-1]
50
+
51
+ keep = []
52
+ while order.size > 0:
53
+ i = order[0]
54
+ keep.append(i)
55
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57
+
58
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60
+
61
+ inds = np.where(ovr <= thresh)[0]
62
+ order = order[inds + 1]
63
+
64
+ return keep
65
+
66
+
67
+ def encode(matched, priors, variances):
68
+ """Encode the variances from the priorbox layers into the ground truth boxes
69
+ we have matched (based on jaccard overlap) with the prior boxes.
70
+ Args:
71
+ matched: (tensor) Coords of ground truth for each prior in point-form
72
+ Shape: [num_priors, 4].
73
+ priors: (tensor) Prior boxes in center-offset form
74
+ Shape: [num_priors,4].
75
+ variances: (list[float]) Variances of priorboxes
76
+ Return:
77
+ encoded boxes (tensor), Shape: [num_priors, 4]
78
+ """
79
+
80
+ # dist b/t match center and prior's center
81
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82
+ # encode variance
83
+ g_cxcy /= (variances[0] * priors[:, 2:])
84
+ # match wh / prior wh
85
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86
+ g_wh = torch.log(g_wh) / variances[1]
87
+ # return target for smooth_l1_loss
88
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89
+
90
+
91
+ def decode(loc, priors, variances):
92
+ """Decode locations from predictions using priors to undo
93
+ the encoding we did for offset regression at train time.
94
+ Args:
95
+ loc (tensor): location predictions for loc layers,
96
+ Shape: [num_priors,4]
97
+ priors (tensor): Prior boxes in center-offset form.
98
+ Shape: [num_priors,4].
99
+ variances: (list[float]) Variances of priorboxes
100
+ Return:
101
+ decoded bounding box predictions
102
+ """
103
+
104
+ boxes = torch.cat((
105
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107
+ boxes[:, :2] -= boxes[:, 2:] / 2
108
+ boxes[:, 2:] += boxes[:, :2]
109
+ return boxes
110
+
111
+ def batch_decode(loc, priors, variances):
112
+ """Decode locations from predictions using priors to undo
113
+ the encoding we did for offset regression at train time.
114
+ Args:
115
+ loc (tensor): location predictions for loc layers,
116
+ Shape: [num_priors,4]
117
+ priors (tensor): Prior boxes in center-offset form.
118
+ Shape: [num_priors,4].
119
+ variances: (list[float]) Variances of priorboxes
120
+ Return:
121
+ decoded bounding box predictions
122
+ """
123
+
124
+ boxes = torch.cat((
125
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128
+ boxes[:, :, 2:] += boxes[:, :, :2]
129
+ return boxes
musetalk/utils/face_detection/detection/sfd/detect.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import cv2
7
+ import random
8
+ import datetime
9
+ import math
10
+ import argparse
11
+ import numpy as np
12
+
13
+ import scipy.io as sio
14
+ import zipfile
15
+ from .net_s3fd import s3fd
16
+ from .bbox import *
17
+
18
+
19
+ def detect(net, img, device):
20
+ img = img - np.array([104, 117, 123])
21
+ img = img.transpose(2, 0, 1)
22
+ img = img.reshape((1,) + img.shape)
23
+
24
+ if 'cuda' in device:
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ img = torch.from_numpy(img).float().to(device)
28
+ BB, CC, HH, WW = img.size()
29
+ with torch.no_grad():
30
+ olist = net(img)
31
+
32
+ bboxlist = []
33
+ for i in range(len(olist) // 2):
34
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35
+ olist = [oelem.data.cpu() for oelem in olist]
36
+ for i in range(len(olist) // 2):
37
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38
+ FB, FC, FH, FW = ocls.size() # feature map size
39
+ stride = 2**(i + 2) # 4,8,16,32,64,128
40
+ anchor = stride * 4
41
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42
+ for Iindex, hindex, windex in poss:
43
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44
+ score = ocls[0, 1, hindex, windex]
45
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47
+ variances = [0.1, 0.2]
48
+ box = decode(loc, priors, variances)
49
+ x1, y1, x2, y2 = box[0] * 1.0
50
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51
+ bboxlist.append([x1, y1, x2, y2, score])
52
+ bboxlist = np.array(bboxlist)
53
+ if 0 == len(bboxlist):
54
+ bboxlist = np.zeros((1, 5))
55
+
56
+ return bboxlist
57
+
58
+ def batch_detect(net, imgs, device):
59
+ imgs = imgs - np.array([104, 117, 123])
60
+ imgs = imgs.transpose(0, 3, 1, 2)
61
+
62
+ if 'cuda' in device:
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ imgs = torch.from_numpy(imgs).float().to(device)
66
+ BB, CC, HH, WW = imgs.size()
67
+ with torch.no_grad():
68
+ olist = net(imgs)
69
+ # print(olist)
70
+
71
+ bboxlist = []
72
+ for i in range(len(olist) // 2):
73
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
74
+
75
+ olist = [oelem.cpu() for oelem in olist]
76
+ for i in range(len(olist) // 2):
77
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
78
+ FB, FC, FH, FW = ocls.size() # feature map size
79
+ stride = 2**(i + 2) # 4,8,16,32,64,128
80
+ anchor = stride * 4
81
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
82
+ for Iindex, hindex, windex in poss:
83
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
84
+ score = ocls[:, 1, hindex, windex]
85
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
86
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
87
+ variances = [0.1, 0.2]
88
+ box = batch_decode(loc, priors, variances)
89
+ box = box[:, 0] * 1.0
90
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
91
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
92
+ bboxlist = np.array(bboxlist)
93
+ if 0 == len(bboxlist):
94
+ bboxlist = np.zeros((1, BB, 5))
95
+
96
+ return bboxlist
97
+
98
+ def flip_detect(net, img, device):
99
+ img = cv2.flip(img, 1)
100
+ b = detect(net, img, device)
101
+
102
+ bboxlist = np.zeros(b.shape)
103
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
104
+ bboxlist[:, 1] = b[:, 1]
105
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
106
+ bboxlist[:, 3] = b[:, 3]
107
+ bboxlist[:, 4] = b[:, 4]
108
+ return bboxlist
109
+
110
+
111
+ def pts_to_bb(pts):
112
+ min_x, min_y = np.min(pts, axis=0)
113
+ max_x, max_y = np.max(pts, axis=0)
114
+ return np.array([min_x, min_y, max_x, max_y])
musetalk/utils/face_detection/detection/sfd/net_s3fd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L2Norm(nn.Module):
7
+ def __init__(self, n_channels, scale=1.0):
8
+ super(L2Norm, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.scale = scale
11
+ self.eps = 1e-10
12
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13
+ self.weight.data *= 0.0
14
+ self.weight.data += self.scale
15
+
16
+ def forward(self, x):
17
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18
+ x = x / norm * self.weight.view(1, -1, 1, 1)
19
+ return x
20
+
21
+
22
+ class s3fd(nn.Module):
23
+ def __init__(self):
24
+ super(s3fd, self).__init__()
25
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30
+
31
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38
+
39
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42
+
43
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45
+
46
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48
+
49
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51
+
52
+ self.conv3_3_norm = L2Norm(256, scale=10)
53
+ self.conv4_3_norm = L2Norm(512, scale=8)
54
+ self.conv5_3_norm = L2Norm(512, scale=5)
55
+
56
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62
+
63
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, x):
71
+ h = F.relu(self.conv1_1(x))
72
+ h = F.relu(self.conv1_2(h))
73
+ h = F.max_pool2d(h, 2, 2)
74
+
75
+ h = F.relu(self.conv2_1(h))
76
+ h = F.relu(self.conv2_2(h))
77
+ h = F.max_pool2d(h, 2, 2)
78
+
79
+ h = F.relu(self.conv3_1(h))
80
+ h = F.relu(self.conv3_2(h))
81
+ h = F.relu(self.conv3_3(h))
82
+ f3_3 = h
83
+ h = F.max_pool2d(h, 2, 2)
84
+
85
+ h = F.relu(self.conv4_1(h))
86
+ h = F.relu(self.conv4_2(h))
87
+ h = F.relu(self.conv4_3(h))
88
+ f4_3 = h
89
+ h = F.max_pool2d(h, 2, 2)
90
+
91
+ h = F.relu(self.conv5_1(h))
92
+ h = F.relu(self.conv5_2(h))
93
+ h = F.relu(self.conv5_3(h))
94
+ f5_3 = h
95
+ h = F.max_pool2d(h, 2, 2)
96
+
97
+ h = F.relu(self.fc6(h))
98
+ h = F.relu(self.fc7(h))
99
+ ffc7 = h
100
+ h = F.relu(self.conv6_1(h))
101
+ h = F.relu(self.conv6_2(h))
102
+ f6_2 = h
103
+ h = F.relu(self.conv7_1(h))
104
+ h = F.relu(self.conv7_2(h))
105
+ f7_2 = h
106
+
107
+ f3_3 = self.conv3_3_norm(f3_3)
108
+ f4_3 = self.conv4_3_norm(f4_3)
109
+ f5_3 = self.conv5_3_norm(f5_3)
110
+
111
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117
+ cls4 = self.fc7_mbox_conf(ffc7)
118
+ reg4 = self.fc7_mbox_loc(ffc7)
119
+ cls5 = self.conv6_2_mbox_conf(f6_2)
120
+ reg5 = self.conv6_2_mbox_loc(f6_2)
121
+ cls6 = self.conv7_2_mbox_conf(f7_2)
122
+ reg6 = self.conv7_2_mbox_loc(f7_2)
123
+
124
+ # max-out background label
125
+ chunk = torch.chunk(cls1, 4, 1)
126
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
128
+
129
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
musetalk/utils/face_detection/detection/sfd/sfd_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from torch.utils.model_zoo import load_url
4
+
5
+ from ..core import FaceDetector
6
+
7
+ from .net_s3fd import s3fd
8
+ from .bbox import *
9
+ from .detect import *
10
+
11
+ models_urls = {
12
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13
+ }
14
+
15
+
16
+ class SFDDetector(FaceDetector):
17
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18
+ super(SFDDetector, self).__init__(device, verbose)
19
+
20
+ # Initialise the face detector
21
+ if not os.path.isfile(path_to_detector):
22
+ model_weights = load_url(models_urls['s3fd'])
23
+ else:
24
+ model_weights = torch.load(path_to_detector)
25
+
26
+ self.face_detector = s3fd()
27
+ self.face_detector.load_state_dict(model_weights)
28
+ self.face_detector.to(device)
29
+ self.face_detector.eval()
30
+
31
+ def detect_from_image(self, tensor_or_path):
32
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
33
+
34
+ bboxlist = detect(self.face_detector, image, device=self.device)
35
+ keep = nms(bboxlist, 0.3)
36
+ bboxlist = bboxlist[keep, :]
37
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38
+
39
+ return bboxlist
40
+
41
+ def detect_from_batch(self, images):
42
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
43
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
44
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
45
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
46
+
47
+ return bboxlists
48
+
49
+ @property
50
+ def reference_scale(self):
51
+ return 195
52
+
53
+ @property
54
+ def reference_x_shift(self):
55
+ return 0
56
+
57
+ @property
58
+ def reference_y_shift(self):
59
+ return 0
musetalk/utils/face_detection/models.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8
+ "3x3 convolution with padding"
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10
+ stride=strd, padding=padding, bias=bias)
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_planes, out_planes):
15
+ super(ConvBlock, self).__init__()
16
+ self.bn1 = nn.BatchNorm2d(in_planes)
17
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22
+
23
+ if in_planes != out_planes:
24
+ self.downsample = nn.Sequential(
25
+ nn.BatchNorm2d(in_planes),
26
+ nn.ReLU(True),
27
+ nn.Conv2d(in_planes, out_planes,
28
+ kernel_size=1, stride=1, bias=False),
29
+ )
30
+ else:
31
+ self.downsample = None
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out1 = self.bn1(x)
37
+ out1 = F.relu(out1, True)
38
+ out1 = self.conv1(out1)
39
+
40
+ out2 = self.bn2(out1)
41
+ out2 = F.relu(out2, True)
42
+ out2 = self.conv2(out2)
43
+
44
+ out3 = self.bn3(out2)
45
+ out3 = F.relu(out3, True)
46
+ out3 = self.conv3(out3)
47
+
48
+ out3 = torch.cat((out1, out2, out3), 1)
49
+
50
+ if self.downsample is not None:
51
+ residual = self.downsample(residual)
52
+
53
+ out3 += residual
54
+
55
+ return out3
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+
60
+ expansion = 4
61
+
62
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(planes)
66
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67
+ padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70
+ self.bn3 = nn.BatchNorm2d(planes * 4)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HourGlass(nn.Module):
99
+ def __init__(self, num_modules, depth, num_features):
100
+ super(HourGlass, self).__init__()
101
+ self.num_modules = num_modules
102
+ self.depth = depth
103
+ self.features = num_features
104
+
105
+ self._generate_network(self.depth)
106
+
107
+ def _generate_network(self, level):
108
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109
+
110
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111
+
112
+ if level > 1:
113
+ self._generate_network(level - 1)
114
+ else:
115
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116
+
117
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118
+
119
+ def _forward(self, level, inp):
120
+ # Upper branch
121
+ up1 = inp
122
+ up1 = self._modules['b1_' + str(level)](up1)
123
+
124
+ # Lower branch
125
+ low1 = F.avg_pool2d(inp, 2, stride=2)
126
+ low1 = self._modules['b2_' + str(level)](low1)
127
+
128
+ if level > 1:
129
+ low2 = self._forward(level - 1, low1)
130
+ else:
131
+ low2 = low1
132
+ low2 = self._modules['b2_plus_' + str(level)](low2)
133
+
134
+ low3 = low2
135
+ low3 = self._modules['b3_' + str(level)](low3)
136
+
137
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138
+
139
+ return up1 + up2
140
+
141
+ def forward(self, x):
142
+ return self._forward(self.depth, x)
143
+
144
+
145
+ class FAN(nn.Module):
146
+
147
+ def __init__(self, num_modules=1):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+
151
+ # Base part
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153
+ self.bn1 = nn.BatchNorm2d(64)
154
+ self.conv2 = ConvBlock(64, 128)
155
+ self.conv3 = ConvBlock(128, 128)
156
+ self.conv4 = ConvBlock(128, 256)
157
+
158
+ # Stacking part
159
+ for hg_module in range(self.num_modules):
160
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162
+ self.add_module('conv_last' + str(hg_module),
163
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
166
+ 68, kernel_size=1, stride=1, padding=0))
167
+
168
+ if hg_module < self.num_modules - 1:
169
+ self.add_module(
170
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
172
+ 256, kernel_size=1, stride=1, padding=0))
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.bn1(self.conv1(x)), True)
176
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177
+ x = self.conv3(x)
178
+ x = self.conv4(x)
179
+
180
+ previous = x
181
+
182
+ outputs = []
183
+ for i in range(self.num_modules):
184
+ hg = self._modules['m' + str(i)](previous)
185
+
186
+ ll = hg
187
+ ll = self._modules['top_m_' + str(i)](ll)
188
+
189
+ ll = F.relu(self._modules['bn_end' + str(i)]
190
+ (self._modules['conv_last' + str(i)](ll)), True)
191
+
192
+ # Predict heatmaps
193
+ tmp_out = self._modules['l' + str(i)](ll)
194
+ outputs.append(tmp_out)
195
+
196
+ if i < self.num_modules - 1:
197
+ ll = self._modules['bl' + str(i)](ll)
198
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
199
+ previous = previous + ll + tmp_out_
200
+
201
+ return outputs
202
+
203
+
204
+ class ResNetDepth(nn.Module):
205
+
206
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207
+ self.inplanes = 64
208
+ super(ResNetDepth, self).__init__()
209
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210
+ bias=False)
211
+ self.bn1 = nn.BatchNorm2d(64)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214
+ self.layer1 = self._make_layer(block, 64, layers[0])
215
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218
+ self.avgpool = nn.AvgPool2d(7)
219
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
220
+
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224
+ m.weight.data.normal_(0, math.sqrt(2. / n))
225
+ elif isinstance(m, nn.BatchNorm2d):
226
+ m.weight.data.fill_(1)
227
+ m.bias.data.zero_()
228
+
229
+ def _make_layer(self, block, planes, blocks, stride=1):
230
+ downsample = None
231
+ if stride != 1 or self.inplanes != planes * block.expansion:
232
+ downsample = nn.Sequential(
233
+ nn.Conv2d(self.inplanes, planes * block.expansion,
234
+ kernel_size=1, stride=stride, bias=False),
235
+ nn.BatchNorm2d(planes * block.expansion),
236
+ )
237
+
238
+ layers = []
239
+ layers.append(block(self.inplanes, planes, stride, downsample))
240
+ self.inplanes = planes * block.expansion
241
+ for i in range(1, blocks):
242
+ layers.append(block(self.inplanes, planes))
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def forward(self, x):
247
+ x = self.conv1(x)
248
+ x = self.bn1(x)
249
+ x = self.relu(x)
250
+ x = self.maxpool(x)
251
+
252
+ x = self.layer1(x)
253
+ x = self.layer2(x)
254
+ x = self.layer3(x)
255
+ x = self.layer4(x)
256
+
257
+ x = self.avgpool(x)
258
+ x = x.view(x.size(0), -1)
259
+ x = self.fc(x)
260
+
261
+ return x
musetalk/utils/face_detection/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import time
5
+ import torch
6
+ import math
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ def _gaussian(
12
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14
+ mean_vert=0.5):
15
+ # handle some defaults
16
+ if width is None:
17
+ width = size
18
+ if height is None:
19
+ height = size
20
+ if sigma_horz is None:
21
+ sigma_horz = sigma
22
+ if sigma_vert is None:
23
+ sigma_vert = sigma
24
+ center_x = mean_horz * width + 0.5
25
+ center_y = mean_vert * height + 0.5
26
+ gauss = np.empty((height, width), dtype=np.float32)
27
+ # generate kernel
28
+ for i in range(height):
29
+ for j in range(width):
30
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32
+ if normalize:
33
+ gauss = gauss / np.sum(gauss)
34
+ return gauss
35
+
36
+
37
+ def draw_gaussian(image, point, sigma):
38
+ # Check if the gaussian is inside
39
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42
+ return image
43
+ size = 6 * sigma + 1
44
+ g = _gaussian(size)
45
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49
+ assert (g_x[0] > 0 and g_y[1] > 0)
50
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52
+ image[image > 1] = 1
53
+ return image
54
+
55
+
56
+ def transform(point, center, scale, resolution, invert=False):
57
+ """Generate and affine transformation matrix.
58
+
59
+ Given a set of points, a center, a scale and a targer resolution, the
60
+ function generates and affine transformation matrix. If invert is ``True``
61
+ it will produce the inverse transformation.
62
+
63
+ Arguments:
64
+ point {torch.tensor} -- the input 2D point
65
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66
+ scale {float} -- the scale of the face/object
67
+ resolution {float} -- the output resolution
68
+
69
+ Keyword Arguments:
70
+ invert {bool} -- define wherever the function should produce the direct or the
71
+ inverse transformation matrix (default: {False})
72
+ """
73
+ _pt = torch.ones(3)
74
+ _pt[0] = point[0]
75
+ _pt[1] = point[1]
76
+
77
+ h = 200.0 * scale
78
+ t = torch.eye(3)
79
+ t[0, 0] = resolution / h
80
+ t[1, 1] = resolution / h
81
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
82
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
83
+
84
+ if invert:
85
+ t = torch.inverse(t)
86
+
87
+ new_point = (torch.matmul(t, _pt))[0:2]
88
+
89
+ return new_point.int()
90
+
91
+
92
+ def crop(image, center, scale, resolution=256.0):
93
+ """Center crops an image or set of heatmaps
94
+
95
+ Arguments:
96
+ image {numpy.array} -- an rgb image
97
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
98
+ scale {float} -- scale of the face
99
+
100
+ Keyword Arguments:
101
+ resolution {float} -- the size of the output cropped image (default: {256.0})
102
+
103
+ Returns:
104
+ [type] -- [description]
105
+ """ # Crop around the center point
106
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
107
+ ul = transform([1, 1], center, scale, resolution, True)
108
+ br = transform([resolution, resolution], center, scale, resolution, True)
109
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110
+ if image.ndim > 2:
111
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112
+ image.shape[2]], dtype=np.int32)
113
+ newImg = np.zeros(newDim, dtype=np.uint8)
114
+ else:
115
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116
+ newImg = np.zeros(newDim, dtype=np.uint8)
117
+ ht = image.shape[0]
118
+ wd = image.shape[1]
119
+ newX = np.array(
120
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121
+ newY = np.array(
122
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128
+ interpolation=cv2.INTER_LINEAR)
129
+ return newImg
130
+
131
+
132
+ def get_preds_fromhm(hm, center=None, scale=None):
133
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134
+ and the scale is provided the function will return the points also in
135
+ the original coordinate frame.
136
+
137
+ Arguments:
138
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139
+
140
+ Keyword Arguments:
141
+ center {torch.tensor} -- the center of the bounding box (default: {None})
142
+ scale {float} -- face scale (default: {None})
143
+ """
144
+ max, idx = torch.max(
145
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146
+ idx += 1
147
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150
+
151
+ for i in range(preds.size(0)):
152
+ for j in range(preds.size(1)):
153
+ hm_ = hm[i, j, :]
154
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156
+ diff = torch.FloatTensor(
157
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159
+ preds[i, j].add_(diff.sign_().mul_(.25))
160
+
161
+ preds.add_(-.5)
162
+
163
+ preds_orig = torch.zeros(preds.size())
164
+ if center is not None and scale is not None:
165
+ for i in range(hm.size(0)):
166
+ for j in range(hm.size(1)):
167
+ preds_orig[i, j] = transform(
168
+ preds[i, j], center, scale, hm.size(2), True)
169
+
170
+ return preds, preds_orig
171
+
172
+ def get_preds_fromhm_batch(hm, centers=None, scales=None):
173
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174
+ and the scales is provided the function will return the points also in
175
+ the original coordinate frame.
176
+
177
+ Arguments:
178
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179
+
180
+ Keyword Arguments:
181
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
182
+ scales {float} -- face scales (default: {None})
183
+ """
184
+ max, idx = torch.max(
185
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186
+ idx += 1
187
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190
+
191
+ for i in range(preds.size(0)):
192
+ for j in range(preds.size(1)):
193
+ hm_ = hm[i, j, :]
194
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196
+ diff = torch.FloatTensor(
197
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199
+ preds[i, j].add_(diff.sign_().mul_(.25))
200
+
201
+ preds.add_(-.5)
202
+
203
+ preds_orig = torch.zeros(preds.size())
204
+ if centers is not None and scales is not None:
205
+ for i in range(hm.size(0)):
206
+ for j in range(hm.size(1)):
207
+ preds_orig[i, j] = transform(
208
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
209
+
210
+ return preds, preds_orig
211
+
212
+ def shuffle_lr(parts, pairs=None):
213
+ """Shuffle the points left-right according to the axis of symmetry
214
+ of the object.
215
+
216
+ Arguments:
217
+ parts {torch.tensor} -- a 3D or 4D object containing the
218
+ heatmaps.
219
+
220
+ Keyword Arguments:
221
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
222
+ """
223
+ if pairs is None:
224
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228
+ 62, 61, 60, 67, 66, 65]
229
+ if parts.ndimension() == 3:
230
+ parts = parts[pairs, ...]
231
+ else:
232
+ parts = parts[:, pairs, ...]
233
+
234
+ return parts
235
+
236
+
237
+ def flip(tensor, is_label=False):
238
+ """Flip an image or a set of heatmaps left-right
239
+
240
+ Arguments:
241
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242
+
243
+ Keyword Arguments:
244
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245
+ """
246
+ if not torch.is_tensor(tensor):
247
+ tensor = torch.from_numpy(tensor)
248
+
249
+ if is_label:
250
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251
+ else:
252
+ tensor = tensor.flip(tensor.ndimension() - 1)
253
+
254
+ return tensor
255
+
256
+ # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257
+
258
+
259
+ def appdata_dir(appname=None, roaming=False):
260
+ """ appdata_dir(appname=None, roaming=False)
261
+
262
+ Get the path to the application directory, where applications are allowed
263
+ to write user specific files (e.g. configurations). For non-user specific
264
+ data, consider using common_appdata_dir().
265
+ If appname is given, a subdir is appended (and created if necessary).
266
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
267
+ """
268
+
269
+ # Define default user directory
270
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271
+ if userDir is None:
272
+ userDir = os.path.expanduser('~')
273
+ if not os.path.isdir(userDir): # pragma: no cover
274
+ userDir = '/var/tmp' # issue #54
275
+
276
+ # Get system app data dir
277
+ path = None
278
+ if sys.platform.startswith('win'):
279
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280
+ path = (path2 or path1) if roaming else (path1 or path2)
281
+ elif sys.platform.startswith('darwin'):
282
+ path = os.path.join(userDir, 'Library', 'Application Support')
283
+ # On Linux and as fallback
284
+ if not (path and os.path.isdir(path)):
285
+ path = userDir
286
+
287
+ # Maybe we should store things local to the executable (in case of a
288
+ # portable distro or a frozen application that wants to be portable)
289
+ prefix = sys.prefix
290
+ if getattr(sys, 'frozen', None):
291
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
292
+ for reldir in ('settings', '../settings'):
293
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
294
+ if os.path.isdir(localpath): # pragma: no cover
295
+ try:
296
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
297
+ os.remove(os.path.join(localpath, 'test.write'))
298
+ except IOError:
299
+ pass # We cannot write in this directory
300
+ else:
301
+ path = localpath
302
+ break
303
+
304
+ # Get path specific for this app
305
+ if appname:
306
+ if path == userDir:
307
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
308
+ path = os.path.join(path, appname)
309
+ if not os.path.isdir(path): # pragma: no cover
310
+ os.mkdir(path)
311
+
312
+ # Done
313
+ return path
musetalk/utils/face_parsing/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ from .model import BiSeNet
8
+ import torchvision.transforms as transforms
9
+
10
+ class FaceParsing():
11
+ def __init__(self, left_cheek_width=80, right_cheek_width=80):
12
+ self.net = self.model_init()
13
+ self.preprocess = self.image_preprocess()
14
+ # Ensure all size parameters are integers
15
+ cone_height = 21
16
+ tail_height = 12
17
+ total_size = cone_height + tail_height
18
+
19
+ # Create kernel with explicit integer dimensions
20
+ kernel = np.zeros((total_size, total_size), dtype=np.uint8)
21
+ center_x = total_size // 2 # Ensure center coordinates are integers
22
+
23
+ # Cone part
24
+ for row in range(cone_height):
25
+ if row < cone_height//2:
26
+ continue
27
+ width = int(2 * (row - cone_height//2) + 1)
28
+ start = int(center_x - (width // 2))
29
+ end = int(center_x + (width // 2) + 1)
30
+ kernel[row, start:end] = 1
31
+
32
+ # Vertical extension part
33
+ if cone_height > 0:
34
+ base_width = int(kernel[cone_height-1].sum())
35
+ else:
36
+ base_width = 1
37
+
38
+ for row in range(cone_height, total_size):
39
+ start = max(0, int(center_x - (base_width//2)))
40
+ end = min(total_size, int(center_x + (base_width//2) + 1))
41
+ kernel[row, start:end] = 1
42
+ self.kernel = kernel
43
+
44
+ # Modify cheek erosion kernel to be flatter ellipse
45
+ self.cheek_kernel = cv2.getStructuringElement(
46
+ cv2.MORPH_ELLIPSE, (35, 3))
47
+
48
+ # Add cheek area mask (protect chin area)
49
+ self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
50
+
51
+ def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
52
+ """Create cheek area mask (1/4 area on both sides)"""
53
+ mask = np.zeros((512, 512), dtype=np.uint8)
54
+ center = 512 // 2
55
+ cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
56
+ cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
57
+ return mask
58
+
59
+ def model_init(self,
60
+ resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
61
+ model_pth='./models/face-parse-bisent/79999_iter.pth'):
62
+ net = BiSeNet(resnet_path)
63
+ if torch.cuda.is_available():
64
+ net.cuda()
65
+ net.load_state_dict(torch.load(model_pth))
66
+ else:
67
+ net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
68
+ net.eval()
69
+ return net
70
+
71
+ def image_preprocess(self):
72
+ return transforms.Compose([
73
+ transforms.ToTensor(),
74
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
75
+ ])
76
+
77
+ def __call__(self, image, size=(512, 512), mode="raw"):
78
+ if isinstance(image, str):
79
+ image = Image.open(image)
80
+
81
+ width, height = image.size
82
+ with torch.no_grad():
83
+ image = image.resize(size, Image.BILINEAR)
84
+ img = self.preprocess(image)
85
+ if torch.cuda.is_available():
86
+ img = torch.unsqueeze(img, 0).cuda()
87
+ else:
88
+ img = torch.unsqueeze(img, 0)
89
+ out = self.net(img)[0]
90
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
91
+
92
+ # Add 14:neck, remove 10:nose and 7:8:9
93
+ if mode == "neck":
94
+ parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
95
+ parsing[np.where(parsing!=255)] = 0
96
+ elif mode == "jaw":
97
+ face_region = np.isin(parsing, [1])*255
98
+ face_region = face_region.astype(np.uint8)
99
+ original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
100
+ eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
101
+ face_region = cv2.bitwise_and(eroded, self.cheek_mask)
102
+ face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
103
+ parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
104
+ parsing[np.isin(parsing, [11, 12, 13])] = 255
105
+ parsing[np.where(parsing!=255)] = 0
106
+ else:
107
+ parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
108
+ parsing[np.where(parsing!=255)] = 0
109
+
110
+ parsing = Image.fromarray(parsing.astype(np.uint8))
111
+ return parsing
112
+
113
+ if __name__ == "__main__":
114
+ fp = FaceParsing()
115
+ segmap = fp('154_small.png')
116
+ segmap.save('res.png')
117
+
musetalk/utils/face_parsing/model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from .resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, resnet_path, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18(resnet_path)
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath(resnet_path)
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
musetalk/utils/face_parsing/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self, model_path):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight(model_path)
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self, model_path):
83
+ state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
musetalk/utils/preprocessing.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from face_detection import FaceAlignment,LandmarksType
3
+ from os import listdir, path
4
+ import subprocess
5
+ import numpy as np
6
+ import cv2
7
+ import pickle
8
+ import os
9
+ import json
10
+ from mmpose.apis import inference_topdown, init_model
11
+ from mmpose.structures import merge_data_samples
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ # initialize the mmpose model
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
18
+ checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
19
+ model = init_model(config_file, checkpoint_file, device=device)
20
+
21
+ # initialize the face detection model
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device)
24
+
25
+ # maker if the bbox is not sufficient
26
+ coord_placeholder = (0.0,0.0,0.0,0.0)
27
+
28
+ def resize_landmark(landmark, w, h, new_w, new_h):
29
+ w_ratio = new_w / w
30
+ h_ratio = new_h / h
31
+ landmark_norm = landmark / [w, h]
32
+ landmark_resized = landmark_norm * [new_w, new_h]
33
+ return landmark_resized
34
+
35
+ def read_imgs(img_list):
36
+ frames = []
37
+ print('reading images...')
38
+ for img_path in tqdm(img_list):
39
+ frame = cv2.imread(img_path)
40
+ frames.append(frame)
41
+ return frames
42
+
43
+ def get_bbox_range(img_list,upperbondrange =0):
44
+ frames = read_imgs(img_list)
45
+ batch_size_fa = 1
46
+ batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
47
+ coords_list = []
48
+ landmarks = []
49
+ if upperbondrange != 0:
50
+ print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
51
+ else:
52
+ print('get key_landmark and face bounding boxes with the default value')
53
+ average_range_minus = []
54
+ average_range_plus = []
55
+ for fb in tqdm(batches):
56
+ results = inference_topdown(model, np.asarray(fb)[0])
57
+ results = merge_data_samples(results)
58
+ keypoints = results.pred_instances.keypoints
59
+ face_land_mark= keypoints[0][23:91]
60
+ face_land_mark = face_land_mark.astype(np.int32)
61
+
62
+ # get bounding boxes by face detetion
63
+ bbox = fa.get_detections_for_batch(np.asarray(fb))
64
+
65
+ # adjust the bounding box refer to landmark
66
+ # Add the bounding box to a tuple and append it to the coordinates list
67
+ for j, f in enumerate(bbox):
68
+ if f is None: # no face in the image
69
+ coords_list += [coord_placeholder]
70
+ continue
71
+
72
+ half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
73
+ range_minus = (face_land_mark[30]- face_land_mark[29])[1]
74
+ range_plus = (face_land_mark[29]- face_land_mark[28])[1]
75
+ average_range_minus.append(range_minus)
76
+ average_range_plus.append(range_plus)
77
+ if upperbondrange != 0:
78
+ half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
79
+
80
+ text_range=f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}"
81
+ return text_range
82
+
83
+
84
+ def get_landmark_and_bbox(img_list,upperbondrange =0):
85
+ frames = read_imgs(img_list)
86
+ batch_size_fa = 1
87
+ batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
88
+ coords_list = []
89
+ landmarks = []
90
+ if upperbondrange != 0:
91
+ print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
92
+ else:
93
+ print('get key_landmark and face bounding boxes with the default value')
94
+ average_range_minus = []
95
+ average_range_plus = []
96
+ for fb in tqdm(batches):
97
+ results = inference_topdown(model, np.asarray(fb)[0])
98
+ results = merge_data_samples(results)
99
+ keypoints = results.pred_instances.keypoints
100
+ face_land_mark= keypoints[0][23:91]
101
+ face_land_mark = face_land_mark.astype(np.int32)
102
+
103
+ # get bounding boxes by face detetion
104
+ bbox = fa.get_detections_for_batch(np.asarray(fb))
105
+
106
+ # adjust the bounding box refer to landmark
107
+ # Add the bounding box to a tuple and append it to the coordinates list
108
+ for j, f in enumerate(bbox):
109
+ if f is None: # no face in the image
110
+ coords_list += [coord_placeholder]
111
+ continue
112
+
113
+ half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
114
+ range_minus = (face_land_mark[30]- face_land_mark[29])[1]
115
+ range_plus = (face_land_mark[29]- face_land_mark[28])[1]
116
+ average_range_minus.append(range_minus)
117
+ average_range_plus.append(range_plus)
118
+ if upperbondrange != 0:
119
+ half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
120
+ half_face_dist = np.max(face_land_mark[:,1]) - half_face_coord[1]
121
+ min_upper_bond = 0
122
+ upper_bond = max(min_upper_bond, half_face_coord[1] - half_face_dist)
123
+
124
+ f_landmark = (np.min(face_land_mark[:, 0]),int(upper_bond),np.max(face_land_mark[:, 0]),np.max(face_land_mark[:,1]))
125
+ x1, y1, x2, y2 = f_landmark
126
+
127
+ if y2-y1<=0 or x2-x1<=0 or x1<0: # if the landmark bbox is not suitable, reuse the bbox
128
+ coords_list += [f]
129
+ w,h = f[2]-f[0], f[3]-f[1]
130
+ print("error bbox:",f)
131
+ else:
132
+ coords_list += [f_landmark]
133
+
134
+ print("********************************************bbox_shift parameter adjustment**********************************************************")
135
+ print(f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}")
136
+ print("*************************************************************************************************************************************")
137
+ return coords_list,frames
138
+
139
+
140
+ if __name__ == "__main__":
141
+ img_list = ["./results/lyria/00000.png","./results/lyria/00001.png","./results/lyria/00002.png","./results/lyria/00003.png"]
142
+ crop_coord_path = "./coord_face.pkl"
143
+ coords_list,full_frames = get_landmark_and_bbox(img_list)
144
+ with open(crop_coord_path, 'wb') as f:
145
+ pickle.dump(coords_list, f)
146
+
147
+ for bbox, frame in zip(coords_list,full_frames):
148
+ if bbox == coord_placeholder:
149
+ continue
150
+ x1, y1, x2, y2 = bbox
151
+ crop_frame = frame[y1:y2, x1:x2]
152
+ print('Cropped shape', crop_frame.shape)
153
+
154
+ #cv2.imwrite(path.join(save_dir, '{}.png'.format(i)),full_frames[i][0][y1:y2, x1:x2])
155
+ print(coords_list)
musetalk/utils/training_utils.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from diffusers import AutoencoderKL, UNet2DConditionModel
9
+ from transformers import WhisperModel
10
+ from diffusers.optimization import get_scheduler
11
+ from omegaconf import OmegaConf
12
+ from einops import rearrange
13
+
14
+ from musetalk.models.syncnet import SyncNet
15
+ from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
16
+ from musetalk.loss.basic_loss import Interpolate
17
+ import musetalk.loss.vgg_face as vgg_face
18
+ from musetalk.data.dataset import PortraitDataset
19
+ from musetalk.utils.utils import (
20
+ get_image_pred,
21
+ process_audio_features,
22
+ process_and_save_images
23
+ )
24
+
25
+ class Net(nn.Module):
26
+ def __init__(
27
+ self,
28
+ unet: UNet2DConditionModel,
29
+ ):
30
+ super().__init__()
31
+ self.unet = unet
32
+
33
+ def forward(
34
+ self,
35
+ input_latents,
36
+ timesteps,
37
+ audio_prompts,
38
+ ):
39
+ model_pred = self.unet(
40
+ input_latents,
41
+ timesteps,
42
+ encoder_hidden_states=audio_prompts
43
+ ).sample
44
+ return model_pred
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
49
+ """Initialize models and optimizers"""
50
+ model_dict = {
51
+ 'vae': None,
52
+ 'unet': None,
53
+ 'net': None,
54
+ 'wav2vec': None,
55
+ 'optimizer': None,
56
+ 'lr_scheduler': None,
57
+ 'scheduler_max_steps': None,
58
+ 'trainable_params': None
59
+ }
60
+
61
+ model_dict['vae'] = AutoencoderKL.from_pretrained(
62
+ cfg.pretrained_model_name_or_path,
63
+ subfolder=cfg.vae_type,
64
+ )
65
+
66
+ unet_config_file = os.path.join(
67
+ cfg.pretrained_model_name_or_path,
68
+ cfg.unet_sub_folder + "/musetalk.json"
69
+ )
70
+
71
+ with open(unet_config_file, 'r') as f:
72
+ unet_config = json.load(f)
73
+ model_dict['unet'] = UNet2DConditionModel(**unet_config)
74
+
75
+ if not cfg.random_init_unet:
76
+ pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
77
+ print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
78
+ checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
79
+ model_dict['unet'].load_state_dict(checkpoint)
80
+
81
+ unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
82
+ logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
83
+
84
+ model_dict['vae'].requires_grad_(False)
85
+ model_dict['unet'].requires_grad_(True)
86
+
87
+ model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
88
+
89
+ model_dict['net'] = Net(model_dict['unet'])
90
+
91
+ model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
92
+ device="cuda", dtype=weight_dtype).eval()
93
+ model_dict['wav2vec'].requires_grad_(False)
94
+
95
+ if cfg.solver.gradient_checkpointing:
96
+ model_dict['unet'].enable_gradient_checkpointing()
97
+
98
+ if cfg.solver.scale_lr:
99
+ learning_rate = (
100
+ cfg.solver.learning_rate
101
+ * cfg.solver.gradient_accumulation_steps
102
+ * cfg.data.train_bs
103
+ * accelerator.num_processes
104
+ )
105
+ else:
106
+ learning_rate = cfg.solver.learning_rate
107
+
108
+ if cfg.solver.use_8bit_adam:
109
+ try:
110
+ import bitsandbytes as bnb
111
+ except ImportError:
112
+ raise ImportError(
113
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
114
+ )
115
+ optimizer_cls = bnb.optim.AdamW8bit
116
+ else:
117
+ optimizer_cls = torch.optim.AdamW
118
+
119
+ model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
120
+ if accelerator.is_main_process:
121
+ print('trainable params')
122
+ for n, p in model_dict['net'].named_parameters():
123
+ if p.requires_grad:
124
+ print(n)
125
+
126
+ model_dict['optimizer'] = optimizer_cls(
127
+ model_dict['trainable_params'],
128
+ lr=learning_rate,
129
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
130
+ weight_decay=cfg.solver.adam_weight_decay,
131
+ eps=cfg.solver.adam_epsilon,
132
+ )
133
+
134
+ model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
135
+ model_dict['lr_scheduler'] = get_scheduler(
136
+ cfg.solver.lr_scheduler,
137
+ optimizer=model_dict['optimizer'],
138
+ num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
139
+ num_training_steps=model_dict['scheduler_max_steps'],
140
+ )
141
+
142
+ return model_dict
143
+
144
+ def initialize_dataloaders(cfg):
145
+ """Initialize training and validation dataloaders"""
146
+ dataloader_dict = {
147
+ 'train_dataset': None,
148
+ 'val_dataset': None,
149
+ 'train_dataloader': None,
150
+ 'val_dataloader': None
151
+ }
152
+
153
+ dataloader_dict['train_dataset'] = PortraitDataset(cfg={
154
+ 'image_size': cfg.data.image_size,
155
+ 'T': cfg.data.n_sample_frames,
156
+ "sample_method": cfg.data.sample_method,
157
+ 'top_k_ratio': cfg.data.top_k_ratio,
158
+ "contorl_face_min_size": cfg.data.contorl_face_min_size,
159
+ "dataset_key": cfg.data.dataset_key,
160
+ "padding_pixel_mouth": cfg.padding_pixel_mouth,
161
+ "whisper_path": cfg.whisper_path,
162
+ "min_face_size": cfg.data.min_face_size,
163
+ "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
164
+ "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
165
+ "crop_type": cfg.crop_type,
166
+ "random_margin_method": cfg.random_margin_method,
167
+ })
168
+
169
+ dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
170
+ dataloader_dict['train_dataset'],
171
+ batch_size=cfg.data.train_bs,
172
+ shuffle=True,
173
+ num_workers=cfg.data.num_workers,
174
+ )
175
+
176
+ dataloader_dict['val_dataset'] = PortraitDataset(cfg={
177
+ 'image_size': cfg.data.image_size,
178
+ 'T': cfg.data.n_sample_frames,
179
+ "sample_method": cfg.data.sample_method,
180
+ 'top_k_ratio': cfg.data.top_k_ratio,
181
+ "contorl_face_min_size": cfg.data.contorl_face_min_size,
182
+ "dataset_key": cfg.data.dataset_key,
183
+ "padding_pixel_mouth": cfg.padding_pixel_mouth,
184
+ "whisper_path": cfg.whisper_path,
185
+ "min_face_size": cfg.data.min_face_size,
186
+ "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
187
+ "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
188
+ "crop_type": cfg.crop_type,
189
+ "random_margin_method": cfg.random_margin_method,
190
+ })
191
+
192
+ dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
193
+ dataloader_dict['val_dataset'],
194
+ batch_size=cfg.data.train_bs,
195
+ shuffle=True,
196
+ num_workers=1,
197
+ )
198
+
199
+ return dataloader_dict
200
+
201
+ def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
202
+ """Initialize loss functions and discriminators"""
203
+ loss_dict = {
204
+ 'L1_loss': nn.L1Loss(reduction='mean'),
205
+ 'discriminator': None,
206
+ 'mouth_discriminator': None,
207
+ 'optimizer_D': None,
208
+ 'mouth_optimizer_D': None,
209
+ 'scheduler_D': None,
210
+ 'mouth_scheduler_D': None,
211
+ 'disc_scales': None,
212
+ 'discriminator_full': None,
213
+ 'mouth_discriminator_full': None
214
+ }
215
+
216
+ if cfg.loss_params.gan_loss > 0:
217
+ loss_dict['discriminator'] = MultiScaleDiscriminator(
218
+ **cfg.model_params.discriminator_params).to(accelerator.device)
219
+ loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
220
+ loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
221
+ loss_dict['optimizer_D'] = optim.AdamW(
222
+ loss_dict['discriminator'].parameters(),
223
+ lr=cfg.discriminator_train_params.lr,
224
+ weight_decay=cfg.discriminator_train_params.weight_decay,
225
+ betas=cfg.discriminator_train_params.betas,
226
+ eps=cfg.discriminator_train_params.eps)
227
+ loss_dict['scheduler_D'] = CosineAnnealingLR(
228
+ loss_dict['optimizer_D'],
229
+ T_max=scheduler_max_steps,
230
+ eta_min=1e-6
231
+ )
232
+
233
+ if cfg.loss_params.mouth_gan_loss > 0:
234
+ loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
235
+ **cfg.model_params.discriminator_params).to(accelerator.device)
236
+ loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
237
+ loss_dict['mouth_optimizer_D'] = optim.AdamW(
238
+ loss_dict['mouth_discriminator'].parameters(),
239
+ lr=cfg.discriminator_train_params.lr,
240
+ weight_decay=cfg.discriminator_train_params.weight_decay,
241
+ betas=cfg.discriminator_train_params.betas,
242
+ eps=cfg.discriminator_train_params.eps)
243
+ loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
244
+ loss_dict['mouth_optimizer_D'],
245
+ T_max=scheduler_max_steps,
246
+ eta_min=1e-6
247
+ )
248
+
249
+ return loss_dict
250
+
251
+ def initialize_syncnet(cfg, accelerator, weight_dtype):
252
+ """Initialize SyncNet model"""
253
+ if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
254
+ if cfg.data.n_sample_frames != 16:
255
+ raise ValueError(
256
+ f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
257
+ )
258
+ syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
259
+ syncnet = SyncNet(OmegaConf.to_container(
260
+ syncnet_config.model)).to(accelerator.device)
261
+ print(
262
+ f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
263
+ checkpoint = torch.load(
264
+ syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
265
+ syncnet.load_state_dict(checkpoint["state_dict"])
266
+ syncnet.to(dtype=weight_dtype)
267
+ syncnet.requires_grad_(False)
268
+ syncnet.eval()
269
+ return syncnet
270
+ return None
271
+
272
+ def initialize_vgg(cfg, accelerator):
273
+ """Initialize VGG model"""
274
+ if cfg.loss_params.vgg_loss > 0:
275
+ vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
276
+ pyramid = vgg_face.ImagePyramide(
277
+ cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
278
+ vgg_IN.eval()
279
+ downsampler = Interpolate(
280
+ size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
281
+ return vgg_IN, pyramid, downsampler
282
+ return None, None, None
283
+
284
+ def validation(
285
+ cfg,
286
+ val_dataloader,
287
+ net,
288
+ vae,
289
+ wav2vec,
290
+ accelerator,
291
+ save_dir,
292
+ global_step,
293
+ weight_dtype,
294
+ syncnet_score=1,
295
+ ):
296
+ """Validation function for model evaluation"""
297
+ net.eval() # Set the model to evaluation mode
298
+ for batch in val_dataloader:
299
+ # The same ref_latents
300
+ ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
301
+ accelerator.device, non_blocking=True
302
+ )
303
+ pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
304
+ accelerator.device, non_blocking=True
305
+ )
306
+ bsz, num_frames, c, h, w = ref_pixel_values.shape
307
+
308
+ audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
309
+ # audio feature for unet
310
+ audio_prompts = rearrange(
311
+ audio_prompts,
312
+ 'b f c h w-> (b f) c h w'
313
+ )
314
+ audio_prompts = rearrange(
315
+ audio_prompts,
316
+ '(b f) c h w -> (b f) (c h) w',
317
+ b=bsz
318
+ )
319
+ # different masked_latents
320
+ image_pred_train = get_image_pred(
321
+ pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
322
+ image_pred_infer = get_image_pred(
323
+ ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
324
+
325
+ process_and_save_images(
326
+ batch,
327
+ image_pred_train,
328
+ image_pred_infer,
329
+ save_dir,
330
+ global_step,
331
+ accelerator,
332
+ cfg.num_images_to_keep,
333
+ syncnet_score
334
+ )
335
+ # only infer 1 image in validation
336
+ break
337
+ net.train() # Set the model back to training mode