Web Analytics Made Easy - Statcounter

Passing the Torch: Training a Mamba Model for Smooth Handover

We present our explorations on training language models based on the new Mamba architecture, which deviates from the traditional Transformer architecture.

April 10, 2024
Lightbulb

TL;DR

"We experiment with the Warmup-Stable-Decay (WSD) learning rate scheduler and a novel positional weighting of the loss for language model pre-training; We find that WSD outperforms the cosine scheduler, and positional weighting results in better top k accuracy. We conduct the experiments using the Mamba architecture, which with its linear complexity achieves substantially higher throughput at inference than transformers. Finally, based on our experiments we train Mambaoutai, a 1.6B parameters model on 300B tokens. The training dataset comprises mainly French, English and code from various programming languages. Over 80 checkpoints of Mambaoutai 1.6B are released for the ML community to explore here, which thanks to the WSD scheduler can be further pre-trained smoothly without any degradation."

1. Introduction

We present our explorations on training language models based on the new Mamba architecture [1], which deviates from the traditional Transformer architecture. Because of its linear complexity and computing efficiency, Mamba caught a lot of attention among practitioners and the ML broader community. These explorations yield a training recipe used to effectively train a 1.6B model, Mambaoutai, that we also release openly for the community. We make over 80 checkpoints available, enabling interpretability studies and further training. Besides the main English training corpus, the model’s training data includes a significant portion of French data and code from various programming languages (Rust is among them if you were wondering).

We will outline all the components involved in training Mambaoutai, from data preparation and modeling decisions to the training setup itself. We discuss the training throughput achieved with a Fully Sharded Data Parallel (FSDP) codebase and compare it with the Nanotron library [3]. We further explore the impact of different training techniques on Mambaoutai’s performance and convergence during pre-training. We focus on comparing three strategies aimed at enhancing pre-training: two learning rate schedulers and positional weighting of the pre-training loss.

First, we examine the effects of different learning rate schedulers on model performance within a fixed training budget. Our comparison includes the traditional warmup with cosine decay and the Warmup-Stable-Decay (WSD) scheduler introduced for MiniCPM [2]. The WSD scheduler, compared to the cosine scheduler, has a clear separation between pre-training stages enabling introduction of high quality and/or instruction data during the final decay phase, and allows seamless continuation of pretraining from a pre-decay checkpoint without the concern of a cold restart.

Finally, we explore how weighting different token positions in the loss function can improve accuracy across top-k metrics by decreasing the loss weight assigned to model predictions for the first tokens of every text sequence.

All the checkpoints are available on Hugging Face at: Mambaoutai.

The training code is available on Github at: Code.

2. Training: Throughput analysis

A central aspect of large language model training is the distribution of the training load over several accelerators. For instance, DBRX-132B is allegedly trained over 2 weeks on 3072 GPUs [9], and Falcon-180B was trained on up to 4096 GPUs [10], exhibiting the crucial importance of efficient distributed computation required for such model sizes. Whereas these models are substantially larger than the models trained in our Mamba experiments, smaller models can still benefit significantly from distributed training as it allows to speed up the training that otherwise risks being too slow due to the large scale of training datasets.

There are two general approaches when it comes to distributed training :

3D parallelism [14, 18]: this approach requires a model architecture-specific approach where the model design is changed in order to make it compliant with the specifics of 3D parallelism.

PyTorch FSDP [11, 12]: this approach is much simpler and does not require intrusive modifications to the model’s forward. This is done by sharding the model parameters, gradients and optimizer states across accelerators by wrapping the model into individual units called FSDP units.

We compare our adapted version of the composer codebase [23] using FSDP with the Nanotron ⚡ codebase [3] that leverages 3D parallelism on a 2B parameter scale. All training runs are conducted with bfloat16 [19] and mixed precision training [13].

The benchmarking is carried out on one node of 4xA100-64GB GPUs. The context length is fixed at 4096. We optimize both setups and only show the best results for each codebase. The results of the different runs are presented in the table below:

