Spaces:
Sleeping
Sleeping
Update model_loading.py
Browse files- model_loading.py +2 -0
model_loading.py
CHANGED
|
@@ -46,8 +46,10 @@ def import_model(modelID, modelType):
|
|
| 46 |
elif modelType == 'FluxPipeline':
|
| 47 |
from diffusers import FluxPipeline
|
| 48 |
T2IModel = FluxPipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
|
|
|
|
| 49 |
else:
|
| 50 |
from diffusers import AutoPipelineForText2Image
|
| 51 |
T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
|
|
|
|
| 52 |
T2IModel.to("cuda")
|
| 53 |
return T2IModel
|
|
|
|
| 46 |
elif modelType == 'FluxPipeline':
|
| 47 |
from diffusers import FluxPipeline
|
| 48 |
T2IModel = FluxPipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
|
| 49 |
+
T2IModel.enable_model_cpu_offload()
|
| 50 |
else:
|
| 51 |
from diffusers import AutoPipelineForText2Image
|
| 52 |
T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
|
| 53 |
+
|
| 54 |
T2IModel.to("cuda")
|
| 55 |
return T2IModel
|