-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce ShawRelativePositionSDPA. #90
Conversation
64868d7
to
f64d89d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like other PRs, in pretty good shape. Most of my comments are nit picks.
if self.rel_v_embedding is not None: | ||
nn.init.xavier_uniform_(self.rel_v_embedding.weight) | ||
|
||
def rel_position_indices(self, seq_len: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest passing a device
argument here and constructing the tensor on that device instead of moving it later (i.e. line 128).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note necessarily required for this PR, but I wonder whether we should cache the generated indices (see relative_position.py as an example) in the future to avoid repeatedly initializing a mostly static buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah I was wondering how to do this as well!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean relative_attention.py
?
@@ -145,6 +158,9 @@ class Wav2Vec2EncoderConfig: | |||
conv_norm_type: Literal["batch_norm", "layer_norm"] | |||
"""The type of norm layer in the Conformer convolution module.""" | |||
|
|||
shaw_rel_position_sdpa_config: Optional[ShawRelativePositionSDPAConfig] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, it would be a BC breaking change, but I think it would be nicer if we could wrap parameters for "conv" positional encoder in the same way in a dataclass. Not relevant for this PR though. Just food for thought :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate wrap parameters for "conv" positional encoder in the same way in a dataclass
? I don't understand at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not related to this PR at all. In Wav2Vec2EncoderConfig
we have several attributes related to the default convolutional position encoder. We can consolidate them under a single configuration dataclass like you did for Shaw encoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we should do that, even I was thinking of that.
c4d5e71
to
5241bd1
Compare
What does this PR do? Please describe:
Implement the relative position SDPA as described in https://doi.org/10.48550/arxiv.1803.02155
Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.
Check list: