Spaces:
Sleeping
Sleeping
Update model_loading.py
Browse filesAdded support for 'StableDiffusion3Pipeline' and 'FluxPipeline' models
- model_loading.py +7 -1
model_loading.py
CHANGED
|
@@ -8,7 +8,7 @@ else:
|
|
| 8 |
device = 'cpu'
|
| 9 |
|
| 10 |
validT2IModelTypes = ["KandinskyPipeline", "StableDiffusionPipeline", "DiffusionPipeline", "StableDiffusionXLPipeline",
|
| 11 |
-
"LatentConsistencyModelPipeline"]
|
| 12 |
def check_if_model_exists(repoName):
|
| 13 |
modelLoaded = None
|
| 14 |
huggingFaceURL = "https://huggingface.co/" + repoName + "/raw/main/model_index.json"
|
|
@@ -40,6 +40,12 @@ def import_model(modelID, modelType):
|
|
| 40 |
elif modelType == 'LatentConsistencyModelPipeline':
|
| 41 |
from diffusers import DiffusionPipeline
|
| 42 |
T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
else:
|
| 44 |
from diffusers import AutoPipelineForText2Image
|
| 45 |
T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
|
|
|
|
| 8 |
device = 'cpu'
|
| 9 |
|
| 10 |
validT2IModelTypes = ["KandinskyPipeline", "StableDiffusionPipeline", "DiffusionPipeline", "StableDiffusionXLPipeline",
|
| 11 |
+
"LatentConsistencyModelPipeline","StableDiffusion3Pipeline", "FluxPipeline"]
|
| 12 |
def check_if_model_exists(repoName):
|
| 13 |
modelLoaded = None
|
| 14 |
huggingFaceURL = "https://huggingface.co/" + repoName + "/raw/main/model_index.json"
|
|
|
|
| 40 |
elif modelType == 'LatentConsistencyModelPipeline':
|
| 41 |
from diffusers import DiffusionPipeline
|
| 42 |
T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
|
| 43 |
+
elif modelType == 'StableDiffusion3Pipeline':
|
| 44 |
+
from diffusers import StableDiffusion3Pipeline
|
| 45 |
+
T2IModel = StableDiffusion3Pipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
|
| 46 |
+
elif modelType == 'FluxPipeline':
|
| 47 |
+
from diffusers import FluxPipeline
|
| 48 |
+
T2IModel = FluxPipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
|
| 49 |
else:
|
| 50 |
from diffusers import AutoPipelineForText2Image
|
| 51 |
T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
|