Abstract

This repository provides a domain-adapted Turkish legal instruction-tuned model derived from meta-llama/Llama-3.1-8B-Instruct. This model corresponds to the BF16 baseline configuration with using the default Tensorwise quantization scaling recipe, trained on 8 nodes with a global batch size of 32 as part of the “Harnessing Fully Sharded Data Parallelism v2 with Float8 Precision for Faster Training” study. The model was fine-tuned on the newmindai/EuroHPC-Legal corpus (Q/A format) to enhance reasoning across multiple Turkish legal subdomains. It delivers stable convergence, high accuracy, and consistent loss behavior, making it an important baseline for evaluating mixed-precision strategies under large-scale distributed training.

Experiment Context

This model was trained as part of our study for comparing FSDP2 with bfloat16 precision against FSDP2 with FP8 mixed precision bfp16-fp8. We used meta-llama/Llama-3.1-8B-Instruct. The model has been loaded using torch_dtype = bfloat16 and wrapped at once, also during forward/backward passes bfloat16 has been used for computations.

from torch.distributed._composable.fsdp import fully_shard
mesh_device_type = "cuda" if use_cuda else "cpu"
mesh = DeviceMesh(mesh_device_type, list(range(world_size)))
fsdp_kwargs = {
    "mesh": mesh,
    "reshard_after_forward": True,
}
model = fully_shard(model, **fsdp_kwargs)

Base Model Technical Specifications

  • Parameters: 8 Billion
  • Architecture Family: Llama 3.1
  • Maximum Position Embeddings: 131,072
  • Attention Heads: 32 (num_attention_heads)
  • Key-Value Heads: 8 (num_key_value_heads)
  • Hidden Layers: 32 (num_hidden_layers)
  • Hidden Size: 4,096 (hidden_size)
  • Intermediate Size: 14,336
  • Vocabulary Size: 128,256
  • Precision: bfloat16
  • RoPE Scaling: type llama3, factor = 8.0
  • RMS Norm Epsilon: 1e-05
  • Activation: SiLU

Training Methodology

Training Configuration

  • Model: meta-llama/Llama-3.1-8B-Instruct
  • Sequence Length: 4,096 (seq_len)
  • Epochs: 2
  • Max Steps: 1,200
  • Per-Device Micro Batch Size: 4
  • Gradient Accumulation: 8
  • GPUs: 4 (via CUDA_VISIBLE_DEVICES=0,1,2,3)
  • dtype: bf16 && fp8=false
    • Weights: bfloat16
    • Activations: bfloat16
  • Optimizer: AdamW
    • Learning Rate: 2e-5
    • Weight Decay: 0.01
    • Betas: (0.9, 0.95)
    • Epsilon: 1e-8
  • LR Scheduler: Cosine; warmup = 10% (warmup_ratio=0.1) | also warmup_steps=100
  • Max Grad Norm: 1.0
  • Gradient Checkpointing: Enabled
  • Checkpointing: every 10 steps; keep last 5; select best by eval_loss
  • Logging: every step to file; Weights & Biases in offline mode
  • Seed: 100
  • Distributed Training: torch.distributed.run (multi-nodes, multi-GPU)
    • FSDP2 (Optimized Fully Sharded Data Parallel)

Setups

  • Precision: Used Half-precision bfloat16 as data type and for computation.
  • Hardware: HPC (EuroHPC/BSC-class) 8 nodes with 4 × NVIDIA H100 GPUs.
  • Framework: PyTorch with torchrun for distributed training.

Dependencies

package Version
Transformers 4.57.1
torch 2.9.0+cu128
accelerate 0.14.1
datasets 4.3.0
huggingface-hub 0.36.0
tensorboard 2.20.0
tensorboard-data-server 0.7.2
wandb 0.22.1

Job Details

model Job ID Runtime (mins) Nodes GPUs Node-hour GPU-hour micro-batch batch-size gradient_accumulation total_batch_size
Llama-3.1-8B-Instruct_w16a8_rw 31768103 115.75 1 4 1.929 7.716 2 2 4 32
Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp 31837629 109.00 1 4 1.816 7.266 2 2 4 32
Llama-3.1-8B-Instruct-w16a8-mxtw 31768031 64.00 1 4 1.066 4.266 2 2 4 32
Llama-3.1-8B-Instruct-w16a16-tw 31768074 138.75 1 4 2,312 9,25 2 2 4 32
Llama-3.1-8B-Instruct-w16a8-1node-bs8 31768093 123.75 1 4 2.062 8,250 2 2 4 32
Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 31478433 31.75 4 4 2.117 8.467 4 4 8 512
Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 31478468 39.75 4 4 2.650 10.600 4 4 8 512
Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 31476914 22.00 8 4 2.933 11.733 4 4 8 1024
Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 31476844 23.50 8 4 3.133 12.533 4 4 8 1024
Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 31476914 22.00 8 4 2.933 11.733 4 8 8 1024
Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 31476844 23.50 8 4 3.133 12.533 4 8 8 1024
Llama-3.1-8B-Instruct-w16a8-rw_4nodes 33477070 39.75 4 4 2.650 10.600 4 4 8 512
Llama-3.1-8B-Instruct-w16a8-rw-8nodes 33476690 23.50 8 4 3.133 12.533 4 4 8 1024
Llama-3.1-8B-Instruct-w16a8-rw_with_gw_hp_4nodes 33477179 37.43 4 4 2.495 9.982 4 4 8 512
Llama-3.1-8B-Instruct-w16a8-rw-with-gw-hp-8nodes 33476618 22.13 8 4 2.951 11.802 4 4 8 1024