Parallelism Type Model Batch Size Block-wise Activation Recomputation Throughput TFLOPs
FSDP 1.6B 5 Yes 11000 96.25
3D (\(dp=2\), \(tp=1\), \(pp=2\)) 1.6B 2 Not supported 4880 51

With our codebase, we achieve a throughput of 96.25 TFLOPs per second per GPU, which corresponds to an MFU (Model FLOPs Utilization) of \(0.31\) while we expect the hardware utilization to be much higher due to the use of activation recomputation that MFU does not consider. In fact, the reported FLOPs are only accounting for one forward and one backward pass, which results in an underestimation of the real floating point operations performed by the hardware. A higher MFU can be achieved by increasing the number of gradient accumulation steps. Our training codebase allows for different types of activation recomputation; In addition to the activation recomputation described in the original Mamba paper [1] we also support recomputation of the output of every Mamba block. We get the most throughput when using both types of activations checkpointing simultaneously, as the reduced memory usage allows for a larger micro batch size. We also note that the wrapping policy in FSDP is another key element to achieving optimal results, and the default behavior, which consists in wrapping the entire model as a single unit, results in suboptimal results; thus, adapting the FSDP wrapping policy was necessary.

Finally, we also find that the Nanotron 3D parallelism achieves worse results than our codebase with PyTorch FSDP for this model size. One reason is the use of the per block activation recomputation, and other reasons might include increased communication and the existence of bubble times in the pipeline parallelism (Nanotron uses one-forward-one-backward pipeline scheduler) resulting in underutilization of the hardware. It is expected that 3D parallelism would be better suited for larger models and achieve better throughput at a larger scale.

3. Model Architecture and Dataset

3.1 Model Architecture

The model configuration for Mambaoutai 1.6B follows the original implementation but differs in a few architectural details. The first difference is the use of a much wider hidden state compared to the number of layers, as this enables higher throughput with a close to negligible degradation in training loss. Theoretically, the larger hidden state allows the model to compress more information from previous parts of a context, although introducing a trade-off of less depth in the model given a constant model size. Another difference is that we do not tie the word embedding matrix with the last projection layer in this work for the sake of simplicity and ease of integration with PyTorch FSDP. Finally, one of the key aspects of the Mamba architecture is its linear complexity dependency to the context length, which is very useful to handle long contexts. We thus extended the context length up to \(4096\) to further leverage the strong capacity of this new architecture.

3.2 Data Preparation

We wanted Mambaoutai to be competent in French, English and coding. We thus created a pre-training dataset reflecting these domains, aligned with the state-of-the-art large language model practices, incorporating a mix of widely used datasets like RedPajama-Data-V2, subsets of RedPajama-Data-1T and the Stack. More details about the different datasets and the custom tokenizer training details are given below.

3.2.1 Datasets

To train our Mamba, we have used only open-source datasets available on Hugging Face to encourage reproducibility and transparency, following recent releases [25]. These datasets are already filtered and deduplicated, which saves us a lot of time and resources in preprocessing.

togethercomputer/RedPajama-Data-V2

This dataset, including over 100B text documents coming from 84 CommonCrawl snapshots, was our primary source of data during our stable pre-training phase. All available Common Crawl snapshots contain 5000 shards per language.

We used the first 30 shards of each snapshot (head_middle partition) to gather 78B tokens in English, and the first 300 shards to gather around 97B tokens in French.

Using the precomputed metrics of the RedPajama-Data-V2 dataset, we removed the duplicated samples and filtered the dataset with the Gopher[6] and C4[7] filters.

togethercomputer/RedPajama-Data-1T

We used only the book, arxiv, and wikipedia (English and French) subsets of this dataset.

The other subsets of this dataset (C4, Commoncrawl and GitHub) were not used because of possible duplication with the other datasets used for training (RedPajama V2 for CommonCrawl and The Stack for code).

bigcode/the-stack-dedup

Since we wanted our model to have reasonable coding capabilities, we decided to use a non-negligible amount of code extracted from this near-deduplicated version of bigcode/the-stack.

We only used the most popular programming languages of GitHub: javascript, python, cpp, c, java, go, rust, and php.

