Skip to content

Commit

Permalink
Merge pull request #690 from luraproject/backend_level_query_strings
Browse files Browse the repository at this point in the history
add backend level query strings filtering
  • Loading branch information
kpacha authored Oct 4, 2023
2 parents a02df2b + ed8dc3d commit 8f8a142
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 1 deletion.
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ type Backend struct {
ExtraConfig ExtraConfig `mapstructure:"extra_config"`
// HeadersToPass defines the list of headers to pass to this backend
HeadersToPass []string `mapstructure:"input_headers"`
// QueryStringsToPass has the list of query string params to be sent to the backend
QueryStringsToPass []string `mapstructure:"input_query_strings"`
}

// Plugin contains the config required by the plugin module
Expand Down
2 changes: 1 addition & 1 deletion config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func TestConfig_init(t *testing.T) {
t.Error(err.Error())
}

if hash != "0O1SlXMLFZwKikXa02ymwM301C8q0P4ekbb5PzsBbxM=" {
if hash != "v28MFBnMvvy1JAQZcC3ZBhusgtxl/o0k+7R1NiK0M34=" {
t.Errorf("unexpected hash: %s", hash)
}
}
Expand Down
2 changes: 2 additions & 0 deletions config/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ type parseableBackend struct {
SD string `json:"sd"`
HeadersToPass []string `json:"input_headers"`
SDScheme string `json:"sd_scheme"`
QueryStringsToPass []string `json:"input_query_strings"`
}

