Bertoin commited on
Commit
47b2864
·
verified ·
1 Parent(s): 3f4652b

monkey patch _compute_timestep_embedding

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
 
6
  from diffusers.pipelines.prx import PRXPipeline
7
 
8
  # monkey patch to add 1024 aspect ratios
@@ -62,7 +63,7 @@ def get_timestep_embedding(
62
  emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
63
  return emb
64
 
65
- def _compute_timestep_embedding(timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
66
  return self.time_in(
67
  get_timestep_embedding(
68
  timesteps=timestep,
@@ -111,7 +112,8 @@ pipe = PRXPipeline.from_pretrained(
111
  torch_dtype=dtype
112
  ).to(device)
113
 
114
- pipe.transformer._compute_timestep_embedding = _compute_timestep_embedding
 
115
 
116
  MAX_SEED = np.iinfo(np.int32).max
117
  MAX_IMAGE_SIZE = 1024
 
3
  import random
4
  import spaces
5
  import torch
6
+ import types
7
  from diffusers.pipelines.prx import PRXPipeline
8
 
9
  # monkey patch to add 1024 aspect ratios
 
63
  emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
64
  return emb
65
 
66
+ def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
67
  return self.time_in(
68
  get_timestep_embedding(
69
  timesteps=timestep,
 
112
  torch_dtype=dtype
113
  ).to(device)
114
 
115
+ # Properly bind the method to the instance using types.MethodType
116
+ pipe.transformer._compute_timestep_embedding = types.MethodType(_compute_timestep_embedding, pipe.transformer)
117
 
118
  MAX_SEED = np.iinfo(np.int32).max
119
  MAX_IMAGE_SIZE = 1024