Skip to content

Commit

Permalink
add encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
spikelu2016 committed Nov 16, 2024
1 parent 78d1f24 commit d3e7cc8
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 30 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
release_notes.md
target
.DS_STORE
.vscode/launch.json
.vscode/launch.json
.env
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
## 1.39.0 - 2024-11-15
### Added
- Added encryption integration

### Changed
- Removed support for Redis TLS config


## 1.38.0 - 2024-11-09
### Added
- Added support for `claude-3-5-haiku`
Expand Down
2 changes: 0 additions & 2 deletions cmd/bricksllm/.env
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ POSTGRESQL_PASSWORD=
POSTGRESQL_SSL_MODE=disable
POSTGRESQL_PORT=5432
REDIS_HOSTS=localhost
REDIS_ENABLE_TLS=false
REDIS_INSECURE_SKIP_VERIFY=false
REDIS_PORT=6379
REDIS_USERNAME=
REDIS_PASSWORD=
Expand Down
2 changes: 0 additions & 2 deletions cmd/bricksllm/config_local.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
"postgresql_port": "5432",
"redis_hosts": "localhost",
"redis_port": "6379",
"redis_enable_tls": false,
"redis_insecure_skip_verify": false,
"redis_username": "",
"redis_password": "",
"redis_read_time_out": "1s",
Expand Down
14 changes: 5 additions & 9 deletions cmd/bricksllm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"crypto/tls"
"flag"
"fmt"
"os"
Expand All @@ -13,6 +12,7 @@ import (
auth "github.com/bricks-cloud/bricksllm/internal/authenticator"
"github.com/bricks-cloud/bricksllm/internal/cache"
"github.com/bricks-cloud/bricksllm/internal/config"
"github.com/bricks-cloud/bricksllm/internal/encryptor"
"github.com/bricks-cloud/bricksllm/internal/logger/zap"
"github.com/bricks-cloud/bricksllm/internal/manager"
"github.com/bricks-cloud/bricksllm/internal/message"
Expand Down Expand Up @@ -182,12 +182,6 @@ func main() {
DB: cfg.RedisDBStartIndex + dbIndex,
}

if cfg.RedisEnableTLS {
options.TLSConfig = &tls.Config{
InsecureSkipVerify: cfg.RedisInsecureSkipVerify,
}
}

return options
}

