Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ShawRelativePositionSDPA. #90

Merged
merged 3 commits into from
Oct 9, 2023
Merged

Introduce ShawRelativePositionSDPA. #90

merged 3 commits into from
Oct 9, 2023

Conversation

kauterry
Copy link
Contributor

@kauterry kauterry commented Oct 6, 2023

What does this PR do? Please describe:

Implement the relative position SDPA as described in https://doi.org/10.48550/arxiv.1803.02155

Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

@kauterry kauterry requested a review from cbalioglu as a code owner October 6, 2023 02:29
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 6, 2023
@kauterry kauterry force-pushed the relpos_attn branch 2 times, most recently from 64868d7 to f64d89d Compare October 7, 2023 22:02
@kauterry kauterry changed the base branch from main to causal_conv1d October 7, 2023 22:03
Copy link
Contributor

@cbalioglu cbalioglu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like other PRs, in pretty good shape. Most of my comments are nit picks.

src/fairseq2/nn/transformer/relative_position_attention.py Outdated Show resolved Hide resolved
src/fairseq2/nn/transformer/relative_position_attention.py Outdated Show resolved Hide resolved
if self.rel_v_embedding is not None:
nn.init.xavier_uniform_(self.rel_v_embedding.weight)

def rel_position_indices(self, seq_len: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest passing a device argument here and constructing the tensor on that device instead of moving it later (i.e. line 128).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note necessarily required for this PR, but I wonder whether we should cache the generated indices (see relative_position.py as an example) in the future to avoid repeatedly initializing a mostly static buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah I was wondering how to do this as well!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean relative_attention.py?

src/fairseq2/nn/transformer/relative_position_attention.py Outdated Show resolved Hide resolved
src/fairseq2/nn/transformer/relative_position_attention.py Outdated Show resolved Hide resolved
src/fairseq2/nn/transformer/relative_position_attention.py Outdated Show resolved Hide resolved
src/fairseq2/nn/transformer/relative_position_attention.py Outdated Show resolved Hide resolved
src/fairseq2/nn/transformer/relative_position_attention.py Outdated Show resolved Hide resolved
src/fairseq2/models/wav2vec2/builder.py Show resolved Hide resolved
@@ -145,6 +158,9 @@ class Wav2Vec2EncoderConfig:
conv_norm_type: Literal["batch_norm", "layer_norm"]
"""The type of norm layer in the Conformer convolution module."""

shaw_rel_position_sdpa_config: Optional[ShawRelativePositionSDPAConfig]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, it would be a BC breaking change, but I think it would be nicer if we could wrap parameters for "conv" positional encoder in the same way in a dataclass. Not relevant for this PR though. Just food for thought :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate wrap parameters for "conv" positional encoder in the same way in a dataclass? I don't understand at all.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not related to this PR at all. In Wav2Vec2EncoderConfig we have several attributes related to the default convolutional position encoder. We can consolidate them under a single configuration dataclass like you did for Shaw encoder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should do that, even I was thinking of that.

Base automatically changed from causal_conv1d to main October 9, 2023 19:26
src/fairseq2/nn/transformer/shaw_attention.py Outdated Show resolved Hide resolved
src/fairseq2/nn/transformer/shaw_attention.py Outdated Show resolved Hide resolved
src/fairseq2/nn/transformer/shaw_attention.py Outdated Show resolved Hide resolved
@cbalioglu cbalioglu merged commit 199cf93 into main Oct 9, 2023
@cbalioglu cbalioglu deleted the relpos_attn branch October 9, 2023 22:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants