You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

Base model_v2 gemma_3_800M_base_v2_multilingual_10B_data

June 23

Base model trained on 10B kk,en,ru data.

Inference params

import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
import os 
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1"
# Загрузка твоей обученной модели
model_path = "SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Gemma3ForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# example = {"system": "Вы профессиональный переводчик. Переведите следующее предложение на қазақ язык.", "user": "<src=ru><tgt=kk>\nЗа один год с тех пор какие изменения произошли в Туркестане, какое дело доведено до конца?", "assistant": "Содан бергі бір жыл ішінде Түркістанда қандай өзгерістер болды, нендей іс тындырылды?"}
# example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nСауда-саттықта салқынқандылық басым.", "assistant": "Composure prevails in trade."}
example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nқала картасы", "assistant": "city map"}
s = example["system"]
u = example["user"]
a = example["assistant"]

tok = tokenizer
# Промпт в формате чата
prompt = (
    (f"<start_of_turn>system\n{s}<end_of_turn>\n"
    f"<start_of_turn>user\n{u}<end_of_turn>\n"
    f"<start_of_turn>assistant"))

model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(
    **model_inputs,
    max_new_tokens=64,
    do_sample=True,
    top_p=0.9,
    #temperature=0.7,
    #repetition_penalty=1.2,
    eos_token_id=tok.convert_tokens_to_ids("<end_of_turn>"),
    pad_token_id=tok.eos_token_id,
    #min_new_tokens=5,
)
    generation = generation[0][input_len:]

decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)

Train

Main script for training

# gemma_pretrain_mix_cli.py  – balance 50 % KK, 30 % RU, 20 % EN

import os, math, json, argparse
from pathlib import Path
from datasets import (load_dataset, concatenate_datasets,
                      disable_caching)
from transformers import (AutoTokenizer, Gemma3TextConfig,
                          Gemma3ForCausalLM,
                          DataCollatorForLanguageModeling)
from trl import SFTTrainer, SFTConfig

disable_caching()

# ────────── CLI ──────────
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer_path", required=True)
parser.add_argument("--meta_files", nargs=3, required=True,
                    metavar=("META_KK", "META_RU", "META_EN"),
                    help="пути к meta_*.json в порядке kk ru en")
parser.add_argument("--output_dir", default="runs/gemma_mix_50_30_20")
parser.add_argument("--model_path")
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--per_device_batch_size", type=int, default=32)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--wandb_project", default="gemma-pretrain")
parser.add_argument("--wandb_run_name")
args = parser.parse_args()

cpu = os.cpu_count()
os.environ["WANDB_PROJECT"]          = args.wandb_project
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# ────────── Tokenizer / Model ──────────
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)

if args.model_path:
    model = Gemma3ForCausalLM.from_pretrained(
        args.model_path, torch_dtype="bfloat16", _attn_implementation="eager")
else:
    # TODO WRONG
    # cfg = Gemma3TextConfig(
    #     vocab_size=len(tok),
    #     bos_token_id=tok.bos_token_id, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id,
    #     hidden_size=2304, num_hidden_layers=26, num_attention_heads=4, head_dim=256,
    #     intermediate_size=9216, max_position_embeddings=32_768,
    #     torch_dtype="bfloat16", _attn_implementation="eager")
    model = Gemma3ForCausalLM(cfg)
    model.resize_token_embeddings(len(tok))

# ────────── Load helper ──────────
def load_meta(path: str):
    meta = json.load(open(path))
    return concatenate_datasets(
        [load_dataset("json", data_files=i["path"], split="train")
         for i in meta.values()]
    )

kk_ds, ru_ds, en_ds = [load_meta(p) for p in args.meta_files]
print(f"Raw rows — KK={len(kk_ds):,}, RU={len(ru_ds):,}, EN={len(en_ds):,}")

# ────────── Target sizes 50 / 30 / 20 ──────────
target_total = int(len(kk_ds) / 0.50)      # kk = 50 %
need_ru = int(target_total * 0.30)
need_en = int(target_total * 0.20)

def resize(ds, need):
    if len(ds) >= need:                       # down-sample
        return ds.shuffle(seed=42).select(range(need))
    reps  = need // len(ds) + 1               # up-sample
    big   = concatenate_datasets([ds] * reps).shuffle(seed=42)
    return big.select(range(need))

ru_ds = resize(ru_ds, need_ru)
en_ds = resize(en_ds, need_en)
print(f"Balanced rows — KK={len(kk_ds):,}, RU={len(ru_ds):,}, EN={len(en_ds):,}")

# ────────── Merge & preprocess ──────────
ds = concatenate_datasets([kk_ds, ru_ds, en_ds]).shuffle(seed=42)

def add_bos_eos(ex):
    return {"text": f"{tok.bos_token}{ex['text']}{tok.eos_token}"}
ds = ds.map(add_bos_eos, num_proc=cpu)

# ────────── Training params ──────────
world  = int(os.getenv("WORLD_SIZE", 1))
eff_bs = args.per_device_batch_size * args.grad_acc * world
max_st = math.ceil(len(ds) / eff_bs)
print(f"Dataset={len(ds):,}  eff_batch={eff_bs}  max_steps={max_st}")

collator = DataCollatorForLanguageModeling(tok, mlm=False)
cfg_t = SFTConfig(
    output_dir=args.output_dir,
    max_seq_length=args.max_seq_length,
    packing=True, bf16=True,
    per_device_train_batch_size=args.per_device_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    warmup_ratio=0.05,
    max_grad_norm=2.0,
    max_steps=max_st,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    save_steps=200, save_total_limit=20,
    logging_steps=1,
    deepspeed="ds_stage1.json",
    run_name=args.wandb_run_name,
    report_to="wandb",
    dataloader_num_workers=8,
    dataset_text_field="text",
    dataset_num_proc=cpu,
)

trainer = SFTTrainer(model=model, args=cfg_t,
                     train_dataset=ds, data_collator=collator,
                     processing_class=tok, formatting_func=None)

if __name__ == "__main__":
    print("🚀 Start pre-training 50/30/20")
    trainer.train()
    trainer.save_model(f"{args.output_dir}/checkpoint-final")
    tok.save_pretrained(f"{args.output_dir}/checkpoint-final")

To run training please use similar bash

#bash

export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton
mkdir -p "$TRITON_CACHE_DIR"

export WANDB_API_KEY=""

OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling'
WANDB_RUN_NAME='base-model-v1_gemma_1B_test_v2_with_kk_en_ru'
if [ ! -d "$OUTPUT_DIR" ]; then
  mkdir -p "$OUTPUT_DIR"
fi

# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node 8 base_train_v2_multi.py \
  --tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \
  --max_seq_length 2048 \
  --meta_files \
      /scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_kk.json \
      /scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_ru.json \
      /scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_en.json \
  --per_device_batch_size 32 \
  --gradient_accumulation_steps 8 \
  --learning_rate 3e-4 \
  --output_dir ${OUTPUT_DIR} \
  --wandb_project "small_llm_SRP" \
  --wandb_run_name ${WANDB_RUN_NAME}

Meta in such format

  "train_en_news_cleaned_v2_splited_processed.jsonl": {
    "path": "/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/en_data/train.jsonl",
    "examples": 268890,
    "tokens": 92970273
  },
    "train_en_news_cleaned_v2_splited_processed_2.jsonl": {
    "path": "/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/en_data/train_2.jsonl",
    "examples": 268123,
    "tokens": 64523423
  }

Notes: path /scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling/checkpoint-1978

Downloads last month
1
Safetensors
Model size
0.9B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data

Quantizations
1 model

Dataset used to train SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data

Collection including SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data