func (p *parseableBackend) normalize() *Backend {
Expand All @@ -374,6 +375,7 @@ func (p *parseableBackend) normalize() *Backend {
AllowList: p.AllowList,
DenyList: p.DenyList,
HeadersToPass: p.HeadersToPass,
QueryStringsToPass: p.QueryStringsToPass,
}
if b.SDScheme == "" {
b.SDScheme = "http"
Expand Down
1 change: 1 addition & 0 deletions proxy/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func (pf defaultFactory) newStack(backend *config.Backend) (p Proxy) {
p = NewBackendPluginMiddleware(pf.logger, backend)(p)
p = NewGraphQLMiddleware(pf.logger, backend)(p)
p = NewFilterHeadersMiddleware(pf.logger, backend)(p)
p = NewFilterQueryStringsMiddleware(pf.logger, backend)(p)
p = NewLoadBalancedMiddlewareWithSubscriberAndLogger(pf.logger, pf.subscriberFactory(backend))(p)
if backend.ConcurrentCalls > 1 {
p = NewConcurrentMiddlewareWithLogger(pf.logger, backend)(p)
Expand Down
67 changes: 67 additions & 0 deletions proxy/headers_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,70 @@ func TestNewFilterHeadersMiddleware(t *testing.T) {
return
}
}

func TestNewFilterHeadersMiddlewareBlockAll(t *testing.T) {
mw := NewFilterHeadersMiddleware(
logging.NoOp,
&config.Backend{
HeadersToPass: []string{""},
},
)

var receivedReq *Request
prxy := mw(func(ctx context.Context, req *Request) (*Response, error) {
receivedReq = req
return nil, nil
})

sentReq := &Request{
Body: nil,
Params: map[string]string{},
Headers: map[string][]string{
"X-This-Shall-Pass": []string{"tupu", "supu"},
"X-You-Shall-Not-Pass": []string{"Balrog"},
},
}

prxy(context.Background(), sentReq)

if receivedReq == sentReq {
t.Errorf("request should be different")
return
}

if len(receivedReq.Headers) != 0 {
t.Errorf("should have blocked all headers")
return
}
}

func TestNewFilterHeadersMiddlewareAllowAll(t *testing.T) {
mw := NewFilterHeadersMiddleware(
logging.NoOp,
&config.Backend{
HeadersToPass: []string{},
},
)

var receivedReq *Request
prxy := mw(func(ctx context.Context, req *Request) (*Response, error) {
receivedReq = req
return nil, nil
})

sentReq := &Request{
Body: nil,
Params: map[string]string{},
Headers: map[string][]string{
"X-This-Shall-Pass": []string{"tupu", "supu"},
"X-You-Shall-Not-Pass": []string{"Balrog"},
},
}

prxy(context.Background(), sentReq)

if len(receivedReq.Headers) != 2 {
t.Errorf("should have let pass all headers")
return
}
}
64 changes: 64 additions & 0 deletions proxy/query_strings_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-License-Identifier: Apache-2.0

package proxy

import (
"context"
"net/url"

"github.com/luraproject/lura/v2/config"
"github.com/luraproject/lura/v2/logging"
)

// NewFilterQueryStringsMiddleware returns a middleware with or without a header filtering
// proxy wrapping the next element (depending on the configuration).
func NewFilterQueryStringsMiddleware(logger logging.Logger, remote *config.Backend) Middleware {
if len(remote.QueryStringsToPass) == 0 {
return emptyMiddlewareFallback(logger)
}

return func(next ...Proxy) Proxy {
if len(next) > 1 {
logger.Fatal("too many proxies for this proxy middleware: NewFilterQueryStringsMiddleware only accepts 1 proxy, got %d", len(next))
return nil
}
nextProxy := next[0]
return func(ctx context.Context, request *Request) (*Response, error) {
if len(request.Query) == 0 {
return nextProxy(ctx, request)
}
numQueryStringsToPass := 0
for _, v := range remote.QueryStringsToPass {
if _, ok := request.Query[v]; ok {
numQueryStringsToPass++
}
}
if numQueryStringsToPass == len(request.Query) {
// all the query strings should pass, no need to clone the headers
return nextProxy(ctx, request)
}
// ATTENTION: this is not a clone of query strings!
// this just filters the query strings we do not want to send:
// issues and race conditions could happen the same way as when we
// do not filter the headers. This is a design decission, and if we
// want to clone the query string values (because of write modifications),
// that should be done at an upper level (so the approach is the same
// for non filtered parallel requests).
newQueryStrings := make(url.Values, numQueryStringsToPass)
for _, v := range remote.QueryStringsToPass {
if values, ok := request.Query[v]; ok {
newQueryStrings[v] = values
}
}
return nextProxy(ctx, &Request{
Method: request.Method,
URL: request.URL,
Query: newQueryStrings,
Path: request.Path,
Body: request.Body,
Params: request.Params,
Headers: request.Headers,
})
}
}
}
168 changes: 168 additions & 0 deletions proxy/query_strings_filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// SPDX-License-Identifier: Apache-2.0

package proxy

import (
"context"
"testing"

"github.com/luraproject/lura/v2/config"
"github.com/luraproject/lura/v2/logging"
)

func TestNewFilterQueryStringsMiddleware(t *testing.T) {
mw := NewFilterQueryStringsMiddleware(
logging.NoOp,
&config.Backend{
QueryStringsToPass: []string{
"oak",
"cedar",
},
},
)

var receivedReq *Request
prxy := mw(func(ctx context.Context, req *Request) (*Response, error) {
receivedReq = req
return nil, nil
})

sentReq := &Request{
Body: nil,
Params: map[string]string{},
Query: map[string][]string{
"oak": []string{"acorn", "evergreen"},
"maple": []string{"tree", "shrub"},
"cedar": []string{"mediterranean", "himalayas"},
"willow": []string{"350"},
},
}

prxy(context.Background(), sentReq)

if receivedReq == sentReq {
t.Errorf("request should be different")
return
}

oak, ok := receivedReq.Query["oak"]
if !ok {
t.Errorf("missing 'oak'")
return
}
if len(oak) != len(sentReq.Query["oak"]) {
t.Errorf("want len(oak): %d, got %d",
len(sentReq.Query["oak"]), len(oak))
return
}

for idx, expected := range sentReq.Query["oak"] {
if expected != oak[idx] {
t.Errorf("want oak[%d] = %s, got %s",
idx, expected, oak[idx])
return
}
}

if _, ok := receivedReq.Query["cedar"]; !ok {
t.Errorf("missing 'cedar'")
return
}

if _, ok := receivedReq.Query["mapple"]; ok {
t.Errorf("should not be there: 'mapple'")
return
}

if _, ok := receivedReq.Query["willow"]; ok {
t.Errorf("should not be there: 'willow'")
return
}

// check that when query strings are all the expected, no need to copy
sentReq = &Request{
Body: nil,
Params: map[string]string{},
Query: map[string][]string{
"oak": []string{"acorn", "evergreen"},
"cedar": []string{"mediterranean", "himalayas"},
},
}

prxy(context.Background(), sentReq)

if receivedReq != sentReq {
t.Errorf("request should be the same, no modification of query string expected")
return
}
}

func TestFilterQueryStringsBlockAll(t *testing.T) {
// In order to block all the query strings, we must only let pass
// the 'empty' string ""
mw := NewFilterQueryStringsMiddleware(
logging.NoOp,
&config.Backend{
QueryStringsToPass: []string{""},
},
)

var receivedReq *Request
prxy := mw(func(ctx context.Context, req *Request) (*Response, error) {
receivedReq = req
return nil, nil
})

sentReq := &Request{
Body: nil,
Params: map[string]string{},
Query: map[string][]string{
"oak": []string{"acorn", "evergreen"},
"maple": []string{"tree", "shrub"},
},
}

prxy(context.Background(), sentReq)

if receivedReq == sentReq {
t.Errorf("request should be different")
return
}

if len(receivedReq.Query) != 0 {
t.Errorf("should have blocked all query strings")
return
}
}

func TestFilterQueryStringsAllowAll(t *testing.T) {
// Empty backend query strings to passa everything
mw := NewFilterQueryStringsMiddleware(
logging.NoOp,
&config.Backend{
QueryStringsToPass: []string{},
},
)

var receivedReq *Request
prxy := mw(func(ctx context.Context, req *Request) (*Response, error) {
receivedReq = req
return nil, nil
})

sentReq := &Request{
Body: nil,
Params: map[string]string{},
Query: map[string][]string{
"oak": []string{"acorn", "evergreen"},
"maple": []string{"tree", "shrub"},
},
}

prxy(context.Background(), sentReq)

if len(receivedReq.Query) != 2 {
t.Errorf("should have passed all query strings")
return
}
}

0 comments on commit 8f8a142

Please sign in to comment.