Skip to content

Commit

Permalink
Merge branch 'fla-org:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
yibozhong authored Jan 18, 2025
2 parents 6f2276a + 7b05692 commit ed12c30
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions fla/layers/rwkv6.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def __init__(
input_dim: int,
output_dim: int,
low_rank_dim: int,
bias: Optional[bool] = True
bias: Optional[bool] = True,
activation: Optional[str] = 'tanh'
):
super().__init__()

Expand All @@ -200,9 +201,20 @@ def __init__(
self.low_rank_dim = low_rank_dim
self.bias = bias

if activation is None:
self.activation = nn.Identity()
elif activation == 'sigmoid':
self.activation = nn.Sigmoid()
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'relu':
self.activation = nn.ReLU()
else:
raise ValueError(f"Not supported activation `{activation}`.")

self.lora = nn.Sequential(
nn.Linear(input_dim, low_rank_dim, bias=False),
nn.Tanh(),
self.activation,
nn.Linear(low_rank_dim, output_dim, bias=bias)
)

Expand Down

0 comments on commit ed12c30

Please sign in to comment.