diff --git a/keys.go b/keys.go index fa650fe7..dde406ae 100644 --- a/keys.go +++ b/keys.go @@ -37,6 +37,7 @@ import ( type PrivateKey []byte +// MustPrivateKeyFromBase58 returns a PrivateKey from a base58-encoded string, panicking if the input is invalid. func MustPrivateKeyFromBase58(in string) PrivateKey { out, err := PrivateKeyFromBase58(in) if err != nil { @@ -45,11 +46,30 @@ func MustPrivateKeyFromBase58(in string) PrivateKey { return out } +// PrivateKeyFromBase58 returns a PrivateKey from a base58-encoded string. +// +// PrivateKeyFromBase58 returns a PrivateKey from a base58-encoded string. The function +// first decodes the input string using base58, then checks if the resulting private key +// is valid by deriving the corresponding public key and checking if it is on the Ed25519 +// curve. If the private key is invalid, an error is returned. +// +// Parameters: +// +// privkey - the base58-encoded private key string +// +// Returns: +// +// PrivateKey - the decoded private key +// error - an error if the input string is invalid or the derived public key is not on the curve func PrivateKeyFromBase58(privkey string) (PrivateKey, error) { res, err := base58.Decode(privkey) if err != nil { return nil, err } + pub := PrivateKey(res).PublicKey().Bytes() + if !IsOnCurve(pub) { + return nil, errors.New("invalid private key") + } return res, nil } @@ -72,6 +92,15 @@ func (k PrivateKey) String() string { return base58.Encode(k) } +// NewRandomPrivateKey generates a new random Ed25519 private key. +// +// NewRandomPrivateKey returns a new random Ed25519 private key. The private key is +// generated using a cryptographically secure random number generator. +// +// Returns: +// +// PrivateKey: a new random Ed25519 private key +// error: an error if the key generation fails func NewRandomPrivateKey() (PrivateKey, error) { pub, priv, err := ed25519.GenerateKey(crypto_rand.Reader) if err != nil { diff --git a/keys_test.go b/keys_test.go index 923c3206..d29e9dac 100644 --- a/keys_test.go +++ b/keys_test.go @@ -18,10 +18,13 @@ package solana import ( + "crypto/ed25519" + "crypto/rand" "encoding/binary" "encoding/hex" "errors" "flag" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -101,6 +104,89 @@ func TestPublicKeyFromBase58(t *testing.T) { } } +func TestPrivateKeyFromBase58(t *testing.T) { + tests := []struct { + name string + in string + want string + wantErr error + }{ + { + name: "normal case", + in: "6HsFaXKVD7mo43oTbdqyGgAnYFeNNhqY75B3JGJ6K8a227KjjG3uW3v", + want: "6HsFaXKVD7mo43oTbdqyGgAnYFeNNhqY75B3JGJ6K8a227KjjG3uW3v", + }, + { + name: "edge case - empty string", + in: "", + want: "", + wantErr: errors.New("zero length string"), + }, + { + name: "edge case - invalid base58", + in: "invalid-base58", + want: "", + wantErr: errors.New("invalid base58 digit ('l')"), + }, + { + name: "extreme case - very long input", + in: strings.Repeat("3yZe7d", 100), + want: strings.Repeat("3yZe7d", 100), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := PrivateKeyFromBase58(test.in) + require.Equal(t, test.wantErr, err) + require.Equal(t, test.want, got.String()) + }) + } +} + +func TestMustPrivateKeyFromBase58(t *testing.T) { + tests := []struct { + name string + in string + want string + wantPanic bool + }{ + { + name: "normal case", + in: "3yZe7d", + want: "3yZe7d", + }, + { + name: "edge case - empty string", + in: "", + wantPanic: true, + }, + { + name: "edge case - invalid base58", + in: "invalid-base58", + wantPanic: true, + }, + { + name: "extreme case - very long input", + in: strings.Repeat("3yZe7d", 100), + want: strings.Repeat("3yZe7d", 100), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.wantPanic { + require.Panics(t, func() { + MustPrivateKeyFromBase58(test.in) + }) + } else { + got := MustPrivateKeyFromBase58(test.in) + require.Equal(t, test.want, got.String()) + } + }) + } +} + func TestPrivateKeyFromSolanaKeygenFile(t *testing.T) { tests := []struct { inFile string @@ -395,7 +481,38 @@ func TestGetAddedRemoved(t *testing.T) { ) } } +func TestIsOnCurve(t *testing.T) { + // Test a valid private key + privateKey, err := NewRandomPrivateKey() + if err != nil { + t.Errorf("Failed to generate private key: %v", err) + } + // Test a valid public key + publicKey := privateKey.PublicKey() + if !IsOnCurve(publicKey.Bytes()) { + t.Errorf("Valid public key is not on the curve") + } + + // Test an invalid key (too short) + shortKey := []byte{1, 2, 3} + if IsOnCurve(shortKey) { + t.Errorf("Invalid key (too short) is on the curve") + } + + // Test an invalid key (too long) + longKey := make([]byte, ed25519.PrivateKeySize+1) + if IsOnCurve(longKey) { + t.Errorf("Invalid key (too long) is on the curve") + } + + // Test an invalid key (random bytes) + randKey := make([]byte, ed25519.PrivateKeySize) + _, _ = rand.Read(randKey) + if IsOnCurve(randKey) { + t.Errorf("Invalid key (random bytes) is on the curve") + } +} func TestIsNativeProgramID(t *testing.T) { require.True(t, isNativeProgramID(ConfigProgramID)) }