Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The project implements a custom runtime that applies many performance optimizati

The following model types are currently supported:

* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper T5Gemma
* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper, T5Gemma, T5Gemma2
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2
* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa

Expand Down
39 changes: 39 additions & 0 deletions docs/guides/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ CTranslate2 supports selected models from Hugging Face's [Transformers](https://
* Qwen 3
* T5
* T5Gemma
* T5Gemma2
* Whisper
* XLM-RoBERTa

Expand Down Expand Up @@ -631,6 +632,44 @@ print(final_translation)
```


## T5Gemma2

[T5Gemma2](https://huggingface.co/collections/google/t5gemma2-686038abe6c47de48d6d3aa4) is a collection of Google encoder-decoder models that combines the T5 encoder-decoder architecture with Gemma 2 decoder components, featuring sliding-window and full attention layers.

To convert a model:

```bash
ct2-transformers-converter --model google/t5gemma-2-270m-270m --output_dir t5gemma2_270m_270m.ct2
```

Usage:

```python
import ctranslate2
import transformers

translator = ctranslate2.Translator("t5gemma2_270m_270m.ct2")
tokenizer = transformers.AutoTokenizer.from_pretrained("google/t5gemma-2-270m-270m")

sentences = ["Question: Why is the sky blue? Answer:"]

tokenized_sentences = [
tokenizer.convert_ids_to_tokens(tokenizer.encode(sentence))
for sentence in sentences
]

translated_batches = translator.translate_batch(
tokenized_sentences, beam_size=1, repetition_penalty=1.2, max_decoding_length=50
)

