## Introduction
One million lines of `python` code. Through them, the `transformers` library supports more than 400 model architectures, from state-of-the-art LLMs and VLMs to specialized models for audio, video, and tables.
Built on `PyTorch`, it's a foundational tool for modern LLM usage, research, education, and tens of thousands of other open-source projects. Each AI model is added by the community, harmonized into a consistent interface, and tested daily on a CI to ensure reproducibility.
This scale presents a monumental engineering challenge.
How do you keep such a ship afloat, made of so many moving, unrelated parts, contributed to by a buzzing hivemind? Especially as the pace of ML research accelerates? We receive constant feedback on everything from function signatures with hundreds of arguments to duplicated code and optimization concerns, and we listen to all of it, or try to. The library's usage keeps on growing, and we are a small team of maintainers and contributors, backed by hundreds of open-source community members. We continue supporting all models that come out and will continue to do so in the foreseeable future.
This post dissects the design philosophy that makes this possible. It's a continuation of our older principles, detailed on our previous [philosophy](https://huggingface.co/docs/transformers/en/philosophy) page, as well as its accompanying [blog post from 2022](https://huggingface.co/blog/transformers-design-philosophy). More recently, and I recommend the read if it's not done yet, a blog post about [recent upgrades to transformers](https://huggingface.co/blog/faster-transformers) was written, explaining in particular what makes the library faster today. Again, all of that development was only made possible thanks to these principles.
We codify the "tenets" that guide our development, demonstrate how they are implemented in code, and show the measurable impact they have on the library's sustainability and growth.
For any OSS maintainer, power user, or contributor, this is the map to understanding, using, and building upon `transformers`, but not only: any project of comparable size will require you to make deep choices, not only on design and choice of abstraction, but on the very mindset of the software you are building.
[Tenets exemplified](#source-of-truth) will have their summary available on hover.
[External links](https://huggingface.co/blog/welcome-openai-gpt-oss) to articles will help you solidify your knowledge.
[Several interactive visualisations](#generated-modeling) are available as you go - scroll, zoom, drag away.
Throughout this post, you'll find breadcrumb boxes like this one. They summarize what you just learned, connect it to the tenets, and point to what's coming Next. Think of them as narrative signposts to help you keep track.
## The core tenets of transformers
We summarize the foundations on which we've built everything, and write the "tenets" of the library. They behave like _software interfaces_, hence it is crucial that they are explicitly written down. However opinionated they are, they have evolved over time.
Note that the library _evolved_ towards these principles, and that they _emerged_ from decisions taken, and once emerged they were recognized as critical.
.py` across `src/transformers/models/.` Why keep it? Because we want all the model logic to be [contained in the modeling file](#one-model-one-file). In order to do that, we [do repeat ourselves](#do-repeat-yourself).
```python
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
```
You can use a simple regex to look at all methods of a given name across your codebase and look at their differences and similarities, that's what I did (+ a hash to avoid quadraticity).
We want all models to have self-contained modeling code.
Every core functionality _must_ be in the modeling code, every non-core functionality _can_ be outside of it.
This comes as a great cost. Enter the `#Copied from...` mechanism: for a long time, these comments were indicating that some code was copied from another model, saving time both for the reviewers and for the CI. But the LOC count kept creeping up. Each new model copied over hundreds of lines that we considered largely boilerplate, yet, we could not remove them.
We needed to separate both principles that were so far intertwined, [repetition](#do-repeat-yourself) and [hackabilty](#one-model-one-file).
What was the solution to this?
## Modular transformers
Transformers is an opiniated library. The previous [philosophy](https://huggingface.co/docs/transformers/en/philosophy) page, and the [blog post](https://huggingface.co/blog/transformers-design-philosophy) were already pointing at the drawbacks mentioned just above, which have been iteratively addressed. [`modular` transformers were introduced](https://huggingface.co/docs/transformers/en/modular_transformers), allowing a form of inheritance without breaking [One model, One file](#one-model-one-file).
We amended the principle of [DRY*](#do-repeat-yourself) by removing progressively all pieces of code that were "copied from" another file.
It works as follows. In order to contribute a model, say for instance define a `modular_` file that can inherit from _any function across all other modeling, configuration and processor files_.
Auto-generated modeling code
{{{fragment-glm-compare}}}
As you can see, we can now define any model as a _modular_ of another.
You might think "well that's just how inheritance works". The crucial difference is that we do _visibly_ what is essentially the _compiler_'s job: by unrolling the inheritances, we make visible all of the modeling code, keeping it [all in one piece](#one-model-one-file).
What is the consequence? When adding a model, we do not need to go over the entire modeling file. The modular (left side above) is enough.
When `AutoModel.from_pretrained(...)` is called, it is indeed the modeling (right side) that is ran, and all the tests are run on the modeling code.
What does that gives us?
A small
modular_*.py declares reuse; the expanded modeling file stays visible (
tenet kept). Reviewers and contributors maintain the shard, not the repetition.
Next: the measurable effect on effective LOC and maintenance cost.
### A maintainable control surface
The effect of modular can be measured straight from git history: at every commit, we look under the model directory.
If it only has a modeling file, we add its LOC count.
However, if a model has a modular_*.py and a corresponding automatically generated modeling_*/.py, we only count the LOC under the modular file. The modeling code has no maintenance cost as it is strictly dependent on the modular file.
That gives an โeffective LOCโ curve: the ๐บ๐ฎ๐ถ๐ป๐๐ฒ๐ป๐ฎ๐ป๐ฐ๐ฒ ๐๐๐ฟ๐ณ๐ฎ๐ฐ๐ฒ.
๐๐๐๐ ๐น๐ผ๐ผ๐ธ ๐ฎ๐ ๐๐ต๐ฒ ๐ฟ๐ฒ๐๐๐น๐: ๐๐ต๐ฒ ๐ด๐ฟ๐ผ๐๐๐ต ๐ฟ๐ฎ๐๐ฒ ๐ผ๐ณ ๐น๐ถ๐ป๐ฒ๐ ๐ผ๐ณ ๐ฐ๐ผ๐ฑ๐ฒ ๐ฐ๐ผ๐น๐น๐ฎ๐ฝ๐๐ฒ๐ฑ! Counting raw ๐๐๐๐๐๐๐๐_*.๐๐ข (with โCopied fromโฆโ everywhere) we were around 362 new LOC/day; with ๐๐๐๐๐๐๐ in place the effective rate is ~25 LOC/day. About ๐ญ๐ฑร ๐น๐ผ๐๐ฒ๐ฟ! Had we continued with a strict "one model, one file" policy who knows where we'd have ended up.
Less code to hand-maintain means fewer places to break: cyclomatic complexity isnโt LOC, but they strongly correlate.
{{{fragment-loc-growth}}}
There's a sharp drop near the end, it's due to us [removing support for Jax and TensorFlow](https://github.com/huggingface/transformers/commit/4df2529d79d75f44e70396df5888a32ffa02d61e#diff-60849db3e9922197854ef1cac92bf4aba08b5d7fd3fe6f3c16a3511e29e0eacc) library-wide.
Of course, it is not only this effort that allowed to reduce the maintenance load.
A related optimization was the following one. You've likely heard about [flash attention](https://huggingface.co/docs/text-generation-inference/en/conceptual/flash_attention) and its several variants.
The _attention computation_ itself happens at a _lower_ level of abstraction than the model itself.
However, we were adding specific torch operations for each backend (sdpa, flash-attention iterations, flex attention) but it wasn't a [minimal user api](#minimal-user-api).
Evidence: effective LOC drops ~15ร when counting shards instead of expanded modeling. Less to read, fewer places to break. Related cleanups: attention backends moved behind a function interface. Next: how the attention interface stays standard without hiding semantics.
### External Attention classes
We moved to an [attention interface](https://huggingface.co/docs/transformers/en/attention_interface) that allowed the following:
We keep a `Callable` for the naive implementation of the attention, called "eager" computation. This Callable is named `eager_attention_forward`, and can be run as long as the user had `torch` installed, which is a requirement in any case.
In other words, we moved from a class interface to a function interface: in order to use more complex attention implementations, the config is checked, and can use other Callables, including kernel bindings that are much faster, if they are available.
This exemplifies the fact that we prefer to have an interface that is [standard, but not abstract](#standardize-dont-abstract).
```python
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
```
A strength of the new attention interface is the possibility to enforce specific kwargs, which are needed by kernel providers and other dependencies. We know that kwargs are often a necessary evil that plagues tools with widespread compatibility; and it is something we have aimed to reduce, and will continue reduce in order to improve readability - with them, the current system is a [minimal user api](#minimal-user-api).
For better _information_, we plan to use `python` features such as `Annotated` for example, to inform users of what we expect typically in an argument. That way, higher-level information could be included directly in the type annotations, like so (tentative design):
```python
from typing import Annotated
MyModelOutputAnnotated = Annotated[MyModelOutput, "shape: (B, C, H, W)"]
```
Semantics remain in eager_attention_forward; faster backends are opt-in via config. We inform via types/annotations rather than enforce rigid kwargs, preserving integrations. Next: distribution concerns are declared as a plan, not model surgery.
### Configurable Tensor Parallelism
If you're not familiar with the different flavours of parallelism, I recommend checking out [this blog post](https://huggingface.co/blog/accelerate-nd-parallel) first, and of course a full [dive into the ultra-scale playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook) is always recommended.
The essential part is that, as [the documentation states](https://huggingface.co/docs/transformers/v4.56.2/perf_train_gpu_many#tensor-parallelism) when tensors get too large to fit on a single GPU, they are sliced along a particular dimension and every slice is sent to a different GPU.
Why does it matter?
Because we want to avoid code modifications that are unrelated to the model.
We choose to place the level of abstraction higher than the device placement: a matrix multiplication - a `nn.Linear` layer - should be always expressed in the same way, regardless of how it is placed.
Hence, we want to touch [minimally](#minimal-user-api) to the modeling code, and only modify it when _architectural changes_ are involved. For instance, for tensor parallelism, we instead now specify a simple `tp_plan`.
The alternative would be to modify parent classes specific to their
It is written once in the config and passed to `.from_pretrained()`. The plan maps module name patterns to partitioning strategies. Strategies are resolved by the internal `ParallelInterface`, which wires to sharding implementations `ColwiseParallel`, `RowwiseParallel`, packed variants, and so on.
{{{fragment-tp-plan}}}
Which allows a user to run with multiple processes per node, e.g. 4 GPUs:
`torchrun --nproc-per-node 4 demo.py`
Semantics stay in the model (a Linear stays a Linear), distribution is orthogonal and declared via strings: "colwise" splits columns of weights/bias across ranks; "rowwise" splits rows; packed variants shard fused weights; The mapping keys accept glob patterns like `layers.*.mlp.down_proj` to target repeated submodules.
Sharding is configuration (tp_plan), not edits to Linears. Glob patterns target repeated blocks; modeling semantics stay intact. Next: per-layer attention/caching schedules declared in config, not hardcoded.
### Layers, attentions and caches
Following the same logic, the _nature_ of attention and caching per layer of a model should not be hardcoded. We should be able to specify in a configuration-based fashion how each layer is implemented. Thus we defined a mapping that can be then
```python
ALLOWED_LAYER_TYPES = (
"full_attention",
"sliding_attention",
"chunked_attention",
"linear_attention",
...
)
```
and the configuration can be _explicit_ about which attention type is in which layer, see e.g. gpt-oss, which alternates sliding and full attention:
```python
"layer_types": [
"sliding_attention",
"full_attention",
...,
"sliding_attention",
"full_attention"
],
```
This is [minimal](#minimal-user-api) to implement on the user side, and allows to keep the modeling untouched. It is also easy to tweak.
Allowed layer types are explicit; schedules (e.g., sliding/full alternation) live in config. This keeps the file readable and easy to tweak. Next: speedups come from kernels that don't change semantics.
### Community Kernels
The same principle extends to normalization, activation, and other code paths. The model defines **semantics**; a kernel defines **how** to execute them faster. We annotate the module to borrow a communityโprovided forward, keeping a [consistent public surface](#consistent-public-surface)
```python
@use_kernel_forward_from_hub("RMSNorm")
class GlmRMSNorm(nn.Module):
...
```
Plus, this opened another angle of contribution for the community. People who are GPU whisperers can now contribute optimized kernels. You can check on the [kernel community blog post](https://huggingface.co/blog/hello-hf-kernels) to learn more about it!
Even more resources have been added, like the formidable [kernel builder](https://github.com/huggingface/kernel-builder) with its connected resources to [help you build kernels with it](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) and [with nix](https://github.com/huggingface/kernel-builder/blob/main/docs/nix.md).
Models define semantics; kernels define how to run them faster. Use annotations to borrow community forwards while keeping a consistent public surface. Next: what modularity looks like across the repo.
## Modular developments
Now, we have a form of inheritance in our codebase. Some models become standards, and model contributors are given the opportunity to _define standards_. Pushing the boundaries of scientific knowledge can translate into the boundaries of engineering if this effort is made, and we're striving for it.
It's hard to conceptualize very large libraries and how their components interact with each other, regardless of your cognitive abilities for abstractions.
So I wanted to take a look at the current **state of modularity** across the repository. How many models are defined using components of others?
To get this graph, I used the heuristic of modular inheritance.
1. Does this model have a `modular` file?
2. In this `modular` file, what models, configurations and processings are imported?
3. Recurse through the model list that way.
So what do we see? Llama is a basis for many models, and it shows.
Radically different architectures such as mamba have spawned their own dependency subgraph.
{{{fragment-dependency-graph}}}
However, even if llava defines a few VLMs, there's far too many vision-based architectures that are not yet defined as modulars of other existing archs. In other words, there is no strong reference point in terms of software for vision models.
As you can see, there is a small DETR island, a little llava pocket, and so on, but it's not comparable to the centrality observed for llama.
Another problem is, this is only for `modular` models. Several models do NOT have a modular file.
How do we spot them, and how do we identify modularisable models?
Graph reading guide: nodes are models; edges are modular imports. Llama-lineage is a hub; several VLMs remain islands โ engineering opportunity for shared parents. Next: timeline + similarity signals to spot candidates.
### Many models, but not enough yet, are alike
So I looked into Jaccard similarity, which we use to measure set differences. I know that code is more than a set of characters stringed together. I also used code embedding models to check out code similarities, and it yielded better results, for the needs of this blog post I will stick to Jaccard index.
It is interesting, for that, to look at _when_ we deployed this modular logic and what was its rippling effect on the library. You can check the [larger space](https://huggingface.co/spaces/Molbap/transformers-modular-refactor) to play around, but the gist is: adding modular allowed to connect more and more models to solid reference points. We have a lot of gaps to fill in still.
{{{fragment-model-timeline}}}
If you've checked out llava, you've seen that llava_video is a red node, connected by a red edge to llava: it's a candidate, something that we can _likely_ remodularize, [not touching the actual model](#backwards-compatibility) but being much more readable with [DRY*](#do-repeat-yourself).
Similarity (Jaccard; embeddings tried separately) surfaces likely parents; the timeline shows consolidation after modular landed. Red nodes/edges = candidates (e.g., llava_video โ llava) for refactors that preserve behavior. Next: concrete VLM choices that avoid leaky abstractions.
### VLM improvements, avoiding abstraction
We don't have cookbook for common VLM patterns (image token scatter, multiโtower encoders, crossโattn bridges). This is one of the main improvement points where we can work.
For instance, we thought of abstracting away the mixing of `inputs_embeds`, the tensor fed into an llm decoder in 95% of the existing VLMs. It would have looked like something like
```python
class InputsEmbeddingMixerMixin(nn.Module):
#
```
But this is [abstracting away an important component of the modeling.](#standardize-dont-abstract). Embedding mixin is part of the model, removing it would break it. A user opening [`modeling_qwen2.5_vl`](https://huggingface.co/collections/Qwen/qwen25-vl-6795ffac22b334a837c0f9a5) should not have to go to another file to understand how it works.
This is the current state of abstractions across a modeling file:

The following [Pull request to standardize placeholder masking](https://github.com/huggingface/transformers/pull/39777) is a good example of what kind of changes are acceptable. In a VLM, we always need to insert embeddings from various encoders at various positions, so we can have a function to do it. For Qwen2 VL, for instance, it will look like this:
```python
def get_placeholder_mask(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
image_features: torch.FloatTensor = None,
video_features: torch.FloatTensor = None,
):
"""
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
n_image_tokens = special_image_mask.sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
)
n_video_tokens = special_video_mask.sum()
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
raise ValueError(
f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
)
return special_image_mask, special_video_mask
```
But this is _within_ the modeling file, not in the `PreTrainedModel` base class. It will not move away from it, because it'd break the [self-contained logic](#one-model-one-file) of the model.
Keep VLM embedding mix in the modeling file (semantics), standardize safe helpers (e.g., placeholder masking), donโt migrate behavior to PreTrainedModel. Next: pipeline-level wins that came from PyTorch-first choices (fast processors).
### On image processing and processors
Choosing to be a `torch`-first software meant relieving a tremendous amount of support from `jax ` and `TensorFlow` , and it also meant that we could be more lenient into the amount of torch-dependent utilities that we were able to add. One of these is the _fast processing_ of images. Where they were before assumed to be minimal ndarrays, making stronger assumptions and enforcing `torch` and `torchvision`native inputs allowed up to speed up massively the processing time for each model.
The gains in performance are immense, up to 20x speed for most models when compiled torchvision ops. Further, it allows to have the whole pipeline solely on GPU.

Thanks Yoni Gozlan for the great work!
Torch-first lets processors assume torch/torchvision and run the whole pipeline on GPU; big per-model speedups. Next: how this lowers friction for contributors and downstream users.
## Reduce barrier to entry/contribution
This is an overall objective: there's no `transformers` without its community.
Having a framework means forcing users into it. It restrains flexibility and creativity, which are the fertile soil for new ideas to grow.
Among the most valuable contributions to `transformers` is of course the addition of new models. Very recently, [OpenAI added GPT-OSS](https://huggingface.co/blog/welcome-openai-gpt-oss), which prompted the addition of many new features to the library in order to support [their model](https://huggingface.co/openai/gpt-oss-120b).
A second one is the ability to fine-tune and pipeline these models into many other softwares. Check here on the hub how many finetunes are registered for [gpt-oss 120b](https://huggingface.co/models?other=base_model:finetune:openai/gpt-oss-120b), despite its size!
The shape of a contribution: add a model (or variant) with a small modular shard; the community and serving stacks pick it up immediately. Popularity trends (encoders/embeddings) guide where we invest. Next: power tools enabled by a consistent API.
### Models popularity
Talking about dependencies, we can take a look at the number of downloads for transformer models popularity. One thing we see is the prominence of encoders: This is because the usage of encoders lies in embeddings, just check out [EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) for a modern recap. Hence, it is vital to keep the encoders part viable, usable, fine-tune-able.
{{{fragment-model-visualisation}}}
As the codebase grows, with our friend codebase [Sentence Transformers](https://huggingface.co/sentence-transformers), we need to maintain this one as well. Retrieval use-cases, smart dbs, like FAISS-based indexing rely on it, and thus indirectly on transformers.
In that regard, we DO want to be a modular toolbox, being [minimal](#minimal-user-api) enough and well documented enough so any ML/AI developer can use `transformers` without having to think about it. We aim to reduce the cognitive load brought about by model development, not increase it.
So, how do these design choices, these "tenets" influence development of models and overall usage of transformers?
Encoders remain critical for embeddings and retrieval; maintaining them well benefits the broader ecosystem (e.g., Sentence Transformers, FAISS). Next: dev tools that leverage unified attention APIs and PyTorch-only internals.
## A surgical toolbox for model development
### Attention visualisation
All models have the same API internally for attention computation, thanks to [the externalisation of attention classes](#external-attention-classes). it allows us to build cool tools to visualize the inner workings of the attention mechanism.
One particular piece of machinery is the `attention mask`. Here you see the famous bidirectional attention pattern for the whole prefix (text + image) in PaliGemma and all Gemma2+ models, contrasting with the usual "causal-only" models.
{{{fragment-attention-visualizer}}}
Uniform attention APIs enable cross-model diagnostics (e.g., PaliGemma prefix bidirectionality vs causal). Next: whole-model tracing for ports and regressions.
### Logging entire model activations
Further, because it is all PyTorch (and it is even more now that we support only PyTorch), we can easily [debug any model](https://huggingface.co/docs/transformers/internal/model_debugging_utils) when we want to add it to transformers. We now have a power-user tool for porting or adding models, that wraps a forward pass, intercepts every submodule call, and logs shapes, dtypes, and sample statistics of inputs/outputs to nested JSON.
It just works with PyTorch models and is especially useful when aligning outputs with a reference implementation, aligned with our [core guideline](#source-of-truth).

Forward interception and nested JSON logging align ports to reference implementations, reinforcing โSource of Truth.โ Next: CUDA warmup reduces load-time stalls without touching modeling semantics.
### Cooking faster CUDA warmups
Having a clean _external_ API allows us to work on the [true inner workings of transformers](#code-is-product). One of the few recent additions was the _CUDA warmup_ via `caching_allocator_warmup` which improved massively the loading footprint by pre-allocating GPU memory to avoid malloc bottlenecks during model loading, achieving a 7x factor for an 8B model, 6x for a 32B, you can check out [the source](https://github.com/huggingface/transformers/pull/36380)!
{{{fragment-warmup_demo}}}
It's hard to overstate how much of a lifesaver that is when you're trying to load a model as fast as possible, as it's the narrowest bottleneck for your iteration speed.
Pre-allocating GPU memory removes malloc spikes (e.g., 7ร for 8B, 6ร for 32B in the referenced PR). Next: serving benefits directly from consistent interfaces and modularity.
### Transformers-serve and continuous batching
Having all these models readily available allows to use all of them with transformers-serve, and enable interfacing with them with an Open API-like pattern. As a reminder, the hub also opens access to various [inference providers](https://huggingface.co/docs/inference-providers/en/index) if you're interested in model deployment in general.
```bash
transformers serve
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages": [{"role": "system", "content": "hello"}], "temperature": 0.9, "max_tokens": 1000, "stream": true, "model": "Qwen/Qwen2.5-0.5B-Instruct"}'
```
This provides an OpenAI-compatible API with features like [continuous batching](https://github.com/huggingface/transformers/pull/38085) (also check [here](https://github.com/huggingface/transformers/pull/40426)) for better GPU utilization.
Continuous batching is in itself very much linked to the great work of vLLM with the `paged attention kernel`, further justifying the facilitation of [external kernels](#community-kernels).
OpenAI-compatible surface + continuous batching; kernels/backends slot in because the modeling API stayed stable. Next: reuse across vLLM/SGLang relies on the same consistency.
## Community reusability
Transformers-serve is transformers-first, for sure, but the library is made first and foremost to be _reused_ at large by the open-source ecosystem.
Adding a model to transformers means:
- having it immediately available to the community
- having it immediately usable in vLLM, [SGLang](https://huggingface.co/blog/transformers-backend-sglang), and so on without additional code. In April 2025, transformers was added as a backend to run models on vLLM, which optimizes throughput/latency on top of existing transformers architectures [as seen in this great vLLM x HF blog post.](https://blog.vllm.ai/2025/04/11/transformers-backend.html)
This cements the need even more for a [consistent public surface](#consistent-public-surface): we are now a backend, and there's more optimized software than us to handle serving. At the time of writing, more effort is done in that direction. We already have compatible configs for VLMs for vLLM (say that three times fast), [here for GLM4 video support](https://github.com/huggingface/transformers/pull/40696/files), and here for [MoE support](https://github.com/huggingface/transformers/pull/40132) for instance.
Being a good backend consumer requires a consistent public surface; modular shards and configs make that stability practical. Next: what changes in v5 without breaking the promise of visible semantics.
## What is coming next
The next major version of `transformers` is just around the corner. When v5 is releasd, [backwards compatibility](#backwards-compatibility) will try to stay as solid as possible. Changes we do now are to ensure this.
Instead, what we aim to be is way more of a modular Toolbox. What we are not is a framework: you should not be FORCED to rewrite every modeling, but it is better for your model to be able to inherit from PreTrainedModel and have enabled TensorParallel, from_pretrained, sharding, push_to_hub, loss, as well as PEFT/TRL/SGLang/vLLM and other fine-tuning and fast inference options.