Quickstart

How to train a model from scratch

Step 1: Prepare the data

To get started, we propose to download a toy English-German dataset for machine translation containing 10k tokenized sentences:

wget https://s3.amazonaws.com/opennmt-trainingdata/toy-ende.tar.gz
tar xf toy-ende.tar.gz
cd toy-ende

The data consists of parallel source (src) and target (tgt) data containing one sentence per line with tokens separated by a space:

  • src-train.txt

  • tgt-train.txt

  • src-val.txt

  • tgt-val.txt

Validation files are used to evaluate the convergence of the training. It usually contains no more than 5k sentences.

$ head -n 2 toy-ende/src-train.txt
It is not acceptable that , with the help of the national bureaucracies , Parliament 's legislative prerogative should be made null and void by means of implementing provisions whose content , purpose and extent are not laid down in advance .
Federal Master Trainer and Senior Instructor of the Italian Federation of Aerobic Fitness , Group Fitness , Postural Gym , Stretching and Pilates; from 2004 , he has been collaborating with Antiche Terme as personal Trainer and Instructor of Stretching , Pilates and Postural Gym .

We need to build a YAML configuration file to specify the data that will be used:

# toy_en_de.yaml

## Where the samples will be written
save_data: toy-ende/run/example
## Where the vocab(s) will be written
src_vocab: toy-ende/run/example.vocab.src
tgt_vocab: toy-ende/run/example.vocab.tgt
# Prevent overwriting existing files in the folder
overwrite: False

# Corpus opts:
data:
    corpus_1:
        path_src: toy-ende/src-train.txt
        path_tgt: toy-ende/tgt-train.txt
    valid:
        path_src: toy-ende/src-val.txt
        path_tgt: toy-ende/tgt-val.txt
...

From this configuration, we can build the vocab(s), that will be necessary to train the model:

onmt_build_vocab -config toy_en_de.yaml -n_sample 10000

Notes:

  • -n_sample is required here – it represents the number of lines sampled from each corpus to build the vocab.

  • This configuration is the simplest possible, without any tokenization or other transforms. See other example configurations for more complex pipelines.

Step 2: Train the model

To train a model, we need to add the following to the YAML configuration file:

  • the vocabulary path(s) that will be used: can be that generated by onmt_build_vocab;

  • training specific parameters.

# toy_en_de.yaml

...

# Vocabulary files that were just created
src_vocab: toy-ende/run/example.vocab.src
tgt_vocab: toy-ende/run/example.vocab.tgt

# Train on a single GPU
world_size: 1
gpu_ranks: [0]

# Where to save the checkpoints
save_model: toy-ende/run/model
save_checkpoint_steps: 500
train_steps: 1000
valid_steps: 500

Then you can simply run:

onmt_train -config toy_en_de.yaml

This configuration will run the default model, which consists of a 2-layer LSTM with 500 hidden units on both the encoder and decoder. It will run on a single GPU (world_size 1 & gpu_ranks [0]).

Before the training process actually starts, the *.vocab.pt together with *.transforms.pt can be dumped to -save_data with configurations specified in -config yaml file by enabling the -dump_fields and -dump_transforms flags. It is also possible to generate transformed samples to simplify any potentially required visual inspection. The number of sample lines to dump per corpus is set with the -n_sample flag.

For more advanded models and parameters, see other example configurations or the FAQ.

Step 3: Translate

onmt_translate -model toy-ende/run/model_step_1000.pt -src toy-ende/src-test.txt -output toy-ende/pred_1000.txt -gpu 0 -verbose

Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into toy-ende/pred_1000.txt.

Note:

The predictions are going to be quite terrible, as the demo dataset is small. Try running on some larger datasets!

For example you can download millions of parallel sentences for translation or summarization.

How to generate with a pretrained LLM

Step 1: Convert a model from Hugging Face Hub

You will find in “tools” a set of converters for models 1) from Hugging Face hub: T5, Falcon, MPT, Openllama, Redpajama, Xgen or 2) the legacy Llama from Meta.

T5 (and variant Flan-T5), Llama and Openllama use Sentencepiece. The command line to convert a model to OpenNMT-py is:

python tools/convert_openllama.py --model_dir /path_to_HF_model --tokenizer_model /path_to_tokenizer.model --output /path_to_Openllama-onmt.pt --format ['pytorch', 'safetensors'] --nshards N

Other models uses BPE, we had to reconstruct the BPE model and vocab file:

MPT bpe model

MPT vocab

Redpajama bpe model

Redpajama vocab

Falcon bpe model

Falcon vocab

The command line to convert a model to OpenNMT-py is:

python tools/convert_mpt.py --model_dir /path_to_HF_model --vocab_file /path_to_mpt.vocab --output /path_to_MPT-onmt.pt --format ['pytorch', 'safetensors'] --nshards N

/path_to_HF_model can be directly a Huggin Face repo.

Step 2: Prepare an inference.yaml config file

Even though it is not mandatory, the best way to run inference is to use a config file; here is an example:

transforms: [sentencepiece]

#### Subword
src_subword_model: "/path_to/llama7B/tokenizer.model"
tgt_subword_model: "/path_to/llama7B/tokenizer.model"

# Model info
model: "/path_to/llama7B/llama7B-onmt.pt"

# Inference
seed: 42
max_length: 256
gpu: 0
batch_type: sents
batch_size: 1
precision: fp16
#random_sampling_topk: 40
#random_sampling_topp: 0.75
#random_sampling_temp: 0.1
beam_size: 1
n_best: 1
report_time: true

