Spaces:
Sleeping
Sleeping
Update model_loading.py
Browse files- model_loading.py +11 -0
model_loading.py
CHANGED
|
@@ -37,12 +37,15 @@ def import_model(modelID, modelType):
|
|
| 37 |
if modelType == 'StableDiffusionXLPipeline':
|
| 38 |
from diffusers import StableDiffusionXLPipeline
|
| 39 |
T2IModel = StableDiffusionXLPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
|
|
|
|
| 40 |
elif modelType == 'LatentConsistencyModelPipeline':
|
| 41 |
from diffusers import DiffusionPipeline
|
| 42 |
T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
|
|
|
|
| 43 |
elif modelType == 'StableDiffusion3Pipeline':
|
| 44 |
from diffusers import StableDiffusion3Pipeline
|
| 45 |
T2IModel = StableDiffusion3Pipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
|
|
|
|
| 46 |
elif modelType == 'FluxPipeline':
|
| 47 |
from diffusers import FluxPipeline
|
| 48 |
T2IModel = FluxPipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
|
|
@@ -50,6 +53,14 @@ def import_model(modelID, modelType):
|
|
| 50 |
else:
|
| 51 |
from diffusers import AutoPipelineForText2Image
|
| 52 |
T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
|
|
|
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
T2IModel.to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
return T2IModel
|
|
|
|
| 37 |
if modelType == 'StableDiffusionXLPipeline':
|
| 38 |
from diffusers import StableDiffusionXLPipeline
|
| 39 |
T2IModel = StableDiffusionXLPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
|
| 40 |
+
T2IModel.to("cuda")
|
| 41 |
elif modelType == 'LatentConsistencyModelPipeline':
|
| 42 |
from diffusers import DiffusionPipeline
|
| 43 |
T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
|
| 44 |
+
T2IModel.to("cuda")
|
| 45 |
elif modelType == 'StableDiffusion3Pipeline':
|
| 46 |
from diffusers import StableDiffusion3Pipeline
|
| 47 |
T2IModel = StableDiffusion3Pipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
|
| 48 |
+
T2IModel.to("cuda")
|
| 49 |
elif modelType == 'FluxPipeline':
|
| 50 |
from diffusers import FluxPipeline
|
| 51 |
T2IModel = FluxPipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
|
|
|
|
| 53 |
else:
|
| 54 |
from diffusers import AutoPipelineForText2Image
|
| 55 |
T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
|
| 56 |
+
T2IModel.to("cuda")
|
| 57 |
|
| 58 |
+
if 'StableDiffusionXLPipeline' in modelType.split(','):
|
| 59 |
+
from diffusers import StableDiffusionXLPipeline
|
| 60 |
+
T2IModel = StableDiffusionXLPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
|
| 61 |
T2IModel.to("cuda")
|
| 62 |
+
try:
|
| 63 |
+
T2IModel.safety_checker = None
|
| 64 |
+
except:
|
| 65 |
+
pass # if the model does not contain a safety checker no need to remove it
|
| 66 |
return T2IModel
|