From a02683328d2443a7ec63142ac1fb9bd3707bc418 Mon Sep 17 00:00:00 2001 From: Scott Newcomer Date: Tue, 5 Nov 2024 06:40:03 -0600 Subject: [PATCH 1/6] feat: add mpnet model family --- lib/bumblebee.ex | 2 + lib/bumblebee/text/mpnet.ex | 458 ++++++++++++++++++++++++++++++++++++ 2 files changed, 460 insertions(+) create mode 100644 lib/bumblebee/text/mpnet.ex diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 51f2330f..66d1eb54 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -170,6 +170,8 @@ defmodule Bumblebee do "MistralModel" => {Bumblebee.Text.Mistral, :base}, "MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling}, "MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification}, + "MPNetModel" => {Bumblebee.Text.MPNet, :base}, + "MPNetForMaskedLM" => {Bumblebee.Text.MPNet, :for_masked_language_modeling}, "PhiModel" => {Bumblebee.Text.Phi, :base}, "PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling}, "PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification}, diff --git a/lib/bumblebee/text/mpnet.ex b/lib/bumblebee/text/mpnet.ex new file mode 100644 index 00000000..d74c06ff --- /dev/null +++ b/lib/bumblebee/text/mpnet.ex @@ -0,0 +1,458 @@ +defmodule Bumblebee.Text.MPNet do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 30527, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 514, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + max_token_types: [ + default: 2, + doc: """ + the vocabulary size of the token type embedding (also referred to as segment embedding). + This corresponds to how many different token groups can be distinguished in the input + """ + ], + hidden_size: [ + default: 768, + doc: "the dimensionality of hidden layers" + ], + num_blocks: [ + default: 12, + doc: "the number of Transformer blocks in the encoder" + ], + num_attention_heads: [ + default: 12, + doc: "the number of attention heads for each attention layer in the encoder" + ], + intermediate_size: [ + default: 3072, + doc: + "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" + ], + activation: [ + default: :gelu, + doc: "the activation function" + ], + dropout_rate: [ + default: 0.1, + doc: "the dropout rate for embedding and encoder" + ], + attention_dropout_rate: [ + default: 0.1, + doc: "the dropout rate for attention weights" + ], + classifier_dropout_rate: [ + default: nil, + doc: + "the dropout rate for the classification head. If not specified, the value of `:dropout_rate` is used instead" + ], + layer_norm_epsilon: [ + default: 1.0e-05, + doc: "the epsilon used by the layer normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ] + ] ++ Shared.common_options([:use_cross_attention, :num_labels, :id_to_label]) + + @moduledoc """ + MPNet model family. + + ## Architectures + + * `:base` - plain MPNet without any head on top + + * `:for_masked_language_modeling` - MPNet with a language modeling + head. The head returns logits for each token in the original + sequence + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"token_type_ids"` - `{batch_size, sequence_length}` + + Mask distinguishing groups in the input sequence. This is used + in when the input sequence is a semantically a pair of sequences. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{num_blocks, num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + ### Exceptions + + The `:for_causal_language_modeling` model is a decoder and accepts + the following additional inputs: `"encoder_hidden_state"`, + `"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + + ## References + + * [MPNet: MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/pdf/2004.09297) + + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_masked_language_modeling, + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + def input_template(_spec) do + %{"input_ids" => Nx.template({1, 1}, :u32)} + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_masked_language_modeling} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + @impl true + def init_cache(spec, batch_size, max_length, inputs) do + encoder_sequence_length = + if encoder_hidden_state = inputs["encoder_hidden_state"] do + Nx.axis_size(encoder_hidden_state, 1) + end + + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + decoder_num_attention_heads: spec.num_attention_heads, + encoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks, + encoder_sequence_length: encoder_sequence_length + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + defp inputs(spec, opts \\ []) do + shape = Keyword.get(opts, :shape, {nil, nil}) + decoder? = Keyword.get(opts, :decoder?, false) + + hidden_shape = Tuple.append(shape, spec.hidden_size) + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + inputs = + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("token_type_ids", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape) + ]) + + extra_decoder_inputs = + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("encoder_hidden_state", optional: true, shape: hidden_shape), + Axon.input("encoder_attention_mask", optional: true, shape: shape), + Axon.input("cross_attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("cache", optional: true) + ]) + + extra_decoder_inputs = + if decoder? do + extra_decoder_inputs + else + Map.new(extra_decoder_inputs, fn {name, _input} -> {name, Layers.none()} end) + end + + Map.merge(inputs, extra_decoder_inputs) + end + + defp core(inputs, spec, opts \\ []) do + decoder? = Keyword.get(opts, :decoder?, false) + + embeddings = + embedder(inputs["input_ids"], inputs["position_ids"], inputs["token_type_ids"], spec, + name: "embedder" + ) + + encoder_outputs = + encoder( + embeddings, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["encoder_hidden_state"], + inputs["encoder_attention_mask"], + inputs["cross_attention_head_mask"], + inputs["cache"], + spec, + decoder?: decoder?, + name: "encoder" + ) + + pooled_state = pooler(encoder_outputs.hidden_state, spec, name: "pooler") + + %{ + hidden_state: encoder_outputs.hidden_state, + pooled_state: pooled_state, + hidden_states: encoder_outputs.hidden_states, + attentions: encoder_outputs.attentions, + cross_attentions: encoder_outputs.cross_attentions, + cache: encoder_outputs.cache + } + end + + defp embedder(input_ids, position_ids, token_type_ids, spec, opts) do + name = opts[:name] + + position_ids = + Layers.default position_ids do + Layers.default_position_ids(input_ids) + end + + token_type_ids = + Layers.default token_type_ids do + Layers.default_token_type_ids(input_ids) + end + + inputs_embeddings = + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + + position_embeddings = + Axon.embedding(position_ids, spec.max_positions, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "position_embedding") + ) + + token_type_embeddings = + Axon.embedding(token_type_ids, spec.max_token_types, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_type_embedding") + ) + + Axon.add([inputs_embeddings, position_embeddings, token_type_embeddings]) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm")) + |> Axon.dropout(rate: spec.dropout_rate, name: join(name, "dropout")) + end + + defp encoder( + hidden_state, + attention_mask, + attention_head_mask, + encoder_hidden_state, + encoder_attention_mask, + cross_attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + decoder? = opts[:decoder?] + + cross_attention? = decoder? and spec.use_cross_attention + + Layers.Transformer.blocks( + hidden_state, + [ + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + causal: decoder?, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + dropout_rate: spec.dropout_rate, + attention_dropout_rate: spec.attention_dropout_rate, + layer_norm: [ + epsilon: spec.layer_norm_epsilon + ], + ffn: [ + intermediate_size: spec.intermediate_size, + activation: spec.activation + ], + name: join(name, "blocks") + ] ++ + if(cross_attention?, + do: [ + cross_hidden_state: encoder_hidden_state, + cross_attention_mask: encoder_attention_mask, + cross_attention_head_mask: cross_attention_head_mask + ], + else: [] + ) + ) + end + + defp pooler(hidden_state, spec, opts) do + name = opts[:name] + + hidden_state + |> Layers.take_token(index: 0, axis: 1) + |> Axon.dense(spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + |> Axon.tanh() + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: use a shared parameter with embeddings.word_embeddings.kernel + # if spec.tie_word_embeddings is true (relevant for training) + + hidden_state + |> Axon.dense(spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "dense") + ) + |> Layers.activation(spec.activation, name: join(name, "activation")) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm")) + # We reuse the kernel of input embeddings and add bias for each token + |> Layers.dense_transposed(spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + |> Axon.bias(name: join(name, "bias")) + end + + defp classifier_dropout_rate(spec) do + spec.classifier_dropout_rate || spec.dropout_rate + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + max_token_types: {"type_vocab_size", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + dropout_rate: {"hidden_dropout_prob", number()}, + attention_dropout_rate: {"attention_probs_dropout_prob", number()}, + classifier_dropout_rate: {"classifier_dropout", optional(number())}, + layer_norm_epsilon: {"layer_norm_eps", number()}, + initializer_scale: {"initializer_range", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.token_embedding" => "mpnet.embeddings.word_embeddings", + "embedder.position_embedding" => "mpnet.embeddings.position_embeddings", + "embedder.token_type_embedding" => "mpnet.embeddings.token_type_embeddings", + "embedder.norm" => "mpnet.embeddings.LayerNorm", + "encoder.blocks.{n}.self_attention.query" => + "mpnet.encoder.layer.{n}.attention.self.query", + "encoder.blocks.{n}.self_attention.key" => "mpnet.encoder.layer.{n}.attention.self.key", + "encoder.blocks.{n}.self_attention.value" => + "mpnet.encoder.layer.{n}.attention.self.value", + "encoder.blocks.{n}.self_attention.output" => + "mpnet.encoder.layer.{n}.attention.output.dense", + "encoder.blocks.{n}.self_attention_norm" => + "mpnet.encoder.layer.{n}.attention.output.LayerNorm", + "encoder.blocks.{n}.cross_attention.query" => + "mpnet.encoder.layer.{n}.crossattention.self.query", + "encoder.blocks.{n}.cross_attention.key" => + "mpnet.encoder.layer.{n}.crossattention.self.key", + "encoder.blocks.{n}.cross_attention.value" => + "mpnet.encoder.layer.{n}.crossattention.self.value", + "encoder.blocks.{n}.cross_attention.output" => + "mpnet.encoder.layer.{n}.crossattention.output.dense", + "encoder.blocks.{n}.cross_attention_norm" => + "mpnet.encoder.layer.{n}.crossattention.output.LayerNorm", + "encoder.blocks.{n}.ffn.intermediate" => "mpnet.encoder.layer.{n}.intermediate.dense", + "encoder.blocks.{n}.ffn.output" => "mpnet.encoder.layer.{n}.output.dense", + "encoder.blocks.{n}.output_norm" => "mpnet.encoder.layer.{n}.output.LayerNorm", + "pooler.output" => "mpnet.pooler.dense", + "language_modeling_head.dense" => "cls.predictions.transform.dense", + "language_modeling_head.norm" => "cls.predictions.transform.LayerNorm", + "language_modeling_head.output" => "cls.predictions.decoder", + "language_modeling_head.bias" => "cls.predictions" + } + end + end +end From 34a43d859c14a92705496a3638ca8b1a53c39f3a Mon Sep 17 00:00:00 2001 From: Scott Newcomer Date: Tue, 5 Nov 2024 06:52:16 -0600 Subject: [PATCH 2/6] add test --- test/bumblebee/text/mpnet_test.exs | 51 ++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 test/bumblebee/text/mpnet_test.exs diff --git a/test/bumblebee/text/mpnet_test.exs b/test/bumblebee/text/mpnet_test.exs new file mode 100644 index 00000000..26f437d5 --- /dev/null +++ b/test/bumblebee/text/mpnet_test.exs @@ -0,0 +1,51 @@ +defmodule Bumblebee.Text.MPNetTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MPNetModel"}) + + assert %Bumblebee.Text.MPNet{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-0.2331, 1.7817, 1.1736], [-1.1001, 1.3922, -0.3391], [0.0408, 0.8677, -0.0779]] + ]) + ) + end + + test ":for_masked_language_modeling" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MPNetForMaskedLM"}) + + assert %Bumblebee.Text.MPNet{architecture: :for_masked_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1124} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([[[-0.0127, 0.0508, 0.0904], [0.1151, 0.1189, 0.0922], [0.0089, 0.1132, -0.2470]]]) + ) + end +end From 356f92a209f073fba03e6e09ad73d9db83ecc015 Mon Sep 17 00:00:00 2001 From: Scott Newcomer Date: Wed, 6 Nov 2024 07:47:53 -0600 Subject: [PATCH 3/6] address some comments --- lib/bumblebee.ex | 4 ++-- lib/bumblebee/text/mpnet.ex | 10 +++++----- test/bumblebee/text/mpnet_test.exs | 16 ++++++++-------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 66d1eb54..d50a817f 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -170,8 +170,8 @@ defmodule Bumblebee do "MistralModel" => {Bumblebee.Text.Mistral, :base}, "MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling}, "MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification}, - "MPNetModel" => {Bumblebee.Text.MPNet, :base}, - "MPNetForMaskedLM" => {Bumblebee.Text.MPNet, :for_masked_language_modeling}, + "MpNetModel" => {Bumblebee.Text.MpNet, :base}, + "MpNetForMaskedLM" => {Bumblebee.Text.MpNet, :for_masked_language_modeling}, "PhiModel" => {Bumblebee.Text.Phi, :base}, "PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling}, "PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification}, diff --git a/lib/bumblebee/text/mpnet.ex b/lib/bumblebee/text/mpnet.ex index d74c06ff..6b334d8e 100644 --- a/lib/bumblebee/text/mpnet.ex +++ b/lib/bumblebee/text/mpnet.ex @@ -1,4 +1,4 @@ -defmodule Bumblebee.Text.MPNet do +defmodule Bumblebee.Text.MpNet do alias Bumblebee.Shared options = @@ -71,13 +71,13 @@ defmodule Bumblebee.Text.MPNet do ] ++ Shared.common_options([:use_cross_attention, :num_labels, :id_to_label]) @moduledoc """ - MPNet model family. + MpNet model family. ## Architectures - * `:base` - plain MPNet without any head on top + * `:base` - plain MpNet without any head on top - * `:for_masked_language_modeling` - MPNet with a language modeling + * `:for_masked_language_modeling` - MpNet with a language modeling head. The head returns logits for each token in the original sequence @@ -124,7 +124,7 @@ defmodule Bumblebee.Text.MPNet do ## References - * [MPNet: MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/pdf/2004.09297) + * [MpNet: MpNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/pdf/2004.09297) """ diff --git a/test/bumblebee/text/mpnet_test.exs b/test/bumblebee/text/mpnet_test.exs index 26f437d5..4f43f919 100644 --- a/test/bumblebee/text/mpnet_test.exs +++ b/test/bumblebee/text/mpnet_test.exs @@ -1,4 +1,4 @@ -defmodule Bumblebee.Text.MPNetTest do +defmodule Bumblebee.Text.MpNetTest do use ExUnit.Case, async: true import Bumblebee.TestHelpers @@ -7,9 +7,9 @@ defmodule Bumblebee.Text.MPNetTest do test ":base" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MPNetModel"}) + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MpNetModel"}) - assert %Bumblebee.Text.MPNet{architecture: :base} = spec + assert %Bumblebee.Text.MpNet{architecture: :base} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -18,21 +18,21 @@ defmodule Bumblebee.Text.MPNetTest do outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + assert Nx.shape(outputs.hidden_state) == {1, 10, 64} assert_all_close( - outputs.hidden_state[[.., 1..3, 1..3]], + outputs.hidden_state[[.., 1..4, 1..4]], Nx.tensor([ - [[-0.2331, 1.7817, 1.1736], [-1.1001, 1.3922, -0.3391], [0.0408, 0.8677, -0.0779]] + [[0.0033, -0.2547, 0.4954], [-1.5348, -1.5433, 0.4846], [0.7795, -0.3995, -0.9499]] ]) ) end test ":for_masked_language_modeling" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MPNetForMaskedLM"}) + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MpNetForMaskedLM"}) - assert %Bumblebee.Text.MPNet{architecture: :for_masked_language_modeling} = spec + assert %Bumblebee.Text.MpNet{architecture: :for_masked_language_modeling} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), From 848be4389045b7455f4dcc16c2811771badf7c68 Mon Sep 17 00:00:00 2001 From: Scott Newcomer Date: Wed, 6 Nov 2024 07:52:05 -0600 Subject: [PATCH 4/6] update to torch --- lib/bumblebee/text/mpnet.ex | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/text/mpnet.ex b/lib/bumblebee/text/mpnet.ex index 6b334d8e..09f875ef 100644 --- a/lib/bumblebee/text/mpnet.ex +++ b/lib/bumblebee/text/mpnet.ex @@ -11,7 +11,7 @@ defmodule Bumblebee.Text.MpNet do """ ], max_positions: [ - default: 514, + default: 512, doc: """ the vocabulary size of the position embedding. This corresponds to the maximum sequence length that this model can process. Typically this is set to a large value just in case, @@ -60,7 +60,7 @@ defmodule Bumblebee.Text.MpNet do "the dropout rate for the classification head. If not specified, the value of `:dropout_rate` is used instead" ], layer_norm_epsilon: [ - default: 1.0e-05, + default: 1.0e-12, doc: "the epsilon used by the layer normalization layers" ], initializer_scale: [ From f36abb61bfc01fc8724584bc75d25b1e64b64041 Mon Sep 17 00:00:00 2001 From: Scott Newcomer Date: Wed, 6 Nov 2024 08:53:57 -0600 Subject: [PATCH 5/6] no cross attention due to mpnet architecture --- lib/bumblebee/text/mpnet.ex | 36 ++++++------------------------------ 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/lib/bumblebee/text/mpnet.ex b/lib/bumblebee/text/mpnet.ex index 09f875ef..3f9d1b75 100644 --- a/lib/bumblebee/text/mpnet.ex +++ b/lib/bumblebee/text/mpnet.ex @@ -68,7 +68,7 @@ defmodule Bumblebee.Text.MpNet do doc: "the standard deviation of the normal initializer used for initializing kernel parameters" ] - ] ++ Shared.common_options([:use_cross_attention, :num_labels, :id_to_label]) + ] ++ Shared.common_options([:num_labels, :id_to_label]) @moduledoc """ MpNet model family. @@ -112,7 +112,7 @@ defmodule Bumblebee.Text.MpNet do The `:for_causal_language_modeling` model is a decoder and accepts the following additional inputs: `"encoder_hidden_state"`, - `"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`. + `"encoder_attention_mask"`, `"cache"`. ## Global layer options @@ -219,7 +219,6 @@ defmodule Bumblebee.Text.MpNet do Bumblebee.Utils.Model.inputs_to_map([ Axon.input("encoder_hidden_state", optional: true, shape: hidden_shape), Axon.input("encoder_attention_mask", optional: true, shape: shape), - Axon.input("cross_attention_head_mask", optional: true, shape: attention_head_mask_shape), Axon.input("cache", optional: true) ]) @@ -248,7 +247,6 @@ defmodule Bumblebee.Text.MpNet do inputs["attention_head_mask"], inputs["encoder_hidden_state"], inputs["encoder_attention_mask"], - inputs["cross_attention_head_mask"], inputs["cache"], spec, decoder?: decoder?, @@ -262,7 +260,6 @@ defmodule Bumblebee.Text.MpNet do pooled_state: pooled_state, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions, - cross_attentions: encoder_outputs.cross_attentions, cache: encoder_outputs.cache } end @@ -309,7 +306,6 @@ defmodule Bumblebee.Text.MpNet do attention_head_mask, encoder_hidden_state, encoder_attention_mask, - cross_attention_head_mask, cache, spec, opts @@ -317,8 +313,6 @@ defmodule Bumblebee.Text.MpNet do name = opts[:name] decoder? = opts[:decoder?] - cross_attention? = decoder? and spec.use_cross_attention - Layers.Transformer.blocks( hidden_state, [ @@ -340,15 +334,7 @@ defmodule Bumblebee.Text.MpNet do activation: spec.activation ], name: join(name, "blocks") - ] ++ - if(cross_attention?, - do: [ - cross_hidden_state: encoder_hidden_state, - cross_attention_mask: encoder_attention_mask, - cross_attention_head_mask: cross_attention_head_mask - ], - else: [] - ) + ] ) end @@ -426,24 +412,14 @@ defmodule Bumblebee.Text.MpNet do "embedder.token_type_embedding" => "mpnet.embeddings.token_type_embeddings", "embedder.norm" => "mpnet.embeddings.LayerNorm", "encoder.blocks.{n}.self_attention.query" => - "mpnet.encoder.layer.{n}.attention.self.query", - "encoder.blocks.{n}.self_attention.key" => "mpnet.encoder.layer.{n}.attention.self.key", + "mpnet.encoder.layer.{n}.attention.self.q", + "encoder.blocks.{n}.self_attention.key" => "mpnet.encoder.layer.{n}.attention.self.k", "encoder.blocks.{n}.self_attention.value" => - "mpnet.encoder.layer.{n}.attention.self.value", + "mpnet.encoder.layer.{n}.attention.self.v", "encoder.blocks.{n}.self_attention.output" => "mpnet.encoder.layer.{n}.attention.output.dense", "encoder.blocks.{n}.self_attention_norm" => "mpnet.encoder.layer.{n}.attention.output.LayerNorm", - "encoder.blocks.{n}.cross_attention.query" => - "mpnet.encoder.layer.{n}.crossattention.self.query", - "encoder.blocks.{n}.cross_attention.key" => - "mpnet.encoder.layer.{n}.crossattention.self.key", - "encoder.blocks.{n}.cross_attention.value" => - "mpnet.encoder.layer.{n}.crossattention.self.value", - "encoder.blocks.{n}.cross_attention.output" => - "mpnet.encoder.layer.{n}.crossattention.output.dense", - "encoder.blocks.{n}.cross_attention_norm" => - "mpnet.encoder.layer.{n}.crossattention.output.LayerNorm", "encoder.blocks.{n}.ffn.intermediate" => "mpnet.encoder.layer.{n}.intermediate.dense", "encoder.blocks.{n}.ffn.output" => "mpnet.encoder.layer.{n}.output.dense", "encoder.blocks.{n}.output_norm" => "mpnet.encoder.layer.{n}.output.LayerNorm", From e2312e9a2c8bf09415b89f8a5c046f8b8bff33b1 Mon Sep 17 00:00:00 2001 From: Scott Newcomer Date: Wed, 6 Nov 2024 09:06:01 -0600 Subject: [PATCH 6/6] moar alignment --- lib/bumblebee/text/mpnet.ex | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/bumblebee/text/mpnet.ex b/lib/bumblebee/text/mpnet.ex index 3f9d1b75..153021f5 100644 --- a/lib/bumblebee/text/mpnet.ex +++ b/lib/bumblebee/text/mpnet.ex @@ -143,6 +143,10 @@ defmodule Bumblebee.Text.MpNet do do: [ :base, :for_masked_language_modeling, + # :for_sequence_classification, # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpnet/modeling_mpnet.py + # :for_token_classification, + # :for_question_answering, + # :for_multiple_choice, ] @impl true @@ -424,10 +428,10 @@ defmodule Bumblebee.Text.MpNet do "encoder.blocks.{n}.ffn.output" => "mpnet.encoder.layer.{n}.output.dense", "encoder.blocks.{n}.output_norm" => "mpnet.encoder.layer.{n}.output.LayerNorm", "pooler.output" => "mpnet.pooler.dense", - "language_modeling_head.dense" => "cls.predictions.transform.dense", - "language_modeling_head.norm" => "cls.predictions.transform.LayerNorm", - "language_modeling_head.output" => "cls.predictions.decoder", - "language_modeling_head.bias" => "cls.predictions" + "language_modeling_head.dense" => "lm_head.dense", + "language_modeling_head.norm" => "lm_head.layer_norm", + "language_modeling_head.output" => "lm_head.decoder", + "language_modeling_head.bias" => "lm_head.bias" } end end