3.2.3 Dataset distribution

As previously introduced, the WSD scheduler is composed of two phases: the stable learning phase, during which the model has a steady and high learning rate for a fast training and the decay phase, where the model is converging and thus requires higher quality data. Accordingly, we have two data mixtures for each of the phases.

Our stable pre-training dataset distribution is as follows:


Data distribution during the stable phase of the training

We then add instruction data to our decay phase, resulting in the following dataset distribution:


Data distribution during the decay phase of the training

For the decay phase, we added one epoch of the UltraChat dataset [24] and scaled down the other datasets accordingly.

3.2.3 Tokenizer

We decided to train our tokenizer from scratch with the HuggingFace tokenizers library to improve coverage of our training datasets and to have a larger vocabulary size. Larger vocabulary sizes compress the sequence length further, meaning more information can be packed into the context, which is critical for language models.

We trained a Byte-Pair Encoding (BPE) tokenizer with a vocabulary size of 64k tokens and added a set of special “ChatML” tokens used during the instruction fine-tuning to simulate a conversation between the user and the assistant. The tokenizer was trained on 15GB of data, and the language distribution approximately matches the training distribution of the Mamba model (10% of code, 30% of French and the remaining 60% of English).

4. Experiments and Ablations

In this section, we detail the experiments conducted to validate different technical choices for the training procedure. These experiments were made to study the choice of different learning rate schedulers and the effect of loss positional weighting on performance. We used a smaller 0.5B parameters model and a maximum of 10B tokens for computational reasons.

We adhere to the architecture and hyperparameters used in the original Mamba [1] implementation. The peak learning rate is chosen based on a small hyperparameter search over 1B tokens, with a maximum value of \(2 \times 10^{-3}\). Following the literature, we use a batch size of 2 million tokens, linearly ramped up over the first 1B tokens, and the sequence length is set to 4096. After warmup, we experiment with two schedules: Cosine and WSD. The AdamW optimizer is employed with \(\beta_1 = 0.9\), \(\beta_2 = 0.95\), weight decay of \(0.1\), and gradient clipping at a threshold of \(1.0\). We do not tie word embeddings and softmax matrices for the experiments in this section too. Training takes place on 16 nodes (4xA100-40GB GPUs) utilizing PyTorch FSDP for model sharding over \(64\) GPUs. For datasets, we utilize the French subset from Red Pajama-v2 available on Hugging Face with a fixed vocabulary size of 64k. We use some of the French data from the Red Pajama-v2 dataset as validation data for our model.

4.1 Learning rate schedulers

4.1.1 Cosine scheduler

The cosine scheduler makes the learning rate follow a cosine-shaped function, starting with a high learning rate at the beginning of training and gradually decreasing it to a minimum value over time. This approach allows for faster initial progress while ensuring finer steps towards the end of training. Compared to other schedulers, such as step decay or exponential decay, the cosine scheduler has been shown to provide better performance on various tasks, including image classification, language modeling and neural machine translation, and has become the default scheduler, at least in NLP [10].

The used cosine scheduler starts with a linear warmup for 480 steps and then decays the learning rate following a cosine curve to the minimum value of \(2 \times 10^{-4}\). It is illustrated in the following figure.


The cosine scheduler with a linear warmup phase

A critical limitation of cosine learning rate decay is that it achieves optimal performance only when performing an entire cosine period [4], forcing practitioners to fix the number of steps beforehand, which poses a significant hurdle if we want to continue pre-training later when more data or/and compute becomes available.

4.1.2 WSD scheduler

