Upload pipeline.py
Browse files- pipeline.py +3 -3
pipeline.py
CHANGED
|
@@ -102,9 +102,9 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
|
|
| 102 |
dtype = torch.float16
|
| 103 |
|
| 104 |
vae.to(device)
|
| 105 |
-
unet.to(device)
|
| 106 |
-
text_encoder.to(device)
|
| 107 |
-
text_encoder_2.to(device)
|
| 108 |
|
| 109 |
self.register_modules(
|
| 110 |
unet=unet,
|
|
|
|
| 102 |
dtype = torch.float16
|
| 103 |
|
| 104 |
vae.to(device)
|
| 105 |
+
unet.to(device, dtype=dtype)
|
| 106 |
+
text_encoder.to(device, dtype=dtype)
|
| 107 |
+
text_encoder_2.to(device, dtype=dtype)
|
| 108 |
|
| 109 |
self.register_modules(
|
| 110 |
unet=unet,
|