Skip to content

Commit

Permalink
fix: update ret_output shape (#1147)
Browse files Browse the repository at this point in the history
* fix: update ret_output shape

* fix: update w_g

* fix: update w_o shape
  • Loading branch information
ArnolFokam authored Dec 3, 2024
1 parent eec8945 commit 371fd9d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions mava/networks/retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def setup(self) -> None:
self.w_g = self.param(
"w_g",
nn.initializers.normal(stddev=1 / self.embed_dim),
(self.embed_dim, self.head_size),
(self.embed_dim, self.embed_dim),
)
self.w_o = self.param(
"w_o",
nn.initializers.normal(stddev=1 / self.embed_dim),
(self.head_size, self.embed_dim),
(self.embed_dim, self.embed_dim),
)
self.group_norm = nn.GroupNorm(num_groups=self.n_head)

Expand Down Expand Up @@ -278,7 +278,7 @@ def __call__(
if self.memory_config.timestep_positional_encoding:
key, query, value = self.pe(key, query, value, step_count)

ret_output = jnp.zeros((B, C, self.head_size), dtype=value.dtype)
ret_output = jnp.zeros((B, C, self.embed_dim), dtype=value.dtype)
for head in range(self.n_head):
y, new_hs = self.retention_heads[head](key, query, value, hstate[:, head], dones)
ret_output = ret_output.at[
Expand All @@ -304,7 +304,7 @@ def recurrent(
if self.memory_config.timestep_positional_encoding:
key_n, query_n, value_n = self.pe(key_n, query_n, value_n, step_count)

ret_output = jnp.zeros((B, S, self.head_size), dtype=value_n.dtype)
ret_output = jnp.zeros((B, S, self.embed_dim), dtype=value_n.dtype)
for head in range(self.n_head):
y, new_hs = self.retention_heads[head].recurrent(
key_n, query_n, value_n, hstate[:, head]
Expand Down

0 comments on commit 371fd9d

Please sign in to comment.