Expand Down Expand Up @@ -292,9 +286,11 @@ func main() {
psCache := redisStorage.NewProviderSettingsCache(providerSettingsRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
keysCache := redisStorage.NewKeysCache(keysRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)

encryptor := encryptor.NewEncryptor(cfg.DecryptionEndpoint, cfg.EncryptionEndpoint, cfg.EnableEncrytion, cfg.EncryptionTimeout)

m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache)
krm := manager.NewReportingManager(costStorage, store, store)
psm := manager.NewProviderSettingsManager(store, psCache)
psm := manager.NewProviderSettingsManager(store, psCache, encryptor)
cpm := manager.NewCustomProvidersManager(store, cpMemStore)
rm := manager.NewRouteManager(store, store, rMemStore, psm)
pm := manager.NewPolicyManager(store, rMemStore)
Expand Down Expand Up @@ -332,7 +328,7 @@ func main() {

rec := recorder.NewRecorder(costStorage, userCostStorage, costLimitCache, userCostLimitCache, ce, store)
rlm := manager.NewRateLimitManager(rateLimitCache, userRateLimitCache)
a := auth.NewAuthenticator(psm, m, rm, store)
a := auth.NewAuthenticator(psm, m, rm, store, encryptor)

c := cache.NewCache(apiCache)

Expand Down
50 changes: 41 additions & 9 deletions internal/authenticator/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"net/http"
"strconv"
"strings"

internal_errors "github.com/bricks-cloud/bricksllm/internal/errors"
Expand Down Expand Up @@ -34,19 +35,26 @@ type keyStorage interface {
GetKeyByHash(hash string) (*key.ResponseKey, error)
}

type Decryptor interface {
Decrypt(input string, headers map[string]string) (string, error)
Enabled() bool
}

type Authenticator struct {
psm providerSettingsManager
kc keysCache
rm routesManager
ks keyStorage
psm providerSettingsManager
kc keysCache
rm routesManager
ks keyStorage
decryptor Decryptor
}

func NewAuthenticator(psm providerSettingsManager, kc keysCache, rm routesManager, ks keyStorage) *Authenticator {
func NewAuthenticator(psm providerSettingsManager, kc keysCache, rm routesManager, ks keyStorage, decryptor Decryptor) *Authenticator {
return &Authenticator{
psm: psm,
kc: kc,
rm: rm,
ks: ks,
psm: psm,
kc: kc,
rm: rm,
ks: ks,
decryptor: decryptor,
}
}

Expand Down Expand Up @@ -268,6 +276,30 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons
used = selected[rand.Intn(len(selected))]
}

if a.decryptor.Enabled() {
encryptedParam := ""
if used.Provider == "amazon" {
encryptedParam = used.Setting["awsSecretAccessKey"]
} else if len(used.Setting["apikey"]) != 0 {
encryptedParam = used.Setting["apikey"]
}

if len(encryptedParam) != 0 {
decryptedSecret, err := a.decryptor.Decrypt(encryptedParam, map[string]string{"X-UPDATED-AT": strconv.FormatInt(used.UpdatedAt, 10)})
if err == nil {
if used.Provider == "amazon" {
used.Setting["awsSecretAccessKey"] = decryptedSecret
} else {
used.Setting["apikey"] = decryptedSecret
}
}

if err != nil {
fmt.Println(fmt.Printf("error when encrypting %v", err))
}
}
}

err := rewriteHttpAuthHeader(req, used)
if err != nil {
return nil, nil, err
Expand Down
11 changes: 9 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"errors"
"os"
"path/filepath"
"time"
Expand All @@ -25,8 +26,6 @@ type Config struct {
RedisPort string `koanf:"redis_port" env:"REDIS_PORT" envDefault:"6379"`
RedisUsername string `koanf:"redis_username" env:"REDIS_USERNAME"`
RedisPassword string `koanf:"redis_password" env:"REDIS_PASSWORD"`
RedisEnableTLS bool `koanf:"redis_enable_tls" env:"REDIS_ENABLE_TLS" envDefault:"false"`
RedisInsecureSkipVerify bool `koanf:"redis_insecure_skip_verify" env:"REDIS_INSECURE_SKIP_VERIFY" envDefault:"false"`
RedisDBStartIndex int `koanf:"redis_db_start_index" env:"REDIS_DB_START_INDEX" envDefault:"0"`
RedisReadTimeout time.Duration `koanf:"redis_read_time_out" env:"REDIS_READ_TIME_OUT" envDefault:"1s"`
RedisWriteTimeout time.Duration `koanf:"redis_write_time_out" env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"`
Expand All @@ -47,6 +46,10 @@ type Config struct {
AmazonRequestTimeout time.Duration `koanf:"amazon_request_timeout" env:"AMAZON_REQUEST_TIMEOUT" envDefault:"5s"`
AmazonConnectionTimeout time.Duration `koanf:"amazon_connection_timeout" env:"AMAZON_CONNECTION_TIMEOUT" envDefault:"10s"`
RemoveUserAgent bool `koanf:"remove_user_agent" env:"REMOVE_USER_AGENT" envDefault:"false"`
EnableEncrytion bool `koanf:"enable_encryption" env:"ENABLE_ENCRYPTION" envDefault:"false"`
EncryptionEndpoint string `koanf:"encryption_endpoint" env:"ENCRYPTION_ENDPOINT"`
DecryptionEndpoint string `koanf:"decryption_endpoint" env:"DECRYPTION_ENDPOINT"`
EncryptionTimeout time.Duration `koanf:"encryption_timeout" env:"ENCRYPTION_TIMEOUT" envDefault:"5s"`
}

func prepareDotEnv(envFilePath string) error {
Expand Down Expand Up @@ -82,6 +85,10 @@ func LoadConfig(log *zap.Logger) (*Config, error) {
return nil, err
}

if cfg.EnableEncrytion && len(cfg.EncryptionEndpoint) == 0 {
return nil, errors.New("encryption endpoint cannot be empty")
}

err = prepareDotEnv(".env")
if err != nil {
log.Sugar().Infof("error loading config from .env file: %v", err)
Expand Down
120 changes: 120 additions & 0 deletions internal/encryptor/encryptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package encryptor

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"time"
)

type Encryptor struct {
decryptionURL string
encryptionURL string
enabled bool
client http.Client
timeout time.Duration
}

type Secret struct {
Secret string `json:"secret"`
}

type EncryptionResponse struct {
EncryptedSecret string `json:"encryptedSecret"`
}

type DecryptionResponse struct {
DecryptedSecret string `json:"decryptedSecret"`
}

func NewEncryptor(decryptionURL string, encryptionURL string, enabled bool, timeout time.Duration) Encryptor {
return Encryptor{
decryptionURL: decryptionURL,
encryptionURL: encryptionURL,
client: http.Client{},
enabled: enabled,
timeout: timeout,
}
}

func (e Encryptor) Encrypt(input string, headers map[string]string) (string, error) {
data, err := json.Marshal(Secret{
Secret: input,
})
if err != nil {
return "", err
}

ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.encryptionURL, bytes.NewBuffer(data))
if err != nil {
return "", err
}

for header, value := range headers {
req.Header.Add(header, value)
}

res, err := e.client.Do(req)
if err != nil {
return "", err
}

bytes, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}

encryptionResponse := EncryptionResponse{}
err = json.Unmarshal(bytes, &encryptionResponse)
if err != nil {
return "", err
}

return encryptionResponse.EncryptedSecret, nil
}

func (e Encryptor) Enabled() bool {
return e.enabled && len(e.decryptionURL) != 0 && len(e.encryptionURL) != 0
}

func (e Encryptor) Decrypt(input string, headers map[string]string) (string, error) {
data, err := json.Marshal(Secret{
Secret: input,
})
if err != nil {
return "", err
}

ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.decryptionURL, bytes.NewBuffer(data))
if err != nil {
return "", err
}

for header, value := range headers {
req.Header.Add(header, value)
}

res, err := e.client.Do(req)
if err != nil {
return "", err
}

bytes, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}

decryptionSecret := DecryptionResponse{}
err = json.Unmarshal(bytes, &decryptionSecret)
if err != nil {
return "", err
}

return decryptionSecret.DecryptedSecret, nil
}
Loading

0 comments on commit d3e7cc8

Please sign in to comment.