All 15-models trained on(1Node,4Noes,8Nodes with both bfp16-fp8 && bfp16 configurations and fp8 recipes)

perplexity metric results for bfp16 && bfp16-fp8 configurations Accuracy metric results for bfp16 && bfp16-fp8 configurations Loss metric results for bfp16 && bfp16-fp8 configurations Memory allocation for bfp16 && bfp16-fp8 configurations Utilization for bfp16 && bfp16-fp8 configurations
perp acc train_loss memAlo utils
Model Max Loss (train) Min Loss (train) Avg Loss (train) Final Loss (train) ± Std (train) Max Loss (val) Min Loss (val) Avg Loss (val) Final Loss (val) ± Std (val)
Llama-3.1-8B-Instruct-w16a8-rw 8 3.1682 0.5740 0.8118 0.6431 0.2746 1.0613 0.8394 0.8937 0.8394
Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp 8 3.1837 0.5763 0.8116 0.6420 0.2751 1.0599 0.8391 0.8933 0.8391
Llama-3.1-8B-Instruct-w16a8-mxtw 8 3.1983 0.5747 0.8115 0.6446 0.2758 1.0562 0.8384 0.8923 0.8384
Llama-3.1-8B-Instruct-w16a16-tw 8 3.1235 0.7203 0.9750 0.3344 0.7612 1.9113 0.8907 0.9831 0.1897
Llama-3.1-8B-Instruct-w16a8-1node-bs8 8 3.1661 0.7261 0.9804 0.3374 0.7672 1.9230 0.8948 0.9867 0.1906
Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 32 3.2452 0.7414 0.9665 0.4844 0.7504 1.0538 0.8382 0.8844 0.0725
Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 32 3.2840 0.7478 0.9748 0.4905 0.7581 1.0701 0.8430 0.8922 0.0764
Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 32 3.2311 0.8448 1.1856 0.6434 0.8448 1.0257 0.8977 0.9460 0.0568
Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 32 3.3003 0.8473 1.1866 0.6481 0.8473 1.0203 0.8992 0.9445 0.0539
Llama-3.1-8B-Instruct-w16a16-4nodes-bs64 64 3.2311 0.8448 1.1856 0.6434 0.8448 1.0257 0.8977 0.9460 0.0568
Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 64 3.3003 0.8473 1.1866 0.6481 0.8473 1.0203 0.8992 0.9445 0.0539
Llama-3.1-8B-Instruct-w16a8-rw_4nodes 64 3.4517 0.7624 1.1173 0.7624 0.6891 1.3225 0.8791 0.9732 0.8791
Llama-3.1-8B-Instruct-w16a8-rw_8nodes 64 3.8944 0.9583 1.6423 0.9583 1.0117 1.5384 1.0253 1.2103 1.0253
Llama-3.1-8B-Instruct-w16a8-rw_with_gw_hp_4nodes 64 3.4517 0.7481 1.1091 0.7481 0.7021 1.3393 0.8660 0.9641 0.8666
Llama-3.1-8B-Instruct-w16a8-rw_with_gw_hp_8nodes 64 3.9289 0.9702 1.6514 0.9702 1.0127 1.5537 1.0377 1.2222 1.0377

Implementation

Usage

Note: the final model has been saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "newmindai/Llama-3.1-8B-Instruct-w16a16-8nodes-bs32"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
    device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False
    )

print(tok.decode(out[0], skip_special_tokens=True))

Ethical Considerations and Disclaimers

  • Research & development purposes only; not a substitute for professional legal counsel.
  • Users must ensure compliance with data protection and sector regulations.
  • Potential biases may exist in domain data and model outputs.

Model & Data Card Metadata

  • Total Parameters: 8,030,261,248
  • Serialized Size (approx.): 16,060,522,496 bytes
  • Config precision: bfloat16
  • RoPE: llama3 scaling, factor 8.0

References and Citations

Base Model

@misc{meta_llama31_8b_instruct,
  title={Llama 3.1 8B Instruct},
  author={Meta AI},
  year={2024},
  howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}

Training Dataset

@misc{euro_hpc_legal,
  title={EuroHPC-Legal},
  author={newmindai},
  year={2025},
  howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
Downloads last month
69
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for newmindai/Llama-3.1-8B-Instruct-w16a16-8nodes-bs32

Finetuned
(2073)
this model

Dataset used to train newmindai/Llama-3.1-8B-Instruct-w16a16-8nodes-bs32

Collection including newmindai/Llama-3.1-8B-Instruct-w16a16-8nodes-bs32

Evaluation results