DTee8 commited on
Commit
a7d2fa2
·
verified ·
1 Parent(s): d2842ad

Upload 15 files

Browse files
config.json ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "C:\\Users\\teres\\source\\repos\\ModelFineTune\\phi4-model",
3
+ "architectures": [
4
+ "Phi4MMForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "audio_processor": {
9
+ "config": {
10
+ "activation": "swish",
11
+ "activation_checkpointing": {
12
+ "interval": 1,
13
+ "module": "transformer",
14
+ "offload": false
15
+ },
16
+ "attention_dim": 1024,
17
+ "attention_heads": 16,
18
+ "batch_norm": false,
19
+ "bias_in_glu": true,
20
+ "causal": true,
21
+ "chunk_size": -1,
22
+ "cnn_layer_norm": true,
23
+ "conv_activation": "swish",
24
+ "conv_glu_type": "swish",
25
+ "depthwise_multiplier": 1,
26
+ "depthwise_seperable_out_channel": 1024,
27
+ "dropout_rate": 0.0,
28
+ "encoder_embedding_config": {
29
+ "input_size": 80
30
+ },
31
+ "ext_pw_kernel_size": 1,
32
+ "ext_pw_out_channel": 1024,
33
+ "input_layer": "nemo_conv",
34
+ "input_size": 80,
35
+ "kernel_size": 3,
36
+ "left_chunk": 18,
37
+ "linear_units": 1536,
38
+ "nemo_conv_settings": {
39
+ "conv_channels": 1024
40
+ },
41
+ "num_blocks": 24,
42
+ "relative_attention_bias_args": {
43
+ "t5_bias_max_distance": 500,
44
+ "type": "t5"
45
+ },
46
+ "time_reduction": 8
47
+ },
48
+ "name": "cascades"
49
+ },
50
+ "auto_map": {
51
+ "AutoConfig": "configuration_phi4mm.Phi4MMConfig",
52
+ "AutoModelForCausalLM": "modeling_phi4mm.Phi4MMForCausalLM",
53
+ "AutoTokenizer": "Xenova/gpt-4o"
54
+ },
55
+ "bos_token_id": 199999,
56
+ "embd_layer": {
57
+ "audio_embd_layer": {
58
+ "compression_rate": 8,
59
+ "downsample_rate": 1,
60
+ "embedding_cls": "audio",
61
+ "enable_gradient_checkpointing": true,
62
+ "projection_cls": "mlp",
63
+ "use_conv_downsample": false,
64
+ "use_qformer": false
65
+ },
66
+ "embedding_cls": "image_audio",
67
+ "image_embd_layer": {
68
+ "crop_size": 448,
69
+ "embedding_cls": "tune_image",
70
+ "enable_gradient_checkpointing": true,
71
+ "hd_transform_order": "sub_glb",
72
+ "image_token_compression_cls": "avg_pool_2d",
73
+ "projection_cls": "mlp",
74
+ "use_hd_transform": true,
75
+ "with_learnable_separator": true
76
+ }
77
+ },
78
+ "embd_pdrop": 0.0,
79
+ "eos_token_id": 199999,
80
+ "full_attn_mod": 1,
81
+ "hidden_act": "silu",
82
+ "hidden_size": 3072,
83
+ "img_processor": null,
84
+ "initializer_range": 0.02,
85
+ "intermediate_size": 8192,
86
+ "interpolate_factor": 1,
87
+ "lm_head_bias": false,
88
+ "max_position_embeddings": 131072,
89
+ "mlp_bias": false,
90
+ "model_type": "phi4mm",
91
+ "num_attention_heads": 24,
92
+ "num_hidden_layers": 32,
93
+ "num_key_value_heads": 8,
94
+ "original_max_position_embeddings": 4096,
95
+ "pad_token_id": 199999,
96
+ "partial_rotary_factor": 0.75,
97
+ "resid_pdrop": 0.0,
98
+ "rms_norm_eps": 1e-05,
99
+ "rope_scaling": {
100
+ "long_factor": [
101
+ 1,
102
+ 1.118320672,
103
+ 1.250641126,
104
+ 1.398617824,
105
+ 1.564103225,
106
+ 1.74916897,
107
+ 1.956131817,
108
+ 2.187582649,
109
+ 2.446418898,
110
+ 2.735880826,
111
+ 3.059592084,
112
+ 3.421605075,
113
+ 3.826451687,
114
+ 4.279200023,
115
+ 4.785517845,
116
+ 5.351743533,
117
+ 5.984965424,
118
+ 6.693110555,
119
+ 7.485043894,
120
+ 8.370679318,
121
+ 9.36110372,
122
+ 10.4687158,
123
+ 11.70738129,
124
+ 13.09260651,
125
+ 14.64173252,
126
+ 16.37415215,
127
+ 18.31155283,
128
+ 20.47818807,
129
+ 22.90118105,
130
+ 25.61086418,
131
+ 28.64115884,
132
+ 32.03,
133
+ 32.1,
134
+ 32.13,
135
+ 32.23,
136
+ 32.6,
137
+ 32.61,
138
+ 32.64,
139
+ 32.66,
140
+ 32.7,
141
+ 32.71,
142
+ 32.93,
143
+ 32.97,
144
+ 33.28,
145
+ 33.49,
146
+ 33.5,
147
+ 44.16,
148
+ 47.77
149
+ ],
150
+ "short_factor": [
151
+ 1.0,
152
+ 1.0,
153
+ 1.0,
154
+ 1.0,
155
+ 1.0,
156
+ 1.0,
157
+ 1.0,
158
+ 1.0,
159
+ 1.0,
160
+ 1.0,
161
+ 1.0,
162
+ 1.0,
163
+ 1.0,
164
+ 1.0,
165
+ 1.0,
166
+ 1.0,
167
+ 1.0,
168
+ 1.0,
169
+ 1.0,
170
+ 1.0,
171
+ 1.0,
172
+ 1.0,
173
+ 1.0,
174
+ 1.0,
175
+ 1.0,
176
+ 1.0,
177
+ 1.0,
178
+ 1.0,
179
+ 1.0,
180
+ 1.0,
181
+ 1.0,
182
+ 1.0,
183
+ 1.0,
184
+ 1.0,
185
+ 1.0,
186
+ 1.0,
187
+ 1.0,
188
+ 1.0,
189
+ 1.0,
190
+ 1.0,
191
+ 1.0,
192
+ 1.0,
193
+ 1.0,
194
+ 1.0,
195
+ 1.0,
196
+ 1.0,
197
+ 1.0,
198
+ 1.0
199
+ ],
200
+ "type": "longrope"
201
+ },
202
+ "rope_theta": 10000.0,
203
+ "sliding_window": 262144,
204
+ "speech_lora": {
205
+ "dp": 0.01,
206
+ "layer": "((layers.*self_attn\\.(qkv|o)_proj)|(layers.*mlp\\.(gate_up|down)_proj))",
207
+ "lora_alpha": 640,
208
+ "r": 320
209
+ },
210
+ "tie_word_embeddings": true,
211
+ "torch_dtype": "float32",
212
+ "transformers_version": "4.46.1",
213
+ "use_cache": true,
214
+ "vision_lora": {
215
+ "dp": 0.0,
216
+ "layer": "layers.*((self_attn\\.(qkv_proj|o_proj))|(mlp\\.(gate_up|down)_proj))",
217
+ "lora_alpha": 512,
218
+ "r": 256
219
+ },
220
+ "vocab_size": 200064
221
+ }
configuration_phi4mm.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi-4-MM model configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Phi4MMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Phi4MMModel`]. It is used to instantiate a Phi-4-MM
28
+ model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 200064):
35
+ Vocabulary size of the Phi-4-MM model. Defines the number of different tokens that can be represented by the
36
+ `inputs_ids` passed when calling [`Phi4MMModel`].
37
+ hidden_size (`int`, *optional*, defaults to 3072):
38
+ Dimension of the hidden representations.
39
+ intermediate_size (`int`, *optional*, defaults to 8192):
40
+ Dimension of the MLP representations.
41
+ num_hidden_layers (`int`, *optional*, defaults to 32):
42
+ Number of hidden layers in the Transformer decoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 32):
44
+ Number of attention heads for each attention layer in the Transformer decoder.
45
+ num_key_value_heads (`int`, *optional*):
46
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
47
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
48
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
49
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
50
+ by meanpooling all the original heads within that group. For more details checkout [this
51
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
52
+ `num_attention_heads`.
53
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
54
+ Dropout probability for mlp outputs.
55
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
56
+ The dropout ratio for the embeddings.
57
+ attention_dropout (`float`, *optional*, defaults to 0.0):
58
+ The dropout ratio after computing the attention scores.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
62
+ The maximum sequence length that this model might ever be used with.
63
+ original_max_position_embeddings (`int`, *optional*, defaults to 4096):
64
+ The maximum sequence length that this model was trained with. This is used to determine the size of the
65
+ original RoPE embeddings when using long scaling.
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
69
+ The epsilon value used for the RMSNorm.
70
+ use_cache (`bool`, *optional*, defaults to `True`):
71
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
72
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
73
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
74
+ Whether to tie weight embeddings
75
+ rope_theta (`float`, *optional*, defaults to 10000.0):
76
+ The base period of the RoPE embeddings.
77
+ rope_scaling (`dict`, *optional*):
78
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
79
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
80
+ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
81
+ divided by the number of attention heads divided by 2.
82
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
83
+ Percentage of the query and keys which will have rotary embedding.
84
+ bos_token_id (`int`, *optional*, defaults to 199999):
85
+ The id of the "beginning-of-sequence" token.
86
+ eos_token_id (`int`, *optional*, defaults to 199999):
87
+ The id of the "end-of-sequence" token.
88
+ pad_token_id (`int`, *optional*, defaults to 199999):
89
+ The id of the padding token.
90
+ sliding_window (`int`, *optional*):
91
+ Sliding window attention window size. If `None`, no sliding window is applied.
92
+
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import Phi4MMModel, Phi4MMConfig
97
+
98
+ >>> # Initializing a Phi-4-MM style configuration
99
+ >>> configuration = Phi4MMConfig.from_pretrained("TBA")
100
+
101
+ >>> # Initializing a model from the configuration
102
+ >>> model = Phi4MMModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+
108
+ model_type = "phi4mm"
109
+ keys_to_ignore_at_inference = ["past_key_values"]
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=200064,
114
+ hidden_size=3072,
115
+ intermediate_size=8192,
116
+ num_hidden_layers=32,
117
+ num_attention_heads=32,
118
+ num_key_value_heads=None,
119
+ resid_pdrop=0.0,
120
+ embd_pdrop=0.0,
121
+ attention_dropout=0.0,
122
+ hidden_act="silu",
123
+ max_position_embeddings=4096,
124
+ original_max_position_embeddings=4096,
125
+ initializer_range=0.02,
126
+ rms_norm_eps=1e-5,
127
+ use_cache=True,
128
+ tie_word_embeddings=False,
129
+ rope_theta=10000.0,
130
+ rope_scaling=None,
131
+ partial_rotary_factor=1,
132
+ bos_token_id=199999,
133
+ eos_token_id=199999,
134
+ pad_token_id=199999,
135
+ sliding_window=None,
136
+ embd_layer: str = "default",
137
+ img_processor=None,
138
+ audio_processor=None,
139
+ vision_lora=None,
140
+ speech_lora=None,
141
+ **kwargs,
142
+ ):
143
+ self.embd_layer = embd_layer
144
+ self.img_processor = img_processor
145
+ self.audio_processor = audio_processor
146
+ self.vision_lora = vision_lora
147
+ self.speech_lora = speech_lora
148
+
149
+ self.vocab_size = vocab_size
150
+ self.hidden_size = hidden_size
151
+ self.intermediate_size = intermediate_size
152
+ self.num_hidden_layers = num_hidden_layers
153
+ self.num_attention_heads = num_attention_heads
154
+
155
+ if num_key_value_heads is None:
156
+ num_key_value_heads = num_attention_heads
157
+
158
+ self.num_key_value_heads = num_key_value_heads
159
+ self.resid_pdrop = resid_pdrop
160
+ self.embd_pdrop = embd_pdrop
161
+ self.attention_dropout = attention_dropout
162
+ self.hidden_act = hidden_act
163
+ self.max_position_embeddings = max_position_embeddings
164
+ self.original_max_position_embeddings = original_max_position_embeddings
165
+ self.initializer_range = initializer_range
166
+ self.rms_norm_eps = rms_norm_eps
167
+ self.use_cache = use_cache
168
+ self.rope_theta = rope_theta
169
+ self.rope_scaling = rope_scaling
170
+ self.partial_rotary_factor = partial_rotary_factor
171
+ self._rope_scaling_adjustment()
172
+ self._rope_scaling_validation()
173
+ self.sliding_window = sliding_window
174
+
175
+ super().__init__(
176
+ bos_token_id=bos_token_id,
177
+ eos_token_id=eos_token_id,
178
+ pad_token_id=pad_token_id,
179
+ tie_word_embeddings=tie_word_embeddings,
180
+ **kwargs,
181
+ )
182
+
183
+ def _rope_scaling_adjustment(self):
184
+ """
185
+ Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
186
+ """
187
+ if self.rope_scaling is None:
188
+ return
189
+
190
+ rope_scaling_type = self.rope_scaling.get("type", None)
191
+
192
+ # For backward compatibility if previous version used "su" or "yarn"
193
+ if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
194
+ self.rope_scaling["type"] = "longrope"
195
+
196
+ def _rope_scaling_validation(self):
197
+ """
198
+ Validate the `rope_scaling` configuration.
199
+ """
200
+ if self.rope_scaling is None:
201
+ return
202
+
203
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
204
+ raise ValueError(
205
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
206
+ f"got {self.rope_scaling}"
207
+ )
208
+ rope_scaling_type = self.rope_scaling.get("type", None)
209
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
210
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
211
+ if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
212
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
213
+ if not (
214
+ isinstance(rope_scaling_short_factor, list)
215
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
216
+ ):
217
+ raise ValueError(
218
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
219
+ )
220
+ rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
221
+ if not len(rope_scaling_short_factor) == rotary_ndims // 2:
222
+ raise ValueError(
223
+ f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
224
+ )
225
+ if not (
226
+ isinstance(rope_scaling_long_factor, list)
227
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
228
+ ):
229
+ raise ValueError(
230
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
231
+ )
232
+ if not len(rope_scaling_long_factor) == rotary_ndims // 2:
233
+ raise ValueError(
234
+ f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
235
+ )
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 199999,
4
+ "eos_token_id": [
5
+ 200020,
6
+ 199999
7
+ ],
8
+ "pad_token_id": 199999,
9
+ "transformers_version": "4.46.1"
10
+ }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b32e9f73edc116823abe74dd6b7f13b5bdd7a5262e7710835772697d988a2aa
3
+ size 4993293944
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:293a733b7d151acf7b3803632124a67a532145f5a7386ba6f99717e3cdabbc8a
3
+ size 4953762936
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:502a0e423ba01769c3e292512706f14d21b3d1b91d9fc7092765fa53d7ba6794
3
+ size 4936985712
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9716ae50890580ed9f6b6d77508ec0e0b5bf843ada3ede1c5dd7a7141bf2cea
3
+ size 3702770248
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_phi4mm.py ADDED
The diff for this file is too large to render. See raw diff
 
processing_phi4mm.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Processor class for Phi4MM
17
+ """
18
+ import re
19
+ from typing import List, Optional, Tuple, Union
20
+ import math
21
+ from enum import Enum
22
+
23
+ import numpy as np
24
+ import scipy
25
+ import torch
26
+ import torchvision
27
+
28
+ from transformers import AutoFeatureExtractor, AutoImageProcessor
29
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
30
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
31
+ from transformers.image_utils import (
32
+ ImageInput,
33
+ make_list_of_images,
34
+ valid_images,
35
+ )
36
+ from transformers.processing_utils import ProcessorMixin
37
+ from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
38
+ from transformers.utils import TensorType, logging
39
+ from torch.nn.utils.rnn import pad_sequence
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ # Special tokens
45
+ _COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN = r'<\|image_\d+\|>' # For backward compatibility
46
+ _COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN = r'<\|audio_\d+\|>' # For backward compatibility
47
+ _IMAGE_SPECIAL_TOKEN = '<|endoftext10|>'
48
+ _AUDIO_SPECIAL_TOKEN = '<|endoftext11|>'
49
+ _IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>', or we can better name it (in `tokenizer_config.json`)
50
+ _AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>'
51
+
52
+
53
+ class InputMode(Enum):
54
+ LANGUAGE = 0
55
+ VISION = 1
56
+ SPEECH = 2
57
+ VISION_SPEECH = 3
58
+
59
+
60
+ class Phi4MMImageProcessor(BaseImageProcessor):
61
+ r"""
62
+ Constructs a Phi4MM image processor.
63
+ """
64
+ model_input_names = ["input_image_embeds", "image_sizes", "image_attention_mask"]
65
+
66
+ def __init__(
67
+ self,
68
+ dynamic_hd,
69
+ **kwargs,
70
+ ) -> None:
71
+ super().__init__(**kwargs)
72
+ self.dynamic_hd = dynamic_hd
73
+
74
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
75
+ best_ratio_diff = float('inf')
76
+ best_ratio = (1, 1)
77
+ area = width * height
78
+ for ratio in target_ratios:
79
+ target_aspect_ratio = ratio[0] / ratio[1]
80
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
81
+ if ratio_diff < best_ratio_diff:
82
+ best_ratio_diff = ratio_diff
83
+ best_ratio = ratio
84
+ elif ratio_diff == best_ratio_diff:
85
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
86
+ best_ratio = ratio
87
+ return best_ratio
88
+
89
+ def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=384, mask_size=27, use_thumbnail=True):
90
+ orig_width, orig_height = image.size
91
+
92
+ w_crop_num = math.ceil(orig_width/float(image_size))
93
+ h_crop_num = math.ceil(orig_height/float(image_size))
94
+ if w_crop_num * h_crop_num > max_num:
95
+
96
+ aspect_ratio = orig_width / orig_height
97
+
98
+ # calculate the existing image aspect ratio
99
+ target_ratios = set(
100
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
101
+ i * j <= max_num and i * j >= min_num)
102
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
103
+
104
+ # find the closest aspect ratio to the target
105
+ target_aspect_ratio = self.find_closest_aspect_ratio(
106
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
107
+
108
+ # calculate the target width and height
109
+ target_width = image_size * target_aspect_ratio[0]
110
+ target_height = image_size * target_aspect_ratio[1]
111
+ else:
112
+ target_width = image_size * w_crop_num
113
+ target_height = image_size * h_crop_num
114
+ target_aspect_ratio = (w_crop_num, h_crop_num)
115
+
116
+ # Calculate the ratio
117
+ ratio_width = target_width / orig_width
118
+ ratio_height = target_height / orig_height
119
+ if ratio_width < ratio_height:
120
+ new_size = (target_width, int(orig_height * ratio_width))
121
+ padding_width = 0
122
+ padding_height = target_height - int(orig_height * ratio_width)
123
+ else:
124
+ new_size = (int(orig_width * ratio_height), target_height)
125
+ padding_width = target_width - int(orig_width * ratio_height)
126
+ padding_height = 0
127
+
128
+ attention_mask = torch.ones((int(mask_size*target_aspect_ratio[1]), int(mask_size*target_aspect_ratio[0])))
129
+ if padding_width >= 14:
130
+ attention_mask[:, -math.floor(padding_width/14):] = 0
131
+ if padding_height >= 14:
132
+ attention_mask[-math.floor(padding_height/14):,:] = 0
133
+ assert attention_mask.sum() > 0
134
+
135
+ if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10:
136
+ raise ValueError(f'the aspect ratio is very extreme {new_size}')
137
+
138
+ image = torchvision.transforms.functional.resize(image, [new_size[1], new_size[0]],)
139
+
140
+ resized_img = torchvision.transforms.functional.pad(image, [0, 0, padding_width, padding_height], fill=[255,255,255])
141
+
142
+ return resized_img, attention_mask
143
+
144
+ def pad_to_max_num_crops(self, images, max_crops=5):
145
+ """
146
+ images: B x 3 x H x W, B<=max_crops
147
+ """
148
+ B, _, H, W = images.shape
149
+ if B < max_crops:
150
+ pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
151
+ images = torch.cat([images, pad], dim=0)
152
+ return images
153
+
154
+ def pad_mask_to_max_num_crops(self, masks, max_crops=5):
155
+ B, H, W = masks.shape
156
+ if B < max_crops:
157
+ pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device)
158
+ masks = torch.cat([masks, pad], dim=0)
159
+ return masks
160
+
161
+ def preprocess(
162
+ self,
163
+ images: ImageInput,
164
+ return_tensors: Optional[Union[str, TensorType]] = None,
165
+ ):
166
+ """
167
+ Args:
168
+ images (`ImageInput`):
169
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
170
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
171
+ return_tensors (`str` or `TensorType`, *optional*):
172
+ The type of tensors to return. Can be one of:
173
+ - Unset: Return a list of `np.ndarray`.
174
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
175
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
176
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
177
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
178
+ """
179
+ images = make_list_of_images(images)
180
+
181
+ if not valid_images(images):
182
+ raise ValueError(
183
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
184
+ "torch.Tensor, tf.Tensor or jax.ndarray."
185
+ )
186
+
187
+ # Basic settings.
188
+ img_processor = torchvision.transforms.Compose([
189
+ torchvision.transforms.ToTensor(),
190
+ torchvision.transforms.Normalize(
191
+ (0.5, 0.5, 0.5),
192
+ (0.5, 0.5, 0.5)
193
+ ),
194
+ ])
195
+ dyhd_base_resolution = 448
196
+
197
+ # Dynamic HD
198
+ base_resolution = dyhd_base_resolution
199
+ images = [image.convert('RGB') for image in images]
200
+ # cover 384 and 448 resolution
201
+ mask_resolution = base_resolution // 14
202
+ elems, image_attention_masks = [], []
203
+ for im in images:
204
+ elem, attention_mask = self.dynamic_preprocess(im, max_num=self.dynamic_hd, image_size=base_resolution, mask_size=mask_resolution)
205
+ elems.append(elem)
206
+ image_attention_masks.append(attention_mask)
207
+ hd_images = [img_processor(im) for im in elems]
208
+ global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(base_resolution, base_resolution), mode='bicubic',).to(im.dtype) for im in hd_images]
209
+ shapes = [[im.size(1), im.size(2)] for im in hd_images]
210
+ mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks]
211
+ global_attention_mask = [torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images]
212
+ hd_images_reshape = [im.reshape(1, 3,
213
+ h//base_resolution,
214
+ base_resolution,
215
+ w//base_resolution,
216
+ base_resolution
217
+ ).permute(0,2,4,1,3,5).reshape(-1, 3, base_resolution, base_resolution).contiguous() for im, (h, w) in zip(hd_images, shapes)]
218
+ attention_masks_reshape = [mask.reshape(1,
219
+ h//mask_resolution,
220
+ mask_resolution,
221
+ w//mask_resolution,
222
+ mask_resolution
223
+ ).permute(0,1,3,2,4).reshape(-1, mask_resolution, mask_resolution).contiguous() for mask, (h, w) in zip(image_attention_masks, mask_shapes)]
224
+ downsample_attention_masks = [mask[:,0::2,0::2].reshape(1,
225
+ h//mask_resolution,
226
+ w//mask_resolution,
227
+ mask_resolution//2+mask_resolution%2,
228
+ mask_resolution//2+mask_resolution%2
229
+ ).permute(0,1,3,2,4) for mask, (h,w) in zip(attention_masks_reshape, mask_shapes)]
230
+ downsample_attention_masks = [mask.reshape(mask.size(1)*mask.size(2), mask.size(3)*mask.size(4))for mask in downsample_attention_masks]
231
+ num_img_tokens = [256 + 1 + int(mask.sum().item()) + int(mask[:,0].sum().item()) + 16 for mask in downsample_attention_masks]
232
+
233
+ hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)]
234
+ hd_masks_reshape = [torch.cat([_global_mask] + [_mask], dim=0) for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape)]
235
+ max_crops = max([img.size(0) for img in hd_images_reshape])
236
+ image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape]
237
+ image_transformed = torch.stack(image_transformed, dim=0)
238
+ mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape]
239
+ mask_transformed = torch.stack(mask_transformed, dim=0)
240
+
241
+ returned_input_image_embeds = image_transformed
242
+ returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
243
+ returned_image_attention_mask = mask_transformed
244
+ returned_num_img_tokens = num_img_tokens
245
+
246
+ data = {
247
+ "input_image_embeds": returned_input_image_embeds,
248
+ "image_sizes": returned_image_sizes,
249
+ "image_attention_mask": returned_image_attention_mask,
250
+ "num_img_tokens": returned_num_img_tokens,
251
+ }
252
+
253
+ return BatchFeature(data=data, tensor_type=return_tensors)
254
+
255
+
256
+ AudioInput = Tuple[Union[np.ndarray, torch.Tensor], int]
257
+ AudioInputs = List[AudioInput]
258
+
259
+
260
+ def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
261
+ """Create a Mel filter-bank the same as SpeechLib FbankFC.
262
+
263
+ Args:
264
+ sample_rate (int): Sample rate in Hz. number > 0 [scalar]
265
+ n_fft (int): FFT size. int > 0 [scalar]
266
+ n_mel (int): Mel filter size. int > 0 [scalar]
267
+ fmin (float): lowest frequency (in Hz). If None use 0.0.
268
+ float >= 0 [scalar]
269
+ fmax: highest frequency (in Hz). If None use sample_rate / 2.
270
+ float >= 0 [scalar]
271
+
272
+ Returns
273
+ out (numpy.ndarray): Mel transform matrix
274
+ [shape=(n_mels, 1 + n_fft/2)]
275
+ """
276
+
277
+ bank_width = int(n_fft // 2 + 1)
278
+ if fmax is None:
279
+ fmax = sample_rate / 2
280
+ if fmin is None:
281
+ fmin = 0
282
+ assert fmin >= 0, "fmin cannot be negtive"
283
+ assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
284
+
285
+ def mel(f):
286
+ return 1127.0 * np.log(1.0 + f / 700.0)
287
+
288
+ def bin2mel(fft_bin):
289
+ return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
290
+
291
+ def f2bin(f):
292
+ return int((f * n_fft / sample_rate) + 0.5)
293
+
294
+ # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
295
+ klo = f2bin(fmin) + 1
296
+ khi = f2bin(fmax)
297
+
298
+ khi = max(khi, klo)
299
+
300
+ # Spec 2: SpeechLib uses trianges in Mel space
301
+ mlo = mel(fmin)
302
+ mhi = mel(fmax)
303
+ m_centers = np.linspace(mlo, mhi, n_mels + 2)
304
+ ms = (mhi - mlo) / (n_mels + 1)
305
+
306
+ matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
307
+ for m in range(0, n_mels):
308
+ left = m_centers[m]
309
+ center = m_centers[m + 1]
310
+ right = m_centers[m + 2]
311
+ for fft_bin in range(klo, khi):
312
+ mbin = bin2mel(fft_bin)
313
+ if left < mbin < right:
314
+ matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
315
+
316
+ return matrix
317
+
318
+
319
+ class Phi4MMAudioFeatureExtractor(SequenceFeatureExtractor):
320
+ model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
321
+
322
+ def __init__(self, audio_compression_rate, audio_downsample_rate, audio_feat_stride, **kwargs):
323
+ feature_size = 80
324
+ sampling_rate = 16000
325
+ padding_value = 0.0
326
+ super().__init__(feature_size, sampling_rate, padding_value, **kwargs)
327
+
328
+ self.compression_rate = audio_compression_rate
329
+ self.qformer_compression_rate = audio_downsample_rate
330
+ self.feat_stride = audio_feat_stride
331
+
332
+ self._eightk_method = "fillzero"
333
+ self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
334
+
335
+ self._hamming400 = np.hamming(400) # for 16k audio
336
+ self._hamming200 = np.hamming(200) # for 8k audio
337
+
338
+ def duration_to_frames(self, duration):
339
+ """duration in s, estimated frames"""
340
+ frame_rate = 10
341
+
342
+ num_frames = duration * 1000 // frame_rate
343
+ return num_frames
344
+
345
+ def __call__(
346
+ self,
347
+ audios: List[AudioInput],
348
+ return_tensors: Optional[Union[str, TensorType]] = None,
349
+ ):
350
+ # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
351
+ returned_input_audio_embeds = []
352
+ returned_audio_embed_sizes = []
353
+ audio_frames_list = []
354
+
355
+ for audio_data, sample_rate in audios:
356
+ audio_embeds = self._extract_features(audio_data, sample_rate)
357
+ audio_frames = len(audio_embeds) * self.feat_stride
358
+ audio_embed_size = self._compute_audio_embed_size(audio_frames)
359
+
360
+ returned_input_audio_embeds.append(torch.tensor(audio_embeds))
361
+ returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
362
+ audio_frames_list.append(audio_frames)
363
+
364
+ returned_input_audio_embeds = pad_sequence(
365
+ returned_input_audio_embeds, batch_first=True
366
+ )
367
+ returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
368
+ audio_frames = torch.tensor(audio_frames_list)
369
+ returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
370
+
371
+ data = {
372
+ "input_audio_embeds": returned_input_audio_embeds,
373
+ "audio_embed_sizes": returned_audio_embed_sizes,
374
+ }
375
+ if returned_audio_attention_mask is not None:
376
+ data["audio_attention_mask"] = returned_audio_attention_mask
377
+
378
+ return BatchFeature(data=data, tensor_type=return_tensors)
379
+
380
+ def _extract_spectrogram(self, wav, fs):
381
+ """Extract spectrogram features from waveform.
382
+ Args:
383
+ wav (1D array): waveform of the input
384
+ fs (int): sampling rate of the waveform, 16000 or 8000.
385
+ If fs=8000, the waveform will be resampled to 16000Hz.
386
+ Output:
387
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
388
+ D=80, and T is the number of frames.
389
+ """
390
+ if wav.ndim > 1:
391
+ wav = np.squeeze(wav)
392
+
393
+ # by default, we extract the mean if stereo
394
+ if len(wav.shape) == 2:
395
+ wav = wav.mean(1)
396
+
397
+ # Resample to 16000 or 8000 if needed
398
+ if fs > 16000:
399
+ wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
400
+ fs = 16000
401
+ elif 8000 < fs < 16000:
402
+ wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
403
+ fs = 8000
404
+ elif fs < 8000:
405
+ raise RuntimeError(f"Unsupported sample rate {fs}")
406
+
407
+ if fs == 8000:
408
+ if self._eightk_method == "resample":
409
+ # Input audio is 8 kHz. Convert to 16 kHz before feature
410
+ # extraction
411
+ wav = scipy.signal.resample_poly(wav, 2, 1)
412
+ fs = 16000
413
+ # Do nothing here for fillzero method
414
+ elif fs != 16000:
415
+ # Input audio is not a supported sample rate.
416
+ raise RuntimeError(f"Input data using an unsupported sample rate: {fs}")
417
+
418
+ preemphasis = 0.97
419
+
420
+ if fs == 8000:
421
+ n_fft = 256
422
+ win_length = 200
423
+ hop_length = 80
424
+ fft_window = self._hamming200
425
+ elif fs == 16000:
426
+ n_fft = 512
427
+ win_length = 400
428
+ hop_length = 160
429
+ fft_window = self._hamming400
430
+
431
+ # Spec 1: SpeechLib cut remaining sample insufficient for a hop
432
+ n_batch = (wav.shape[0] - win_length) // hop_length + 1
433
+ # Here we don't use stride_tricks since the input array may not satisfy
434
+ # memory layout requirement and we need writeable output
435
+ # Here we only use list of views before copy to desination
436
+ # so it is more efficient than broadcasting
437
+ y_frames = np.array(
438
+ [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
439
+ dtype=np.float32,
440
+ )
441
+
442
+ # Spec 2: SpeechLib applies preemphasis within each batch
443
+ y_frames_prev = np.roll(y_frames, 1, axis=1)
444
+ y_frames_prev[:, 0] = y_frames_prev[:, 1]
445
+ y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
446
+
447
+ S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
448
+
449
+ if fs == 8000:
450
+ # Need to pad the output to look like 16 kHz data but with zeros in
451
+ # the 4 to 8 kHz bins.
452
+ frames, bins = S.shape
453
+ padarray = np.zeros((frames, bins))
454
+ S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
455
+
456
+ spec = np.abs(S).astype(np.float32)
457
+ return spec
458
+
459
+ def _extract_features(self, wav, fs):
460
+ """Extract log filterbank features from waveform.
461
+ Args:
462
+ wav (1D array): waveform of the input
463
+ fs (int): sampling rate of the waveform, 16000 or 8000.
464
+ If fs=8000, the waveform will be resampled to 16000Hz.
465
+ Output:
466
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
467
+ D=80, and T is the number of frames.
468
+ """
469
+ spec = self._extract_spectrogram(wav, fs)
470
+ spec_power = spec**2
471
+
472
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
473
+ log_fbank = np.log(fbank_power).astype(np.float32)
474
+
475
+ return log_fbank
476
+
477
+ def _compute_audio_embed_size(self, audio_frames):
478
+ integer = audio_frames // self.compression_rate
479
+ remainder = audio_frames % self.compression_rate
480
+
481
+ result = integer if remainder == 0 else integer + 1
482
+
483
+ integer = result // self.qformer_compression_rate
484
+ remainder = result % self.qformer_compression_rate
485
+ result = integer if remainder == 0 else integer + 1 # qformer compression
486
+
487
+ return result
488
+
489
+
490
+ class Phi4MMProcessor(ProcessorMixin):
491
+ r"""
492
+ Constructs a Phi4MM processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor.
493
+
494
+ [`Phi4MMProcessor`] offers all the functionalities of [`Phi4MMImageProcessor`] and [`GPT2Tokenizer`]. See the
495
+ [`~Phi4MMProcessor.__call__`] and [`~Phi4MMProcessor.decode`] for more information.
496
+
497
+ Args:
498
+ image_processor ([`Phi4MMImageProcessor`], *optional*):
499
+ The image processor is a required input.
500
+ tokenizer ([`GPT2Tokenizer`], *optional*):
501
+ The tokenizer is a required input.
502
+ """
503
+
504
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
505
+ tokenizer_class = "GPT2TokenizerFast"
506
+ image_processor_class = "AutoImageProcessor" # Phi4MMImageProcessor will be registered later
507
+ audio_processor_class = "AutoFeatureExtractor" # Phi4MMAudioFeatureExtractor will be registered later
508
+
509
+ def __init__(self, image_processor, audio_processor, tokenizer):
510
+ self.image_processor = image_processor
511
+ self.audio_processor = audio_processor
512
+ self.tokenizer = tokenizer
513
+
514
+ def __call__(
515
+ self,
516
+ text: Union[TextInput, List[TextInput]],
517
+ images: Optional[ImageInput] = None,
518
+ audios: Optional[AudioInputs] = None,
519
+ padding: Union[bool, str, PaddingStrategy] = False,
520
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
521
+ max_length=None,
522
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
523
+ ) -> BatchFeature:
524
+ """
525
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text`
526
+ and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode
527
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
528
+ Phi4MMImageProcessor's [`~Phi4MMImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
529
+ of the above two methods for more information.
530
+
531
+ Args:
532
+ text (`str`, `List[str]`, `List[List[str]]`):
533
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
534
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
535
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
536
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
537
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
538
+ tensor. Both channels-first and channels-last formats are supported.
539
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
540
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
541
+ index) among:
542
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
543
+ sequence if provided).
544
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
545
+ acceptable input length for the model if that argument is not provided.
546
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
547
+ lengths).
548
+ max_length (`int`, *optional*):
549
+ Maximum length of the returned list and optionally padding length (see above).
550
+ truncation (`bool`, *optional*):
551
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
552
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
553
+ If set, will return tensors of a particular framework. Acceptable values are:
554
+
555
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
556
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
557
+ - `'np'`: Return NumPy `np.ndarray` objects.
558
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
559
+
560
+ Returns:
561
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
562
+
563
+ - **input_ids** -- List of token ids to be fed to a model.
564
+ - **input_image_embeds** -- Pixel values to be fed to a model.
565
+ - **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`.
566
+ - **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`.
567
+ - **input_audio_embeds** -- Audio embeddings to be fed to a model.
568
+ - **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`.
569
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
570
+ """
571
+ image_inputs = self.image_processor(images, return_tensors=return_tensors) if images is not None else {}
572
+ audio_inputs = self.audio_processor(audios, return_tensors=return_tensors) if audios is not None else {}
573
+ inputs = self._convert_images_audios_text_to_inputs(
574
+ image_inputs,
575
+ audio_inputs,
576
+ text,
577
+ padding=padding,
578
+ truncation=truncation,
579
+ max_length=max_length,
580
+ return_tensors=return_tensors,
581
+ )
582
+
583
+ # idenfity the input mode
584
+ if len(image_inputs) > 0 and len(audio_inputs) > 0:
585
+ input_mode = InputMode.VISION_SPEECH
586
+ elif len(image_inputs) > 0:
587
+ input_mode = InputMode.VISION
588
+ elif len(audio_inputs) > 0:
589
+ input_mode = InputMode.SPEECH
590
+ else:
591
+ input_mode = InputMode.LANGUAGE
592
+ inputs["input_mode"] = torch.tensor([input_mode.value], dtype=torch.long)
593
+
594
+ return inputs
595
+
596
+ @property
597
+ def special_image_token_id(self):
598
+ return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
599
+
600
+ def get_special_image_token_id(self):
601
+ return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
602
+
603
+ @property
604
+ def chat_template(self):
605
+ return self.tokenizer.chat_template
606
+
607
+ def _convert_images_audios_text_to_inputs(
608
+ self, images, audios, text, padding=False, truncation=None, max_length=None, return_tensors=None
609
+ ):
610
+ # prepare image id to image input ids
611
+ if len(images) > 0:
612
+ input_image_embeds = images["input_image_embeds"]
613
+ image_sizes = images["image_sizes"]
614
+ image_attention_mask = images["image_attention_mask"]
615
+ num_img_tokens = images['num_img_tokens']
616
+ else:
617
+ input_image_embeds = torch.tensor([])
618
+ image_sizes = torch.tensor([])
619
+ image_attention_mask = torch.tensor([])
620
+ num_img_tokens = []
621
+
622
+ # prepare audio id to audio input ids
623
+ if len(audios) > 0:
624
+ input_audio_embeds = audios["input_audio_embeds"]
625
+ audio_embed_sizes = audios["audio_embed_sizes"]
626
+ audio_attention_mask = audios.get("audio_attention_mask", None)
627
+ else:
628
+ input_audio_embeds = torch.tensor([])
629
+ audio_embed_sizes = torch.tensor([])
630
+ audio_attention_mask = None
631
+
632
+ # Replace certain special tokens for compatibility
633
+ # Ref: https://stackoverflow.com/questions/11475885/python-replace-regex
634
+ if isinstance(text, str):
635
+ text = [text]
636
+ assert isinstance(text, list)
637
+ processed_text = [re.sub(_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN, _IMAGE_SPECIAL_TOKEN, t) for t in text]
638
+ processed_text = [re.sub(_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN, _AUDIO_SPECIAL_TOKEN, t) for t in processed_text]
639
+
640
+ input_ids_list = [self.tokenizer(t).input_ids for t in processed_text]
641
+
642
+ img_cnt, audio_cnt = 0, 0 # only needed for later assertion
643
+ image_token_count_iter = iter(num_img_tokens)
644
+ audio_embed_size_iter = iter(audio_embed_sizes.tolist())
645
+ new_input_ids_list = []
646
+ for input_ids in input_ids_list:
647
+ i = 0
648
+ while i < len(input_ids):
649
+ token_id = input_ids[i]
650
+ if token_id == _AUDIO_SPECIAL_TOKEN_ID:
651
+ token_count = next(audio_embed_size_iter)
652
+ audio_cnt += 1
653
+ elif token_id == _IMAGE_SPECIAL_TOKEN_ID:
654
+ token_count = next(image_token_count_iter)
655
+ img_cnt += 1
656
+ else:
657
+ i += 1
658
+ continue
659
+ tokens = [token_id] * token_count
660
+ input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
661
+ i += token_count
662
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
663
+ new_input_ids_list.append(input_ids)
664
+ lengths = torch.tensor([len(input_ids) for input_ids in new_input_ids_list])
665
+ max_len = lengths.max()
666
+ input_ids = input_ids.new_full((len(new_input_ids_list), max_len), self.tokenizer.pad_token_id)
667
+ # batched inference requires left padding
668
+ for i in range(len(new_input_ids_list)):
669
+ input_ids[i, max_len - len(new_input_ids_list[i]):] = new_input_ids_list[i]
670
+
671
+ # If the below assertion fails, it might be that input pure-text
672
+ # messages contain image/audio special tokens literally
673
+ # (<|endoftext10|>, <|endoftext11|>).
674
+ assert (
675
+ img_cnt == len(num_img_tokens)
676
+ ), (
677
+ f"Number of image tokens in prompt_token_ids ({img_cnt}) "
678
+ f"does not match number of images ({len(num_img_tokens)})"
679
+ )
680
+ assert (
681
+ audio_cnt == len(audio_embed_sizes)
682
+ ), (
683
+ f"Number of audio tokens in prompt_token_ids ({audio_cnt}) "
684
+ f"does not match number of audios ({len(audio_embed_sizes)})"
685
+ )
686
+
687
+ # prepare attention mask
688
+ seq_range = torch.arange(max_len - 1, -1, -1)
689
+ attention_mask = seq_range.unsqueeze(0) < lengths.unsqueeze(1)
690
+
691
+ # prepare batch feature
692
+ data = {
693
+ "input_ids": input_ids,
694
+ "input_image_embeds": input_image_embeds,
695
+ "image_sizes": image_sizes,
696
+ "image_attention_mask": image_attention_mask,
697
+ "input_audio_embeds": input_audio_embeds,
698
+ "audio_embed_sizes": audio_embed_sizes,
699
+ "audio_attention_mask": audio_attention_mask,
700
+ "attention_mask": attention_mask,
701
+ }
702
+
703
+ return BatchFeature(
704
+ data=data
705
+ )
706
+
707
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
708
+ def batch_decode(self, *args, **kwargs):
709
+ """
710
+ This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
711
+ refer to the docstring of this method for more information.
712
+ """
713
+ return self.tokenizer.batch_decode(*args, **kwargs)
714
+
715
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
716
+ def decode(self, *args, **kwargs):
717
+ """
718
+ This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
719
+ the docstring of this method for more information.
720
+ """
721
+ return self.tokenizer.decode(*args, **kwargs)
722
+
723
+ @property
724
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
725
+ def model_input_names(self):
726
+ tokenizer_input_names = self.tokenizer.model_input_names
727
+ image_processor_input_names = self.image_processor.model_input_names
728
+ audio_processor_input_names = self.audio_processor.model_input_names
729
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names))
730
+
731
+
732
+ AutoImageProcessor.register("Phi4MMImageProcessor", Phi4MMImageProcessor)
733
+ AutoFeatureExtractor.register("Phi4MMAudioFeatureExtractor", Phi4MMAudioFeatureExtractor)
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_phi4mm.Phi4MMProcessor"
4
+ },
5
+ "processor_class": "Phi4MMProcessor"
6
+ }
speech_conformer_encoder.py ADDED
The diff for this file is too large to render. See raw diff
 
trainer_state.json ADDED
@@ -0,0 +1,2329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 252.32,
5
+ "eval_steps": 500,
6
+ "global_step": 3280,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.8,
13
+ "grad_norm": 79.95787811279297,
14
+ "learning_rate": 1.2e-05,
15
+ "loss": 10.8172,
16
+ "step": 10
17
+ },
18
+ {
19
+ "epoch": 1.56,
20
+ "grad_norm": 15.179234504699707,
21
+ "learning_rate": 3.2000000000000005e-05,
22
+ "loss": 1.4707,
23
+ "step": 20
24
+ },
25
+ {
26
+ "epoch": 2.32,
27
+ "grad_norm": 5.771359443664551,
28
+ "learning_rate": 5.2000000000000004e-05,
29
+ "loss": 0.8206,
30
+ "step": 30
31
+ },
32
+ {
33
+ "epoch": 3.08,
34
+ "grad_norm": 5.559211730957031,
35
+ "learning_rate": 7.2e-05,
36
+ "loss": 0.9515,
37
+ "step": 40
38
+ },
39
+ {
40
+ "epoch": 3.88,
41
+ "grad_norm": 2.7674646377563477,
42
+ "learning_rate": 9.200000000000001e-05,
43
+ "loss": 0.7728,
44
+ "step": 50
45
+ },
46
+ {
47
+ "epoch": 4.64,
48
+ "grad_norm": 3.8835511207580566,
49
+ "learning_rate": 9.994979079497908e-05,
50
+ "loss": 0.8743,
51
+ "step": 60
52
+ },
53
+ {
54
+ "epoch": 5.4,
55
+ "grad_norm": 2.758652925491333,
56
+ "learning_rate": 9.986610878661087e-05,
57
+ "loss": 0.6475,
58
+ "step": 70
59
+ },
60
+ {
61
+ "epoch": 6.16,
62
+ "grad_norm": 1.3047797679901123,
63
+ "learning_rate": 9.978242677824268e-05,
64
+ "loss": 0.7883,
65
+ "step": 80
66
+ },
67
+ {
68
+ "epoch": 6.96,
69
+ "grad_norm": 3.255369186401367,
70
+ "learning_rate": 9.969874476987448e-05,
71
+ "loss": 2.3198,
72
+ "step": 90
73
+ },
74
+ {
75
+ "epoch": 7.72,
76
+ "grad_norm": 2.8563857078552246,
77
+ "learning_rate": 9.961506276150628e-05,
78
+ "loss": 0.6541,
79
+ "step": 100
80
+ },
81
+ {
82
+ "epoch": 8.48,
83
+ "grad_norm": 1.9255300760269165,
84
+ "learning_rate": 9.953138075313808e-05,
85
+ "loss": 0.7187,
86
+ "step": 110
87
+ },
88
+ {
89
+ "epoch": 9.24,
90
+ "grad_norm": 2.496511697769165,
91
+ "learning_rate": 9.944769874476987e-05,
92
+ "loss": 0.6086,
93
+ "step": 120
94
+ },
95
+ {
96
+ "epoch": 10.0,
97
+ "grad_norm": 1.0335757732391357,
98
+ "learning_rate": 9.936401673640167e-05,
99
+ "loss": 0.5804,
100
+ "step": 130
101
+ },
102
+ {
103
+ "epoch": 10.8,
104
+ "grad_norm": 2.37092661857605,
105
+ "learning_rate": 9.928033472803347e-05,
106
+ "loss": 0.4609,
107
+ "step": 140
108
+ },
109
+ {
110
+ "epoch": 11.56,
111
+ "grad_norm": 5.012497901916504,
112
+ "learning_rate": 9.919665271966527e-05,
113
+ "loss": 0.4665,
114
+ "step": 150
115
+ },
116
+ {
117
+ "epoch": 12.32,
118
+ "grad_norm": 1.2357938289642334,
119
+ "learning_rate": 9.911297071129707e-05,
120
+ "loss": 0.541,
121
+ "step": 160
122
+ },
123
+ {
124
+ "epoch": 13.08,
125
+ "grad_norm": 3.381162405014038,
126
+ "learning_rate": 9.902928870292888e-05,
127
+ "loss": 0.5019,
128
+ "step": 170
129
+ },
130
+ {
131
+ "epoch": 13.88,
132
+ "grad_norm": 1.603922724723816,
133
+ "learning_rate": 9.894560669456067e-05,
134
+ "loss": 0.408,
135
+ "step": 180
136
+ },
137
+ {
138
+ "epoch": 14.64,
139
+ "grad_norm": 1.6547144651412964,
140
+ "learning_rate": 9.886192468619247e-05,
141
+ "loss": 0.3947,
142
+ "step": 190
143
+ },
144
+ {
145
+ "epoch": 15.4,
146
+ "grad_norm": 0.8650034070014954,
147
+ "learning_rate": 9.877824267782427e-05,
148
+ "loss": 0.593,
149
+ "step": 200
150
+ },
151
+ {
152
+ "epoch": 16.16,
153
+ "grad_norm": 0.9181660413742065,
154
+ "learning_rate": 9.869456066945607e-05,
155
+ "loss": 0.3337,
156
+ "step": 210
157
+ },
158
+ {
159
+ "epoch": 16.96,
160
+ "grad_norm": 1.2371879816055298,
161
+ "learning_rate": 9.861087866108786e-05,
162
+ "loss": 0.4019,
163
+ "step": 220
164
+ },
165
+ {
166
+ "epoch": 17.72,
167
+ "grad_norm": 1.0304924249649048,
168
+ "learning_rate": 9.852719665271966e-05,
169
+ "loss": 0.3883,
170
+ "step": 230
171
+ },
172
+ {
173
+ "epoch": 18.48,
174
+ "grad_norm": 1.2461388111114502,
175
+ "learning_rate": 9.844351464435146e-05,
176
+ "loss": 0.6623,
177
+ "step": 240
178
+ },
179
+ {
180
+ "epoch": 19.24,
181
+ "grad_norm": 0.6923120617866516,
182
+ "learning_rate": 9.835983263598327e-05,
183
+ "loss": 0.3604,
184
+ "step": 250
185
+ },
186
+ {
187
+ "epoch": 20.0,
188
+ "grad_norm": 4.839296817779541,
189
+ "learning_rate": 9.827615062761507e-05,
190
+ "loss": 0.387,
191
+ "step": 260
192
+ },
193
+ {
194
+ "epoch": 20.8,
195
+ "grad_norm": 1.1753798723220825,
196
+ "learning_rate": 9.819246861924687e-05,
197
+ "loss": 0.3514,
198
+ "step": 270
199
+ },
200
+ {
201
+ "epoch": 21.56,
202
+ "grad_norm": 0.9725382924079895,
203
+ "learning_rate": 9.810878661087866e-05,
204
+ "loss": 0.3922,
205
+ "step": 280
206
+ },
207
+ {
208
+ "epoch": 22.32,
209
+ "grad_norm": 0.645702600479126,
210
+ "learning_rate": 9.802510460251046e-05,
211
+ "loss": 0.3331,
212
+ "step": 290
213
+ },
214
+ {
215
+ "epoch": 23.08,
216
+ "grad_norm": 1.4330626726150513,
217
+ "learning_rate": 9.794142259414226e-05,
218
+ "loss": 0.3946,
219
+ "step": 300
220
+ },
221
+ {
222
+ "epoch": 23.88,
223
+ "grad_norm": 1.2405084371566772,
224
+ "learning_rate": 9.785774058577406e-05,
225
+ "loss": 0.354,
226
+ "step": 310
227
+ },
228
+ {
229
+ "epoch": 24.64,
230
+ "grad_norm": 0.9036368131637573,
231
+ "learning_rate": 9.777405857740585e-05,
232
+ "loss": 0.3856,
233
+ "step": 320
234
+ },
235
+ {
236
+ "epoch": 25.4,
237
+ "grad_norm": 0.7258665561676025,
238
+ "learning_rate": 9.769037656903767e-05,
239
+ "loss": 0.2859,
240
+ "step": 330
241
+ },
242
+ {
243
+ "epoch": 26.16,
244
+ "grad_norm": 0.4196911156177521,
245
+ "learning_rate": 9.760669456066946e-05,
246
+ "loss": 0.2962,
247
+ "step": 340
248
+ },
249
+ {
250
+ "epoch": 26.96,
251
+ "grad_norm": 1.2342430353164673,
252
+ "learning_rate": 9.752301255230126e-05,
253
+ "loss": 0.3721,
254
+ "step": 350
255
+ },
256
+ {
257
+ "epoch": 27.72,
258
+ "grad_norm": 0.6905648112297058,
259
+ "learning_rate": 9.743933054393306e-05,
260
+ "loss": 0.2722,
261
+ "step": 360
262
+ },
263
+ {
264
+ "epoch": 28.48,
265
+ "grad_norm": 0.851387083530426,
266
+ "learning_rate": 9.735564853556486e-05,
267
+ "loss": 0.3564,
268
+ "step": 370
269
+ },
270
+ {
271
+ "epoch": 29.24,
272
+ "grad_norm": 0.5951876640319824,
273
+ "learning_rate": 9.727196652719665e-05,
274
+ "loss": 0.2892,
275
+ "step": 380
276
+ },
277
+ {
278
+ "epoch": 30.0,
279
+ "grad_norm": 0.5764455199241638,
280
+ "learning_rate": 9.718828451882845e-05,
281
+ "loss": 0.3402,
282
+ "step": 390
283
+ },
284
+ {
285
+ "epoch": 30.8,
286
+ "grad_norm": 7.990225315093994,
287
+ "learning_rate": 9.710460251046025e-05,
288
+ "loss": 0.3439,
289
+ "step": 400
290
+ },
291
+ {
292
+ "epoch": 31.56,
293
+ "grad_norm": 1.119363784790039,
294
+ "learning_rate": 9.702092050209205e-05,
295
+ "loss": 0.2595,
296
+ "step": 410
297
+ },
298
+ {
299
+ "epoch": 32.32,
300
+ "grad_norm": 27.885967254638672,
301
+ "learning_rate": 9.693723849372386e-05,
302
+ "loss": 0.9194,
303
+ "step": 420
304
+ },
305
+ {
306
+ "epoch": 33.08,
307
+ "grad_norm": 5.194603443145752,
308
+ "learning_rate": 9.685355648535566e-05,
309
+ "loss": 0.3321,
310
+ "step": 430
311
+ },
312
+ {
313
+ "epoch": 33.88,
314
+ "grad_norm": 2.1785478591918945,
315
+ "learning_rate": 9.676987447698745e-05,
316
+ "loss": 0.3875,
317
+ "step": 440
318
+ },
319
+ {
320
+ "epoch": 34.64,
321
+ "grad_norm": 0.844071626663208,
322
+ "learning_rate": 9.668619246861925e-05,
323
+ "loss": 0.2458,
324
+ "step": 450
325
+ },
326
+ {
327
+ "epoch": 35.4,
328
+ "grad_norm": 2.1278724670410156,
329
+ "learning_rate": 9.660251046025105e-05,
330
+ "loss": 0.358,
331
+ "step": 460
332
+ },
333
+ {
334
+ "epoch": 36.16,
335
+ "grad_norm": 0.6194890141487122,
336
+ "learning_rate": 9.651882845188285e-05,
337
+ "loss": 0.3334,
338
+ "step": 470
339
+ },
340
+ {
341
+ "epoch": 36.96,
342
+ "grad_norm": 0.7329260110855103,
343
+ "learning_rate": 9.643514644351464e-05,
344
+ "loss": 0.2981,
345
+ "step": 480
346
+ },
347
+ {
348
+ "epoch": 37.72,
349
+ "grad_norm": 0.8790725469589233,
350
+ "learning_rate": 9.635146443514644e-05,
351
+ "loss": 0.2903,
352
+ "step": 490
353
+ },
354
+ {
355
+ "epoch": 38.48,
356
+ "grad_norm": 0.5892160534858704,
357
+ "learning_rate": 9.626778242677825e-05,
358
+ "loss": 0.362,
359
+ "step": 500
360
+ },
361
+ {
362
+ "epoch": 39.24,
363
+ "grad_norm": 1.6055381298065186,
364
+ "learning_rate": 9.618410041841005e-05,
365
+ "loss": 0.3838,
366
+ "step": 510
367
+ },
368
+ {
369
+ "epoch": 40.0,
370
+ "grad_norm": 11.784494400024414,
371
+ "learning_rate": 9.610041841004185e-05,
372
+ "loss": 0.3082,
373
+ "step": 520
374
+ },
375
+ {
376
+ "epoch": 40.8,
377
+ "grad_norm": 0.5743730664253235,
378
+ "learning_rate": 9.601673640167365e-05,
379
+ "loss": 0.31,
380
+ "step": 530
381
+ },
382
+ {
383
+ "epoch": 41.56,
384
+ "grad_norm": 0.592658519744873,
385
+ "learning_rate": 9.593305439330544e-05,
386
+ "loss": 0.5208,
387
+ "step": 540
388
+ },
389
+ {
390
+ "epoch": 42.32,
391
+ "grad_norm": 5.394102573394775,
392
+ "learning_rate": 9.584937238493724e-05,
393
+ "loss": 0.4108,
394
+ "step": 550
395
+ },
396
+ {
397
+ "epoch": 43.08,
398
+ "grad_norm": 0.7347181439399719,
399
+ "learning_rate": 9.576569037656904e-05,
400
+ "loss": 0.3476,
401
+ "step": 560
402
+ },
403
+ {
404
+ "epoch": 43.88,
405
+ "grad_norm": 0.9459385871887207,
406
+ "learning_rate": 9.568200836820084e-05,
407
+ "loss": 0.3008,
408
+ "step": 570
409
+ },
410
+ {
411
+ "epoch": 44.64,
412
+ "grad_norm": 5.524572849273682,
413
+ "learning_rate": 9.559832635983263e-05,
414
+ "loss": 0.266,
415
+ "step": 580
416
+ },
417
+ {
418
+ "epoch": 45.4,
419
+ "grad_norm": 0.705575704574585,
420
+ "learning_rate": 9.551464435146445e-05,
421
+ "loss": 0.3705,
422
+ "step": 590
423
+ },
424
+ {
425
+ "epoch": 46.16,
426
+ "grad_norm": 0.5305906534194946,
427
+ "learning_rate": 9.543096234309624e-05,
428
+ "loss": 0.2619,
429
+ "step": 600
430
+ },
431
+ {
432
+ "epoch": 46.96,
433
+ "grad_norm": 3.7031350135803223,
434
+ "learning_rate": 9.534728033472804e-05,
435
+ "loss": 0.4041,
436
+ "step": 610
437
+ },
438
+ {
439
+ "epoch": 47.72,
440
+ "grad_norm": 0.9455975294113159,
441
+ "learning_rate": 9.526359832635984e-05,
442
+ "loss": 0.2948,
443
+ "step": 620
444
+ },
445
+ {
446
+ "epoch": 48.48,
447
+ "grad_norm": 0.13628531992435455,
448
+ "learning_rate": 9.517991631799164e-05,
449
+ "loss": 0.2962,
450
+ "step": 630
451
+ },
452
+ {
453
+ "epoch": 49.24,
454
+ "grad_norm": 0.5263031721115112,
455
+ "learning_rate": 9.509623430962343e-05,
456
+ "loss": 0.3395,
457
+ "step": 640
458
+ },
459
+ {
460
+ "epoch": 50.0,
461
+ "grad_norm": 0.47290968894958496,
462
+ "learning_rate": 9.501255230125523e-05,
463
+ "loss": 0.3163,
464
+ "step": 650
465
+ },
466
+ {
467
+ "epoch": 50.8,
468
+ "grad_norm": 0.7216980457305908,
469
+ "learning_rate": 9.492887029288703e-05,
470
+ "loss": 0.303,
471
+ "step": 660
472
+ },
473
+ {
474
+ "epoch": 51.56,
475
+ "grad_norm": 0.44606828689575195,
476
+ "learning_rate": 9.484518828451884e-05,
477
+ "loss": 0.3086,
478
+ "step": 670
479
+ },
480
+ {
481
+ "epoch": 52.32,
482
+ "grad_norm": 0.6489527821540833,
483
+ "learning_rate": 9.476150627615064e-05,
484
+ "loss": 0.3226,
485
+ "step": 680
486
+ },
487
+ {
488
+ "epoch": 53.08,
489
+ "grad_norm": 0.4259757101535797,
490
+ "learning_rate": 9.467782426778243e-05,
491
+ "loss": 0.2513,
492
+ "step": 690
493
+ },
494
+ {
495
+ "epoch": 53.88,
496
+ "grad_norm": 0.5952382683753967,
497
+ "learning_rate": 9.459414225941423e-05,
498
+ "loss": 0.3113,
499
+ "step": 700
500
+ },
501
+ {
502
+ "epoch": 54.64,
503
+ "grad_norm": 1.0894474983215332,
504
+ "learning_rate": 9.451046025104603e-05,
505
+ "loss": 0.2583,
506
+ "step": 710
507
+ },
508
+ {
509
+ "epoch": 55.4,
510
+ "grad_norm": 0.5647149085998535,
511
+ "learning_rate": 9.442677824267783e-05,
512
+ "loss": 0.3042,
513
+ "step": 720
514
+ },
515
+ {
516
+ "epoch": 56.16,
517
+ "grad_norm": 0.9288455843925476,
518
+ "learning_rate": 9.434309623430963e-05,
519
+ "loss": 0.3281,
520
+ "step": 730
521
+ },
522
+ {
523
+ "epoch": 56.96,
524
+ "grad_norm": 0.4492562711238861,
525
+ "learning_rate": 9.425941422594142e-05,
526
+ "loss": 0.3379,
527
+ "step": 740
528
+ },
529
+ {
530
+ "epoch": 57.72,
531
+ "grad_norm": 1.0189441442489624,
532
+ "learning_rate": 9.417573221757323e-05,
533
+ "loss": 0.3194,
534
+ "step": 750
535
+ },
536
+ {
537
+ "epoch": 58.48,
538
+ "grad_norm": 0.40981024503707886,
539
+ "learning_rate": 9.409205020920503e-05,
540
+ "loss": 0.2423,
541
+ "step": 760
542
+ },
543
+ {
544
+ "epoch": 59.24,
545
+ "grad_norm": 0.354190468788147,
546
+ "learning_rate": 9.400836820083683e-05,
547
+ "loss": 0.3266,
548
+ "step": 770
549
+ },
550
+ {
551
+ "epoch": 60.0,
552
+ "grad_norm": 0.4593465328216553,
553
+ "learning_rate": 9.392468619246863e-05,
554
+ "loss": 0.3068,
555
+ "step": 780
556
+ },
557
+ {
558
+ "epoch": 60.8,
559
+ "grad_norm": 0.759248673915863,
560
+ "learning_rate": 9.384100418410042e-05,
561
+ "loss": 0.3078,
562
+ "step": 790
563
+ },
564
+ {
565
+ "epoch": 61.56,
566
+ "grad_norm": 0.8159565925598145,
567
+ "learning_rate": 9.375732217573222e-05,
568
+ "loss": 0.2668,
569
+ "step": 800
570
+ },
571
+ {
572
+ "epoch": 62.32,
573
+ "grad_norm": 0.7888874411582947,
574
+ "learning_rate": 9.367364016736402e-05,
575
+ "loss": 0.2936,
576
+ "step": 810
577
+ },
578
+ {
579
+ "epoch": 63.08,
580
+ "grad_norm": 0.8634017705917358,
581
+ "learning_rate": 9.358995815899582e-05,
582
+ "loss": 0.3182,
583
+ "step": 820
584
+ },
585
+ {
586
+ "epoch": 63.88,
587
+ "grad_norm": 0.5454868078231812,
588
+ "learning_rate": 9.350627615062762e-05,
589
+ "loss": 0.3117,
590
+ "step": 830
591
+ },
592
+ {
593
+ "epoch": 64.64,
594
+ "grad_norm": 0.9895418286323547,
595
+ "learning_rate": 9.342259414225943e-05,
596
+ "loss": 0.3014,
597
+ "step": 840
598
+ },
599
+ {
600
+ "epoch": 65.4,
601
+ "grad_norm": 1.4120994806289673,
602
+ "learning_rate": 9.333891213389122e-05,
603
+ "loss": 0.3001,
604
+ "step": 850
605
+ },
606
+ {
607
+ "epoch": 66.16,
608
+ "grad_norm": 0.6670629978179932,
609
+ "learning_rate": 9.325523012552302e-05,
610
+ "loss": 0.2651,
611
+ "step": 860
612
+ },
613
+ {
614
+ "epoch": 66.96,
615
+ "grad_norm": 0.6429235339164734,
616
+ "learning_rate": 9.317154811715482e-05,
617
+ "loss": 0.2935,
618
+ "step": 870
619
+ },
620
+ {
621
+ "epoch": 67.72,
622
+ "grad_norm": 0.6127046942710876,
623
+ "learning_rate": 9.308786610878662e-05,
624
+ "loss": 0.2936,
625
+ "step": 880
626
+ },
627
+ {
628
+ "epoch": 68.48,
629
+ "grad_norm": 0.6747534275054932,
630
+ "learning_rate": 9.300418410041841e-05,
631
+ "loss": 0.2557,
632
+ "step": 890
633
+ },
634
+ {
635
+ "epoch": 69.24,
636
+ "grad_norm": 0.8579817414283752,
637
+ "learning_rate": 9.292050209205021e-05,
638
+ "loss": 0.3045,
639
+ "step": 900
640
+ },
641
+ {
642
+ "epoch": 70.0,
643
+ "grad_norm": 0.24220693111419678,
644
+ "learning_rate": 9.283682008368201e-05,
645
+ "loss": 0.2818,
646
+ "step": 910
647
+ },
648
+ {
649
+ "epoch": 70.8,
650
+ "grad_norm": 0.5585035681724548,
651
+ "learning_rate": 9.275313807531382e-05,
652
+ "loss": 0.2579,
653
+ "step": 920
654
+ },
655
+ {
656
+ "epoch": 71.56,
657
+ "grad_norm": 0.5965076684951782,
658
+ "learning_rate": 9.266945606694562e-05,
659
+ "loss": 0.3582,
660
+ "step": 930
661
+ },
662
+ {
663
+ "epoch": 72.32,
664
+ "grad_norm": 0.4647338092327118,
665
+ "learning_rate": 9.258577405857742e-05,
666
+ "loss": 0.2384,
667
+ "step": 940
668
+ },
669
+ {
670
+ "epoch": 73.08,
671
+ "grad_norm": 0.4868353605270386,
672
+ "learning_rate": 9.250209205020921e-05,
673
+ "loss": 0.3097,
674
+ "step": 950
675
+ },
676
+ {
677
+ "epoch": 73.88,
678
+ "grad_norm": 0.4390128552913666,
679
+ "learning_rate": 9.241841004184101e-05,
680
+ "loss": 0.2722,
681
+ "step": 960
682
+ },
683
+ {
684
+ "epoch": 74.64,
685
+ "grad_norm": 1.2558406591415405,
686
+ "learning_rate": 9.233472803347281e-05,
687
+ "loss": 0.2546,
688
+ "step": 970
689
+ },
690
+ {
691
+ "epoch": 75.4,
692
+ "grad_norm": 1.468024492263794,
693
+ "learning_rate": 9.225104602510461e-05,
694
+ "loss": 0.3539,
695
+ "step": 980
696
+ },
697
+ {
698
+ "epoch": 76.16,
699
+ "grad_norm": 1.3816922903060913,
700
+ "learning_rate": 9.21673640167364e-05,
701
+ "loss": 0.3131,
702
+ "step": 990
703
+ },
704
+ {
705
+ "epoch": 76.96,
706
+ "grad_norm": 1.0851378440856934,
707
+ "learning_rate": 9.208368200836822e-05,
708
+ "loss": 0.3418,
709
+ "step": 1000
710
+ },
711
+ {
712
+ "epoch": 77.72,
713
+ "grad_norm": 0.6077636480331421,
714
+ "learning_rate": 9.200000000000001e-05,
715
+ "loss": 0.301,
716
+ "step": 1010
717
+ },
718
+ {
719
+ "epoch": 78.48,
720
+ "grad_norm": 0.39989063143730164,
721
+ "learning_rate": 9.191631799163181e-05,
722
+ "loss": 0.2551,
723
+ "step": 1020
724
+ },
725
+ {
726
+ "epoch": 79.24,
727
+ "grad_norm": 0.5581826567649841,
728
+ "learning_rate": 9.183263598326361e-05,
729
+ "loss": 0.2955,
730
+ "step": 1030
731
+ },
732
+ {
733
+ "epoch": 80.0,
734
+ "grad_norm": 0.42415767908096313,
735
+ "learning_rate": 9.17489539748954e-05,
736
+ "loss": 0.3024,
737
+ "step": 1040
738
+ },
739
+ {
740
+ "epoch": 80.8,
741
+ "grad_norm": 0.431111603975296,
742
+ "learning_rate": 9.16652719665272e-05,
743
+ "loss": 0.3108,
744
+ "step": 1050
745
+ },
746
+ {
747
+ "epoch": 81.56,
748
+ "grad_norm": 0.45482027530670166,
749
+ "learning_rate": 9.1581589958159e-05,
750
+ "loss": 0.25,
751
+ "step": 1060
752
+ },
753
+ {
754
+ "epoch": 82.32,
755
+ "grad_norm": 0.6043058633804321,
756
+ "learning_rate": 9.14979079497908e-05,
757
+ "loss": 0.3466,
758
+ "step": 1070
759
+ },
760
+ {
761
+ "epoch": 83.08,
762
+ "grad_norm": 0.7600142955780029,
763
+ "learning_rate": 9.14142259414226e-05,
764
+ "loss": 0.2888,
765
+ "step": 1080
766
+ },
767
+ {
768
+ "epoch": 83.88,
769
+ "grad_norm": 0.36540111899375916,
770
+ "learning_rate": 9.133054393305441e-05,
771
+ "loss": 0.2896,
772
+ "step": 1090
773
+ },
774
+ {
775
+ "epoch": 84.64,
776
+ "grad_norm": 1.1366398334503174,
777
+ "learning_rate": 9.12468619246862e-05,
778
+ "loss": 0.3023,
779
+ "step": 1100
780
+ },
781
+ {
782
+ "epoch": 85.4,
783
+ "grad_norm": 0.646086573600769,
784
+ "learning_rate": 9.1163179916318e-05,
785
+ "loss": 0.2462,
786
+ "step": 1110
787
+ },
788
+ {
789
+ "epoch": 86.16,
790
+ "grad_norm": 0.6224349141120911,
791
+ "learning_rate": 9.10794979079498e-05,
792
+ "loss": 0.2741,
793
+ "step": 1120
794
+ },
795
+ {
796
+ "epoch": 86.96,
797
+ "grad_norm": 0.8657971024513245,
798
+ "learning_rate": 9.09958158995816e-05,
799
+ "loss": 0.3016,
800
+ "step": 1130
801
+ },
802
+ {
803
+ "epoch": 87.72,
804
+ "grad_norm": 0.86732017993927,
805
+ "learning_rate": 9.09121338912134e-05,
806
+ "loss": 0.263,
807
+ "step": 1140
808
+ },
809
+ {
810
+ "epoch": 88.48,
811
+ "grad_norm": 0.8562549948692322,
812
+ "learning_rate": 9.08284518828452e-05,
813
+ "loss": 0.2726,
814
+ "step": 1150
815
+ },
816
+ {
817
+ "epoch": 89.24,
818
+ "grad_norm": 0.5194992423057556,
819
+ "learning_rate": 9.074476987447699e-05,
820
+ "loss": 0.3161,
821
+ "step": 1160
822
+ },
823
+ {
824
+ "epoch": 90.0,
825
+ "grad_norm": 0.3380357027053833,
826
+ "learning_rate": 9.066108786610879e-05,
827
+ "loss": 0.26,
828
+ "step": 1170
829
+ },
830
+ {
831
+ "epoch": 90.8,
832
+ "grad_norm": 0.4834354519844055,
833
+ "learning_rate": 9.057740585774059e-05,
834
+ "loss": 0.2901,
835
+ "step": 1180
836
+ },
837
+ {
838
+ "epoch": 91.56,
839
+ "grad_norm": 0.7634447813034058,
840
+ "learning_rate": 9.04937238493724e-05,
841
+ "loss": 0.2471,
842
+ "step": 1190
843
+ },
844
+ {
845
+ "epoch": 92.32,
846
+ "grad_norm": 0.5605065822601318,
847
+ "learning_rate": 9.04100418410042e-05,
848
+ "loss": 0.3091,
849
+ "step": 1200
850
+ },
851
+ {
852
+ "epoch": 93.08,
853
+ "grad_norm": 0.6867684721946716,
854
+ "learning_rate": 9.0326359832636e-05,
855
+ "loss": 0.292,
856
+ "step": 1210
857
+ },
858
+ {
859
+ "epoch": 93.88,
860
+ "grad_norm": 0.5395390391349792,
861
+ "learning_rate": 9.024267782426779e-05,
862
+ "loss": 0.2727,
863
+ "step": 1220
864
+ },
865
+ {
866
+ "epoch": 94.64,
867
+ "grad_norm": 0.8648020029067993,
868
+ "learning_rate": 9.015899581589959e-05,
869
+ "loss": 0.3273,
870
+ "step": 1230
871
+ },
872
+ {
873
+ "epoch": 95.4,
874
+ "grad_norm": 0.5256586074829102,
875
+ "learning_rate": 9.007531380753139e-05,
876
+ "loss": 0.2611,
877
+ "step": 1240
878
+ },
879
+ {
880
+ "epoch": 96.16,
881
+ "grad_norm": 0.726409375667572,
882
+ "learning_rate": 8.999163179916318e-05,
883
+ "loss": 0.2993,
884
+ "step": 1250
885
+ },
886
+ {
887
+ "epoch": 96.96,
888
+ "grad_norm": 0.5897384285926819,
889
+ "learning_rate": 8.990794979079498e-05,
890
+ "loss": 0.2628,
891
+ "step": 1260
892
+ },
893
+ {
894
+ "epoch": 97.72,
895
+ "grad_norm": 0.3650963306427002,
896
+ "learning_rate": 8.982426778242678e-05,
897
+ "loss": 0.2975,
898
+ "step": 1270
899
+ },
900
+ {
901
+ "epoch": 98.48,
902
+ "grad_norm": 0.6548069715499878,
903
+ "learning_rate": 8.974058577405858e-05,
904
+ "loss": 0.2732,
905
+ "step": 1280
906
+ },
907
+ {
908
+ "epoch": 99.24,
909
+ "grad_norm": 0.6239974498748779,
910
+ "learning_rate": 8.965690376569037e-05,
911
+ "loss": 0.3471,
912
+ "step": 1290
913
+ },
914
+ {
915
+ "epoch": 100.0,
916
+ "grad_norm": 0.00010844325879588723,
917
+ "learning_rate": 8.957322175732217e-05,
918
+ "loss": 0.2865,
919
+ "step": 1300
920
+ },
921
+ {
922
+ "epoch": 100.8,
923
+ "grad_norm": 55.954044342041016,
924
+ "learning_rate": 8.94979079497908e-05,
925
+ "loss": 2.0083,
926
+ "step": 1310
927
+ },
928
+ {
929
+ "epoch": 101.56,
930
+ "grad_norm": 1.1408623456954956,
931
+ "learning_rate": 8.94142259414226e-05,
932
+ "loss": 0.7388,
933
+ "step": 1320
934
+ },
935
+ {
936
+ "epoch": 102.32,
937
+ "grad_norm": 0.7127572298049927,
938
+ "learning_rate": 8.93305439330544e-05,
939
+ "loss": 0.2341,
940
+ "step": 1330
941
+ },
942
+ {
943
+ "epoch": 103.08,
944
+ "grad_norm": 188.80079650878906,
945
+ "learning_rate": 8.92468619246862e-05,
946
+ "loss": 0.8824,
947
+ "step": 1340
948
+ },
949
+ {
950
+ "epoch": 103.88,
951
+ "grad_norm": 22.78482437133789,
952
+ "learning_rate": 8.9163179916318e-05,
953
+ "loss": 0.6129,
954
+ "step": 1350
955
+ },
956
+ {
957
+ "epoch": 104.64,
958
+ "grad_norm": 1.2376023530960083,
959
+ "learning_rate": 8.90794979079498e-05,
960
+ "loss": 0.4303,
961
+ "step": 1360
962
+ },
963
+ {
964
+ "epoch": 105.4,
965
+ "grad_norm": 1.032791018486023,
966
+ "learning_rate": 8.899581589958159e-05,
967
+ "loss": 0.3326,
968
+ "step": 1370
969
+ },
970
+ {
971
+ "epoch": 106.16,
972
+ "grad_norm": 0.6220753192901611,
973
+ "learning_rate": 8.891213389121339e-05,
974
+ "loss": 0.3165,
975
+ "step": 1380
976
+ },
977
+ {
978
+ "epoch": 106.96,
979
+ "grad_norm": 0.9835271835327148,
980
+ "learning_rate": 8.882845188284519e-05,
981
+ "loss": 0.2899,
982
+ "step": 1390
983
+ },
984
+ {
985
+ "epoch": 107.72,
986
+ "grad_norm": 0.4846683144569397,
987
+ "learning_rate": 8.8744769874477e-05,
988
+ "loss": 0.3234,
989
+ "step": 1400
990
+ },
991
+ {
992
+ "epoch": 108.48,
993
+ "grad_norm": 42.94027328491211,
994
+ "learning_rate": 8.86610878661088e-05,
995
+ "loss": 0.2739,
996
+ "step": 1410
997
+ },
998
+ {
999
+ "epoch": 109.24,
1000
+ "grad_norm": 0.42791783809661865,
1001
+ "learning_rate": 8.857740585774059e-05,
1002
+ "loss": 0.3374,
1003
+ "step": 1420
1004
+ },
1005
+ {
1006
+ "epoch": 110.0,
1007
+ "grad_norm": 0.5355058312416077,
1008
+ "learning_rate": 8.849372384937239e-05,
1009
+ "loss": 0.2941,
1010
+ "step": 1430
1011
+ },
1012
+ {
1013
+ "epoch": 110.8,
1014
+ "grad_norm": 0.43855488300323486,
1015
+ "learning_rate": 8.841004184100419e-05,
1016
+ "loss": 0.254,
1017
+ "step": 1440
1018
+ },
1019
+ {
1020
+ "epoch": 111.56,
1021
+ "grad_norm": 0.6513474583625793,
1022
+ "learning_rate": 8.832635983263599e-05,
1023
+ "loss": 0.2992,
1024
+ "step": 1450
1025
+ },
1026
+ {
1027
+ "epoch": 112.32,
1028
+ "grad_norm": 0.5990163087844849,
1029
+ "learning_rate": 8.824267782426778e-05,
1030
+ "loss": 0.2838,
1031
+ "step": 1460
1032
+ },
1033
+ {
1034
+ "epoch": 113.08,
1035
+ "grad_norm": 0.5194154381752014,
1036
+ "learning_rate": 8.815899581589958e-05,
1037
+ "loss": 0.2896,
1038
+ "step": 1470
1039
+ },
1040
+ {
1041
+ "epoch": 113.88,
1042
+ "grad_norm": 0.90041583776474,
1043
+ "learning_rate": 8.807531380753139e-05,
1044
+ "loss": 0.2853,
1045
+ "step": 1480
1046
+ },
1047
+ {
1048
+ "epoch": 114.64,
1049
+ "grad_norm": 0.46713006496429443,
1050
+ "learning_rate": 8.799163179916319e-05,
1051
+ "loss": 0.341,
1052
+ "step": 1490
1053
+ },
1054
+ {
1055
+ "epoch": 115.4,
1056
+ "grad_norm": 1.0562089681625366,
1057
+ "learning_rate": 8.790794979079499e-05,
1058
+ "loss": 0.2485,
1059
+ "step": 1500
1060
+ },
1061
+ {
1062
+ "epoch": 116.16,
1063
+ "grad_norm": 0.7914270162582397,
1064
+ "learning_rate": 8.782426778242678e-05,
1065
+ "loss": 0.2931,
1066
+ "step": 1510
1067
+ },
1068
+ {
1069
+ "epoch": 116.96,
1070
+ "grad_norm": 0.7104222178459167,
1071
+ "learning_rate": 8.774058577405858e-05,
1072
+ "loss": 0.3166,
1073
+ "step": 1520
1074
+ },
1075
+ {
1076
+ "epoch": 117.72,
1077
+ "grad_norm": 0.6477789878845215,
1078
+ "learning_rate": 8.765690376569038e-05,
1079
+ "loss": 0.2712,
1080
+ "step": 1530
1081
+ },
1082
+ {
1083
+ "epoch": 118.48,
1084
+ "grad_norm": 0.2977544069290161,
1085
+ "learning_rate": 8.757322175732218e-05,
1086
+ "loss": 0.2537,
1087
+ "step": 1540
1088
+ },
1089
+ {
1090
+ "epoch": 119.24,
1091
+ "grad_norm": 0.6447045803070068,
1092
+ "learning_rate": 8.748953974895398e-05,
1093
+ "loss": 0.2848,
1094
+ "step": 1550
1095
+ },
1096
+ {
1097
+ "epoch": 120.0,
1098
+ "grad_norm": 0.4993550479412079,
1099
+ "learning_rate": 8.740585774058579e-05,
1100
+ "loss": 0.2712,
1101
+ "step": 1560
1102
+ },
1103
+ {
1104
+ "epoch": 120.8,
1105
+ "grad_norm": 0.28479063510894775,
1106
+ "learning_rate": 8.732217573221758e-05,
1107
+ "loss": 0.2969,
1108
+ "step": 1570
1109
+ },
1110
+ {
1111
+ "epoch": 121.56,
1112
+ "grad_norm": 0.7489855885505676,
1113
+ "learning_rate": 8.723849372384938e-05,
1114
+ "loss": 0.2512,
1115
+ "step": 1580
1116
+ },
1117
+ {
1118
+ "epoch": 122.32,
1119
+ "grad_norm": 0.6503575444221497,
1120
+ "learning_rate": 8.715481171548118e-05,
1121
+ "loss": 0.268,
1122
+ "step": 1590
1123
+ },
1124
+ {
1125
+ "epoch": 123.08,
1126
+ "grad_norm": 0.5870686769485474,
1127
+ "learning_rate": 8.707112970711298e-05,
1128
+ "loss": 0.302,
1129
+ "step": 1600
1130
+ },
1131
+ {
1132
+ "epoch": 123.88,
1133
+ "grad_norm": 0.8388033509254456,
1134
+ "learning_rate": 8.698744769874477e-05,
1135
+ "loss": 0.2784,
1136
+ "step": 1610
1137
+ },
1138
+ {
1139
+ "epoch": 124.64,
1140
+ "grad_norm": 0.7110853791236877,
1141
+ "learning_rate": 8.690376569037657e-05,
1142
+ "loss": 0.2576,
1143
+ "step": 1620
1144
+ },
1145
+ {
1146
+ "epoch": 125.4,
1147
+ "grad_norm": 0.6697489619255066,
1148
+ "learning_rate": 8.682008368200837e-05,
1149
+ "loss": 0.2863,
1150
+ "step": 1630
1151
+ },
1152
+ {
1153
+ "epoch": 126.16,
1154
+ "grad_norm": 0.6678580045700073,
1155
+ "learning_rate": 8.673640167364017e-05,
1156
+ "loss": 0.2945,
1157
+ "step": 1640
1158
+ },
1159
+ {
1160
+ "epoch": 126.96,
1161
+ "grad_norm": 0.5099469423294067,
1162
+ "learning_rate": 8.665271966527198e-05,
1163
+ "loss": 0.2744,
1164
+ "step": 1650
1165
+ },
1166
+ {
1167
+ "epoch": 127.72,
1168
+ "grad_norm": 0.8461157083511353,
1169
+ "learning_rate": 8.656903765690378e-05,
1170
+ "loss": 0.2881,
1171
+ "step": 1660
1172
+ },
1173
+ {
1174
+ "epoch": 128.48,
1175
+ "grad_norm": 0.5801335573196411,
1176
+ "learning_rate": 8.648535564853557e-05,
1177
+ "loss": 0.2743,
1178
+ "step": 1670
1179
+ },
1180
+ {
1181
+ "epoch": 129.24,
1182
+ "grad_norm": 0.00013900638441555202,
1183
+ "learning_rate": 8.640167364016737e-05,
1184
+ "loss": 0.2389,
1185
+ "step": 1680
1186
+ },
1187
+ {
1188
+ "epoch": 130.0,
1189
+ "grad_norm": 0.8893792033195496,
1190
+ "learning_rate": 8.631799163179917e-05,
1191
+ "loss": 0.3059,
1192
+ "step": 1690
1193
+ },
1194
+ {
1195
+ "epoch": 130.8,
1196
+ "grad_norm": 0.5562736392021179,
1197
+ "learning_rate": 8.623430962343097e-05,
1198
+ "loss": 0.2586,
1199
+ "step": 1700
1200
+ },
1201
+ {
1202
+ "epoch": 131.56,
1203
+ "grad_norm": 0.41472992300987244,
1204
+ "learning_rate": 8.615062761506276e-05,
1205
+ "loss": 0.2632,
1206
+ "step": 1710
1207
+ },
1208
+ {
1209
+ "epoch": 132.32,
1210
+ "grad_norm": 0.23705990612506866,
1211
+ "learning_rate": 8.606694560669456e-05,
1212
+ "loss": 0.2678,
1213
+ "step": 1720
1214
+ },
1215
+ {
1216
+ "epoch": 133.08,
1217
+ "grad_norm": 0.6223066449165344,
1218
+ "learning_rate": 8.598326359832637e-05,
1219
+ "loss": 0.3278,
1220
+ "step": 1730
1221
+ },
1222
+ {
1223
+ "epoch": 133.88,
1224
+ "grad_norm": 0.7489154934883118,
1225
+ "learning_rate": 8.589958158995817e-05,
1226
+ "loss": 0.3094,
1227
+ "step": 1740
1228
+ },
1229
+ {
1230
+ "epoch": 134.64,
1231
+ "grad_norm": 0.0001321160380030051,
1232
+ "learning_rate": 8.581589958158997e-05,
1233
+ "loss": 0.2386,
1234
+ "step": 1750
1235
+ },
1236
+ {
1237
+ "epoch": 135.4,
1238
+ "grad_norm": 0.6472476720809937,
1239
+ "learning_rate": 8.573221757322177e-05,
1240
+ "loss": 0.319,
1241
+ "step": 1760
1242
+ },
1243
+ {
1244
+ "epoch": 136.16,
1245
+ "grad_norm": 0.6706916689872742,
1246
+ "learning_rate": 8.564853556485356e-05,
1247
+ "loss": 0.2565,
1248
+ "step": 1770
1249
+ },
1250
+ {
1251
+ "epoch": 136.96,
1252
+ "grad_norm": 0.6396327614784241,
1253
+ "learning_rate": 8.556485355648536e-05,
1254
+ "loss": 0.3257,
1255
+ "step": 1780
1256
+ },
1257
+ {
1258
+ "epoch": 137.72,
1259
+ "grad_norm": 0.945931613445282,
1260
+ "learning_rate": 8.548117154811716e-05,
1261
+ "loss": 0.2912,
1262
+ "step": 1790
1263
+ },
1264
+ {
1265
+ "epoch": 138.48,
1266
+ "grad_norm": 0.4821135699748993,
1267
+ "learning_rate": 8.539748953974896e-05,
1268
+ "loss": 0.239,
1269
+ "step": 1800
1270
+ },
1271
+ {
1272
+ "epoch": 139.24,
1273
+ "grad_norm": 0.5348692536354065,
1274
+ "learning_rate": 8.531380753138077e-05,
1275
+ "loss": 0.2632,
1276
+ "step": 1810
1277
+ },
1278
+ {
1279
+ "epoch": 140.0,
1280
+ "grad_norm": 0.3767673671245575,
1281
+ "learning_rate": 8.523012552301257e-05,
1282
+ "loss": 0.2752,
1283
+ "step": 1820
1284
+ },
1285
+ {
1286
+ "epoch": 140.8,
1287
+ "grad_norm": 0.5291851758956909,
1288
+ "learning_rate": 8.514644351464436e-05,
1289
+ "loss": 0.2997,
1290
+ "step": 1830
1291
+ },
1292
+ {
1293
+ "epoch": 141.56,
1294
+ "grad_norm": 0.6979625225067139,
1295
+ "learning_rate": 8.506276150627616e-05,
1296
+ "loss": 0.2731,
1297
+ "step": 1840
1298
+ },
1299
+ {
1300
+ "epoch": 142.32,
1301
+ "grad_norm": 0.2552458941936493,
1302
+ "learning_rate": 8.497907949790796e-05,
1303
+ "loss": 0.2046,
1304
+ "step": 1850
1305
+ },
1306
+ {
1307
+ "epoch": 143.08,
1308
+ "grad_norm": 0.7138965725898743,
1309
+ "learning_rate": 8.489539748953976e-05,
1310
+ "loss": 0.3448,
1311
+ "step": 1860
1312
+ },
1313
+ {
1314
+ "epoch": 143.88,
1315
+ "grad_norm": 0.602065920829773,
1316
+ "learning_rate": 8.481171548117155e-05,
1317
+ "loss": 0.3201,
1318
+ "step": 1870
1319
+ },
1320
+ {
1321
+ "epoch": 144.64,
1322
+ "grad_norm": 0.8118859529495239,
1323
+ "learning_rate": 8.472803347280335e-05,
1324
+ "loss": 0.2622,
1325
+ "step": 1880
1326
+ },
1327
+ {
1328
+ "epoch": 145.4,
1329
+ "grad_norm": 0.00017894129268825054,
1330
+ "learning_rate": 8.464435146443515e-05,
1331
+ "loss": 0.2388,
1332
+ "step": 1890
1333
+ },
1334
+ {
1335
+ "epoch": 146.16,
1336
+ "grad_norm": 0.5059804320335388,
1337
+ "learning_rate": 8.456066945606696e-05,
1338
+ "loss": 0.3429,
1339
+ "step": 1900
1340
+ },
1341
+ {
1342
+ "epoch": 146.96,
1343
+ "grad_norm": 0.47752124071121216,
1344
+ "learning_rate": 8.447698744769876e-05,
1345
+ "loss": 0.2575,
1346
+ "step": 1910
1347
+ },
1348
+ {
1349
+ "epoch": 147.72,
1350
+ "grad_norm": 0.6666246056556702,
1351
+ "learning_rate": 8.439330543933056e-05,
1352
+ "loss": 0.284,
1353
+ "step": 1920
1354
+ },
1355
+ {
1356
+ "epoch": 148.48,
1357
+ "grad_norm": 0.6572575569152832,
1358
+ "learning_rate": 8.430962343096235e-05,
1359
+ "loss": 0.2432,
1360
+ "step": 1930
1361
+ },
1362
+ {
1363
+ "epoch": 149.24,
1364
+ "grad_norm": 0.644818902015686,
1365
+ "learning_rate": 8.422594142259415e-05,
1366
+ "loss": 0.2844,
1367
+ "step": 1940
1368
+ },
1369
+ {
1370
+ "epoch": 150.0,
1371
+ "grad_norm": 1.5558511018753052,
1372
+ "learning_rate": 8.414225941422595e-05,
1373
+ "loss": 0.2891,
1374
+ "step": 1950
1375
+ },
1376
+ {
1377
+ "epoch": 150.8,
1378
+ "grad_norm": 0.9465558528900146,
1379
+ "learning_rate": 8.405857740585775e-05,
1380
+ "loss": 0.2765,
1381
+ "step": 1960
1382
+ },
1383
+ {
1384
+ "epoch": 151.56,
1385
+ "grad_norm": 0.8076108694076538,
1386
+ "learning_rate": 8.397489539748954e-05,
1387
+ "loss": 0.2718,
1388
+ "step": 1970
1389
+ },
1390
+ {
1391
+ "epoch": 152.32,
1392
+ "grad_norm": 0.6895241141319275,
1393
+ "learning_rate": 8.389121338912134e-05,
1394
+ "loss": 0.2662,
1395
+ "step": 1980
1396
+ },
1397
+ {
1398
+ "epoch": 153.08,
1399
+ "grad_norm": 0.42998993396759033,
1400
+ "learning_rate": 8.380753138075314e-05,
1401
+ "loss": 0.2715,
1402
+ "step": 1990
1403
+ },
1404
+ {
1405
+ "epoch": 153.88,
1406
+ "grad_norm": 0.5157560706138611,
1407
+ "learning_rate": 8.372384937238494e-05,
1408
+ "loss": 0.2677,
1409
+ "step": 2000
1410
+ },
1411
+ {
1412
+ "epoch": 154.64,
1413
+ "grad_norm": 0.5690245628356934,
1414
+ "learning_rate": 8.364016736401675e-05,
1415
+ "loss": 0.2882,
1416
+ "step": 2010
1417
+ },
1418
+ {
1419
+ "epoch": 155.4,
1420
+ "grad_norm": 0.5015438199043274,
1421
+ "learning_rate": 8.355648535564855e-05,
1422
+ "loss": 0.2652,
1423
+ "step": 2020
1424
+ },
1425
+ {
1426
+ "epoch": 156.16,
1427
+ "grad_norm": 0.3558717370033264,
1428
+ "learning_rate": 8.347280334728034e-05,
1429
+ "loss": 0.2938,
1430
+ "step": 2030
1431
+ },
1432
+ {
1433
+ "epoch": 156.96,
1434
+ "grad_norm": 0.8188201189041138,
1435
+ "learning_rate": 8.338912133891214e-05,
1436
+ "loss": 0.2977,
1437
+ "step": 2040
1438
+ },
1439
+ {
1440
+ "epoch": 157.72,
1441
+ "grad_norm": 0.8057423830032349,
1442
+ "learning_rate": 8.330543933054394e-05,
1443
+ "loss": 0.2342,
1444
+ "step": 2050
1445
+ },
1446
+ {
1447
+ "epoch": 158.48,
1448
+ "grad_norm": 0.4321524500846863,
1449
+ "learning_rate": 8.322175732217574e-05,
1450
+ "loss": 0.2955,
1451
+ "step": 2060
1452
+ },
1453
+ {
1454
+ "epoch": 159.24,
1455
+ "grad_norm": 0.5147210955619812,
1456
+ "learning_rate": 8.313807531380753e-05,
1457
+ "loss": 0.2526,
1458
+ "step": 2070
1459
+ },
1460
+ {
1461
+ "epoch": 160.0,
1462
+ "grad_norm": 0.771668016910553,
1463
+ "learning_rate": 8.305439330543933e-05,
1464
+ "loss": 0.3067,
1465
+ "step": 2080
1466
+ },
1467
+ {
1468
+ "epoch": 160.8,
1469
+ "grad_norm": 0.6141147017478943,
1470
+ "learning_rate": 8.297071129707113e-05,
1471
+ "loss": 0.2733,
1472
+ "step": 2090
1473
+ },
1474
+ {
1475
+ "epoch": 161.56,
1476
+ "grad_norm": 1.0377253293991089,
1477
+ "learning_rate": 8.288702928870293e-05,
1478
+ "loss": 0.2909,
1479
+ "step": 2100
1480
+ },
1481
+ {
1482
+ "epoch": 162.32,
1483
+ "grad_norm": 0.3286207318305969,
1484
+ "learning_rate": 8.280334728033472e-05,
1485
+ "loss": 0.2491,
1486
+ "step": 2110
1487
+ },
1488
+ {
1489
+ "epoch": 163.08,
1490
+ "grad_norm": 0.5966292023658752,
1491
+ "learning_rate": 8.271966527196652e-05,
1492
+ "loss": 0.2761,
1493
+ "step": 2120
1494
+ },
1495
+ {
1496
+ "epoch": 163.88,
1497
+ "grad_norm": 0.7708263993263245,
1498
+ "learning_rate": 8.263598326359832e-05,
1499
+ "loss": 0.2874,
1500
+ "step": 2130
1501
+ },
1502
+ {
1503
+ "epoch": 164.64,
1504
+ "grad_norm": 0.6249658465385437,
1505
+ "learning_rate": 8.255230125523013e-05,
1506
+ "loss": 0.2681,
1507
+ "step": 2140
1508
+ },
1509
+ {
1510
+ "epoch": 165.4,
1511
+ "grad_norm": 0.4472973048686981,
1512
+ "learning_rate": 8.246861924686193e-05,
1513
+ "loss": 0.2753,
1514
+ "step": 2150
1515
+ },
1516
+ {
1517
+ "epoch": 166.16,
1518
+ "grad_norm": 0.7961902022361755,
1519
+ "learning_rate": 8.238493723849373e-05,
1520
+ "loss": 0.2798,
1521
+ "step": 2160
1522
+ },
1523
+ {
1524
+ "epoch": 166.96,
1525
+ "grad_norm": 0.48364782333374023,
1526
+ "learning_rate": 8.230125523012552e-05,
1527
+ "loss": 0.2924,
1528
+ "step": 2170
1529
+ },
1530
+ {
1531
+ "epoch": 167.72,
1532
+ "grad_norm": 0.5379577279090881,
1533
+ "learning_rate": 8.221757322175732e-05,
1534
+ "loss": 0.2509,
1535
+ "step": 2180
1536
+ },
1537
+ {
1538
+ "epoch": 168.48,
1539
+ "grad_norm": 0.7065776586532593,
1540
+ "learning_rate": 8.213389121338912e-05,
1541
+ "loss": 0.2763,
1542
+ "step": 2190
1543
+ },
1544
+ {
1545
+ "epoch": 169.24,
1546
+ "grad_norm": 0.664776086807251,
1547
+ "learning_rate": 8.205020920502092e-05,
1548
+ "loss": 0.2726,
1549
+ "step": 2200
1550
+ },
1551
+ {
1552
+ "epoch": 170.0,
1553
+ "grad_norm": 0.47445446252822876,
1554
+ "learning_rate": 8.196652719665271e-05,
1555
+ "loss": 0.2696,
1556
+ "step": 2210
1557
+ },
1558
+ {
1559
+ "epoch": 170.8,
1560
+ "grad_norm": 0.5877684354782104,
1561
+ "learning_rate": 8.188284518828451e-05,
1562
+ "loss": 0.3092,
1563
+ "step": 2220
1564
+ },
1565
+ {
1566
+ "epoch": 171.56,
1567
+ "grad_norm": 0.8296095728874207,
1568
+ "learning_rate": 8.179916317991632e-05,
1569
+ "loss": 0.2458,
1570
+ "step": 2230
1571
+ },
1572
+ {
1573
+ "epoch": 172.32,
1574
+ "grad_norm": 0.2947175204753876,
1575
+ "learning_rate": 8.171548117154812e-05,
1576
+ "loss": 0.245,
1577
+ "step": 2240
1578
+ },
1579
+ {
1580
+ "epoch": 173.08,
1581
+ "grad_norm": 0.6972278952598572,
1582
+ "learning_rate": 8.163179916317992e-05,
1583
+ "loss": 0.3018,
1584
+ "step": 2250
1585
+ },
1586
+ {
1587
+ "epoch": 173.88,
1588
+ "grad_norm": 0.36442530155181885,
1589
+ "learning_rate": 8.154811715481172e-05,
1590
+ "loss": 0.2875,
1591
+ "step": 2260
1592
+ },
1593
+ {
1594
+ "epoch": 174.64,
1595
+ "grad_norm": 0.38456207513809204,
1596
+ "learning_rate": 8.146443514644351e-05,
1597
+ "loss": 0.2328,
1598
+ "step": 2270
1599
+ },
1600
+ {
1601
+ "epoch": 175.4,
1602
+ "grad_norm": 0.573784589767456,
1603
+ "learning_rate": 8.138075313807531e-05,
1604
+ "loss": 0.2851,
1605
+ "step": 2280
1606
+ },
1607
+ {
1608
+ "epoch": 176.16,
1609
+ "grad_norm": 0.5698201656341553,
1610
+ "learning_rate": 8.129707112970711e-05,
1611
+ "loss": 0.2807,
1612
+ "step": 2290
1613
+ },
1614
+ {
1615
+ "epoch": 176.96,
1616
+ "grad_norm": 0.7705904841423035,
1617
+ "learning_rate": 8.121338912133891e-05,
1618
+ "loss": 0.2982,
1619
+ "step": 2300
1620
+ },
1621
+ {
1622
+ "epoch": 177.72,
1623
+ "grad_norm": 0.5071648359298706,
1624
+ "learning_rate": 8.11297071129707e-05,
1625
+ "loss": 0.2127,
1626
+ "step": 2310
1627
+ },
1628
+ {
1629
+ "epoch": 178.48,
1630
+ "grad_norm": 1.354844331741333,
1631
+ "learning_rate": 8.104602510460252e-05,
1632
+ "loss": 0.2938,
1633
+ "step": 2320
1634
+ },
1635
+ {
1636
+ "epoch": 179.24,
1637
+ "grad_norm": 0.4220433533191681,
1638
+ "learning_rate": 8.096234309623431e-05,
1639
+ "loss": 0.241,
1640
+ "step": 2330
1641
+ },
1642
+ {
1643
+ "epoch": 180.0,
1644
+ "grad_norm": 0.4945012331008911,
1645
+ "learning_rate": 8.087866108786611e-05,
1646
+ "loss": 0.2482,
1647
+ "step": 2340
1648
+ },
1649
+ {
1650
+ "epoch": 180.8,
1651
+ "grad_norm": 0.6846901774406433,
1652
+ "learning_rate": 8.079497907949791e-05,
1653
+ "loss": 0.2146,
1654
+ "step": 2350
1655
+ },
1656
+ {
1657
+ "epoch": 181.56,
1658
+ "grad_norm": 0.6438813209533691,
1659
+ "learning_rate": 8.07112970711297e-05,
1660
+ "loss": 0.252,
1661
+ "step": 2360
1662
+ },
1663
+ {
1664
+ "epoch": 182.32,
1665
+ "grad_norm": 0.486453652381897,
1666
+ "learning_rate": 8.06276150627615e-05,
1667
+ "loss": 0.2417,
1668
+ "step": 2370
1669
+ },
1670
+ {
1671
+ "epoch": 183.08,
1672
+ "grad_norm": 0.5143964290618896,
1673
+ "learning_rate": 8.05439330543933e-05,
1674
+ "loss": 0.2882,
1675
+ "step": 2380
1676
+ },
1677
+ {
1678
+ "epoch": 183.88,
1679
+ "grad_norm": 1.5306233167648315,
1680
+ "learning_rate": 8.04602510460251e-05,
1681
+ "loss": 0.2624,
1682
+ "step": 2390
1683
+ },
1684
+ {
1685
+ "epoch": 184.64,
1686
+ "grad_norm": 0.47094181180000305,
1687
+ "learning_rate": 8.037656903765691e-05,
1688
+ "loss": 0.2009,
1689
+ "step": 2400
1690
+ },
1691
+ {
1692
+ "epoch": 185.4,
1693
+ "grad_norm": 0.510780930519104,
1694
+ "learning_rate": 8.029288702928871e-05,
1695
+ "loss": 0.3383,
1696
+ "step": 2410
1697
+ },
1698
+ {
1699
+ "epoch": 186.16,
1700
+ "grad_norm": 1.601507306098938,
1701
+ "learning_rate": 8.02092050209205e-05,
1702
+ "loss": 0.2143,
1703
+ "step": 2420
1704
+ },
1705
+ {
1706
+ "epoch": 186.96,
1707
+ "grad_norm": 0.5631161332130432,
1708
+ "learning_rate": 8.01255230125523e-05,
1709
+ "loss": 0.2093,
1710
+ "step": 2430
1711
+ },
1712
+ {
1713
+ "epoch": 187.72,
1714
+ "grad_norm": 0.2900833189487457,
1715
+ "learning_rate": 8.00418410041841e-05,
1716
+ "loss": 0.1871,
1717
+ "step": 2440
1718
+ },
1719
+ {
1720
+ "epoch": 188.48,
1721
+ "grad_norm": 0.6894598007202148,
1722
+ "learning_rate": 7.99581589958159e-05,
1723
+ "loss": 0.2654,
1724
+ "step": 2450
1725
+ },
1726
+ {
1727
+ "epoch": 189.24,
1728
+ "grad_norm": 8.970075607299805,
1729
+ "learning_rate": 7.98744769874477e-05,
1730
+ "loss": 0.178,
1731
+ "step": 2460
1732
+ },
1733
+ {
1734
+ "epoch": 190.0,
1735
+ "grad_norm": 0.2863423228263855,
1736
+ "learning_rate": 7.97907949790795e-05,
1737
+ "loss": 0.1958,
1738
+ "step": 2470
1739
+ },
1740
+ {
1741
+ "epoch": 190.8,
1742
+ "grad_norm": 0.3206811547279358,
1743
+ "learning_rate": 7.97071129707113e-05,
1744
+ "loss": 0.2548,
1745
+ "step": 2480
1746
+ },
1747
+ {
1748
+ "epoch": 191.56,
1749
+ "grad_norm": 0.6876190900802612,
1750
+ "learning_rate": 7.96234309623431e-05,
1751
+ "loss": 0.1843,
1752
+ "step": 2490
1753
+ },
1754
+ {
1755
+ "epoch": 192.32,
1756
+ "grad_norm": 1.09211266040802,
1757
+ "learning_rate": 7.95397489539749e-05,
1758
+ "loss": 0.1721,
1759
+ "step": 2500
1760
+ },
1761
+ {
1762
+ "epoch": 193.08,
1763
+ "grad_norm": 14.786331176757812,
1764
+ "learning_rate": 7.94560669456067e-05,
1765
+ "loss": 0.327,
1766
+ "step": 2510
1767
+ },
1768
+ {
1769
+ "epoch": 193.88,
1770
+ "grad_norm": 43.85841369628906,
1771
+ "learning_rate": 7.93723849372385e-05,
1772
+ "loss": 0.3656,
1773
+ "step": 2520
1774
+ },
1775
+ {
1776
+ "epoch": 194.64,
1777
+ "grad_norm": 0.32373273372650146,
1778
+ "learning_rate": 7.92887029288703e-05,
1779
+ "loss": 0.2476,
1780
+ "step": 2530
1781
+ },
1782
+ {
1783
+ "epoch": 195.4,
1784
+ "grad_norm": 0.00038791695260442793,
1785
+ "learning_rate": 7.920502092050209e-05,
1786
+ "loss": 0.1761,
1787
+ "step": 2540
1788
+ },
1789
+ {
1790
+ "epoch": 196.16,
1791
+ "grad_norm": 1.776816487312317,
1792
+ "learning_rate": 7.912133891213389e-05,
1793
+ "loss": 0.1828,
1794
+ "step": 2550
1795
+ },
1796
+ {
1797
+ "epoch": 196.96,
1798
+ "grad_norm": 0.5436874628067017,
1799
+ "learning_rate": 7.903765690376569e-05,
1800
+ "loss": 0.2261,
1801
+ "step": 2560
1802
+ },
1803
+ {
1804
+ "epoch": 197.72,
1805
+ "grad_norm": 0.45118626952171326,
1806
+ "learning_rate": 7.89539748953975e-05,
1807
+ "loss": 0.1712,
1808
+ "step": 2570
1809
+ },
1810
+ {
1811
+ "epoch": 198.48,
1812
+ "grad_norm": 3.249994993209839,
1813
+ "learning_rate": 7.88702928870293e-05,
1814
+ "loss": 0.247,
1815
+ "step": 2580
1816
+ },
1817
+ {
1818
+ "epoch": 199.24,
1819
+ "grad_norm": 1.1368451118469238,
1820
+ "learning_rate": 7.878661087866109e-05,
1821
+ "loss": 0.2208,
1822
+ "step": 2590
1823
+ },
1824
+ {
1825
+ "epoch": 200.0,
1826
+ "grad_norm": 0.6388035416603088,
1827
+ "learning_rate": 7.870292887029289e-05,
1828
+ "loss": 0.1518,
1829
+ "step": 2600
1830
+ },
1831
+ {
1832
+ "epoch": 200.8,
1833
+ "grad_norm": 3.905496597290039,
1834
+ "learning_rate": 7.861924686192469e-05,
1835
+ "loss": 0.1279,
1836
+ "step": 2610
1837
+ },
1838
+ {
1839
+ "epoch": 201.56,
1840
+ "grad_norm": 27.926372528076172,
1841
+ "learning_rate": 7.853556485355649e-05,
1842
+ "loss": 0.4156,
1843
+ "step": 2620
1844
+ },
1845
+ {
1846
+ "epoch": 202.32,
1847
+ "grad_norm": 0.19806312024593353,
1848
+ "learning_rate": 7.845188284518828e-05,
1849
+ "loss": 0.0903,
1850
+ "step": 2630
1851
+ },
1852
+ {
1853
+ "epoch": 203.08,
1854
+ "grad_norm": 0.0014255548594519496,
1855
+ "learning_rate": 7.836820083682008e-05,
1856
+ "loss": 0.1962,
1857
+ "step": 2640
1858
+ },
1859
+ {
1860
+ "epoch": 203.88,
1861
+ "grad_norm": 2.5122647285461426,
1862
+ "learning_rate": 7.828451882845189e-05,
1863
+ "loss": 0.1382,
1864
+ "step": 2650
1865
+ },
1866
+ {
1867
+ "epoch": 204.64,
1868
+ "grad_norm": 5.140716075897217,
1869
+ "learning_rate": 7.820083682008369e-05,
1870
+ "loss": 0.2123,
1871
+ "step": 2660
1872
+ },
1873
+ {
1874
+ "epoch": 205.4,
1875
+ "grad_norm": 1.3244779109954834,
1876
+ "learning_rate": 7.811715481171549e-05,
1877
+ "loss": 0.1756,
1878
+ "step": 2670
1879
+ },
1880
+ {
1881
+ "epoch": 206.16,
1882
+ "grad_norm": 26.31413459777832,
1883
+ "learning_rate": 7.803347280334728e-05,
1884
+ "loss": 0.1425,
1885
+ "step": 2680
1886
+ },
1887
+ {
1888
+ "epoch": 206.96,
1889
+ "grad_norm": 1.5801438093185425,
1890
+ "learning_rate": 7.794979079497908e-05,
1891
+ "loss": 0.1822,
1892
+ "step": 2690
1893
+ },
1894
+ {
1895
+ "epoch": 207.72,
1896
+ "grad_norm": 1.6209617853164673,
1897
+ "learning_rate": 7.786610878661088e-05,
1898
+ "loss": 0.1909,
1899
+ "step": 2700
1900
+ },
1901
+ {
1902
+ "epoch": 208.48,
1903
+ "grad_norm": 0.6468285918235779,
1904
+ "learning_rate": 7.778242677824268e-05,
1905
+ "loss": 0.1741,
1906
+ "step": 2710
1907
+ },
1908
+ {
1909
+ "epoch": 209.24,
1910
+ "grad_norm": 0.04982404410839081,
1911
+ "learning_rate": 7.769874476987448e-05,
1912
+ "loss": 0.1027,
1913
+ "step": 2720
1914
+ },
1915
+ {
1916
+ "epoch": 210.0,
1917
+ "grad_norm": 0.0010119529906660318,
1918
+ "learning_rate": 7.761506276150629e-05,
1919
+ "loss": 0.1457,
1920
+ "step": 2730
1921
+ },
1922
+ {
1923
+ "epoch": 210.8,
1924
+ "grad_norm": 38.17147445678711,
1925
+ "learning_rate": 7.753138075313808e-05,
1926
+ "loss": 0.164,
1927
+ "step": 2740
1928
+ },
1929
+ {
1930
+ "epoch": 211.56,
1931
+ "grad_norm": 1.5547058582305908,
1932
+ "learning_rate": 7.744769874476988e-05,
1933
+ "loss": 0.1102,
1934
+ "step": 2750
1935
+ },
1936
+ {
1937
+ "epoch": 212.32,
1938
+ "grad_norm": 0.011813919991254807,
1939
+ "learning_rate": 7.736401673640168e-05,
1940
+ "loss": 0.1455,
1941
+ "step": 2760
1942
+ },
1943
+ {
1944
+ "epoch": 213.08,
1945
+ "grad_norm": 0.007001452147960663,
1946
+ "learning_rate": 7.728033472803348e-05,
1947
+ "loss": 0.1694,
1948
+ "step": 2770
1949
+ },
1950
+ {
1951
+ "epoch": 213.88,
1952
+ "grad_norm": 0.7564431428909302,
1953
+ "learning_rate": 7.719665271966527e-05,
1954
+ "loss": 0.2061,
1955
+ "step": 2780
1956
+ },
1957
+ {
1958
+ "epoch": 214.64,
1959
+ "grad_norm": 1.1208900213241577,
1960
+ "learning_rate": 7.711297071129707e-05,
1961
+ "loss": 0.2951,
1962
+ "step": 2790
1963
+ },
1964
+ {
1965
+ "epoch": 215.4,
1966
+ "grad_norm": 0.4398239552974701,
1967
+ "learning_rate": 7.702928870292887e-05,
1968
+ "loss": 0.1161,
1969
+ "step": 2800
1970
+ },
1971
+ {
1972
+ "epoch": 216.16,
1973
+ "grad_norm": 0.34080252051353455,
1974
+ "learning_rate": 7.694560669456067e-05,
1975
+ "loss": 0.1177,
1976
+ "step": 2810
1977
+ },
1978
+ {
1979
+ "epoch": 216.96,
1980
+ "grad_norm": 0.27304738759994507,
1981
+ "learning_rate": 7.686192468619248e-05,
1982
+ "loss": 0.117,
1983
+ "step": 2820
1984
+ },
1985
+ {
1986
+ "epoch": 217.72,
1987
+ "grad_norm": 0.00020999423577450216,
1988
+ "learning_rate": 7.677824267782428e-05,
1989
+ "loss": 0.1518,
1990
+ "step": 2830
1991
+ },
1992
+ {
1993
+ "epoch": 218.48,
1994
+ "grad_norm": 0.3551616072654724,
1995
+ "learning_rate": 7.669456066945607e-05,
1996
+ "loss": 0.0911,
1997
+ "step": 2840
1998
+ },
1999
+ {
2000
+ "epoch": 219.24,
2001
+ "grad_norm": 18.045923233032227,
2002
+ "learning_rate": 7.661087866108787e-05,
2003
+ "loss": 0.1732,
2004
+ "step": 2850
2005
+ },
2006
+ {
2007
+ "epoch": 220.0,
2008
+ "grad_norm": 0.0011837411439046264,
2009
+ "learning_rate": 7.652719665271967e-05,
2010
+ "loss": 0.3062,
2011
+ "step": 2860
2012
+ },
2013
+ {
2014
+ "epoch": 220.8,
2015
+ "grad_norm": 7.931725025177002,
2016
+ "learning_rate": 7.644351464435147e-05,
2017
+ "loss": 0.1608,
2018
+ "step": 2870
2019
+ },
2020
+ {
2021
+ "epoch": 221.56,
2022
+ "grad_norm": 0.7454537749290466,
2023
+ "learning_rate": 7.635983263598326e-05,
2024
+ "loss": 0.163,
2025
+ "step": 2880
2026
+ },
2027
+ {
2028
+ "epoch": 222.32,
2029
+ "grad_norm": 6.1156840324401855,
2030
+ "learning_rate": 7.627615062761506e-05,
2031
+ "loss": 0.0693,
2032
+ "step": 2890
2033
+ },
2034
+ {
2035
+ "epoch": 223.08,
2036
+ "grad_norm": 1.552299976348877,
2037
+ "learning_rate": 7.619246861924687e-05,
2038
+ "loss": 0.1688,
2039
+ "step": 2900
2040
+ },
2041
+ {
2042
+ "epoch": 223.88,
2043
+ "grad_norm": 11.509157180786133,
2044
+ "learning_rate": 7.610878661087867e-05,
2045
+ "loss": 0.1652,
2046
+ "step": 2910
2047
+ },
2048
+ {
2049
+ "epoch": 224.64,
2050
+ "grad_norm": 0.007240507751703262,
2051
+ "learning_rate": 7.602510460251047e-05,
2052
+ "loss": 0.1246,
2053
+ "step": 2920
2054
+ },
2055
+ {
2056
+ "epoch": 225.4,
2057
+ "grad_norm": 0.2530362010002136,
2058
+ "learning_rate": 7.594142259414227e-05,
2059
+ "loss": 0.0838,
2060
+ "step": 2930
2061
+ },
2062
+ {
2063
+ "epoch": 226.16,
2064
+ "grad_norm": 0.5732694864273071,
2065
+ "learning_rate": 7.585774058577406e-05,
2066
+ "loss": 0.1891,
2067
+ "step": 2940
2068
+ },
2069
+ {
2070
+ "epoch": 226.96,
2071
+ "grad_norm": 3.3265058994293213,
2072
+ "learning_rate": 7.577405857740586e-05,
2073
+ "loss": 0.219,
2074
+ "step": 2950
2075
+ },
2076
+ {
2077
+ "epoch": 227.72,
2078
+ "grad_norm": 19.179094314575195,
2079
+ "learning_rate": 7.569037656903766e-05,
2080
+ "loss": 0.1078,
2081
+ "step": 2960
2082
+ },
2083
+ {
2084
+ "epoch": 228.48,
2085
+ "grad_norm": 0.5755670666694641,
2086
+ "learning_rate": 7.560669456066946e-05,
2087
+ "loss": 0.1628,
2088
+ "step": 2970
2089
+ },
2090
+ {
2091
+ "epoch": 229.24,
2092
+ "grad_norm": 13.489233016967773,
2093
+ "learning_rate": 7.552301255230127e-05,
2094
+ "loss": 0.1031,
2095
+ "step": 2980
2096
+ },
2097
+ {
2098
+ "epoch": 230.0,
2099
+ "grad_norm": 0.42450758814811707,
2100
+ "learning_rate": 7.543933054393307e-05,
2101
+ "loss": 0.1062,
2102
+ "step": 2990
2103
+ },
2104
+ {
2105
+ "epoch": 230.8,
2106
+ "grad_norm": 74.2090835571289,
2107
+ "learning_rate": 7.535564853556486e-05,
2108
+ "loss": 0.2763,
2109
+ "step": 3000
2110
+ },
2111
+ {
2112
+ "epoch": 231.56,
2113
+ "grad_norm": 0.0013025101507082582,
2114
+ "learning_rate": 7.527196652719666e-05,
2115
+ "loss": 0.0475,
2116
+ "step": 3010
2117
+ },
2118
+ {
2119
+ "epoch": 232.32,
2120
+ "grad_norm": 0.36847984790802,
2121
+ "learning_rate": 7.518828451882846e-05,
2122
+ "loss": 0.1262,
2123
+ "step": 3020
2124
+ },
2125
+ {
2126
+ "epoch": 233.08,
2127
+ "grad_norm": 2.1009180545806885,
2128
+ "learning_rate": 7.510460251046026e-05,
2129
+ "loss": 0.2701,
2130
+ "step": 3030
2131
+ },
2132
+ {
2133
+ "epoch": 233.88,
2134
+ "grad_norm": 3.9676990509033203,
2135
+ "learning_rate": 7.502092050209205e-05,
2136
+ "loss": 0.0567,
2137
+ "step": 3040
2138
+ },
2139
+ {
2140
+ "epoch": 234.64,
2141
+ "grad_norm": 2.426058769226074,
2142
+ "learning_rate": 7.493723849372385e-05,
2143
+ "loss": 0.1485,
2144
+ "step": 3050
2145
+ },
2146
+ {
2147
+ "epoch": 235.4,
2148
+ "grad_norm": 0.45143410563468933,
2149
+ "learning_rate": 7.485355648535565e-05,
2150
+ "loss": 0.0808,
2151
+ "step": 3060
2152
+ },
2153
+ {
2154
+ "epoch": 236.16,
2155
+ "grad_norm": 0.001772599876858294,
2156
+ "learning_rate": 7.476987447698746e-05,
2157
+ "loss": 0.159,
2158
+ "step": 3070
2159
+ },
2160
+ {
2161
+ "epoch": 236.96,
2162
+ "grad_norm": 0.6228379011154175,
2163
+ "learning_rate": 7.468619246861926e-05,
2164
+ "loss": 0.3161,
2165
+ "step": 3080
2166
+ },
2167
+ {
2168
+ "epoch": 237.72,
2169
+ "grad_norm": 1.1277211904525757,
2170
+ "learning_rate": 7.460251046025106e-05,
2171
+ "loss": 0.0895,
2172
+ "step": 3090
2173
+ },
2174
+ {
2175
+ "epoch": 238.48,
2176
+ "grad_norm": 0.5020686388015747,
2177
+ "learning_rate": 7.451882845188285e-05,
2178
+ "loss": 0.066,
2179
+ "step": 3100
2180
+ },
2181
+ {
2182
+ "epoch": 239.24,
2183
+ "grad_norm": 0.4364977777004242,
2184
+ "learning_rate": 7.443514644351465e-05,
2185
+ "loss": 0.0973,
2186
+ "step": 3110
2187
+ },
2188
+ {
2189
+ "epoch": 240.0,
2190
+ "grad_norm": 0.331858366727829,
2191
+ "learning_rate": 7.435146443514645e-05,
2192
+ "loss": 0.1385,
2193
+ "step": 3120
2194
+ },
2195
+ {
2196
+ "epoch": 240.8,
2197
+ "grad_norm": 3.562025308609009,
2198
+ "learning_rate": 7.426778242677825e-05,
2199
+ "loss": 0.1554,
2200
+ "step": 3130
2201
+ },
2202
+ {
2203
+ "epoch": 241.56,
2204
+ "grad_norm": 0.4786972403526306,
2205
+ "learning_rate": 7.418410041841004e-05,
2206
+ "loss": 0.1127,
2207
+ "step": 3140
2208
+ },
2209
+ {
2210
+ "epoch": 242.32,
2211
+ "grad_norm": 0.22106488049030304,
2212
+ "learning_rate": 7.410041841004186e-05,
2213
+ "loss": 0.1212,
2214
+ "step": 3150
2215
+ },
2216
+ {
2217
+ "epoch": 243.08,
2218
+ "grad_norm": 0.3226916193962097,
2219
+ "learning_rate": 7.401673640167365e-05,
2220
+ "loss": 0.0622,
2221
+ "step": 3160
2222
+ },
2223
+ {
2224
+ "epoch": 243.88,
2225
+ "grad_norm": 0.18059372901916504,
2226
+ "learning_rate": 7.393305439330545e-05,
2227
+ "loss": 0.0578,
2228
+ "step": 3170
2229
+ },
2230
+ {
2231
+ "epoch": 244.64,
2232
+ "grad_norm": 0.18617404997348785,
2233
+ "learning_rate": 7.384937238493725e-05,
2234
+ "loss": 0.1017,
2235
+ "step": 3180
2236
+ },
2237
+ {
2238
+ "epoch": 245.4,
2239
+ "grad_norm": 0.4568879008293152,
2240
+ "learning_rate": 7.376569037656905e-05,
2241
+ "loss": 0.0456,
2242
+ "step": 3190
2243
+ },
2244
+ {
2245
+ "epoch": 246.16,
2246
+ "grad_norm": 0.00673318887129426,
2247
+ "learning_rate": 7.368200836820084e-05,
2248
+ "loss": 0.0578,
2249
+ "step": 3200
2250
+ },
2251
+ {
2252
+ "epoch": 246.96,
2253
+ "grad_norm": 0.0004931857693009079,
2254
+ "learning_rate": 7.359832635983264e-05,
2255
+ "loss": 0.0363,
2256
+ "step": 3210
2257
+ },
2258
+ {
2259
+ "epoch": 247.72,
2260
+ "grad_norm": 0.18300887942314148,
2261
+ "learning_rate": 7.351464435146444e-05,
2262
+ "loss": 0.0854,
2263
+ "step": 3220
2264
+ },
2265
+ {
2266
+ "epoch": 248.48,
2267
+ "grad_norm": 0.003890759777277708,
2268
+ "learning_rate": 7.343096234309624e-05,
2269
+ "loss": 0.1226,
2270
+ "step": 3230
2271
+ },
2272
+ {
2273
+ "epoch": 249.24,
2274
+ "grad_norm": 1.134564757347107,
2275
+ "learning_rate": 7.334728033472805e-05,
2276
+ "loss": 0.0956,
2277
+ "step": 3240
2278
+ },
2279
+ {
2280
+ "epoch": 250.0,
2281
+ "grad_norm": 0.00022713415091857314,
2282
+ "learning_rate": 7.326359832635985e-05,
2283
+ "loss": 0.1,
2284
+ "step": 3250
2285
+ },
2286
+ {
2287
+ "epoch": 250.8,
2288
+ "grad_norm": 7.879546165466309,
2289
+ "learning_rate": 7.317991631799164e-05,
2290
+ "loss": 0.0877,
2291
+ "step": 3260
2292
+ },
2293
+ {
2294
+ "epoch": 251.56,
2295
+ "grad_norm": 2.57737135887146,
2296
+ "learning_rate": 7.309623430962344e-05,
2297
+ "loss": 0.09,
2298
+ "step": 3270
2299
+ },
2300
+ {
2301
+ "epoch": 252.32,
2302
+ "grad_norm": 0.4551700949668884,
2303
+ "learning_rate": 7.301255230125524e-05,
2304
+ "loss": 0.1288,
2305
+ "step": 3280
2306
+ }
2307
+ ],
2308
+ "logging_steps": 10,
2309
+ "max_steps": 12000,
2310
+ "num_input_tokens_seen": 0,
2311
+ "num_train_epochs": 1000,
2312
+ "save_steps": 10,
2313
+ "stateful_callbacks": {
2314
+ "TrainerControl": {
2315
+ "args": {
2316
+ "should_epoch_stop": false,
2317
+ "should_evaluate": false,
2318
+ "should_log": false,
2319
+ "should_save": true,
2320
+ "should_training_stop": false
2321
+ },
2322
+ "attributes": {}
2323
+ }
2324
+ },
2325
+ "total_flos": 1.8125347399735104e+17,
2326
+ "train_batch_size": 1,
2327
+ "trial_name": null,
2328
+ "trial_params": null
2329
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8cef66ef127754c375a8e9b549e2244a7cc084c0396c35d5511163281be21d7
3
+ size 5507
vision_siglip_navit.py ADDED
@@ -0,0 +1,1717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Siglip model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
28
+ }
29
+
30
+
31
+ class SiglipTextConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
34
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
35
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
36
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 32000):
41
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
42
+ the `inputs_ids` passed when calling [`SiglipModel`].
43
+ hidden_size (`int`, *optional*, defaults to 768):
44
+ Dimensionality of the encoder layers and the pooler layer.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ num_hidden_layers (`int`, *optional*, defaults to 12):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 12):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ max_position_embeddings (`int`, *optional*, defaults to 64):
52
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
53
+ just in case (e.g., 512 or 1024 or 2048).
54
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
55
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
56
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
57
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
58
+ The epsilon used by the layer normalization layers.
59
+ attention_dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout ratio for the attention probabilities.
61
+ pad_token_id (`int`, *optional*, defaults to 1):
62
+ The id of the padding token in the vocabulary.
63
+ bos_token_id (`int`, *optional*, defaults to 49406):
64
+ The id of the beginning-of-sequence token in the vocabulary.
65
+ eos_token_id (`int`, *optional*, defaults to 49407):
66
+ The id of the end-of-sequence token in the vocabulary.
67
+ Example:
68
+ ```python
69
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
70
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
71
+ >>> configuration = SiglipTextConfig()
72
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
73
+ >>> model = SiglipTextModel(configuration)
74
+ >>> # Accessing the model configuration
75
+ >>> configuration = model.config
76
+ ```"""
77
+
78
+ model_type = "siglip_text_model"
79
+
80
+ def __init__(
81
+ self,
82
+ vocab_size=32000,
83
+ hidden_size=768,
84
+ intermediate_size=3072,
85
+ num_hidden_layers=12,
86
+ num_attention_heads=12,
87
+ max_position_embeddings=64,
88
+ hidden_act="gelu_pytorch_tanh",
89
+ layer_norm_eps=1e-6,
90
+ attention_dropout=0.0,
91
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
92
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
93
+ pad_token_id=1,
94
+ bos_token_id=49406,
95
+ eos_token_id=49407,
96
+ _flash_attn_2_enabled=True,
97
+ **kwargs,
98
+ ):
99
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
100
+
101
+ self.vocab_size = vocab_size
102
+ self.hidden_size = hidden_size
103
+ self.intermediate_size = intermediate_size
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.num_attention_heads = num_attention_heads
106
+ self.max_position_embeddings = max_position_embeddings
107
+ self.layer_norm_eps = layer_norm_eps
108
+ self.hidden_act = hidden_act
109
+ self.attention_dropout = attention_dropout
110
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
111
+
112
+ @classmethod
113
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
114
+ cls._set_token_in_kwargs(kwargs)
115
+
116
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
117
+
118
+ # get the text config dict if we are loading from SiglipConfig
119
+ if config_dict.get("model_type") == "siglip":
120
+ config_dict = config_dict["text_config"]
121
+
122
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
123
+ logger.warning(
124
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
125
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
126
+ )
127
+
128
+ return cls.from_dict(config_dict, **kwargs)
129
+
130
+
131
+ class SiglipVisionConfig(PretrainedConfig):
132
+ r"""
133
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
134
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
135
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
136
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
137
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
138
+ documentation from [`PretrainedConfig`] for more information.
139
+ Args:
140
+ hidden_size (`int`, *optional*, defaults to 768):
141
+ Dimensionality of the encoder layers and the pooler layer.
142
+ intermediate_size (`int`, *optional*, defaults to 3072):
143
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
144
+ num_hidden_layers (`int`, *optional*, defaults to 12):
145
+ Number of hidden layers in the Transformer encoder.
146
+ num_attention_heads (`int`, *optional*, defaults to 12):
147
+ Number of attention heads for each attention layer in the Transformer encoder.
148
+ num_channels (`int`, *optional*, defaults to 3):
149
+ Number of channels in the input images.
150
+ image_size (`int`, *optional*, defaults to 224):
151
+ The size (resolution) of each image.
152
+ patch_size (`int`, *optional*, defaults to 16):
153
+ The size (resolution) of each patch.
154
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
155
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
156
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
157
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
158
+ The epsilon used by the layer normalization layers.
159
+ attention_dropout (`float`, *optional*, defaults to 0.0):
160
+ The dropout ratio for the attention probabilities.
161
+ Example:
162
+ ```python
163
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
164
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
165
+ >>> configuration = SiglipVisionConfig()
166
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
167
+ >>> model = SiglipVisionModel(configuration)
168
+ >>> # Accessing the model configuration
169
+ >>> configuration = model.config
170
+ ```"""
171
+
172
+ model_type = "siglip_vision_model"
173
+
174
+ def __init__(
175
+ self,
176
+ hidden_size=768,
177
+ intermediate_size=3072,
178
+ num_hidden_layers=12,
179
+ num_attention_heads=12,
180
+ num_channels=3,
181
+ image_size=224,
182
+ patch_size=16,
183
+ hidden_act="gelu_pytorch_tanh",
184
+ layer_norm_eps=1e-6,
185
+ attention_dropout=0.0,
186
+ _flash_attn_2_enabled=True,
187
+ **kwargs,
188
+ ):
189
+ super().__init__(**kwargs)
190
+
191
+ self.hidden_size = hidden_size
192
+ self.intermediate_size = intermediate_size
193
+ self.num_hidden_layers = num_hidden_layers
194
+ self.num_attention_heads = num_attention_heads
195
+ self.num_channels = num_channels
196
+ self.patch_size = patch_size
197
+ self.image_size = image_size
198
+ self.attention_dropout = attention_dropout
199
+ self.layer_norm_eps = layer_norm_eps
200
+ self.hidden_act = hidden_act
201
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
202
+
203
+ @classmethod
204
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
205
+ cls._set_token_in_kwargs(kwargs)
206
+
207
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
208
+
209
+ # get the vision config dict if we are loading from SiglipConfig
210
+ if config_dict.get("model_type") == "siglip":
211
+ config_dict = config_dict["vision_config"]
212
+
213
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
214
+ logger.warning(
215
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
216
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
217
+ )
218
+
219
+ return cls.from_dict(config_dict, **kwargs)
220
+
221
+
222
+ class SiglipConfig(PretrainedConfig):
223
+ r"""
224
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
225
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
226
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
227
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
228
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
229
+ documentation from [`PretrainedConfig`] for more information.
230
+ Args:
231
+ text_config (`dict`, *optional*):
232
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
233
+ vision_config (`dict`, *optional*):
234
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
235
+ kwargs (*optional*):
236
+ Dictionary of keyword arguments.
237
+ Example:
238
+ ```python
239
+ >>> from transformers import SiglipConfig, SiglipModel
240
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
241
+ >>> configuration = SiglipConfig()
242
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
243
+ >>> model = SiglipModel(configuration)
244
+ >>> # Accessing the model configuration
245
+ >>> configuration = model.config
246
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
247
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
248
+ >>> # Initializing a SiglipText and SiglipVision configuration
249
+ >>> config_text = SiglipTextConfig()
250
+ >>> config_vision = SiglipVisionConfig()
251
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
252
+ ```"""
253
+
254
+ model_type = "siglip"
255
+
256
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
257
+ super().__init__(**kwargs)
258
+
259
+ if text_config is None:
260
+ text_config = {}
261
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
262
+
263
+ if vision_config is None:
264
+ vision_config = {}
265
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
266
+
267
+ self.text_config = SiglipTextConfig(**text_config)
268
+ self.vision_config = SiglipVisionConfig(**vision_config)
269
+
270
+ self.initializer_factor = 1.0
271
+
272
+ @classmethod
273
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
274
+ r"""
275
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
276
+ model configuration.
277
+ Returns:
278
+ [`SiglipConfig`]: An instance of a configuration object
279
+ """
280
+
281
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
282
+
283
+ # coding=utf-8
284
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
285
+ #
286
+ # Licensed under the Apache License, Version 2.0 (the "License");
287
+ # you may not use this file except in compliance with the License.
288
+ # You may obtain a copy of the License at
289
+ #
290
+ # http://www.apache.org/licenses/LICENSE-2.0
291
+ #
292
+ # Unless required by applicable law or agreed to in writing, software
293
+ # distributed under the License is distributed on an "AS IS" BASIS,
294
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
295
+ # See the License for the specific language governing permissions and
296
+ # limitations under the License.
297
+ """ PyTorch Siglip model."""
298
+
299
+
300
+ import math
301
+ import warnings
302
+ from dataclasses import dataclass
303
+ from typing import Any, Optional, Tuple, Union
304
+
305
+ import numpy as np
306
+ import torch
307
+ import torch.nn.functional as F
308
+ import torch.utils.checkpoint
309
+ from torch import nn
310
+ from torch.nn.init import _calculate_fan_in_and_fan_out
311
+
312
+ from transformers.activations import ACT2FN
313
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
314
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
315
+ from transformers.modeling_utils import PreTrainedModel
316
+ from transformers.utils import (
317
+ ModelOutput,
318
+ add_start_docstrings,
319
+ add_start_docstrings_to_model_forward,
320
+ is_flash_attn_2_available,
321
+ logging,
322
+ replace_return_docstrings,
323
+ )
324
+
325
+ logger = logging.get_logger(__name__)
326
+
327
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
328
+
329
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
330
+ "google/siglip-base-patch16-224",
331
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
332
+ ]
333
+
334
+ if is_flash_attn_2_available():
335
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
336
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
337
+
338
+
339
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
340
+ def _get_unpad_data(attention_mask):
341
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
342
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
343
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
344
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
345
+ return (
346
+ indices,
347
+ cu_seqlens,
348
+ max_seqlen_in_batch,
349
+ )
350
+
351
+
352
+ def _trunc_normal_(tensor, mean, std, a, b):
353
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
354
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
355
+ def norm_cdf(x):
356
+ # Computes standard normal cumulative distribution function
357
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
358
+
359
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
360
+ warnings.warn(
361
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
362
+ "The distribution of values may be incorrect.",
363
+ stacklevel=2,
364
+ )
365
+
366
+ # Values are generated by using a truncated uniform distribution and
367
+ # then using the inverse CDF for the normal distribution.
368
+ # Get upper and lower cdf values
369
+ l = norm_cdf((a - mean) / std)
370
+ u = norm_cdf((b - mean) / std)
371
+
372
+ # Uniformly fill tensor with values from [l, u], then translate to
373
+ # [2l-1, 2u-1].
374
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
375
+
376
+ # Use inverse cdf transform for normal distribution to get truncated
377
+ # standard normal
378
+ if tensor.dtype in [torch.float16, torch.bfloat16]:
379
+ # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
380
+ og_dtype = tensor.dtype
381
+ tensor = tensor.to(torch.float32)
382
+ tensor.erfinv_()
383
+ tensor = tensor.to(og_dtype)
384
+ else:
385
+ tensor.erfinv_()
386
+
387
+ # Transform to proper mean, std
388
+ tensor.mul_(std * math.sqrt(2.0))
389
+ tensor.add_(mean)
390
+
391
+ # Clamp to ensure it's in the proper range
392
+ if tensor.dtype == torch.float16:
393
+ # The `clamp_` op is not (yet?) defined in float16+cpu
394
+ tensor = tensor.to(torch.float32)
395
+ tensor.clamp_(min=a, max=b)
396
+ tensor = tensor.to(torch.float16)
397
+ else:
398
+ tensor.clamp_(min=a, max=b)
399
+
400
+
401
+ def trunc_normal_tf_(
402
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
403
+ ) -> torch.Tensor:
404
+ """Fills the input Tensor with values drawn from a truncated
405
+ normal distribution. The values are effectively drawn from the
406
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
407
+ with values outside :math:`[a, b]` redrawn until they are within
408
+ the bounds. The method used for generating the random values works
409
+ best when :math:`a \\leq \text{mean} \\leq b`.
410
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
411
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
412
+ and the result is subsquently scaled and shifted by the mean and std args.
413
+ Args:
414
+ tensor: an n-dimensional `torch.Tensor`
415
+ mean: the mean of the normal distribution
416
+ std: the standard deviation of the normal distribution
417
+ a: the minimum cutoff value
418
+ b: the maximum cutoff value
419
+ """
420
+ with torch.no_grad():
421
+ _trunc_normal_(tensor, 0, 1.0, a, b)
422
+ tensor.mul_(std).add_(mean)
423
+
424
+
425
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
426
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
427
+ if mode == "fan_in":
428
+ denom = fan_in
429
+ elif mode == "fan_out":
430
+ denom = fan_out
431
+ elif mode == "fan_avg":
432
+ denom = (fan_in + fan_out) / 2
433
+
434
+ variance = scale / denom
435
+
436
+ if distribution == "truncated_normal":
437
+ # constant is stddev of standard normal truncated to (-2, 2)
438
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
439
+ elif distribution == "normal":
440
+ with torch.no_grad():
441
+ tensor.normal_(std=math.sqrt(variance))
442
+ elif distribution == "uniform":
443
+ bound = math.sqrt(3 * variance)
444
+ with torch.no_grad():
445
+ tensor.uniform_(-bound, bound)
446
+ else:
447
+ raise ValueError(f"invalid distribution {distribution}")
448
+
449
+
450
+ def lecun_normal_(tensor):
451
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
452
+
453
+
454
+ def default_flax_embed_init(tensor):
455
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
456
+
457
+
458
+ @dataclass
459
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
460
+ class SiglipVisionModelOutput(ModelOutput):
461
+ """
462
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
463
+ Args:
464
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
465
+ The image embeddings obtained by applying the projection layer to the pooler_output.
466
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
467
+ Sequence of hidden-states at the output of the last layer of the model.
468
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
469
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
470
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
471
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
472
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
473
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
474
+ sequence_length)`.
475
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
476
+ heads.
477
+ """
478
+
479
+ image_embeds: Optional[torch.FloatTensor] = None
480
+ last_hidden_state: torch.FloatTensor = None
481
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
482
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
483
+
484
+
485
+ @dataclass
486
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
487
+ class SiglipTextModelOutput(ModelOutput):
488
+ """
489
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
490
+ Args:
491
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
492
+ The text embeddings obtained by applying the projection layer to the pooler_output.
493
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
494
+ Sequence of hidden-states at the output of the last layer of the model.
495
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
496
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
497
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
498
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
499
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
500
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
501
+ sequence_length)`.
502
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
503
+ heads.
504
+ """
505
+
506
+ text_embeds: Optional[torch.FloatTensor] = None
507
+ last_hidden_state: torch.FloatTensor = None
508
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
509
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
510
+
511
+
512
+ @dataclass
513
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
514
+ class SiglipOutput(ModelOutput):
515
+ """
516
+ Args:
517
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
518
+ Contrastive loss for image-text similarity.
519
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
520
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
521
+ similarity scores.
522
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
523
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
524
+ similarity scores.
525
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
526
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
527
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
528
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
529
+ text_model_output(`BaseModelOutputWithPooling`):
530
+ The output of the [`SiglipTextModel`].
531
+ vision_model_output(`BaseModelOutputWithPooling`):
532
+ The output of the [`SiglipVisionModel`].
533
+ """
534
+
535
+ loss: Optional[torch.FloatTensor] = None
536
+ logits_per_image: torch.FloatTensor = None
537
+ logits_per_text: torch.FloatTensor = None
538
+ text_embeds: torch.FloatTensor = None
539
+ image_embeds: torch.FloatTensor = None
540
+ text_model_output: BaseModelOutputWithPooling = None
541
+ vision_model_output: BaseModelOutputWithPooling = None
542
+
543
+ def to_tuple(self) -> Tuple[Any]:
544
+ return tuple(
545
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
546
+ for k in self.keys()
547
+ )
548
+
549
+
550
+ class SiglipVisionEmbeddings(nn.Module):
551
+ def __init__(self, config: SiglipVisionConfig):
552
+ super().__init__()
553
+ self.config = config
554
+ self.embed_dim = config.hidden_size
555
+ self.image_size = config.image_size
556
+ self.patch_size = config.patch_size
557
+
558
+ self.patch_embedding = nn.Conv2d(
559
+ in_channels=config.num_channels,
560
+ out_channels=self.embed_dim,
561
+ kernel_size=self.patch_size,
562
+ stride=self.patch_size,
563
+ padding="valid",
564
+ )
565
+
566
+ self.num_patches_per_side = self.image_size // self.patch_size
567
+ self.num_patches = self.num_patches_per_side**2
568
+ self.num_positions = self.num_patches
569
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
570
+
571
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
572
+ batch_size = pixel_values.size(0)
573
+
574
+ patch_embeds = self.patch_embedding(pixel_values)
575
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
576
+
577
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
578
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
579
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
580
+ position_ids = torch.full(
581
+ size=(
582
+ batch_size,
583
+ max_nb_patches_h * max_nb_patches_w,
584
+ ),
585
+ fill_value=0,
586
+ )
587
+
588
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
589
+ nb_patches_h = p_attn_mask[:, 0].sum()
590
+ nb_patches_w = p_attn_mask[0].sum()
591
+
592
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
593
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
594
+
595
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
596
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
597
+
598
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
599
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
600
+
601
+ position_ids = position_ids.to(self.position_embedding.weight.device)
602
+
603
+ embeddings = embeddings + self.position_embedding(position_ids)
604
+ return embeddings
605
+
606
+
607
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
608
+ class SiglipTextEmbeddings(nn.Module):
609
+ def __init__(self, config: SiglipTextConfig):
610
+ super().__init__()
611
+ embed_dim = config.hidden_size
612
+
613
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
614
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
615
+
616
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
617
+ self.register_buffer(
618
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
619
+ )
620
+
621
+ def forward(
622
+ self,
623
+ input_ids: Optional[torch.LongTensor] = None,
624
+ position_ids: Optional[torch.LongTensor] = None,
625
+ inputs_embeds: Optional[torch.FloatTensor] = None,
626
+ ) -> torch.Tensor:
627
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
628
+
629
+ if position_ids is None:
630
+ position_ids = self.position_ids[:, :seq_length]
631
+
632
+ if inputs_embeds is None:
633
+ inputs_embeds = self.token_embedding(input_ids)
634
+
635
+ position_embeddings = self.position_embedding(position_ids)
636
+ embeddings = inputs_embeds + position_embeddings
637
+
638
+ return embeddings
639
+
640
+
641
+ class SiglipAttention(nn.Module):
642
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
643
+
644
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.config = config
648
+ self.embed_dim = config.hidden_size
649
+ self.num_heads = config.num_attention_heads
650
+ self.head_dim = self.embed_dim // self.num_heads
651
+ if self.head_dim * self.num_heads != self.embed_dim:
652
+ raise ValueError(
653
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
654
+ f" {self.num_heads})."
655
+ )
656
+ self.scale = self.head_dim**-0.5
657
+ self.dropout = config.attention_dropout
658
+
659
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
660
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
661
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
662
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
663
+
664
+ def forward(
665
+ self,
666
+ hidden_states: torch.Tensor,
667
+ attention_mask: Optional[torch.Tensor] = None,
668
+ output_attentions: Optional[bool] = False,
669
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
670
+ """Input shape: Batch x Time x Channel"""
671
+
672
+ batch_size, q_len, _ = hidden_states.size()
673
+
674
+ query_states = self.q_proj(hidden_states)
675
+ key_states = self.k_proj(hidden_states)
676
+ value_states = self.v_proj(hidden_states)
677
+
678
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
679
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
680
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
681
+
682
+ k_v_seq_len = key_states.shape[-2]
683
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
684
+
685
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
686
+ raise ValueError(
687
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
688
+ f" {attn_weights.size()}"
689
+ )
690
+
691
+ if attention_mask is not None:
692
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
693
+ raise ValueError(
694
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
695
+ )
696
+ attn_weights = attn_weights + attention_mask
697
+
698
+ # upcast attention to fp32
699
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
700
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
701
+ attn_output = torch.matmul(attn_weights, value_states)
702
+
703
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
704
+ raise ValueError(
705
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
706
+ f" {attn_output.size()}"
707
+ )
708
+
709
+ attn_output = attn_output.transpose(1, 2).contiguous()
710
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
711
+
712
+ attn_output = self.out_proj(attn_output)
713
+
714
+ return attn_output, attn_weights
715
+
716
+
717
+ class SiglipFlashAttention2(SiglipAttention):
718
+ """
719
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
720
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
721
+ flash attention and deal with padding tokens in case the input contains any of them.
722
+ """
723
+
724
+ def __init__(self, *args, **kwargs):
725
+ super().__init__(*args, **kwargs)
726
+ self.is_causal = False # Hack to make sure we don't use a causal mask
727
+
728
+ def forward(
729
+ self,
730
+ hidden_states: torch.Tensor,
731
+ attention_mask: Optional[torch.LongTensor] = None,
732
+ position_ids: Optional[torch.LongTensor] = None,
733
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
734
+ output_attentions: bool = False,
735
+ use_cache: bool = False,
736
+ **kwargs,
737
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
738
+ output_attentions = False
739
+
740
+ bsz, q_len, _ = hidden_states.size()
741
+
742
+ query_states = self.q_proj(hidden_states)
743
+ key_states = self.k_proj(hidden_states)
744
+ value_states = self.v_proj(hidden_states)
745
+
746
+ # Flash attention requires the input to have the shape
747
+ # batch_size x seq_length x head_dim x hidden_dim
748
+ # therefore we just need to keep the original shape
749
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
750
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
751
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
752
+
753
+ kv_seq_len = key_states.shape[-2]
754
+ if past_key_value is not None:
755
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
756
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
757
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
758
+
759
+ # if past_key_value is not None:
760
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
761
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
762
+
763
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
764
+ # to be able to avoid many of these transpose/reshape/view.
765
+ query_states = query_states.transpose(1, 2)
766
+ key_states = key_states.transpose(1, 2)
767
+ value_states = value_states.transpose(1, 2)
768
+
769
+ dropout_rate = self.dropout if self.training else 0.0
770
+
771
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
772
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
773
+ # cast them back in the correct dtype just to be sure everything works as expected.
774
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
775
+ # in fp32. (LlamaRMSNorm handles it correctly)
776
+
777
+ input_dtype = query_states.dtype
778
+ if input_dtype == torch.float32:
779
+ if torch.is_autocast_enabled():
780
+ target_dtype = torch.get_autocast_gpu_dtype()
781
+ # Handle the case where the model is quantized
782
+ elif hasattr(self.config, "_pre_quantization_dtype"):
783
+ target_dtype = self.config._pre_quantization_dtype
784
+ else:
785
+ target_dtype = self.q_proj.weight.dtype
786
+
787
+ logger.warning_once(
788
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
789
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
790
+ f" {target_dtype}."
791
+ )
792
+
793
+ query_states = query_states.to(target_dtype)
794
+ key_states = key_states.to(target_dtype)
795
+ value_states = value_states.to(target_dtype)
796
+
797
+ attn_output = self._flash_attention_forward(
798
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
799
+ )
800
+
801
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
802
+ attn_output = self.out_proj(attn_output)
803
+
804
+ if not output_attentions:
805
+ attn_weights = None
806
+
807
+ return attn_output, attn_weights
808
+
809
+ def _flash_attention_forward(
810
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
811
+ ):
812
+ """
813
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
814
+ first unpad the input, then computes the attention scores and pad the final attention scores.
815
+ Args:
816
+ query_states (`torch.Tensor`):
817
+ Input query states to be passed to Flash Attention API
818
+ key_states (`torch.Tensor`):
819
+ Input key states to be passed to Flash Attention API
820
+ value_states (`torch.Tensor`):
821
+ Input value states to be passed to Flash Attention API
822
+ attention_mask (`torch.Tensor`):
823
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
824
+ position of padding tokens and 1 for the position of non-padding tokens.
825
+ dropout (`int`, *optional*):
826
+ Attention dropout
827
+ softmax_scale (`float`, *optional*):
828
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
829
+ """
830
+
831
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
832
+ causal = self.is_causal and query_length != 1
833
+
834
+ # Contains at least one padding token in the sequence
835
+ if attention_mask is not None:
836
+ batch_size = query_states.shape[0]
837
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
838
+ query_states, key_states, value_states, attention_mask, query_length
839
+ )
840
+
841
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
842
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
843
+
844
+ attn_output_unpad = flash_attn_varlen_func(
845
+ query_states,
846
+ key_states,
847
+ value_states,
848
+ cu_seqlens_q=cu_seqlens_q,
849
+ cu_seqlens_k=cu_seqlens_k,
850
+ max_seqlen_q=max_seqlen_in_batch_q,
851
+ max_seqlen_k=max_seqlen_in_batch_k,
852
+ dropout_p=dropout,
853
+ softmax_scale=softmax_scale,
854
+ causal=causal,
855
+ )
856
+
857
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
858
+ else:
859
+ attn_output = flash_attn_func(
860
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
861
+ )
862
+
863
+ return attn_output
864
+
865
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
866
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
867
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
868
+
869
+ key_layer = index_first_axis(
870
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
871
+ )
872
+ value_layer = index_first_axis(
873
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
874
+ )
875
+ if query_length == kv_seq_len:
876
+ query_layer = index_first_axis(
877
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
878
+ )
879
+ cu_seqlens_q = cu_seqlens_k
880
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
881
+ indices_q = indices_k
882
+ elif query_length == 1:
883
+ max_seqlen_in_batch_q = 1
884
+ cu_seqlens_q = torch.arange(
885
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
886
+ ) # There is a memcpy here, that is very bad.
887
+ indices_q = cu_seqlens_q[:-1]
888
+ query_layer = query_layer.squeeze(1)
889
+ else:
890
+ # The -q_len: slice assumes left padding.
891
+ attention_mask = attention_mask[:, -query_length:]
892
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
893
+
894
+ return (
895
+ query_layer,
896
+ key_layer,
897
+ value_layer,
898
+ indices_q,
899
+ (cu_seqlens_q, cu_seqlens_k),
900
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
901
+ )
902
+
903
+
904
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
905
+ class SiglipMLP(nn.Module):
906
+ def __init__(self, config):
907
+ super().__init__()
908
+ self.config = config
909
+ self.activation_fn = ACT2FN[config.hidden_act]
910
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
911
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
912
+
913
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
914
+ hidden_states = self.fc1(hidden_states)
915
+ hidden_states = self.activation_fn(hidden_states)
916
+ hidden_states = self.fc2(hidden_states)
917
+ return hidden_states
918
+
919
+
920
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
921
+ class SiglipEncoderLayer(nn.Module):
922
+ def __init__(self, config: SiglipConfig):
923
+ super().__init__()
924
+ self.embed_dim = config.hidden_size
925
+ self.self_attn = (
926
+ SiglipAttention(config)
927
+ if not getattr(config, "_flash_attn_2_enabled", False)
928
+ else SiglipFlashAttention2(config)
929
+ )
930
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
931
+ self.mlp = SiglipMLP(config)
932
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
933
+
934
+ def forward(
935
+ self,
936
+ hidden_states: torch.Tensor,
937
+ attention_mask: torch.Tensor,
938
+ output_attentions: Optional[bool] = False,
939
+ ) -> Tuple[torch.FloatTensor]:
940
+ """
941
+ Args:
942
+ hidden_states (`torch.FloatTensor`):
943
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
944
+ attention_mask (`torch.FloatTensor`):
945
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
946
+ output_attentions (`bool`, *optional*, defaults to `False`):
947
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
948
+ returned tensors for more detail.
949
+ """
950
+ residual = hidden_states
951
+
952
+ hidden_states = self.layer_norm1(hidden_states)
953
+ hidden_states, attn_weights = self.self_attn(
954
+ hidden_states=hidden_states,
955
+ attention_mask=attention_mask,
956
+ output_attentions=output_attentions,
957
+ )
958
+ hidden_states = residual + hidden_states
959
+
960
+ residual = hidden_states
961
+ hidden_states = self.layer_norm2(hidden_states)
962
+ hidden_states = self.mlp(hidden_states)
963
+ hidden_states = residual + hidden_states
964
+
965
+ outputs = (hidden_states,)
966
+
967
+ if output_attentions:
968
+ outputs += (attn_weights,)
969
+
970
+ return outputs
971
+
972
+
973
+ class SiglipPreTrainedModel(PreTrainedModel):
974
+ """
975
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
976
+ models.
977
+ """
978
+
979
+ config_class = SiglipConfig
980
+ base_model_prefix = "siglip"
981
+ supports_gradient_checkpointing = True
982
+
983
+ def _init_weights(self, module):
984
+ """Initialize the weights"""
985
+
986
+ if isinstance(module, SiglipVisionEmbeddings):
987
+ width = (
988
+ self.config.vision_config.hidden_size
989
+ if isinstance(self.config, SiglipConfig)
990
+ else self.config.hidden_size
991
+ )
992
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
993
+ elif isinstance(module, nn.Embedding):
994
+ default_flax_embed_init(module.weight)
995
+ elif isinstance(module, SiglipAttention):
996
+ nn.init.normal_(module.q_proj.weight)
997
+ nn.init.normal_(module.k_proj.weight)
998
+ nn.init.normal_(module.v_proj.weight)
999
+ nn.init.normal_(module.out_proj.weight)
1000
+ nn.init.zeros_(module.q_proj.bias)
1001
+ nn.init.zeros_(module.k_proj.bias)
1002
+ nn.init.zeros_(module.v_proj.bias)
1003
+ nn.init.zeros_(module.out_proj.bias)
1004
+ elif isinstance(module, SiglipMLP):
1005
+ nn.init.normal_(module.fc1.weight)
1006
+ nn.init.normal_(module.fc2.weight)
1007
+ nn.init.normal_(module.fc1.bias, std=1e-6)
1008
+ nn.init.normal_(module.fc2.bias, std=1e-6)
1009
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
1010
+ nn.init.normal_(module.probe.data)
1011
+ nn.init.normal_(module.attention.in_proj_weight.data)
1012
+ nn.init.zeros_(module.attention.in_proj_bias.data)
1013
+ elif isinstance(module, SiglipModel):
1014
+ logit_scale_init = torch.tensor(0.0)
1015
+ module.logit_scale.data.fill_(logit_scale_init)
1016
+ module.logit_bias.data.zero_()
1017
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
1018
+ lecun_normal_(module.weight)
1019
+ if module.bias is not None:
1020
+ nn.init.zeros_(module.bias)
1021
+ elif isinstance(module, nn.LayerNorm):
1022
+ module.bias.data.zero_()
1023
+ module.weight.data.fill_(1.0)
1024
+
1025
+
1026
+ SIGLIP_START_DOCSTRING = r"""
1027
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1028
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1029
+ etc.)
1030
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1031
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1032
+ and behavior.
1033
+ Parameters:
1034
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
1035
+ Initializing with a config file does not load the weights associated with the model, only the
1036
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1037
+ """
1038
+
1039
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
1040
+ Args:
1041
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1042
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1043
+ it.
1044
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1045
+ [`PreTrainedTokenizer.__call__`] for details.
1046
+ [What are input IDs?](../glossary#input-ids)
1047
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1048
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1049
+ - 1 for tokens that are **not masked**,
1050
+ - 0 for tokens that are **masked**.
1051
+ [What are attention masks?](../glossary#attention-mask)
1052
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1053
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1054
+ config.max_position_embeddings - 1]`.
1055
+ [What are position IDs?](../glossary#position-ids)
1056
+ output_attentions (`bool`, *optional*):
1057
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1058
+ tensors for more detail.
1059
+ output_hidden_states (`bool`, *optional*):
1060
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1061
+ more detail.
1062
+ return_dict (`bool`, *optional*):
1063
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1064
+ """
1065
+
1066
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
1067
+ Args:
1068
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1069
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
1070
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
1071
+ output_attentions (`bool`, *optional*):
1072
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1073
+ tensors for more detail.
1074
+ output_hidden_states (`bool`, *optional*):
1075
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1076
+ more detail.
1077
+ return_dict (`bool`, *optional*):
1078
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1079
+ """
1080
+
1081
+ SIGLIP_INPUTS_DOCSTRING = r"""
1082
+ Args:
1083
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1084
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1085
+ it.
1086
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1087
+ [`PreTrainedTokenizer.__call__`] for details.
1088
+ [What are input IDs?](../glossary#input-ids)
1089
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1090
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1091
+ - 1 for tokens that are **not masked**,
1092
+ - 0 for tokens that are **masked**.
1093
+ [What are attention masks?](../glossary#attention-mask)
1094
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1095
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1096
+ config.max_position_embeddings - 1]`.
1097
+ [What are position IDs?](../glossary#position-ids)
1098
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1099
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
1100
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
1101
+ return_loss (`bool`, *optional*):
1102
+ Whether or not to return the contrastive loss.
1103
+ output_attentions (`bool`, *optional*):
1104
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1105
+ tensors for more detail.
1106
+ output_hidden_states (`bool`, *optional*):
1107
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1108
+ more detail.
1109
+ return_dict (`bool`, *optional*):
1110
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1111
+ """
1112
+
1113
+
1114
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
1115
+ class SiglipEncoder(nn.Module):
1116
+ """
1117
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
1118
+ [`SiglipEncoderLayer`].
1119
+ Args:
1120
+ config: SiglipConfig
1121
+ """
1122
+
1123
+ def __init__(self, config: SiglipConfig):
1124
+ super().__init__()
1125
+ self.config = config
1126
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
1127
+ self.gradient_checkpointing = False
1128
+
1129
+ # Ignore copy
1130
+ def forward(
1131
+ self,
1132
+ inputs_embeds,
1133
+ attention_mask: Optional[torch.Tensor] = None,
1134
+ output_attentions: Optional[bool] = None,
1135
+ output_hidden_states: Optional[bool] = None,
1136
+ return_dict: Optional[bool] = None,
1137
+ ) -> Union[Tuple, BaseModelOutput]:
1138
+ r"""
1139
+ Args:
1140
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1141
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1142
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1143
+ than the model's internal embedding lookup matrix.
1144
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1145
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1146
+ - 1 for tokens that are **not masked**,
1147
+ - 0 for tokens that are **masked**.
1148
+ [What are attention masks?](../glossary#attention-mask)
1149
+ output_attentions (`bool`, *optional*):
1150
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1151
+ returned tensors for more detail.
1152
+ output_hidden_states (`bool`, *optional*):
1153
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1154
+ for more detail.
1155
+ return_dict (`bool`, *optional*):
1156
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1157
+ """
1158
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1159
+ output_hidden_states = (
1160
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1161
+ )
1162
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1163
+
1164
+ encoder_states = () if output_hidden_states else None
1165
+ all_attentions = () if output_attentions else None
1166
+
1167
+ hidden_states = inputs_embeds
1168
+ for encoder_layer in self.layers:
1169
+ if output_hidden_states:
1170
+ encoder_states = encoder_states + (hidden_states,)
1171
+ if self.gradient_checkpointing and self.training:
1172
+ layer_outputs = self._gradient_checkpointing_func(
1173
+ encoder_layer.__call__,
1174
+ hidden_states,
1175
+ attention_mask,
1176
+ output_attentions,
1177
+ )
1178
+ else:
1179
+ layer_outputs = encoder_layer(
1180
+ hidden_states,
1181
+ attention_mask,
1182
+ output_attentions=output_attentions,
1183
+ )
1184
+
1185
+ hidden_states = layer_outputs[0]
1186
+
1187
+ if output_attentions:
1188
+ all_attentions = all_attentions + (layer_outputs[1],)
1189
+
1190
+ if output_hidden_states:
1191
+ encoder_states = encoder_states + (hidden_states,)
1192
+
1193
+ if not return_dict:
1194
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1195
+ return BaseModelOutput(
1196
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1197
+ )
1198
+
1199
+
1200
+ class SiglipTextTransformer(nn.Module):
1201
+ def __init__(self, config: SiglipTextConfig):
1202
+ super().__init__()
1203
+ self.config = config
1204
+ embed_dim = config.hidden_size
1205
+ self.embeddings = SiglipTextEmbeddings(config)
1206
+ self.encoder = SiglipEncoder(config)
1207
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1208
+
1209
+ self.head = nn.Linear(embed_dim, embed_dim)
1210
+
1211
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1212
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
1213
+ def forward(
1214
+ self,
1215
+ input_ids: Optional[torch.Tensor] = None,
1216
+ attention_mask: Optional[torch.Tensor] = None,
1217
+ position_ids: Optional[torch.Tensor] = None,
1218
+ output_attentions: Optional[bool] = None,
1219
+ output_hidden_states: Optional[bool] = None,
1220
+ return_dict: Optional[bool] = None,
1221
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1222
+ r"""
1223
+ Returns:
1224
+ """
1225
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1226
+ output_hidden_states = (
1227
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1228
+ )
1229
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1230
+
1231
+ if input_ids is None:
1232
+ raise ValueError("You have to specify input_ids")
1233
+
1234
+ input_shape = input_ids.size()
1235
+ input_ids = input_ids.view(-1, input_shape[-1])
1236
+
1237
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
1238
+
1239
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
1240
+ # expand attention_mask
1241
+ if attention_mask is not None:
1242
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
1243
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
1244
+
1245
+ encoder_outputs = self.encoder(
1246
+ inputs_embeds=hidden_states,
1247
+ attention_mask=attention_mask,
1248
+ output_attentions=output_attentions,
1249
+ output_hidden_states=output_hidden_states,
1250
+ return_dict=return_dict,
1251
+ )
1252
+
1253
+ last_hidden_state = encoder_outputs[0]
1254
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
1255
+
1256
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
1257
+ pooled_output = last_hidden_state[:, -1, :]
1258
+ pooled_output = self.head(pooled_output)
1259
+
1260
+ if not return_dict:
1261
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1262
+
1263
+ return BaseModelOutputWithPooling(
1264
+ last_hidden_state=last_hidden_state,
1265
+ pooler_output=pooled_output,
1266
+ hidden_states=encoder_outputs.hidden_states,
1267
+ attentions=encoder_outputs.attentions,
1268
+ )
1269
+
1270
+
1271
+ @add_start_docstrings(
1272
+ """The text model from SigLIP without any head or projection on top.""",
1273
+ SIGLIP_START_DOCSTRING,
1274
+ )
1275
+ class SiglipTextModel(SiglipPreTrainedModel):
1276
+ config_class = SiglipTextConfig
1277
+
1278
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
1279
+
1280
+ def __init__(self, config: SiglipTextConfig):
1281
+ super().__init__(config)
1282
+ self.text_model = SiglipTextTransformer(config)
1283
+ # Initialize weights and apply final processing
1284
+ self.post_init()
1285
+
1286
+ def get_input_embeddings(self) -> nn.Module:
1287
+ return self.text_model.embeddings.token_embedding
1288
+
1289
+ def set_input_embeddings(self, value):
1290
+ self.text_model.embeddings.token_embedding = value
1291
+
1292
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1293
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
1294
+ def forward(
1295
+ self,
1296
+ input_ids: Optional[torch.Tensor] = None,
1297
+ attention_mask: Optional[torch.Tensor] = None,
1298
+ position_ids: Optional[torch.Tensor] = None,
1299
+ output_attentions: Optional[bool] = None,
1300
+ output_hidden_states: Optional[bool] = None,
1301
+ return_dict: Optional[bool] = None,
1302
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1303
+ r"""
1304
+ Returns:
1305
+ Examples:
1306
+ ```python
1307
+ >>> from transformers import AutoTokenizer, SiglipTextModel
1308
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1309
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1310
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1311
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1312
+ >>> outputs = model(**inputs)
1313
+ >>> last_hidden_state = outputs.last_hidden_state
1314
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
1315
+ ```"""
1316
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1317
+
1318
+ return self.text_model(
1319
+ input_ids=input_ids,
1320
+ attention_mask=attention_mask,
1321
+ position_ids=position_ids,
1322
+ output_attentions=output_attentions,
1323
+ output_hidden_states=output_hidden_states,
1324
+ return_dict=return_dict,
1325
+ )
1326
+
1327
+
1328
+ class SiglipVisionTransformer(nn.Module):
1329
+ def __init__(self, config: SiglipVisionConfig):
1330
+ super().__init__()
1331
+ self.config = config
1332
+ embed_dim = config.hidden_size
1333
+
1334
+ self.embeddings = SiglipVisionEmbeddings(config)
1335
+ self.encoder = SiglipEncoder(config)
1336
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1337
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
1338
+
1339
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1340
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1341
+ def forward(
1342
+ self,
1343
+ pixel_values,
1344
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1345
+ output_attentions: Optional[bool] = None,
1346
+ output_hidden_states: Optional[bool] = None,
1347
+ return_dict: Optional[bool] = None,
1348
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1349
+ r"""
1350
+ Returns:
1351
+ """
1352
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1353
+ output_hidden_states = (
1354
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1355
+ )
1356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1357
+
1358
+ batch_size = pixel_values.size(0)
1359
+ if patch_attention_mask is None:
1360
+ patch_attention_mask = torch.ones(
1361
+ size=(
1362
+ batch_size,
1363
+ pixel_values.size(2) // self.config.patch_size,
1364
+ pixel_values.size(3) // self.config.patch_size,
1365
+ ),
1366
+ dtype=torch.bool,
1367
+ device=pixel_values.device,
1368
+ )
1369
+
1370
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
1371
+
1372
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1373
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
1374
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
1375
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
1376
+ if not torch.any(~patch_attention_mask):
1377
+ attention_mask=None
1378
+ else:
1379
+ attention_mask = (
1380
+ _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1381
+ if not self.config._flash_attn_2_enabled
1382
+ else patch_attention_mask
1383
+ )
1384
+
1385
+ encoder_outputs = self.encoder(
1386
+ inputs_embeds=hidden_states,
1387
+ attention_mask=attention_mask,
1388
+ output_attentions=output_attentions,
1389
+ output_hidden_states=output_hidden_states,
1390
+ return_dict=return_dict,
1391
+ )
1392
+
1393
+ last_hidden_state = encoder_outputs[0]
1394
+ last_hidden_state = self.post_layernorm(last_hidden_state)
1395
+
1396
+ pooled_output = self.head(
1397
+ hidden_state=last_hidden_state,
1398
+ attention_mask=patch_attention_mask,
1399
+ )
1400
+
1401
+ if not return_dict:
1402
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1403
+
1404
+ return BaseModelOutputWithPooling(
1405
+ last_hidden_state=last_hidden_state,
1406
+ pooler_output=pooled_output,
1407
+ hidden_states=encoder_outputs.hidden_states,
1408
+ attentions=encoder_outputs.attentions,
1409
+ )
1410
+
1411
+
1412
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1413
+ """Multihead Attention Pooling."""
1414
+
1415
+ def __init__(self, config: SiglipVisionConfig):
1416
+ super().__init__()
1417
+
1418
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1419
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
1420
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1421
+ self.mlp = SiglipMLP(config)
1422
+
1423
+ def forward(self, hidden_state, attention_mask):
1424
+ batch_size = hidden_state.shape[0]
1425
+ probe = self.probe.repeat(batch_size, 1, 1)
1426
+
1427
+ hidden_state = self.attention(
1428
+ query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
1429
+ )[0]
1430
+
1431
+ residual = hidden_state
1432
+ hidden_state = self.layernorm(hidden_state)
1433
+ hidden_state = residual + self.mlp(hidden_state)
1434
+
1435
+ return hidden_state[:, 0]
1436
+
1437
+
1438
+ @add_start_docstrings(
1439
+ """The vision model from SigLIP without any head or projection on top.""",
1440
+ SIGLIP_START_DOCSTRING,
1441
+ )
1442
+ class SiglipVisionModel(SiglipPreTrainedModel):
1443
+ config_class = SiglipVisionConfig
1444
+ main_input_name = "pixel_values"
1445
+
1446
+ def __init__(self, config: SiglipVisionConfig):
1447
+ super().__init__(config)
1448
+
1449
+ self.vision_model = SiglipVisionTransformer(config)
1450
+
1451
+ # Initialize weights and apply final processing
1452
+ self.post_init()
1453
+
1454
+ def get_input_embeddings(self) -> nn.Module:
1455
+ return self.vision_model.embeddings.patch_embedding
1456
+
1457
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1458
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1459
+ def forward(
1460
+ self,
1461
+ pixel_values,
1462
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1463
+ output_attentions: Optional[bool] = None,
1464
+ output_hidden_states: Optional[bool] = None,
1465
+ return_dict: Optional[bool] = None,
1466
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1467
+ r"""
1468
+ Returns:
1469
+ Examples:
1470
+ ```python
1471
+ >>> from PIL import Image
1472
+ >>> import requests
1473
+ >>> from transformers import AutoProcessor, SiglipVisionModel
1474
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1475
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1476
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1477
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1478
+ >>> inputs = processor(images=image, return_tensors="pt")
1479
+ >>> outputs = model(**inputs)
1480
+ >>> last_hidden_state = outputs.last_hidden_state
1481
+ >>> pooled_output = outputs.pooler_output # pooled features
1482
+ ```"""
1483
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1484
+
1485
+ return self.vision_model(
1486
+ pixel_values=pixel_values,
1487
+ patch_attention_mask=patch_attention_mask,
1488
+ output_attentions=output_attentions,
1489
+ output_hidden_states=output_hidden_states,
1490
+ return_dict=return_dict,
1491
+ )
1492
+
1493
+
1494
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
1495
+ class SiglipModel(SiglipPreTrainedModel):
1496
+ config_class = SiglipConfig
1497
+
1498
+ def __init__(self, config: SiglipConfig):
1499
+ super().__init__(config)
1500
+
1501
+ if not isinstance(config.text_config, SiglipTextConfig):
1502
+ raise ValueError(
1503
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
1504
+ f" {type(config.text_config)}."
1505
+ )
1506
+
1507
+ if not isinstance(config.vision_config, SiglipVisionConfig):
1508
+ raise ValueError(
1509
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1510
+ f" {type(config.vision_config)}."
1511
+ )
1512
+
1513
+ text_config = config.text_config
1514
+ vision_config = config.vision_config
1515
+
1516
+ self.text_model = SiglipTextTransformer(text_config)
1517
+ self.vision_model = SiglipVisionTransformer(vision_config)
1518
+
1519
+ self.logit_scale = nn.Parameter(torch.randn(1))
1520
+ self.logit_bias = nn.Parameter(torch.randn(1))
1521
+
1522
+ # Initialize weights and apply final processing
1523
+ self.post_init()
1524
+
1525
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1526
+ def get_text_features(
1527
+ self,
1528
+ input_ids: Optional[torch.Tensor] = None,
1529
+ attention_mask: Optional[torch.Tensor] = None,
1530
+ position_ids: Optional[torch.Tensor] = None,
1531
+ output_attentions: Optional[bool] = None,
1532
+ output_hidden_states: Optional[bool] = None,
1533
+ return_dict: Optional[bool] = None,
1534
+ ) -> torch.FloatTensor:
1535
+ r"""
1536
+ Returns:
1537
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1538
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1539
+ Examples:
1540
+ ```python
1541
+ >>> from transformers import AutoTokenizer, AutoModel
1542
+ >>> import torch
1543
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1544
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1545
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1546
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1547
+ >>> with torch.no_grad():
1548
+ ... text_features = model.get_text_features(**inputs)
1549
+ ```"""
1550
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1551
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1552
+ output_hidden_states = (
1553
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1554
+ )
1555
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1556
+
1557
+ text_outputs = self.text_model(
1558
+ input_ids=input_ids,
1559
+ attention_mask=attention_mask,
1560
+ position_ids=position_ids,
1561
+ output_attentions=output_attentions,
1562
+ output_hidden_states=output_hidden_states,
1563
+ return_dict=return_dict,
1564
+ )
1565
+
1566
+ pooled_output = text_outputs[1]
1567
+
1568
+ return pooled_output
1569
+
1570
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1571
+ def get_image_features(
1572
+ self,
1573
+ pixel_values: Optional[torch.FloatTensor] = None,
1574
+ output_attentions: Optional[bool] = None,
1575
+ output_hidden_states: Optional[bool] = None,
1576
+ return_dict: Optional[bool] = None,
1577
+ ) -> torch.FloatTensor:
1578
+ r"""
1579
+ Returns:
1580
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1581
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1582
+ Examples:
1583
+ ```python
1584
+ >>> from PIL import Image
1585
+ >>> import requests
1586
+ >>> from transformers import AutoProcessor, AutoModel
1587
+ >>> import torch
1588
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1589
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1590
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1591
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1592
+ >>> inputs = processor(images=image, return_tensors="pt")
1593
+ >>> with torch.no_grad():
1594
+ ... image_features = model.get_image_features(**inputs)
1595
+ ```"""
1596
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1597
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1598
+ output_hidden_states = (
1599
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1600
+ )
1601
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1602
+
1603
+ vision_outputs = self.vision_model(
1604
+ pixel_values=pixel_values,
1605
+ output_attentions=output_attentions,
1606
+ output_hidden_states=output_hidden_states,
1607
+ return_dict=return_dict,
1608
+ )
1609
+
1610
+ pooled_output = vision_outputs[1]
1611
+
1612
+ return pooled_output
1613
+
1614
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1615
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1616
+ def forward(
1617
+ self,
1618
+ input_ids: Optional[torch.LongTensor] = None,
1619
+ pixel_values: Optional[torch.FloatTensor] = None,
1620
+ attention_mask: Optional[torch.Tensor] = None,
1621
+ position_ids: Optional[torch.LongTensor] = None,
1622
+ return_loss: Optional[bool] = None,
1623
+ output_attentions: Optional[bool] = None,
1624
+ output_hidden_states: Optional[bool] = None,
1625
+ return_dict: Optional[bool] = None,
1626
+ ) -> Union[Tuple, SiglipOutput]:
1627
+ r"""
1628
+ Returns:
1629
+ Examples:
1630
+ ```python
1631
+ >>> from PIL import Image
1632
+ >>> import requests
1633
+ >>> from transformers import AutoProcessor, AutoModel
1634
+ >>> import torch
1635
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1636
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1637
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1638
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1639
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1640
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1641
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1642
+ >>> with torch.no_grad():
1643
+ ... outputs = model(**inputs)
1644
+ >>> logits_per_image = outputs.logits_per_image
1645
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1646
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1647
+ 31.9% that image 0 is 'a photo of 2 cats'
1648
+ ```"""
1649
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1650
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1651
+ output_hidden_states = (
1652
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1653
+ )
1654
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1655
+
1656
+ vision_outputs = self.vision_model(
1657
+ pixel_values=pixel_values,
1658
+ output_attentions=output_attentions,
1659
+ output_hidden_states=output_hidden_states,
1660
+ return_dict=return_dict,
1661
+ )
1662
+
1663
+ text_outputs = self.text_model(
1664
+ input_ids=input_ids,
1665
+ attention_mask=attention_mask,
1666
+ position_ids=position_ids,
1667
+ output_attentions=output_attentions,
1668
+ output_hidden_states=output_hidden_states,
1669
+ return_dict=return_dict,
1670
+ )
1671
+
1672
+ image_embeds = vision_outputs[1]
1673
+ text_embeds = text_outputs[1]
1674
+
1675
+ # normalized features
1676
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1677
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1678
+
1679
+ # cosine similarity as logits
1680
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1681
+ logits_per_image = logits_per_text.t()
1682
+
1683
+ loss = None
1684
+ if return_loss:
1685
+ raise NotImplementedError("SigLIP loss to be implemented")
1686
+
1687
+ if not return_dict:
1688
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1689
+ return ((loss,) + output) if loss is not None else output
1690
+
1691
+ return SiglipOutput(
1692
+ loss=loss,
1693
+ logits_per_image=logits_per_image,
1694
+ logits_per_text=logits_per_text,
1695
+ text_embeds=text_embeds,
1696
+ image_embeds=image_embeds,
1697
+ text_model_output=text_outputs,
1698
+ vision_model_output=vision_outputs,
1699
+ )
1700
+
1701
+
1702
+ def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs):
1703
+ siglip_vision_config = {
1704
+ "hidden_size": 1152,
1705
+ "image_size": 448,
1706
+ "intermediate_size": 4304,
1707
+ "model_type": "siglip_vision_model",
1708
+ "num_attention_heads": 16,
1709
+ "num_hidden_layers": 27,
1710
+ "patch_size": 14,
1711
+ }
1712
+
1713
+ model_config = SiglipVisionConfig(**siglip_vision_config, _flash_attn_2_enabled=_flash_attn_2_enabled, **kwargs)
1714
+
1715
+ vision_model = SiglipVisionModel(model_config).vision_model
1716
+
1717
+ return vision_model