Spaces:
Running
Running
IceClear
commited on
Commit
·
af83433
1
Parent(s):
63837ca
update
Browse files
projects/video_diffusion_sr/infer.py
CHANGED
|
@@ -26,14 +26,14 @@ from common.diffusion import (
|
|
| 26 |
create_sampling_timesteps_from_config,
|
| 27 |
create_schedule_from_config,
|
| 28 |
)
|
| 29 |
-
from common.distributed import (
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
from common.distributed.meta_init_utils import (
|
| 35 |
-
|
| 36 |
-
)
|
| 37 |
# from common.fs import download
|
| 38 |
|
| 39 |
from models.dit_v2 import na
|
|
@@ -69,7 +69,7 @@ class VideoDiffusionInfer():
|
|
| 69 |
raise NotImplementedError
|
| 70 |
|
| 71 |
# @log_on_entry
|
| 72 |
-
@log_runtime
|
| 73 |
def configure_dit_model(self, device="cpu", checkpoint=None):
|
| 74 |
# Load dit checkpoint.
|
| 75 |
# For fast init & resume,
|
|
@@ -90,7 +90,7 @@ class VideoDiffusionInfer():
|
|
| 90 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
| 91 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
| 92 |
print(f"Loading info: {loading_info}")
|
| 93 |
-
self.dit = meta_non_persistent_buffer_init_fn(self.dit)
|
| 94 |
|
| 95 |
# if device in [get_device(), "cuda"]:
|
| 96 |
self.dit.to("cuda")
|
|
@@ -100,7 +100,7 @@ class VideoDiffusionInfer():
|
|
| 100 |
print(f"DiT trainable parameters: {num_params:,}")
|
| 101 |
|
| 102 |
# @log_on_entry
|
| 103 |
-
@log_runtime
|
| 104 |
def configure_vae_model(self):
|
| 105 |
# Create vae model.
|
| 106 |
dtype = getattr(torch, self.config.vae.dtype)
|
|
|
|
| 26 |
create_sampling_timesteps_from_config,
|
| 27 |
create_schedule_from_config,
|
| 28 |
)
|
| 29 |
+
# from common.distributed import (
|
| 30 |
+
# get_device,
|
| 31 |
+
# get_global_rank,
|
| 32 |
+
# )
|
| 33 |
+
|
| 34 |
+
# from common.distributed.meta_init_utils import (
|
| 35 |
+
# meta_non_persistent_buffer_init_fn,
|
| 36 |
+
# )
|
| 37 |
# from common.fs import download
|
| 38 |
|
| 39 |
from models.dit_v2 import na
|
|
|
|
| 69 |
raise NotImplementedError
|
| 70 |
|
| 71 |
# @log_on_entry
|
| 72 |
+
# @log_runtime
|
| 73 |
def configure_dit_model(self, device="cpu", checkpoint=None):
|
| 74 |
# Load dit checkpoint.
|
| 75 |
# For fast init & resume,
|
|
|
|
| 90 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
| 91 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
| 92 |
print(f"Loading info: {loading_info}")
|
| 93 |
+
# self.dit = meta_non_persistent_buffer_init_fn(self.dit)
|
| 94 |
|
| 95 |
# if device in [get_device(), "cuda"]:
|
| 96 |
self.dit.to("cuda")
|
|
|
|
| 100 |
print(f"DiT trainable parameters: {num_params:,}")
|
| 101 |
|
| 102 |
# @log_on_entry
|
| 103 |
+
# @log_runtime
|
| 104 |
def configure_vae_model(self):
|
| 105 |
# Create vae model.
|
| 106 |
dtype = getattr(torch, self.config.vae.dtype)
|