The Warmup-Stable-Decay (WSD) [2] scheduler solves this issue and allows to continue pre-training without having to define a predetermined number of steps. This is particularly handy because the number of training steps might not be known a priori, especially if the model is further trained after being openly release. The WSD scheduling strategy is composed of three stages: warmup phase, stable learning rate phase, and decay phase. The warmup stage gradually increases the learning rate from a minimum to a maximum value over a specified number of steps or amount of training. The stable training stage keeps the learning rate constant at the peak value resulting in faster training overall. In the decay stage, the learning rate decreases according to a predefined function, such as exponential or linear decay. The decay phase ensures convergence while avoiding overshooting near convergence points. Additionally, high quality and/or instruction data can be added to the pretraining data mix at the start of the decay phase, enabling a more efficient way of integrating resource-constrained data during pretraining. What is even more interesting is that we can continue pre-training from any checkpoint from the stable stage without worrying about very low learning rates as is the case with the cosine scheduler or warm model degradation [17, 16]. Contrary to checkpoints obtained with the cosine scheduler, any checkpoint prior to the decay phase can be used as a starting point to continue the pre-training. So for ML practitioners it is easy to take a model pre-trained for 300B tokens and continue pre-training up to 3T tokens for example. This has already been advocated in the computer vision domain [26], and we think it should become the standard for NLP. We envision this as being a game changer for the open model ecosystem as entities can iteratively continue pretraining open models and collectively train models trained over volumes of data that were impossible for smaller organizations. In this mindset we release Mambaoutai pre and post decay, for open science as well as hoping to kickstart a virtuous collective pre-training effort.

The WSD scheduler we used starts with a linear warmup for 480 steps, similar to the cosine schedule, followed by a constant learning rate at its peak value for \(80\%\) of the total training steps. The decay phase then begins at the end of the stable training phase, where the learning rate decreases linearly to a minimum value of \(2 \times 10^{-4}\). This scheduler is illustrated in the following figure.


The WSD scheduler

4.2 Loss Positional Weighting

The early tokens of a sequence are harder for the model to predict, as they are conditioned on fewer tokens. As a consequence, the cross-entropy loss is substantially higher for the initial tokens. We hypothesize that the loss on early tokens thus contributes to a noisier learning signal. Furthermore model generations are rarely conditioned on just a few tokens in practice. Hence, we propose to apply a positional weighting of the loss, i.e, multiplying each token loss term by a weight based on its position. We therefore experiment with down weighting the loss on early tokens during training, using a scaled hyperbolic tangent (tanh) function as a starting point, where the weight is calculated as:

\[\mathbf{positional\_weighting} = \mathbf{tanh}(\frac{\mathbf{position}}{10})\]

The figure below shows that, in practice, except for the first few tokens (around \(30\) tokens in this case) which are down weighted, the rest of the tokens have a weight of \(1\), which is equivalent to traditional language modeling.


The scaled hyperbolic tangent function(tanh) used to scale the loss

In summary, we have conducted three experiments:

  • We train a 0.5B Mamba with a WSD scheduler for 10B tokens.
  • We train a 0.5B Mamba with a Cosine scheduler for 10B tokens.
  • We train a 0.5B Mamba with a WSD scheduler and positional weighting for 10B tokens.

4.3 Insights

The results of each setup is given in the following figure:


Training loss curves for all three experiments

We observe that the training losses of different setups were similar throughout the training process, except for some spikes for the runs using the WSD scheduler. Despite the spikes, the WSD runs still achieved a better training loss than the cosine scheduler, which achieved a higher loss. Spikes in loss can occur for various reasons such as dirty data, or simply due to model scale, and may also happen randomly during optimization when passing through a bad optimization point. As the spikes only occur during training with WSD, and not for training with cosine schedule for the same model size and exact same data, we can rule out the bad data and model size as the main triggering factor of the spikes. Instead, it seems like the spikes were caused by maintaining a peak learning rate for an extended training period. This indicates that stable training with WSD would require using a lower learning rate than a cosine run. In any case, even though WSD might be less stable, it is still a better choice than the cosine scheduler.

Focusing on the decay phase allows for a clearer view of the training dynamics during this phase, as illustrated in the following figure. The WSD scheduler exhibits a totally different training dynamic compared to the cosine scheduler. In fact, while it lags behind the latter during all of the stable stage, it quickly catches up and surpasses it during the decay phase. This illustrates how critical the decay phase is for the model and should be carefully designed. Also, this means that sharing pre-decay checkpoints allow practitioners to not only continue pre-training but also perform their own appropriate decay phase.


Training loss curves during the decay phase for all three experiments

