From 7652e9ebfb1eaefbf9a6054b8e88ee74b7b21b07 Mon Sep 17 00:00:00 2001 From: Alaric Whitney Date: Wed, 25 Aug 2021 15:29:24 -0500 Subject: [PATCH 1/5] added filter parsing that will convert to hex vals --- filter.go | 163 +++++++++++++++++++++++++++++++++++++++++++++++++ filter_test.go | 75 +++++++++++++++++++++++ 2 files changed, 238 insertions(+) diff --git a/filter.go b/filter.go index 73505e79..820d9a37 100644 --- a/filter.go +++ b/filter.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "regexp" "strings" "unicode" "unicode/utf8" @@ -27,6 +28,10 @@ const ( FilterExtensibleMatch = 9 ) +var ( + isAlphaNumeric = regexp.MustCompile(`^[a-zA-Z0-9]+$`).MatchString +) + // FilterMap contains human readable descriptions of Filter choices var FilterMap = map[uint64]string{ FilterAnd: "And", @@ -78,6 +83,10 @@ func CompileFilter(filter string) (*ber.Packet, error) { if len(filter) == 0 || filter[0] != '(' { return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('")) } + filter, err := ParseFilter(filter) + if err != nil { + return nil, NewError(ErrorFilterCompile, err) + } packet, pos, err := compileFilter(filter, 1) if err != nil { return nil, err @@ -485,3 +494,157 @@ func decodeEscapedSymbols(src []byte) (string, error) { offset += runeSize } } + +// ParseFilter will take the filter string, and transform the DN into an ascii safe string +func ParseFilter(filter string) (parsedFilter string, err error) { + var startRecording bool + var value string + var balance []string + for i, val := range filter { + switch string(val) { + case "=": + if !startRecording { + startRecording = true + } else { + // we've run into a 2nd = symbol in the statement. Reset the value + parsedFilter += value + value = "" + } + parsedFilter += string(val) + case ")": + if startRecording { + balance, err = checkBalance("(", filter, balance) + if err != nil { + err = nil + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + value = "" + } else { + value += string(val) + } + } else { + parsedFilter += string(val) + } + case "(": + if startRecording { + balance = append(balance, "(") + value += string(val) + } else { + parsedFilter += string(val) + } + case `\`: + if startRecording { + // look ahead for hex parenthesis to check for unbalanced queries + if utf8.RuneCountInString(filter) > i+2 { + byteVal := make([]byte, 1) + if !isAlphaNumeric(string(filter[i+1])) { + value += string(val) + continue + } + if _, err = hexpac.Decode(byteVal, []byte(filter[i+1:i+3])); err != nil { + return "", fmt.Errorf("ldap: invalid characters for escape in filter: %v", err) + } + if strings.EqualFold(filter[i+1:i+3], hexpac.EncodeToString([]byte(`(`))) { + balance = append(balance, hexpac.EncodeToString([]byte(`(`))) + } else if strings.EqualFold(filter[i+1:i+3], hexpac.EncodeToString([]byte(`)`))) { + balance, err = checkBalance(hexpac.EncodeToString([]byte(`(`)), filter, balance) + if err != nil { + err = nil + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + value = "" + } else { + value += string(val) + } + } else { + value += string(val) + } + } else { + value += string(val) + } + } else { + parsedFilter += string(val) + } + case ",": + if !startRecording { + return "", fmt.Errorf("ldap: invalid filter string: %s", filter) + } + if len(balance) != 0 { + return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) + } + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + value = "" + default: + currentRune, _ := utf8.DecodeRuneInString(string(val)) + if currentRune == utf8.RuneError { + return "", fmt.Errorf("ldap: error reading rune at position %d", i) + } + if startRecording { + value += string(val) + } else { + parsedFilter += string(val) + } + } + } + if len(balance) != 0 { + return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) + } + return +} + +// checkBalance will check if a recorded value within the filter is balanced and adjust the balance queue. If not, reset the balance queue and produce an error +func checkBalance(checkElement string, filter string, balance []string) (newBalance []string, err error) { + if len(balance) == 0 { + err = fmt.Errorf("ldap: unbalanced filter string: %s", filter) + } else if !strings.EqualFold(balance[len(balance)-1], checkElement) { + err = fmt.Errorf("ldap: unbalanced filter string: %s", filter) + } else { + newBalance = balance[:len(balance)-1] + } + return +} + +// encodeToHex will take the input value, and turn non-alpha numeric characters into hex values +func encodeToHex(filterValue string) (encodedValue string) { + var skip int + for i, val := range filterValue { + if skip != 0 { + skip-- + continue + } + if string(val) == ` ` { + encodedValue += string(val) + } else if string(val) == `*` { + encodedValue += string(val) + } else if string(val) == `\` { + if len(filterValue) > i+2 { + if isAlphaNumeric(string(filterValue[i+1 : i+3])) { + encodedValue += fmt.Sprintf("\\%s", string(filterValue[i+1:i+3])) + skip = 2 + continue + } else { + encodedValue += encodeHexValue(string(val)) + } + } else { + encodedValue += encodeHexValue(string(val)) + } + } else if isAlphaNumeric(string(val)) { + encodedValue += string(val) + } else { + encodedValue += encodeHexValue(string(val)) + } + } + return +} + +// encodeHexValue is a helper function that will take hex values longer than 2 characters, and add a \ delimiter +func encodeHexValue(input string) (output string) { + output = hexpac.EncodeToString([]byte(string(input))) + if len(output) > 2 { + for i := 2; i < len(output); i += 3 { + output = fmt.Sprintf("%s\\%s", output[:i], output[i:]) + } + } + return fmt.Sprintf("\\%s", output) +} diff --git a/filter_test.go b/filter_test.go index f67d6c74..f17c02a0 100644 --- a/filter_test.go +++ b/filter_test.go @@ -290,3 +290,78 @@ func BenchmarkFilterDecompile(b *testing.B) { DecompileFilter(filters[i%maxIdx]) } } + +func TestParseFilter(t *testing.T) { + + for _, testInfo := range []struct { + src string + expecting string + err string + }{ + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\+,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c\2b,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example),OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example\29,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \29Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \29Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(sn=Mill*)`, + expecting: `(sn=Mill*)`, + }, + { + src: `(sn=Mi*\ed\95\a8*r)`, + expecting: `(sn=Mi*\ed\95\a8*r)`, + }, + { + src: `(sn=Mi*함*r)`, + expecting: `(sn=Mi*\ed\95\a8*r)`, + }, + { + src: `\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d`, + expecting: `\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d`, + }, + { + src: `(objectGUID=\a)`, + err: `ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+0029 ')'`, + }, + { + src: `(objectGUID=\a\a)`, + err: `ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+005C '\'`, + }, + } { + + res, err := ParseFilter(testInfo.src) + if err != nil { + if err.Error() != testInfo.err { + t.Fatal(testInfo.src, "=> ", err, "!=", testInfo.err) + } + } else if testInfo.err != "" { + t.Fatal(testInfo.src, "=> ", err, "!=", testInfo.err) + } + if res != testInfo.expecting { + t.Fatal(testInfo.expecting, "=> ", "invalid result", res) + } + } +} From bd1440a8af99772ac3433df6f72d1489cf28943f Mon Sep 17 00:00:00 2001 From: Alaric Whitney Date: Wed, 25 Aug 2021 15:41:33 -0500 Subject: [PATCH 2/5] adding changes to the v3 folder --- v3/filter.go | 163 ++++++++++++++++++++++++++++++++++++++++++++++ v3/filter_test.go | 75 +++++++++++++++++++++ 2 files changed, 238 insertions(+) diff --git a/v3/filter.go b/v3/filter.go index 73505e79..820d9a37 100644 --- a/v3/filter.go +++ b/v3/filter.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "regexp" "strings" "unicode" "unicode/utf8" @@ -27,6 +28,10 @@ const ( FilterExtensibleMatch = 9 ) +var ( + isAlphaNumeric = regexp.MustCompile(`^[a-zA-Z0-9]+$`).MatchString +) + // FilterMap contains human readable descriptions of Filter choices var FilterMap = map[uint64]string{ FilterAnd: "And", @@ -78,6 +83,10 @@ func CompileFilter(filter string) (*ber.Packet, error) { if len(filter) == 0 || filter[0] != '(' { return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('")) } + filter, err := ParseFilter(filter) + if err != nil { + return nil, NewError(ErrorFilterCompile, err) + } packet, pos, err := compileFilter(filter, 1) if err != nil { return nil, err @@ -485,3 +494,157 @@ func decodeEscapedSymbols(src []byte) (string, error) { offset += runeSize } } + +// ParseFilter will take the filter string, and transform the DN into an ascii safe string +func ParseFilter(filter string) (parsedFilter string, err error) { + var startRecording bool + var value string + var balance []string + for i, val := range filter { + switch string(val) { + case "=": + if !startRecording { + startRecording = true + } else { + // we've run into a 2nd = symbol in the statement. Reset the value + parsedFilter += value + value = "" + } + parsedFilter += string(val) + case ")": + if startRecording { + balance, err = checkBalance("(", filter, balance) + if err != nil { + err = nil + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + value = "" + } else { + value += string(val) + } + } else { + parsedFilter += string(val) + } + case "(": + if startRecording { + balance = append(balance, "(") + value += string(val) + } else { + parsedFilter += string(val) + } + case `\`: + if startRecording { + // look ahead for hex parenthesis to check for unbalanced queries + if utf8.RuneCountInString(filter) > i+2 { + byteVal := make([]byte, 1) + if !isAlphaNumeric(string(filter[i+1])) { + value += string(val) + continue + } + if _, err = hexpac.Decode(byteVal, []byte(filter[i+1:i+3])); err != nil { + return "", fmt.Errorf("ldap: invalid characters for escape in filter: %v", err) + } + if strings.EqualFold(filter[i+1:i+3], hexpac.EncodeToString([]byte(`(`))) { + balance = append(balance, hexpac.EncodeToString([]byte(`(`))) + } else if strings.EqualFold(filter[i+1:i+3], hexpac.EncodeToString([]byte(`)`))) { + balance, err = checkBalance(hexpac.EncodeToString([]byte(`(`)), filter, balance) + if err != nil { + err = nil + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + value = "" + } else { + value += string(val) + } + } else { + value += string(val) + } + } else { + value += string(val) + } + } else { + parsedFilter += string(val) + } + case ",": + if !startRecording { + return "", fmt.Errorf("ldap: invalid filter string: %s", filter) + } + if len(balance) != 0 { + return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) + } + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + value = "" + default: + currentRune, _ := utf8.DecodeRuneInString(string(val)) + if currentRune == utf8.RuneError { + return "", fmt.Errorf("ldap: error reading rune at position %d", i) + } + if startRecording { + value += string(val) + } else { + parsedFilter += string(val) + } + } + } + if len(balance) != 0 { + return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) + } + return +} + +// checkBalance will check if a recorded value within the filter is balanced and adjust the balance queue. If not, reset the balance queue and produce an error +func checkBalance(checkElement string, filter string, balance []string) (newBalance []string, err error) { + if len(balance) == 0 { + err = fmt.Errorf("ldap: unbalanced filter string: %s", filter) + } else if !strings.EqualFold(balance[len(balance)-1], checkElement) { + err = fmt.Errorf("ldap: unbalanced filter string: %s", filter) + } else { + newBalance = balance[:len(balance)-1] + } + return +} + +// encodeToHex will take the input value, and turn non-alpha numeric characters into hex values +func encodeToHex(filterValue string) (encodedValue string) { + var skip int + for i, val := range filterValue { + if skip != 0 { + skip-- + continue + } + if string(val) == ` ` { + encodedValue += string(val) + } else if string(val) == `*` { + encodedValue += string(val) + } else if string(val) == `\` { + if len(filterValue) > i+2 { + if isAlphaNumeric(string(filterValue[i+1 : i+3])) { + encodedValue += fmt.Sprintf("\\%s", string(filterValue[i+1:i+3])) + skip = 2 + continue + } else { + encodedValue += encodeHexValue(string(val)) + } + } else { + encodedValue += encodeHexValue(string(val)) + } + } else if isAlphaNumeric(string(val)) { + encodedValue += string(val) + } else { + encodedValue += encodeHexValue(string(val)) + } + } + return +} + +// encodeHexValue is a helper function that will take hex values longer than 2 characters, and add a \ delimiter +func encodeHexValue(input string) (output string) { + output = hexpac.EncodeToString([]byte(string(input))) + if len(output) > 2 { + for i := 2; i < len(output); i += 3 { + output = fmt.Sprintf("%s\\%s", output[:i], output[i:]) + } + } + return fmt.Sprintf("\\%s", output) +} diff --git a/v3/filter_test.go b/v3/filter_test.go index f67d6c74..f17c02a0 100644 --- a/v3/filter_test.go +++ b/v3/filter_test.go @@ -290,3 +290,78 @@ func BenchmarkFilterDecompile(b *testing.B) { DecompileFilter(filters[i%maxIdx]) } } + +func TestParseFilter(t *testing.T) { + + for _, testInfo := range []struct { + src string + expecting string + err string + }{ + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\+,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c\2b,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example),OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example\29,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \29Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \29Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(sn=Mill*)`, + expecting: `(sn=Mill*)`, + }, + { + src: `(sn=Mi*\ed\95\a8*r)`, + expecting: `(sn=Mi*\ed\95\a8*r)`, + }, + { + src: `(sn=Mi*함*r)`, + expecting: `(sn=Mi*\ed\95\a8*r)`, + }, + { + src: `\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d`, + expecting: `\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d`, + }, + { + src: `(objectGUID=\a)`, + err: `ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+0029 ')'`, + }, + { + src: `(objectGUID=\a\a)`, + err: `ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+005C '\'`, + }, + } { + + res, err := ParseFilter(testInfo.src) + if err != nil { + if err.Error() != testInfo.err { + t.Fatal(testInfo.src, "=> ", err, "!=", testInfo.err) + } + } else if testInfo.err != "" { + t.Fatal(testInfo.src, "=> ", err, "!=", testInfo.err) + } + if res != testInfo.expecting { + t.Fatal(testInfo.expecting, "=> ", "invalid result", res) + } + } +} From af1e9fff5cb2b2cd406e194dcc704c41d0a6da03 Mon Sep 17 00:00:00 2001 From: Alaric Whitney Date: Tue, 31 Aug 2021 10:39:03 -0500 Subject: [PATCH 3/5] updated EscapeFilter test --- ldap_test.go | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/ldap_test.go b/ldap_test.go index c3245b0e..83d70284 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -236,11 +236,31 @@ func TestMultiGoroutineSearch(t *testing.T) { } func TestEscapeFilter(t *testing.T) { - if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want { - t.Errorf("Got %s, expected %s", want, got) - } - if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want { - t.Errorf("Got %s, expected %s", want, got) + for _, testInfo := range []struct { + src string + expecting string + }{ + { + src: "a\x00b(c)d*e\\f", + expecting: `a\00b\28c\29d\2ae\5cf`, + }, + { + src: "Lučić", + expecting: `Lu\c4\8di\c4\87`, + }, + { + src: `\\some-server\code`, + expecting: `\5c\5csome-server\5ccode`, + }, + { + src: `Mi*함*r`, + expecting: `Mi\2a\ed\95\a8\2ar`, + }, + } { + got := EscapeFilter(testInfo.src) + if got != testInfo.expecting { + t.Errorf("Got %s, expected %s", got, testInfo.expecting) + } } } From f32dd2bccf6ea990a593e4f22d0fe83ad77a2673 Mon Sep 17 00:00:00 2001 From: Alaric Whitney Date: Tue, 31 Aug 2021 10:39:34 -0500 Subject: [PATCH 4/5] Switched to use EscapeFilter --- filter.go | 97 +++------------------- filter_test.go | 212 +++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 202 insertions(+), 107 deletions(-) diff --git a/filter.go b/filter.go index 820d9a37..4541f435 100644 --- a/filter.go +++ b/filter.go @@ -78,15 +78,21 @@ var MatchingRuleAssertionMap = map[uint64]string{ var _SymbolAny = []byte{'*'} +// CompileEscapeFilter will take a raw filter string, without any \xx hex input, and convert it into a BER-encoded packet +// All filter values will be placed through EscapeFilter, which cannot handle \xx hex input, nor * wildcards (will be translated to literal *) +func CompileEscapeFilter(filter string) (*ber.Packet, error) { + filter, err := ParseFilter(filter) + if err != nil { + return nil, NewError(ErrorFilterCompile, err) + } + return CompileFilter(filter) +} + // CompileFilter converts a string representation of a filter into a BER-encoded packet func CompileFilter(filter string) (*ber.Packet, error) { if len(filter) == 0 || filter[0] != '(' { return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('")) } - filter, err := ParseFilter(filter) - if err != nil { - return nil, NewError(ErrorFilterCompile, err) - } packet, pos, err := compileFilter(filter, 1) if err != nil { return nil, err @@ -495,7 +501,7 @@ func decodeEscapedSymbols(src []byte) (string, error) { } } -// ParseFilter will take the filter string, and transform the DN into an ascii safe string +// ParseFilter will take the filter string, and escape each value func ParseFilter(filter string) (parsedFilter string, err error) { var startRecording bool var value string @@ -517,7 +523,7 @@ func ParseFilter(filter string) (parsedFilter string, err error) { if err != nil { err = nil startRecording = false - parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, EscapeFilter(value), string(val)) value = "" } else { value += string(val) @@ -532,39 +538,6 @@ func ParseFilter(filter string) (parsedFilter string, err error) { } else { parsedFilter += string(val) } - case `\`: - if startRecording { - // look ahead for hex parenthesis to check for unbalanced queries - if utf8.RuneCountInString(filter) > i+2 { - byteVal := make([]byte, 1) - if !isAlphaNumeric(string(filter[i+1])) { - value += string(val) - continue - } - if _, err = hexpac.Decode(byteVal, []byte(filter[i+1:i+3])); err != nil { - return "", fmt.Errorf("ldap: invalid characters for escape in filter: %v", err) - } - if strings.EqualFold(filter[i+1:i+3], hexpac.EncodeToString([]byte(`(`))) { - balance = append(balance, hexpac.EncodeToString([]byte(`(`))) - } else if strings.EqualFold(filter[i+1:i+3], hexpac.EncodeToString([]byte(`)`))) { - balance, err = checkBalance(hexpac.EncodeToString([]byte(`(`)), filter, balance) - if err != nil { - err = nil - startRecording = false - parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) - value = "" - } else { - value += string(val) - } - } else { - value += string(val) - } - } else { - value += string(val) - } - } else { - parsedFilter += string(val) - } case ",": if !startRecording { return "", fmt.Errorf("ldap: invalid filter string: %s", filter) @@ -573,7 +546,7 @@ func ParseFilter(filter string) (parsedFilter string, err error) { return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) } startRecording = false - parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, EscapeFilter(value), string(val)) value = "" default: currentRune, _ := utf8.DecodeRuneInString(string(val)) @@ -604,47 +577,3 @@ func checkBalance(checkElement string, filter string, balance []string) (newBala } return } - -// encodeToHex will take the input value, and turn non-alpha numeric characters into hex values -func encodeToHex(filterValue string) (encodedValue string) { - var skip int - for i, val := range filterValue { - if skip != 0 { - skip-- - continue - } - if string(val) == ` ` { - encodedValue += string(val) - } else if string(val) == `*` { - encodedValue += string(val) - } else if string(val) == `\` { - if len(filterValue) > i+2 { - if isAlphaNumeric(string(filterValue[i+1 : i+3])) { - encodedValue += fmt.Sprintf("\\%s", string(filterValue[i+1:i+3])) - skip = 2 - continue - } else { - encodedValue += encodeHexValue(string(val)) - } - } else { - encodedValue += encodeHexValue(string(val)) - } - } else if isAlphaNumeric(string(val)) { - encodedValue += string(val) - } else { - encodedValue += encodeHexValue(string(val)) - } - } - return -} - -// encodeHexValue is a helper function that will take hex values longer than 2 characters, and add a \ delimiter -func encodeHexValue(input string) (output string) { - output = hexpac.EncodeToString([]byte(string(input))) - if len(output) > 2 { - for i := 2; i < len(output); i += 3 { - output = fmt.Sprintf("%s\\%s", output[:i], output[i:]) - } - } - return fmt.Sprintf("\\%s", output) -} diff --git a/filter_test.go b/filter_test.go index f17c02a0..95aa1a76 100644 --- a/filter_test.go +++ b/filter_test.go @@ -213,6 +213,185 @@ func TestFilter(t *testing.T) { } } +func TestCompileEscapeFilter(t *testing.T) { + var testEscapeFilters = []compileTest{ + { + filterStr: "(&(sn=Miller)(givenName=Bob))", + expectedFilter: "(&(sn=Miller)(givenName=Bob))", + expectedType: FilterAnd, + }, + { + filterStr: "(|(sn=Miller)(givenName=Bob))", + expectedFilter: "(|(sn=Miller)(givenName=Bob))", + expectedType: FilterOr, + }, + { + filterStr: "(!(sn=Miller))", + expectedFilter: "(!(sn=Miller))", + expectedType: FilterNot, + }, + { + filterStr: "(sn=Miller)", + expectedFilter: "(sn=Miller)", + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mill*)", + expectedFilter: `(sn=Mill\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*Mill)", + expectedFilter: `(sn=\2aMill)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*Mill*)", + expectedFilter: `(sn=\2aMill\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*i*le*)", + expectedFilter: `(sn=\2ai\2ale\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mi*l*r)", + expectedFilter: `(sn=Mi\2al\2ar)`, + expectedType: FilterEqualityMatch, + }, + // substring filters escape properly + { + filterStr: `(sn=Mi*함*r)`, + expectedFilter: `(sn=Mi\2a\ed\95\a8\2ar)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mi*le*)", + expectedFilter: `(sn=Mi\2ale\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*i*ler)", + expectedFilter: `(sn=\2ai\2aler)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn>=Miller)", + expectedFilter: "(sn>=Miller)", + expectedType: FilterGreaterOrEqual, + }, + { + filterStr: "(sn<=Miller)", + expectedFilter: "(sn<=Miller)", + expectedType: FilterLessOrEqual, + }, + { + filterStr: "(sn=*)", + expectedFilter: `(sn=\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn~=Miller)", + expectedFilter: "(sn~=Miller)", + expectedType: FilterApproxMatch, + }, + { + filterStr: `(objectGUID=абвгдеёжзийклмнопрстуфхцчшщъыьэюя)`, + expectedFilter: `(objectGUID=\d0\b0\d0\b1\d0\b2\d0\b3\d0\b4\d0\b5\d1\91\d0\b6\d0\b7\d0\b8\d0\b9\d0\ba\d0\bb\d0\bc\d0\bd\d0\be\d0\bf\d1\80\d1\81\d1\82\d1\83\d1\84\d1\85\d1\86\d1\87\d1\88\d1\89\d1\8a\d1\8b\d1\8c\d1\8d\d1\8e\d1\8f)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: `(objectGUID=함수목록)`, + expectedFilter: `(objectGUID=\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: `(objectGUID=`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `(objectGUID=함수목록`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `((cn=)`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `(&(objectclass=inetorgperson)(cn=中文))`, + expectedFilter: `(&(objectclass=inetorgperson)(cn=\e4\b8\ad\e6\96\87))`, + expectedType: 0, + }, + // attr extension + { + filterStr: `(memberOf:=foo)`, + expectedFilter: `(memberOf:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+named matching rule extension + { + filterStr: `(memberOf:test:=foo)`, + expectedFilter: `(memberOf:test:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+oid matching rule extension + { + filterStr: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedFilter: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn+oid matching rule extension + { + filterStr: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedFilter: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn extension + { + filterStr: `(o:dn:=Ace Industry)`, + expectedFilter: `(o:dn:=Ace Industry)`, + expectedType: FilterExtensibleMatch, + }, + // dn extension + { + filterStr: `(:dn:2.4.6.8.10:=Dino)`, + expectedFilter: `(:dn:2.4.6.8.10:=Dino)`, + expectedType: FilterExtensibleMatch, + }, + { + filterStr: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedFilter: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedType: FilterExtensibleMatch, + }, + } + // Test Compiler and Decompiler + for _, i := range testEscapeFilters { + filter, err := CompileEscapeFilter(i.filterStr) + switch { + case err != nil: + if i.expectedErr == "" || !strings.Contains(err.Error(), i.expectedErr) { + t.Errorf("Problem compiling '%s' - '%v' (expected error to contain '%v')", i.filterStr, err, i.expectedErr) + } + case filter.Tag != ber.Tag(i.expectedType): + t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[uint64(i.expectedType)], FilterMap[uint64(filter.Tag)]) + default: + o, err := DecompileFilter(filter) + if err != nil { + t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error()) + } else if i.expectedFilter != o { + t.Errorf("%q expected, got %q", i.expectedFilter, o) + } + } + } +} + func TestDecodeEscapedSymbols(t *testing.T) { for _, testInfo := range []struct { @@ -292,19 +471,18 @@ func BenchmarkFilterDecompile(b *testing.B) { } func TestParseFilter(t *testing.T) { - for _, testInfo := range []struct { src string expecting string err string }{ { - src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\,OU=Groups,DC=example,DC=foo,DC=bar))`, expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, }, { src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\+,OU=Groups,DC=example,DC=foo,DC=bar))`, - expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c\2b,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c+,OU=Groups,DC=example,DC=foo,DC=bar))`, }, { src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example),OU=Groups,DC=example,DC=foo,DC=bar))`, @@ -315,12 +493,8 @@ func TestParseFilter(t *testing.T) { err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, }, { - src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example,OU=Groups,DC=example,DC=foo,DC=bar))`, - err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example,OU=Groups,DC=example,DC=foo,DC=bar))`, - }, - { - src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \29Example,OU=Groups,DC=example,DC=foo,DC=bar))`, - err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \29Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, }, { src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, @@ -328,27 +502,19 @@ func TestParseFilter(t *testing.T) { }, { src: `(sn=Mill*)`, - expecting: `(sn=Mill*)`, - }, - { - src: `(sn=Mi*\ed\95\a8*r)`, - expecting: `(sn=Mi*\ed\95\a8*r)`, + expecting: `(sn=Mill\2a)`, }, { src: `(sn=Mi*함*r)`, - expecting: `(sn=Mi*\ed\95\a8*r)`, - }, - { - src: `\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d`, - expecting: `\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d`, + expecting: `(sn=Mi\2a\ed\95\a8\2ar)`, }, { - src: `(objectGUID=\a)`, - err: `ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+0029 ')'`, + src: `(objectGUID=\a)`, + expecting: `(objectGUID=\5ca)`, }, { - src: `(objectGUID=\a\a)`, - err: `ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+005C '\'`, + src: `(objectGUID=\a\a)`, + expecting: `(objectGUID=\5ca\5ca)`, }, } { From d3b539dde95f02a9da54d9a6a9f8fca1be3ee19d Mon Sep 17 00:00:00 2001 From: Alaric Whitney Date: Tue, 31 Aug 2021 10:57:58 -0500 Subject: [PATCH 5/5] added changes to v3 --- v3/filter.go | 97 +++------------------ v3/filter_test.go | 212 +++++++++++++++++++++++++++++++++++++++++----- v3/ldap_test.go | 30 +++++-- 3 files changed, 227 insertions(+), 112 deletions(-) diff --git a/v3/filter.go b/v3/filter.go index 820d9a37..4541f435 100644 --- a/v3/filter.go +++ b/v3/filter.go @@ -78,15 +78,21 @@ var MatchingRuleAssertionMap = map[uint64]string{ var _SymbolAny = []byte{'*'} +// CompileEscapeFilter will take a raw filter string, without any \xx hex input, and convert it into a BER-encoded packet +// All filter values will be placed through EscapeFilter, which cannot handle \xx hex input, nor * wildcards (will be translated to literal *) +func CompileEscapeFilter(filter string) (*ber.Packet, error) { + filter, err := ParseFilter(filter) + if err != nil { + return nil, NewError(ErrorFilterCompile, err) + } + return CompileFilter(filter) +} + // CompileFilter converts a string representation of a filter into a BER-encoded packet func CompileFilter(filter string) (*ber.Packet, error) { if len(filter) == 0 || filter[0] != '(' { return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('")) } - filter, err := ParseFilter(filter) - if err != nil { - return nil, NewError(ErrorFilterCompile, err) - } packet, pos, err := compileFilter(filter, 1) if err != nil { return nil, err @@ -495,7 +501,7 @@ func decodeEscapedSymbols(src []byte) (string, error) { } } -// ParseFilter will take the filter string, and transform the DN into an ascii safe string +// ParseFilter will take the filter string, and escape each value func ParseFilter(filter string) (parsedFilter string, err error) { var startRecording bool var value string @@ -517,7 +523,7 @@ func ParseFilter(filter string) (parsedFilter string, err error) { if err != nil { err = nil startRecording = false - parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, EscapeFilter(value), string(val)) value = "" } else { value += string(val) @@ -532,39 +538,6 @@ func ParseFilter(filter string) (parsedFilter string, err error) { } else { parsedFilter += string(val) } - case `\`: - if startRecording { - // look ahead for hex parenthesis to check for unbalanced queries - if utf8.RuneCountInString(filter) > i+2 { - byteVal := make([]byte, 1) - if !isAlphaNumeric(string(filter[i+1])) { - value += string(val) - continue - } - if _, err = hexpac.Decode(byteVal, []byte(filter[i+1:i+3])); err != nil { - return "", fmt.Errorf("ldap: invalid characters for escape in filter: %v", err) - } - if strings.EqualFold(filter[i+1:i+3], hexpac.EncodeToString([]byte(`(`))) { - balance = append(balance, hexpac.EncodeToString([]byte(`(`))) - } else if strings.EqualFold(filter[i+1:i+3], hexpac.EncodeToString([]byte(`)`))) { - balance, err = checkBalance(hexpac.EncodeToString([]byte(`(`)), filter, balance) - if err != nil { - err = nil - startRecording = false - parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) - value = "" - } else { - value += string(val) - } - } else { - value += string(val) - } - } else { - value += string(val) - } - } else { - parsedFilter += string(val) - } case ",": if !startRecording { return "", fmt.Errorf("ldap: invalid filter string: %s", filter) @@ -573,7 +546,7 @@ func ParseFilter(filter string) (parsedFilter string, err error) { return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) } startRecording = false - parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, encodeToHex(value), string(val)) + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, EscapeFilter(value), string(val)) value = "" default: currentRune, _ := utf8.DecodeRuneInString(string(val)) @@ -604,47 +577,3 @@ func checkBalance(checkElement string, filter string, balance []string) (newBala } return } - -// encodeToHex will take the input value, and turn non-alpha numeric characters into hex values -func encodeToHex(filterValue string) (encodedValue string) { - var skip int - for i, val := range filterValue { - if skip != 0 { - skip-- - continue - } - if string(val) == ` ` { - encodedValue += string(val) - } else if string(val) == `*` { - encodedValue += string(val) - } else if string(val) == `\` { - if len(filterValue) > i+2 { - if isAlphaNumeric(string(filterValue[i+1 : i+3])) { - encodedValue += fmt.Sprintf("\\%s", string(filterValue[i+1:i+3])) - skip = 2 - continue - } else { - encodedValue += encodeHexValue(string(val)) - } - } else { - encodedValue += encodeHexValue(string(val)) - } - } else if isAlphaNumeric(string(val)) { - encodedValue += string(val) - } else { - encodedValue += encodeHexValue(string(val)) - } - } - return -} - -// encodeHexValue is a helper function that will take hex values longer than 2 characters, and add a \ delimiter -func encodeHexValue(input string) (output string) { - output = hexpac.EncodeToString([]byte(string(input))) - if len(output) > 2 { - for i := 2; i < len(output); i += 3 { - output = fmt.Sprintf("%s\\%s", output[:i], output[i:]) - } - } - return fmt.Sprintf("\\%s", output) -} diff --git a/v3/filter_test.go b/v3/filter_test.go index f17c02a0..95aa1a76 100644 --- a/v3/filter_test.go +++ b/v3/filter_test.go @@ -213,6 +213,185 @@ func TestFilter(t *testing.T) { } } +func TestCompileEscapeFilter(t *testing.T) { + var testEscapeFilters = []compileTest{ + { + filterStr: "(&(sn=Miller)(givenName=Bob))", + expectedFilter: "(&(sn=Miller)(givenName=Bob))", + expectedType: FilterAnd, + }, + { + filterStr: "(|(sn=Miller)(givenName=Bob))", + expectedFilter: "(|(sn=Miller)(givenName=Bob))", + expectedType: FilterOr, + }, + { + filterStr: "(!(sn=Miller))", + expectedFilter: "(!(sn=Miller))", + expectedType: FilterNot, + }, + { + filterStr: "(sn=Miller)", + expectedFilter: "(sn=Miller)", + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mill*)", + expectedFilter: `(sn=Mill\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*Mill)", + expectedFilter: `(sn=\2aMill)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*Mill*)", + expectedFilter: `(sn=\2aMill\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*i*le*)", + expectedFilter: `(sn=\2ai\2ale\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mi*l*r)", + expectedFilter: `(sn=Mi\2al\2ar)`, + expectedType: FilterEqualityMatch, + }, + // substring filters escape properly + { + filterStr: `(sn=Mi*함*r)`, + expectedFilter: `(sn=Mi\2a\ed\95\a8\2ar)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mi*le*)", + expectedFilter: `(sn=Mi\2ale\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*i*ler)", + expectedFilter: `(sn=\2ai\2aler)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn>=Miller)", + expectedFilter: "(sn>=Miller)", + expectedType: FilterGreaterOrEqual, + }, + { + filterStr: "(sn<=Miller)", + expectedFilter: "(sn<=Miller)", + expectedType: FilterLessOrEqual, + }, + { + filterStr: "(sn=*)", + expectedFilter: `(sn=\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn~=Miller)", + expectedFilter: "(sn~=Miller)", + expectedType: FilterApproxMatch, + }, + { + filterStr: `(objectGUID=абвгдеёжзийклмнопрстуфхцчшщъыьэюя)`, + expectedFilter: `(objectGUID=\d0\b0\d0\b1\d0\b2\d0\b3\d0\b4\d0\b5\d1\91\d0\b6\d0\b7\d0\b8\d0\b9\d0\ba\d0\bb\d0\bc\d0\bd\d0\be\d0\bf\d1\80\d1\81\d1\82\d1\83\d1\84\d1\85\d1\86\d1\87\d1\88\d1\89\d1\8a\d1\8b\d1\8c\d1\8d\d1\8e\d1\8f)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: `(objectGUID=함수목록)`, + expectedFilter: `(objectGUID=\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: `(objectGUID=`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `(objectGUID=함수목록`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `((cn=)`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `(&(objectclass=inetorgperson)(cn=中文))`, + expectedFilter: `(&(objectclass=inetorgperson)(cn=\e4\b8\ad\e6\96\87))`, + expectedType: 0, + }, + // attr extension + { + filterStr: `(memberOf:=foo)`, + expectedFilter: `(memberOf:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+named matching rule extension + { + filterStr: `(memberOf:test:=foo)`, + expectedFilter: `(memberOf:test:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+oid matching rule extension + { + filterStr: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedFilter: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn+oid matching rule extension + { + filterStr: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedFilter: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn extension + { + filterStr: `(o:dn:=Ace Industry)`, + expectedFilter: `(o:dn:=Ace Industry)`, + expectedType: FilterExtensibleMatch, + }, + // dn extension + { + filterStr: `(:dn:2.4.6.8.10:=Dino)`, + expectedFilter: `(:dn:2.4.6.8.10:=Dino)`, + expectedType: FilterExtensibleMatch, + }, + { + filterStr: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedFilter: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedType: FilterExtensibleMatch, + }, + } + // Test Compiler and Decompiler + for _, i := range testEscapeFilters { + filter, err := CompileEscapeFilter(i.filterStr) + switch { + case err != nil: + if i.expectedErr == "" || !strings.Contains(err.Error(), i.expectedErr) { + t.Errorf("Problem compiling '%s' - '%v' (expected error to contain '%v')", i.filterStr, err, i.expectedErr) + } + case filter.Tag != ber.Tag(i.expectedType): + t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[uint64(i.expectedType)], FilterMap[uint64(filter.Tag)]) + default: + o, err := DecompileFilter(filter) + if err != nil { + t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error()) + } else if i.expectedFilter != o { + t.Errorf("%q expected, got %q", i.expectedFilter, o) + } + } + } +} + func TestDecodeEscapedSymbols(t *testing.T) { for _, testInfo := range []struct { @@ -292,19 +471,18 @@ func BenchmarkFilterDecompile(b *testing.B) { } func TestParseFilter(t *testing.T) { - for _, testInfo := range []struct { src string expecting string err string }{ { - src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\,OU=Groups,DC=example,DC=foo,DC=bar))`, expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, }, { src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\+,OU=Groups,DC=example,DC=foo,DC=bar))`, - expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c\2b,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c+,OU=Groups,DC=example,DC=foo,DC=bar))`, }, { src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example),OU=Groups,DC=example,DC=foo,DC=bar))`, @@ -315,12 +493,8 @@ func TestParseFilter(t *testing.T) { err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, }, { - src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example,OU=Groups,DC=example,DC=foo,DC=bar))`, - err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example,OU=Groups,DC=example,DC=foo,DC=bar))`, - }, - { - src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \29Example,OU=Groups,DC=example,DC=foo,DC=bar))`, - err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \29Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, }, { src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, @@ -328,27 +502,19 @@ func TestParseFilter(t *testing.T) { }, { src: `(sn=Mill*)`, - expecting: `(sn=Mill*)`, - }, - { - src: `(sn=Mi*\ed\95\a8*r)`, - expecting: `(sn=Mi*\ed\95\a8*r)`, + expecting: `(sn=Mill\2a)`, }, { src: `(sn=Mi*함*r)`, - expecting: `(sn=Mi*\ed\95\a8*r)`, - }, - { - src: `\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d`, - expecting: `\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d`, + expecting: `(sn=Mi\2a\ed\95\a8\2ar)`, }, { - src: `(objectGUID=\a)`, - err: `ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+0029 ')'`, + src: `(objectGUID=\a)`, + expecting: `(objectGUID=\5ca)`, }, { - src: `(objectGUID=\a\a)`, - err: `ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+005C '\'`, + src: `(objectGUID=\a\a)`, + expecting: `(objectGUID=\5ca\5ca)`, }, } { diff --git a/v3/ldap_test.go b/v3/ldap_test.go index c3245b0e..83d70284 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -236,11 +236,31 @@ func TestMultiGoroutineSearch(t *testing.T) { } func TestEscapeFilter(t *testing.T) { - if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want { - t.Errorf("Got %s, expected %s", want, got) - } - if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want { - t.Errorf("Got %s, expected %s", want, got) + for _, testInfo := range []struct { + src string + expecting string + }{ + { + src: "a\x00b(c)d*e\\f", + expecting: `a\00b\28c\29d\2ae\5cf`, + }, + { + src: "Lučić", + expecting: `Lu\c4\8di\c4\87`, + }, + { + src: `\\some-server\code`, + expecting: `\5c\5csome-server\5ccode`, + }, + { + src: `Mi*함*r`, + expecting: `Mi\2a\ed\95\a8\2ar`, + }, + } { + got := EscapeFilter(testInfo.src) + if got != testInfo.expecting { + t.Errorf("Got %s, expected %s", got, testInfo.expecting) + } } }