Ali Mosavian
Ali Mosavian
commited on
FIX: TRL trainer preprocessing step was running in one process (#1583)
Browse files* FIX: TRL trainer preprocessing step was running in one process
* FIX: Changed so that dataset_num_proc is sent to CPO, KTO and ORPO trainer args and directly to the trainer when DPO
* FIX: Changed back to only support ORPO for now, since KTO is handled in another way
---------
Co-authored-by: Ali Mosavian <ali.mosavian@kry.se>
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -1462,6 +1462,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1462 |
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
| 1463 |
else:
|
| 1464 |
training_args_kwargs["evaluation_strategy"] = "no"
|
|
|
|
| 1465 |
if self.cfg.bf16 or self.cfg.bfloat16:
|
| 1466 |
training_args_kwargs["bf16"] = True
|
| 1467 |
|
|
@@ -1520,6 +1521,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1520 |
training_args_cls = TrainingArguments
|
| 1521 |
if self.cfg.rl == "orpo":
|
| 1522 |
training_args_cls = ORPOConfig
|
|
|
|
| 1523 |
|
| 1524 |
training_args = training_args_cls(
|
| 1525 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
|
@@ -1564,6 +1566,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1564 |
dpo_trainer_kwargs["max_target_length"] = None
|
| 1565 |
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
| 1566 |
dpo_trainer_kwargs["generate_during_eval"] = True
|
|
|
|
|
|
|
| 1567 |
elif self.cfg.rl == "orpo":
|
| 1568 |
trainer_cls = AxolotlORPOTrainer
|
| 1569 |
trainer_cls_args = [self.model]
|
|
|
|
| 1462 |
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
| 1463 |
else:
|
| 1464 |
training_args_kwargs["evaluation_strategy"] = "no"
|
| 1465 |
+
|
| 1466 |
if self.cfg.bf16 or self.cfg.bfloat16:
|
| 1467 |
training_args_kwargs["bf16"] = True
|
| 1468 |
|
|
|
|
| 1521 |
training_args_cls = TrainingArguments
|
| 1522 |
if self.cfg.rl == "orpo":
|
| 1523 |
training_args_cls = ORPOConfig
|
| 1524 |
+
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
| 1525 |
|
| 1526 |
training_args = training_args_cls(
|
| 1527 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
|
|
|
| 1566 |
dpo_trainer_kwargs["max_target_length"] = None
|
| 1567 |
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
| 1568 |
dpo_trainer_kwargs["generate_during_eval"] = True
|
| 1569 |
+
if self.cfg.rl == "dpo":
|
| 1570 |
+
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
| 1571 |
elif self.cfg.rl == "orpo":
|
| 1572 |
trainer_cls = AxolotlORPOTrainer
|
| 1573 |
trainer_cls_args = [self.model]
|