PierrunoYT commited on
Commit
50f1efd
·
1 Parent(s): 06a685c

perf(model): enable 8-bit quantization and explicit CUDA device targeting

Browse files

Enable 8-bit model loading to reduce memory usage and specify CUDA device type
for automatic mixed precision operations to improve GPU performance.

Files changed (2) hide show
  1. app.py +1 -1
  2. llava/model/qlinear_te.py +2 -2
app.py CHANGED
@@ -13,7 +13,7 @@ import copy
13
  MODEL_BASE_SINGLE = snapshot_download(repo_id="nvidia/audio-flamingo-3")
14
  MODEL_BASE_THINK = os.path.join(MODEL_BASE_SINGLE, 'stage35')
15
 
16
- model_single = llava.load(MODEL_BASE_SINGLE, model_base=None)
17
  model_single_copy = copy.deepcopy(model_single)
18
 
19
  # Move the model to GPU
 
13
  MODEL_BASE_SINGLE = snapshot_download(repo_id="nvidia/audio-flamingo-3")
14
  MODEL_BASE_THINK = os.path.join(MODEL_BASE_SINGLE, 'stage35')
15
 
16
+ model_single = llava.load(MODEL_BASE_SINGLE, model_base=None, load_8bit=True)
17
  model_single_copy = copy.deepcopy(model_single)
18
 
19
  # Move the model to GPU
llava/model/qlinear_te.py CHANGED
@@ -98,7 +98,7 @@ class QLinearTE(nn.Linear):
98
 
99
  class QuantLinearTE(Function):
100
  @staticmethod
101
- @amp.custom_fwd(cast_inputs=torch.bfloat16)
102
  def forward(ctx, input, weight, bias, args, layer_name):
103
 
104
  time_bench = os.getenv("TIME_BENCH")
@@ -149,7 +149,7 @@ class QuantLinearTE(Function):
149
  return fc_output
150
 
151
  @staticmethod
152
- @amp.custom_bwd
153
  def backward(ctx, grad_output):
154
  Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
155
 
 
98
 
99
  class QuantLinearTE(Function):
100
  @staticmethod
101
+ @amp.custom_fwd(cast_inputs=torch.bfloat16, device_type='cuda')
102
  def forward(ctx, input, weight, bias, args, layer_name):
103
 
104
  time_bench = os.getenv("TIME_BENCH")
 
149
  return fc_output
150
 
151
  @staticmethod
152
+ @amp.custom_bwd(device_type='cuda')
153
  def backward(ctx, grad_output):
154
  Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
155