On the validation set, the WSD schedule achieved the best perplexity of \(11.58\), followed by WSD coupled with positional weighting with a perplexity of \(11.60\). The cosine schedule had a higher perplexity of \(11.80\).


Validation perplexity during the decay phase for all three experiments

However, perplexity alone may not provide a comprehensive evaluation of the end model. Therefore, we also decided to include top-k accuracy metrics for k in \(\{3, 10\}\) to gain a more granular view of the model performance. These metrics measure the proportion of correct predictions among the top-k candidates and can provide additional insights beyond perplexity. By including both perplexity and top-k accuracy metrics, we can better assess the overall quality and generalization capabilities of our models.

In the figure below, we see that the run using WSD and positional weighting again achieves better validation results in terms of top 3 and top 10 accuracies, while runs without positional weighting (both WSD and cosine) are almost the same for top 10 accuracy. We note that the top 3 accuracy of the WSD with positional weighting run is at \(0.358\), which is way above the random value of \(3/64000=4.6 \times 10^{-5}\). The top 10 accuracy of the run using both WSD and positional weighting is at \(0.369\), which is also much higher than the random assignment of 10 random values from the vocabulary \(10/64000=1.5 \times 10^{-4}\).


Top-k accuracy metrics for k in \(\{3, 10\}\) for all experiments

While positional weighting of the loss resulted in slightly improved top-k validation metrics, it did not improve the perplexity. This is expected as it trades prediction quality of earlier tokens for latter ones, and perplexity is dominated by the earlier tokens. Top-k accuracy, on the other hand, implicitly normalizes the contribution of each position, and thus is not as dominated by model performance on lower positions. The results support our hypothesis that the first tokens are not modeled with the same strategy as the later ones. For predictions of early tokens the model needs to model their unconditional distribution as it’s basically a guessing game, while prediction on later tokens can better exploit language causality from previous tokens. Moreover, the loss being dominated by the loss on the first few tokens indicates that they dominate on the gradient as well. As the first few tokens are very rarely inferred in all prompt-based use cases, we reduce the task’s complexity, mitigate this split objective, and reduce the gradient noise. Overall, combining these techniques can lead to improved model performance and generalization capabilities.

5. Experimental Results and Benchmarks

In the following section, we present and analyze the performance of our model Mambaoutai 1.6B, trained based on our observations in the previous section, using the WSD scheduler and positional weighting.

We train Mambaoutai 1.6B following the previous recipe while scaling up the hidden state and the number of layers. We select the peak learning rate using the same strategy as done in the experiments. Precisely, we use a batch size of \(896\) with a learning rate of \(4.5 \times 10^{-4}\), and a context length of \(4096\) tokens. We train Mambaoutai on 300B tokens with a warmup phase of \(2000\) steps, and with the linear decay phase of the WSD scheduler lasting for 30B tokens. We run two different decay phases resulting in two different models, one decay phase with the standard pre-training data mix, and one decay phase where we add instruction data to the original data mix to enable the pre-trained model to follow instructions.

During the training process of Mambaoutai, we also encountered some spikes. Whenever a loss spike occurred during training, we restarted the training from a previous checkpoint and skipped the data causing the spike. We also had to restart the training once due to hardware issues. The different chronicles of the training are represented on the training loss curve below, showcasing the different restarts.


The training loss for Mambaoutai with the different restarts

5.1 Long-range Performance

One of the main strengths of the Mamba architecture is its ability to process long inputs more efficiently. Given that we specifically train Mambaoutai to leverage this property (larger context/hidden states), we evaluate the resulting capacity using the Needle-in-a-haystack evaluation. The Needle-in-a-haystack test, as described in [5], is a method used to evaluate the retrieval performance of language models over various context window sizes. In this test, a specific needle (a random number between \(1\) and \(50000\)) is inserted amidst a large volume of unrelated texts, the haystack. The objective of the evaluated model is to locate the needle by answering a question that necessitates finding this fact within the extensive context. This evaluation is methodically conducted by hiding the needle at 10 random points within the context and observing the model’s performance across context window sizes ranging from \(500\) to \(16384\) tokens. Please note that the training context size of Mambaoutai is limited to \(4096\) tokens but we evaluate it at \(4x\) that length.

