Skip to content

Commit

Permalink
fix device issue during calibration (#2100)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <[email protected]>
  • Loading branch information
xin3he authored Dec 25, 2024
1 parent fa8ad83 commit a02dcc1
Showing 1 changed file with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,12 @@ def calib_func(prepared_model):
max_seq_length=args.gptq_max_seq_length,
)
dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader()
from neural_compressor.torch.algorithms.weight_only.utility import move_input_to_device
from neural_compressor.torch.utils import get_model_device, move_input_device
from tqdm import tqdm
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
for batch in tqdm(dataloader_for_calibration):
batch = move_input_to_device(batch, device=None)
device = get_model_device(model)
batch = move_input_device(batch, device=device)
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
Expand Down Expand Up @@ -525,11 +526,12 @@ def run_fn_for_autoround(model, dataloader):
)
dataloader = dataloaderPreprocessor.get_prepared_dataloader()
custom_tune_config = TuningConfig(config_set=get_woq_tuning_config())
from neural_compressor.torch.algorithms.weight_only.utility import move_input_to_device
from neural_compressor.torch.utils import get_model_device, move_input_device
from tqdm import tqdm
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
for batch in tqdm(dataloader_for_calibration):
batch = move_input_to_device(batch, device=None)
device = get_model_device(model)
batch = move_input_device(batch, device=device)
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
Expand Down

0 comments on commit a02dcc1

Please sign in to comment.