Spaces:
Runtime error
Runtime error
Commit
·
50f1efd
1
Parent(s):
06a685c
perf(model): enable 8-bit quantization and explicit CUDA device targeting
Browse filesEnable 8-bit model loading to reduce memory usage and specify CUDA device type
for automatic mixed precision operations to improve GPU performance.
- app.py +1 -1
- llava/model/qlinear_te.py +2 -2
app.py
CHANGED
|
@@ -13,7 +13,7 @@ import copy
|
|
| 13 |
MODEL_BASE_SINGLE = snapshot_download(repo_id="nvidia/audio-flamingo-3")
|
| 14 |
MODEL_BASE_THINK = os.path.join(MODEL_BASE_SINGLE, 'stage35')
|
| 15 |
|
| 16 |
-
model_single = llava.load(MODEL_BASE_SINGLE, model_base=None)
|
| 17 |
model_single_copy = copy.deepcopy(model_single)
|
| 18 |
|
| 19 |
# Move the model to GPU
|
|
|
|
| 13 |
MODEL_BASE_SINGLE = snapshot_download(repo_id="nvidia/audio-flamingo-3")
|
| 14 |
MODEL_BASE_THINK = os.path.join(MODEL_BASE_SINGLE, 'stage35')
|
| 15 |
|
| 16 |
+
model_single = llava.load(MODEL_BASE_SINGLE, model_base=None, load_8bit=True)
|
| 17 |
model_single_copy = copy.deepcopy(model_single)
|
| 18 |
|
| 19 |
# Move the model to GPU
|
llava/model/qlinear_te.py
CHANGED
|
@@ -98,7 +98,7 @@ class QLinearTE(nn.Linear):
|
|
| 98 |
|
| 99 |
class QuantLinearTE(Function):
|
| 100 |
@staticmethod
|
| 101 |
-
@amp.custom_fwd(cast_inputs=torch.bfloat16)
|
| 102 |
def forward(ctx, input, weight, bias, args, layer_name):
|
| 103 |
|
| 104 |
time_bench = os.getenv("TIME_BENCH")
|
|
@@ -149,7 +149,7 @@ class QuantLinearTE(Function):
|
|
| 149 |
return fc_output
|
| 150 |
|
| 151 |
@staticmethod
|
| 152 |
-
@amp.custom_bwd
|
| 153 |
def backward(ctx, grad_output):
|
| 154 |
Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
|
| 155 |
|
|
|
|
| 98 |
|
| 99 |
class QuantLinearTE(Function):
|
| 100 |
@staticmethod
|
| 101 |
+
@amp.custom_fwd(cast_inputs=torch.bfloat16, device_type='cuda')
|
| 102 |
def forward(ctx, input, weight, bias, args, layer_name):
|
| 103 |
|
| 104 |
time_bench = os.getenv("TIME_BENCH")
|
|
|
|
| 149 |
return fc_output
|
| 150 |
|
| 151 |
@staticmethod
|
| 152 |
+
@amp.custom_bwd(device_type='cuda')
|
| 153 |
def backward(ctx, grad_output):
|
| 154 |
Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
|
| 155 |
|