We note in particular that Mamba not requiring positional embeddings facilitates context length extension without bells and whistles, unlike transformer-based models where extending the context window requires sophisticated methods such as YaRN [20], StreamingLLM [21], NTK-YaRN [22] or LongLlama.

The figure presented below showcases the performance of Mambaoutai in a Needle-in-a-Haystack evaluation. It demonstrates that the model is nearly flawless in identifying the needle within a pre-training context length of \(4096\), and it maintains decent performance even when the context length is tripled. However, as the context length exceeds manageable limits, the model begins to struggle with identifying the key, particularly when the key is positioned at the beginning of an extensive context. This behavior showcases the Mamba architecture’s method of scanning the context and selectively compressing essential information into a fixed-size hidden state. This phenomenon becomes apparent towards the figure’s end, where the context’s length becomes too large for the hidden state to handle, compelling the model to discard older information to accommodate the influx of new data, especially when the needle is located at the text’s beginning.


Needle in a haystack evaluation for Mambaoutai on a 16k context length

To draw a comparison between Mambaoutai and existing transformer models of the same size, we chose the Stable LM 2 1.6B [8] model released by Stability AI, which was trained with the same context length of \(4096\) on 2T tokens. We generated the plot shown in the subsequent figure by conducting the identical benchmark test. The Stable LM 2 1.6B model succeeds in handling context lengths of approximately 4k and slightly beyond. However, beyond this threshold, it fails completely when used as is, a limitation attributed to the use of positional embeddings within the transformer architecture. These positional embeddings inherently restrict the transformer models’ capacity. In contrast, the Mamba architecture, free from such constraints, does not suffer from this limitation, showcasing its distinct advantage in processing extended contexts.


Needle in a haystack evaluation for Stable LM 2 1.6B on a 16k context length

Whether larger Mamba models can surpass the performance of transformer-based architectures remains unanswered. Expanding the size of a Mamba model’s hidden state and consequently its model size would likely allow it to retain more context information. However, for a constant model size, there is a compromise between depth and width, thus resulting in a possible tradeoff between processing capabilities and long context handling. Moreover, the lack of positional embeddings in Mamba models means they can naturally handle larger contexts without worrying about pre-training context length.

5.2 Academic Benchmarks

Generally, benchmarks ought to be approached with caution due to the risk of contamination or manipulation especially when the models are trained on trillions of tokens of data. With this in mind, we provide indicative results to have an idea of model performance on popular benchmarks. In line with our training dataset mixture for which the biggest two subsets consist of French and English data, we evaluated our model on two sets of benchmarks: FrenchBench and the Open LLM leaderboard benchmarks (English). In order to have a reference and a baseline of comparison for our model we have chosen two similarly sized models: Stable LM 2 1.6B and Falcon-RW-1B (1.3B). Before diving into the results, it is important to note that the comparison is to be considered with the following differences in the number of tokens used for training in mind. In fact, model performance is highly correlated with the number of training tokens used during training. Here is a detailed view of the amount of data used in the training of the models being compared:

Model French Data English Data Total
Mambaoutai 1.6B 100B 150B 300B
Stable LM 2 1.6B 20B 1675B 2000B
Falcon 1B 0 350B 350B

First, we note that Mambaoutai 1.6B is the one with the least total number of tokens used for training, followed by Falcon-RW-1B, then by a large margin we have Stable LM 2 1.6B with 2T tokens. Second, Falcon-RW-1B is trained mainly on English and Stable LM 2 1.6B used only 20B tokens while Mambaoutai was trained on 100B tokens of French. Finally, Stable LM 2 1.6B was trained on around 1675B tokens of English followed by Falcon-RW-1B with 350B tokens and Mambaoutai 1.6B with 150B tokens. The Open LLM Leaderboard Benchmarks include common sense reasoning benchmarks such as HellaSwag (10-shot), WinoGrande (5-shot) and ARC-Challenge (25-shot). It also includes GSM8K (5-shot), TruthfulQA (zero-shot) and MMLU (5-shot).