translations = [
tokenizer.decode(tokenizer.convert_tokens_to_ids(t.hypotheses[0]))
for t in translated_batches
]
print(translations[0])
```


## Whisper

[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper) is a multilingual speech recognition model published by OpenAI.
Expand Down
18 changes: 18 additions & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ namespace ctranslate2 {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
}

bool has_merged_encoder_attention() const {
return bool(_memory_kv);
}

void forward_merged(const StorageView& queries,
const StorageView* memory,
const StorageView* memory_lengths_mask,
const StorageView* self_lengths_mask,
StorageView& output,
StorageView* cached_self_keys,
StorageView* cached_self_values,
StorageView* cached_memory_keys,
StorageView* cached_memory_values,
const Padder* queries_padder,
const Padder* memory_padder,
dim_t offset) const;

protected:
void process_cross_attention(const StorageView& queries,
const StorageView& values,
Expand Down Expand Up @@ -90,6 +107,7 @@ namespace ctranslate2 {
std::unique_ptr<const LayerNorm> _q_norm; // Query normalization
std::unique_ptr<const LayerNorm> _k_norm; // Key normalization
std::unique_ptr<const LayerNorm> _v_norm; // Value normalization (no learnable scale)
std::unique_ptr<const Dense> _memory_kv; // Fused K+V projection of encoder memory (merged attention)
};
}
}
3 changes: 2 additions & 1 deletion include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ namespace ctranslate2 {
}

bool has_cross_attention() const {
return bool(_encoder_attention);
return bool(_encoder_attention) || _has_merged_encoder_attention;
}

const AttentionLayer& get_self_attention() const {
Expand All @@ -130,6 +130,7 @@ namespace ctranslate2 {
const std::unique_ptr<const LayerNorm> _external_pre_encoder_attention_layer_norm;
const std::unique_ptr<const LayerNorm> _external_post_encoder_attention_layer_norm;
const float _layer_scalar;
const bool _has_merged_encoder_attention;
};

class TransformerEncoder : public Encoder
Expand Down
201 changes: 200 additions & 1 deletion python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3756,7 +3756,7 @@ def architecture_name(self):
return "T5GemmaForConditionalGeneration"

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight.data + 1.0
spec.gamma = (layer_norm.weight.data + 1.0).float()

def get_model_spec(self, model):
encoder_config = model.config.encoder
Expand Down Expand Up @@ -4009,3 +4009,202 @@ def set_decoder(
delattr(layer, "cross_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("T5Gemma2Config")
class T5Gemma2Loader(ModelLoader):
@property
def architecture_name(self):
return "T5Gemma2ForConditionalGeneration"

def _side_kwargs(self, side_config):
num_heads = side_config.num_attention_heads
num_heads_kv = getattr(side_config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None
head_dim = side_config.head_dim
rope_params = getattr(side_config, "rope_parameters", {}) or {}
global_theta = rope_params.get("full_attention", {}).get(
"rope_theta", getattr(side_config, "rope_theta", 1_000_000)
)
return dict(
num_layers=side_config.num_hidden_layers,
num_heads=num_heads,
pre_norm=True,
activation=_SUPPORTED_ACTIVATIONS[side_config.hidden_activation],
ffn_glu=True,
rms_norm=True,
rotary_dim=head_dim,
rotary_interleave=False,
rotary_scaling_type=attention_spec.RotaryScalingType.Linear,
rotary_scaling_factor=1,
rotary_base=global_theta,
sliding_window=getattr(side_config, "sliding_window", 0),
num_heads_kv=num_heads_kv,
head_dim=head_dim,
qk_norm=True,
pre_post_layer_norm=True,
)

def _apply_layer_types(self, side_config, spec_layers):
layer_types = getattr(side_config, "layer_types", None)
if not layer_types:
return
rope_params = getattr(side_config, "rope_parameters", {}) or {}
global_theta = rope_params.get("full_attention", {}).get(
"rope_theta", 1_000_000
)
local_theta = rope_params.get("sliding_attention", {}).get("rope_theta", 10_000)
sliding_window = getattr(side_config, "sliding_window", 0)
full_attn_params = rope_params.get("full_attention", {})
full_rope_type = full_attn_params.get("rope_type", "default")
full_rope_factor = full_attn_params.get("factor", 1.0)
for layer_type, layer_spec in zip(layer_types, spec_layers):
attn = layer_spec.self_attention
if layer_type == "full_attention":
attn.rotary_base = np.dtype("float32").type(global_theta)
attn.sliding_window = np.dtype("int32").type(0)
if full_rope_type == "linear":
attn.rotary_scaling_factor = np.dtype("float32").type(
full_rope_factor
)
else:
attn.rotary_base = np.dtype("float32").type(local_theta)
attn.sliding_window = np.dtype("int32").type(sliding_window)

def get_model_spec(self, model):
encoder_config = model.config.encoder.text_config
decoder_config = model.config.decoder

encoder = transformer_spec.TransformerEncoderSpec(
**self._side_kwargs(encoder_config)
)
decoder = transformer_spec.TransformerDecoderSpec(
**self._side_kwargs(decoder_config),
with_encoder_attention=True,
merged_encoder_attention=True,
)
spec = transformer_spec.TransformerSpec(encoder, decoder)

self.set_encoder(spec.encoder, model.model.encoder.text_model)
self._apply_layer_types(encoder_config, spec.encoder.layer)

self.set_decoder(spec.decoder, model.model.decoder)
self._apply_layer_types(decoder_config, spec.decoder.layer)

if hasattr(model.lm_head, "weight"):
self.set_linear(spec.decoder.projection, model.lm_head)
else:
self.set_linear(spec.decoder.projection, model.model.decoder.embed_tokens)
return spec

def set_vocabulary(self, spec, tokens):
spec.register_source_vocabulary(tokens)
spec.register_target_vocabulary(tokens)

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)
extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)
return tokens

def set_config(self, config, model, tokenizer):
config.bos_token = getattr(tokenizer, "bos_token", None)
config.eos_token = getattr(tokenizer, "eos_token", None)
config.unk_token = getattr(tokenizer, "unk_token", None)
config.decoder_start_token = getattr(tokenizer, "bos_token", None)
config.layer_norm_epsilon = model.config.encoder.text_config.rms_norm_eps

def set_encoder(self, spec, encoder):
encoder_emb_spec = (
spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings
)
self.set_embeddings(encoder_emb_spec, encoder.embed_tokens)
self.set_layer_norm(spec.layer_norm, encoder.norm)
embed_scale = getattr(encoder.embed_tokens, "embed_scale", None)
spec.scale_embeddings = float(embed_scale) if embed_scale is not None else False

for layer_spec, layer in zip(spec.layer, encoder.layers):
self.set_layer_norm(
layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
)
self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
)
self._set_self_attention(layer_spec.self_attention, layer)
self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)
self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()

def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
embed_scale = getattr(module.embed_tokens, "embed_scale", None)
spec.scale_embeddings = float(embed_scale) if embed_scale is not None else False
spec.start_from_zero_embedding = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)

for layer_spec, layer in zip(spec.layer, module.layers):
attn_spec = layer_spec.self_attention
self.set_layer_norm(
layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
)
self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
)
# Merged attention: same K/V projections feed both self-attn and cross-attn.
# Save them again as a fused memory_kv linear so the runtime can project
# encoder memory through them at inference time.
kv_split = [common_spec.LinearSpec() for _ in range(2)]
self.set_linear(kv_split[0], layer.self_attn.k_proj, quant_type=quant_type)
self.set_linear(kv_split[1], layer.self_attn.v_proj, quant_type=quant_type)
utils.fuse_linear(attn_spec.memory_kv, kv_split)

self._set_self_attention(attn_spec, layer, quant_type=quant_type)

self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)
self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)
self.set_linear(
layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
)
self.set_linear(
layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
)
self.set_linear(
layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
)
delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()

def _set_self_attention(
self, attn_spec, layer, quant_type=common_spec.Quantization.CT2
):
# T5Gemma2 wraps self-attn pre/post norms on the layer (not inside attention).
# We map them via input_layer_norm/post_attention_layer_norm on the layer spec.
qkv_split = [common_spec.LinearSpec() for _ in range(3)]
self.set_linear(qkv_split[0], layer.self_attn.q_proj, quant_type=quant_type)
self.set_linear(qkv_split[1], layer.self_attn.k_proj, quant_type=quant_type)
self.set_linear(qkv_split[2], layer.self_attn.v_proj, quant_type=quant_type)
utils.fuse_linear(attn_spec.linear[0], qkv_split)
self.set_linear(
attn_spec.linear[1], layer.self_attn.o_proj, quant_type=quant_type
)
self.set_layer_norm(attn_spec.q_norm, layer.self_attn.q_norm)
self.set_layer_norm(attn_spec.k_norm, layer.self_attn.k_norm)

def set_layer_norm(self, spec, layer_norm):
spec.gamma = (layer_norm.weight.data + 1.0).float()
5 changes: 5 additions & 0 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
qk_norm_rms=True,
v_norm=False,
has_norm=True,
merged_encoder_attention=False,
):
self.queries_scale = model_spec.OPTIONAL

Expand Down Expand Up @@ -100,3 +101,7 @@ def __init__(

if sliding_window is not None:
self.sliding_window = np.dtype("int32").type(sliding_window)

if merged_encoder_attention:
self.merged_encoder_attention = True
self.memory_kv = common_spec.LinearSpec()
7 changes: 6 additions & 1 deletion python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
Expand Down Expand Up @@ -153,6 +154,7 @@ def __init__(
qk_norm: bool = False,
v_norm: bool = False,
external_pre_post_encoder_layers: Optional[bool] = False,
merged_encoder_attention: bool = False,
):
"""Initializes a Transformer decoder specification.

Expand Down Expand Up @@ -265,6 +267,7 @@ def __init__(
qk_norm=qk_norm,
v_norm=v_norm,
external_pre_post_encoder_layers=external_pre_post_encoder_layers,
merged_encoder_attention=merged_encoder_attention,
)
for _ in range(num_layers)
]
Expand Down Expand Up @@ -364,6 +367,7 @@ def __init__(
qk_norm=False,
v_norm=False,
external_pre_post_encoder_layers=False,
merged_encoder_attention=False,
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
Expand All @@ -382,9 +386,10 @@ def __init__(
sliding_window=sliding_window,
qk_norm=qk_norm,
v_norm=v_norm,
merged_encoder_attention=merged_encoder_attention,
)

if with_encoder_attention:
if with_encoder_attention and not merged_encoder_attention:
self.attention = attention_spec.MultiHeadAttentionSpec(
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
Expand Down
8 changes: 8 additions & 0 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def clear_transformers_cache_in_ci():
"\n\n Answer : \n\n The ▁sky ▁is ▁blue .",
dict(),
),
(
"jordimas/t5gemma-2-270m-270m",
["<bos> Question : ▁Why ▁is ▁the ▁sky ▁blue ? ▁Answer :"],
"",
"<unused6237> ▁The ▁sky ▁is ▁blue ▁because ▁the ▁sun ▁shines ▁on ▁it . "
"▁The ▁sun ▁is ▁the ▁source ▁of ▁all ▁the ▁light ▁in ▁the ▁sky .",
dict(max_decoding_length=50),
),
]


Expand Down
Loading
Loading