or similarly for a model using BPE:

transforms: [onmt_tokenize]

#### Subword
src_subword_type: bpe
src_subword_model: "/path_to/mpt7B/mpt-model.bpe"
src_onmttok_kwargs: '{"mode": "conservative"}'

tgt_subword_type: bpe
tgt_subword_model: "/path_to/mpt7B/mpt-model.bpe"
tgt_onmttok_kwargs: '{"mode": "conservative"}'
gpt2_pretok: true
# Model info
model: "/path_to/mpt7B/mpt-onmt.pt"

# Inference
seed: 42
max_length: 1
gpu: 0
batch_type: sents
batch_size: 1
precision: fp16
#random_sampling_topk: 40
#random_sampling_topp: 0.75
#random_sampling_temp: 0.8
beam_size: 1
report_time: true
src: None
tgt: None

In this second example, we used max_length: 1 and src: None tgt: None which is typically the configuration to be used in a scoring script like MMLU where it expects only 1 token as the answer.

WARNING For inhomogeneous batches with many examples, the potentially high number of tokens inserted in the shortest examples leads to degraded results when attention layer quantization and flash attention are activated. In practice, in the inference configuration file, when batch_size is greater than 1, delete ‘linear_values’, ‘linear_query’, ‘linear_keys’, ‘final_linear’ from quant_layers and specify self_attn_type: scaled-dot.

You can run this script with the following command line:

python eval_llm/MMLU/run_mmlu_opennmt.py --config myinference.yaml

Step 3: Generate text

Generating text is also easier with an inference config file (in which you can set max_length or ramdom sampling settings):

python onmt/bin/translate.py --config /path_to_config/llama7B/llama-inference.yaml --src /path_to_source/input.txt --output /path_to_target/output.txt

How to finetune a pretrained LLM

Step 1: Convert a model from Hugging Face Hub

See instructions in the previous section.

Step 2: Prepare an finetune.yaml config file

Finetuning requires the same settings as for training. Here is an example of finetune.yaml file for Llama

# Corpus opts:
data:
    alpaca:
        path_src: "/path_to/dataAI/alpaca_clean.txt"
        transforms: [sentencepiece, filtertoolong]
        weight: 10

    valid:
        path_src: "/path_to/dataAI/valid.txt"
        transforms: [sentencepiece]

### Transform related opts:
#### Subword
src_subword_model: "/path_to/dataAI/llama7B/tokenizer.model"
tgt_subword_model: "/path_to/dataAI/llama7B/tokenizer.model"

#### Filter
src_seq_length: 1024
tgt_seq_length: 1024

#truncated_decoder: 32

# silently ignore empty lines in the data
skip_empty_level: silent

# General opts
train_from: "/path_to/dataAI/llama7B/llama7B-onmt.pt"
save_model: "/path_to/dataAI/llama7B/llama7B-alpaca"
save_format: safetensors
keep_checkpoint: 10
save_checkpoint_steps: 1000
seed: 1234
report_every: 10
train_steps: 5000
valid_steps: 1000

# Batching
bucket_size: 32768
num_workers: 2
world_size: 1
gpu_ranks: [0]
batch_type: "tokens"
batch_size: 1024
valid_batch_size: 256
batch_size_multiple: 1
accum_count: [32]
accum_steps: [0]

override_opts: true  # CAREFULL this requires all settings to be defined below

share_vocab: true
save_data: "/path_to/dataAI/llama7B"
src_vocab: "/path_to/dataAI/llama7B/llama.vocab"
src_vocab_size: 32000
tgt_vocab_size: 32000

decoder_start_token: '<s>'
# Optimization
model_dtype: "fp16"
apex_opt_level: ""
optim: "fusedadam"
learning_rate: 0.0001
warmup_steps: 100
decay_method: "none"
#learning_rate_decay: 0.98
#start_decay_steps: 100
#decay_steps: 10
adam_beta2: 0.998
max_grad_norm: 0
label_smoothing: 0.0
param_init: 0
param_init_glorot: true
normalization: "tokens"

#4/8bit
quant_layers: ['w_1', 'w_2', 'w_3', 'linear_values', 'linear_query', 'linear_keys', 'final_linear']
quant_type: "bnb_NF4"

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linear_keys', 'final_linear']
lora_rank: 8
lora_dropout: 0.05
lora_alpha: 16
lora_embedding: false

# Chekpointing
#use_ckpting: ['ffn', 'lora']

# Model
model_task: lm
encoder_type: transformer_lm
decoder_type: transformer_lm
layer_norm: rms
pos_ffn_activation_fn: 'silu'
max_relative_positions: -1
position_encoding: false
add_qkvbias: False
add_ffnbias: False
parallel_residual: false
dec_layers: 32
heads: 32
hidden_size: 4096
word_vec_size: 4096
transformer_ff: 11008
dropout_steps: [0]
dropout: [0.0]
attention_dropout: [0.0]

If you want to enable the “zero-out prompt loss” mechanism to ignore the prompt when calculating the loss, you can add the insert_mask_before_placeholder transform as well as the zero_out_prompt_loss flag:

transforms: [insert_mask_before_placeholder, sentencepiece, filtertoolong]
zero_out_prompt_loss: true

The default value for the response response_pattern used to locate the end of the prompt is “Response : ⦅newline⦆”, but you can choose another to align it with your training data.

Step 3: Finetune

You can the run the training with the regular train.py command line:

python train.py --config /path_to/finetune.yaml