As expected given the outlined differences in pre-training data, Stable LM 2 1.6B typically performs better than both Falcon-RW-1B and Mambaoutai 1.6B In the English-centric benchmark, with an exception in the TruthfulQA benchmark, where Falcon-RW-1B shows superior performance compared to the other two models. For the remaining benchmarks, Mambaoutai 1.6B outperforms Falcon-RW-1B in three cases and matches its performance in two others. For GSM8K, we note that Stable LM 2 1.6B is the only model with dedicated math datasets in the pretraining, whereas Mambaoutai and Falcon-RW-1B have no such data and get an accuracy close to zero.


Open LLM Leaderboard Benchmarks

The FrenchBench consists of an array of classification and generation tasks, covering various orthogonal aspects of model performance in the French Language. In particular, it also includes the French Language Test which aims to assess the grammar and vocabulary capabilities of models through language tests. It provides a structured evaluation of a model’s linguistic proficiency, aiming to measure its competency in understanding and generating coherent and grammatically accurate sentences in the French language. Moreover, the FrenchBench also includes popular benchmarks such as HellaSwag and ARC-Challenge translated by GPT3.5.

Finally, on French we can see that Mambaoutai is outperforming both Stable LM 2 1.6B and Falcon-RW-1B on tasks such as grammar, vocabulary and topic based Natural Language Inference, while it still lags behind on the reasoning-heavy tasks that have been translated from English using GPT3.5. We attribute these results to the training data discrepancies as Mambaoutai 1.6B is trained on more French tokens than both Stable LM 2 1.6B and Falcon-RW-1B.


French Bench for French language evaluations

As a conclusion, we note that the model underperforms its comparatively sized transformers counterpart Stable LM 2 1.6B [8] on standard NLP benchmarks in English, but is on par with it on French benchmarks. The difference in performance on English benchmarks can to a large extent be attributed to that Stable LM 2 1.6B has been trained on more than 6x the number of tokens that Mambaoutai has, and that Mambaoutai additionally has a substantially lower percentage of English in the training data because we wanted it to be capable in French and code. We also acknowledge that the decision to increase the model width to achieve higher throughput and having a hidden state that could compress more information from long contexts, can have caused worse reasoning capabilities as the model depth had to be decreased. We encourage further research on this topic for Mamba-like architectures.

6. Conclusion

In this blog post, we have presented the training details of pre-training a relatively large language model based on the newly introduced Mamba architecture. We have shared the results of our explorations to give some insights to practitioners willing to train a Mamba, and additionally insights from the WSD scheduler and positional weighting should be transferable to other model architectures such as transformers, as they are architecture agnostic for autoregressive language modeling.

Firstly, we have validated that the WSD scheduler is a promising alternative to the cosine scheduler as it gives better performance on validation metrics, while additionally enabling smooth continuous pre-training and a custom and adaptable decay phase. Also, we found that positional weighting of the language modeling loss improves performance of the resulting model in terms of top-k accuracy metrics. Another byproduct of using a state space model is the native support for large context lengths which outperformed highly optimized transformer models of the same size on the needle-in-a-haystack task. Finally, we are releasing various checkpoints at different steps of the training, to allow people to evaluate the final performance, continue the pre-training or perform a decay phase more suited to their downstream applications.

Future work:

  • Investigate positional weighting further: experiment with more advanced weighting functions, evaluate the impact on downstream performance.
  • Test retrieval and long context length beyond the Needle-in-haystack task. Fine-tuning the model on larger contexts is a viable choice if we want to use the model as a document scanner before feeding the retrieved results to a bigger model.
  • Train a bigger pure Mamba to validate its effectiveness against the Transformer once and for all.

Limitations:

  • Mambaoutai is trained only on French, English and Code. Other languages are not supported.
  • Constrained by the model size, it may exhibit hallucinations and repetitions, particularly with longer and more elaborate prompts.
  • Constrained by the model capacity, the accuracy of the model’s knowledge recall is limited.
  • The evaluation of the WSD-scheduler and positional weighting is still very limited

To cite this work, please refer to the following bibtex:

@misc{mambaoutai,
  title={Passing the Torch: Training a Mamba Model for Smooth Handover},
  author={Hallström, Oskar and Taghadouini, Said and Thiriet, Clément and Chaffin, Antoine},
  url={https://www.lighton.ai/blog/lighton-s-blog-4/passing-the-torch-training-a-mamba-model-for-smooth-handover-54},
  year={2024}
}

References

[1] Albert Gu, Tri Dao. Mamba: Linear-Time Sequence Modeling with Selective State Spaces. https://arxiv.org/abs/2312.00752 (2023)

[2] MiniCPM: Unveiling the Potential of End-side Large Language Models. https://shengdinghu.notion.site/MiniCPM-Unveiling-the-Potential-of-End-side-Large-Language-Models-d4d3a8c426424654a4e80e42a711cb20

[3] Nanotron. https://github.com/huggingface/nanotron . (as of 28/03/2024)

[4] Jordan Hoffmann et al. Training Compute-Optimal Large Language Models. https://arxiv.org/pdf/2203.15556.pdf

[5] Amirkeivan Mohtashami, Martin Jaggi. Landmark Attention: Random-Access Infinite Context Length for Transformers. https://arxiv.org/abs/2305.16300

[6] Jack W. Rae et al. Scaling Language Models: Methods, Analysis & Insights from Training Gopher. https://arxiv.org/abs/2112.11446

[7] Colin Raffel et al. Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. https://arxiv.org/abs/1910.10683

[8] Marco Bellagente et al.Stable LM 2 1.6B Technical Report. https://arxiv.org/abs/2402.17834

[9] https://www.wired.com/story/dbrx-inside-the-creation-of-the-worlds-most-powerful-open-source-ai-model/

[10] Almazrouei, Ebtesam, et al. The falcon series of open language models. arXiv preprint arXiv:2311.16867 (2023).

[11] Xu et al. Automatic cross-replica sharding of weight update in data-parallel training. arXiv preprint arXiv:2004.13336 (2020).

[12] Rajbhandari, Samyam, et al. “Zero: Memory optimizations toward training trillion parameter models.” SC20: International Conference for High Performance Computing, Networking, Storage and Analysis. IEEE, 2020.

[13] Micikevicius, Paulius, et al. “Mixed precision training.” arXiv preprint arXiv:1710.03740 (2017).

[14] Shoeybi, Mohammad, et al. “Megatron-lm: Training multi-billion parameter language models using model parallelism.” arXiv preprint arXiv:1909.08053 (2019).

[15] Chen, Tianqi, et al. “Training deep nets with sublinear memory cost.” arXiv preprint arXiv:1604.06174 (2016).

[16] Ibrahim, Adam, et al. “Simple and Scalable Strategies to Continually Pre-train Large Language Models.” arXiv preprint arXiv:2403.08763 (2024).

[17] Ash, Jordan T., and Ryan P. Adams. “On the difficulty of warm-starting neural network training.” (2019).

[18] DeepSpeed: Extreme-scale model training for everyone. https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/ (2020).

[19] BFloat16: The secret to high performance on Cloud TPUs. https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus (2019).

[20] YaRN: Efficient Context Window Extension of Large Language Models. https://arxiv.org/abs/2309.00071 (2023).

[21] Efficient Streaming Language Models with Attention Sinks. https://arxiv.org/abs/2309.17453 (2023).

[22] Hallström, Oskar et al. “Alfred-40B-1023” https://huggingface.co/lightonai/alfred-40b-1023 (2023)

[23] The Mosaic ML Team. “Composer” (2021)

[24] Ning Ding et al. Enhancing Chat Language Models by Scaling High-quality Instructional Conversations (2023) https://arxiv.org/abs/2305.14233

[25] Open Language Model: OLMo (2024) https://allenai.org/olmo

[26] Zhai et al. Scaling Vision Transformers (2021) https://arxiv.org/abs/2106.04560

Recent Blogs

Ready to Transform Your Enterprise?