diff --git a/filter.go b/filter.go index 73505e79..4541f435 100644 --- a/filter.go +++ b/filter.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "regexp" "strings" "unicode" "unicode/utf8" @@ -27,6 +28,10 @@ const ( FilterExtensibleMatch = 9 ) +var ( + isAlphaNumeric = regexp.MustCompile(`^[a-zA-Z0-9]+$`).MatchString +) + // FilterMap contains human readable descriptions of Filter choices var FilterMap = map[uint64]string{ FilterAnd: "And", @@ -73,6 +78,16 @@ var MatchingRuleAssertionMap = map[uint64]string{ var _SymbolAny = []byte{'*'} +// CompileEscapeFilter will take a raw filter string, without any \xx hex input, and convert it into a BER-encoded packet +// All filter values will be placed through EscapeFilter, which cannot handle \xx hex input, nor * wildcards (will be translated to literal *) +func CompileEscapeFilter(filter string) (*ber.Packet, error) { + filter, err := ParseFilter(filter) + if err != nil { + return nil, NewError(ErrorFilterCompile, err) + } + return CompileFilter(filter) +} + // CompileFilter converts a string representation of a filter into a BER-encoded packet func CompileFilter(filter string) (*ber.Packet, error) { if len(filter) == 0 || filter[0] != '(' { @@ -485,3 +500,80 @@ func decodeEscapedSymbols(src []byte) (string, error) { offset += runeSize } } + +// ParseFilter will take the filter string, and escape each value +func ParseFilter(filter string) (parsedFilter string, err error) { + var startRecording bool + var value string + var balance []string + for i, val := range filter { + switch string(val) { + case "=": + if !startRecording { + startRecording = true + } else { + // we've run into a 2nd = symbol in the statement. Reset the value + parsedFilter += value + value = "" + } + parsedFilter += string(val) + case ")": + if startRecording { + balance, err = checkBalance("(", filter, balance) + if err != nil { + err = nil + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, EscapeFilter(value), string(val)) + value = "" + } else { + value += string(val) + } + } else { + parsedFilter += string(val) + } + case "(": + if startRecording { + balance = append(balance, "(") + value += string(val) + } else { + parsedFilter += string(val) + } + case ",": + if !startRecording { + return "", fmt.Errorf("ldap: invalid filter string: %s", filter) + } + if len(balance) != 0 { + return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) + } + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, EscapeFilter(value), string(val)) + value = "" + default: + currentRune, _ := utf8.DecodeRuneInString(string(val)) + if currentRune == utf8.RuneError { + return "", fmt.Errorf("ldap: error reading rune at position %d", i) + } + if startRecording { + value += string(val) + } else { + parsedFilter += string(val) + } + } + } + if len(balance) != 0 { + return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) + } + return +} + +// checkBalance will check if a recorded value within the filter is balanced and adjust the balance queue. If not, reset the balance queue and produce an error +func checkBalance(checkElement string, filter string, balance []string) (newBalance []string, err error) { + if len(balance) == 0 { + err = fmt.Errorf("ldap: unbalanced filter string: %s", filter) + } else if !strings.EqualFold(balance[len(balance)-1], checkElement) { + err = fmt.Errorf("ldap: unbalanced filter string: %s", filter) + } else { + newBalance = balance[:len(balance)-1] + } + return +} diff --git a/filter_test.go b/filter_test.go index f67d6c74..95aa1a76 100644 --- a/filter_test.go +++ b/filter_test.go @@ -213,6 +213,185 @@ func TestFilter(t *testing.T) { } } +func TestCompileEscapeFilter(t *testing.T) { + var testEscapeFilters = []compileTest{ + { + filterStr: "(&(sn=Miller)(givenName=Bob))", + expectedFilter: "(&(sn=Miller)(givenName=Bob))", + expectedType: FilterAnd, + }, + { + filterStr: "(|(sn=Miller)(givenName=Bob))", + expectedFilter: "(|(sn=Miller)(givenName=Bob))", + expectedType: FilterOr, + }, + { + filterStr: "(!(sn=Miller))", + expectedFilter: "(!(sn=Miller))", + expectedType: FilterNot, + }, + { + filterStr: "(sn=Miller)", + expectedFilter: "(sn=Miller)", + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mill*)", + expectedFilter: `(sn=Mill\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*Mill)", + expectedFilter: `(sn=\2aMill)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*Mill*)", + expectedFilter: `(sn=\2aMill\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*i*le*)", + expectedFilter: `(sn=\2ai\2ale\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mi*l*r)", + expectedFilter: `(sn=Mi\2al\2ar)`, + expectedType: FilterEqualityMatch, + }, + // substring filters escape properly + { + filterStr: `(sn=Mi*함*r)`, + expectedFilter: `(sn=Mi\2a\ed\95\a8\2ar)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mi*le*)", + expectedFilter: `(sn=Mi\2ale\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*i*ler)", + expectedFilter: `(sn=\2ai\2aler)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn>=Miller)", + expectedFilter: "(sn>=Miller)", + expectedType: FilterGreaterOrEqual, + }, + { + filterStr: "(sn<=Miller)", + expectedFilter: "(sn<=Miller)", + expectedType: FilterLessOrEqual, + }, + { + filterStr: "(sn=*)", + expectedFilter: `(sn=\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn~=Miller)", + expectedFilter: "(sn~=Miller)", + expectedType: FilterApproxMatch, + }, + { + filterStr: `(objectGUID=абвгдеёжзийклмнопрстуфхцчшщъыьэюя)`, + expectedFilter: `(objectGUID=\d0\b0\d0\b1\d0\b2\d0\b3\d0\b4\d0\b5\d1\91\d0\b6\d0\b7\d0\b8\d0\b9\d0\ba\d0\bb\d0\bc\d0\bd\d0\be\d0\bf\d1\80\d1\81\d1\82\d1\83\d1\84\d1\85\d1\86\d1\87\d1\88\d1\89\d1\8a\d1\8b\d1\8c\d1\8d\d1\8e\d1\8f)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: `(objectGUID=함수목록)`, + expectedFilter: `(objectGUID=\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: `(objectGUID=`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `(objectGUID=함수목록`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `((cn=)`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `(&(objectclass=inetorgperson)(cn=中文))`, + expectedFilter: `(&(objectclass=inetorgperson)(cn=\e4\b8\ad\e6\96\87))`, + expectedType: 0, + }, + // attr extension + { + filterStr: `(memberOf:=foo)`, + expectedFilter: `(memberOf:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+named matching rule extension + { + filterStr: `(memberOf:test:=foo)`, + expectedFilter: `(memberOf:test:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+oid matching rule extension + { + filterStr: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedFilter: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn+oid matching rule extension + { + filterStr: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedFilter: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn extension + { + filterStr: `(o:dn:=Ace Industry)`, + expectedFilter: `(o:dn:=Ace Industry)`, + expectedType: FilterExtensibleMatch, + }, + // dn extension + { + filterStr: `(:dn:2.4.6.8.10:=Dino)`, + expectedFilter: `(:dn:2.4.6.8.10:=Dino)`, + expectedType: FilterExtensibleMatch, + }, + { + filterStr: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedFilter: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedType: FilterExtensibleMatch, + }, + } + // Test Compiler and Decompiler + for _, i := range testEscapeFilters { + filter, err := CompileEscapeFilter(i.filterStr) + switch { + case err != nil: + if i.expectedErr == "" || !strings.Contains(err.Error(), i.expectedErr) { + t.Errorf("Problem compiling '%s' - '%v' (expected error to contain '%v')", i.filterStr, err, i.expectedErr) + } + case filter.Tag != ber.Tag(i.expectedType): + t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[uint64(i.expectedType)], FilterMap[uint64(filter.Tag)]) + default: + o, err := DecompileFilter(filter) + if err != nil { + t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error()) + } else if i.expectedFilter != o { + t.Errorf("%q expected, got %q", i.expectedFilter, o) + } + } + } +} + func TestDecodeEscapedSymbols(t *testing.T) { for _, testInfo := range []struct { @@ -290,3 +469,65 @@ func BenchmarkFilterDecompile(b *testing.B) { DecompileFilter(filters[i%maxIdx]) } } + +func TestParseFilter(t *testing.T) { + for _, testInfo := range []struct { + src string + expecting string + err string + }{ + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\+,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c+,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example),OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example\29,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(sn=Mill*)`, + expecting: `(sn=Mill\2a)`, + }, + { + src: `(sn=Mi*함*r)`, + expecting: `(sn=Mi\2a\ed\95\a8\2ar)`, + }, + { + src: `(objectGUID=\a)`, + expecting: `(objectGUID=\5ca)`, + }, + { + src: `(objectGUID=\a\a)`, + expecting: `(objectGUID=\5ca\5ca)`, + }, + } { + + res, err := ParseFilter(testInfo.src) + if err != nil { + if err.Error() != testInfo.err { + t.Fatal(testInfo.src, "=> ", err, "!=", testInfo.err) + } + } else if testInfo.err != "" { + t.Fatal(testInfo.src, "=> ", err, "!=", testInfo.err) + } + if res != testInfo.expecting { + t.Fatal(testInfo.expecting, "=> ", "invalid result", res) + } + } +} diff --git a/ldap_test.go b/ldap_test.go index c3245b0e..83d70284 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -236,11 +236,31 @@ func TestMultiGoroutineSearch(t *testing.T) { } func TestEscapeFilter(t *testing.T) { - if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want { - t.Errorf("Got %s, expected %s", want, got) - } - if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want { - t.Errorf("Got %s, expected %s", want, got) + for _, testInfo := range []struct { + src string + expecting string + }{ + { + src: "a\x00b(c)d*e\\f", + expecting: `a\00b\28c\29d\2ae\5cf`, + }, + { + src: "Lučić", + expecting: `Lu\c4\8di\c4\87`, + }, + { + src: `\\some-server\code`, + expecting: `\5c\5csome-server\5ccode`, + }, + { + src: `Mi*함*r`, + expecting: `Mi\2a\ed\95\a8\2ar`, + }, + } { + got := EscapeFilter(testInfo.src) + if got != testInfo.expecting { + t.Errorf("Got %s, expected %s", got, testInfo.expecting) + } } } diff --git a/v3/filter.go b/v3/filter.go index 73505e79..4541f435 100644 --- a/v3/filter.go +++ b/v3/filter.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "regexp" "strings" "unicode" "unicode/utf8" @@ -27,6 +28,10 @@ const ( FilterExtensibleMatch = 9 ) +var ( + isAlphaNumeric = regexp.MustCompile(`^[a-zA-Z0-9]+$`).MatchString +) + // FilterMap contains human readable descriptions of Filter choices var FilterMap = map[uint64]string{ FilterAnd: "And", @@ -73,6 +78,16 @@ var MatchingRuleAssertionMap = map[uint64]string{ var _SymbolAny = []byte{'*'} +// CompileEscapeFilter will take a raw filter string, without any \xx hex input, and convert it into a BER-encoded packet +// All filter values will be placed through EscapeFilter, which cannot handle \xx hex input, nor * wildcards (will be translated to literal *) +func CompileEscapeFilter(filter string) (*ber.Packet, error) { + filter, err := ParseFilter(filter) + if err != nil { + return nil, NewError(ErrorFilterCompile, err) + } + return CompileFilter(filter) +} + // CompileFilter converts a string representation of a filter into a BER-encoded packet func CompileFilter(filter string) (*ber.Packet, error) { if len(filter) == 0 || filter[0] != '(' { @@ -485,3 +500,80 @@ func decodeEscapedSymbols(src []byte) (string, error) { offset += runeSize } } + +// ParseFilter will take the filter string, and escape each value +func ParseFilter(filter string) (parsedFilter string, err error) { + var startRecording bool + var value string + var balance []string + for i, val := range filter { + switch string(val) { + case "=": + if !startRecording { + startRecording = true + } else { + // we've run into a 2nd = symbol in the statement. Reset the value + parsedFilter += value + value = "" + } + parsedFilter += string(val) + case ")": + if startRecording { + balance, err = checkBalance("(", filter, balance) + if err != nil { + err = nil + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, EscapeFilter(value), string(val)) + value = "" + } else { + value += string(val) + } + } else { + parsedFilter += string(val) + } + case "(": + if startRecording { + balance = append(balance, "(") + value += string(val) + } else { + parsedFilter += string(val) + } + case ",": + if !startRecording { + return "", fmt.Errorf("ldap: invalid filter string: %s", filter) + } + if len(balance) != 0 { + return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) + } + startRecording = false + parsedFilter = fmt.Sprintf("%s%s%s", parsedFilter, EscapeFilter(value), string(val)) + value = "" + default: + currentRune, _ := utf8.DecodeRuneInString(string(val)) + if currentRune == utf8.RuneError { + return "", fmt.Errorf("ldap: error reading rune at position %d", i) + } + if startRecording { + value += string(val) + } else { + parsedFilter += string(val) + } + } + } + if len(balance) != 0 { + return "", fmt.Errorf("ldap: unbalanced filter: %s", filter) + } + return +} + +// checkBalance will check if a recorded value within the filter is balanced and adjust the balance queue. If not, reset the balance queue and produce an error +func checkBalance(checkElement string, filter string, balance []string) (newBalance []string, err error) { + if len(balance) == 0 { + err = fmt.Errorf("ldap: unbalanced filter string: %s", filter) + } else if !strings.EqualFold(balance[len(balance)-1], checkElement) { + err = fmt.Errorf("ldap: unbalanced filter string: %s", filter) + } else { + newBalance = balance[:len(balance)-1] + } + return +} diff --git a/v3/filter_test.go b/v3/filter_test.go index f67d6c74..95aa1a76 100644 --- a/v3/filter_test.go +++ b/v3/filter_test.go @@ -213,6 +213,185 @@ func TestFilter(t *testing.T) { } } +func TestCompileEscapeFilter(t *testing.T) { + var testEscapeFilters = []compileTest{ + { + filterStr: "(&(sn=Miller)(givenName=Bob))", + expectedFilter: "(&(sn=Miller)(givenName=Bob))", + expectedType: FilterAnd, + }, + { + filterStr: "(|(sn=Miller)(givenName=Bob))", + expectedFilter: "(|(sn=Miller)(givenName=Bob))", + expectedType: FilterOr, + }, + { + filterStr: "(!(sn=Miller))", + expectedFilter: "(!(sn=Miller))", + expectedType: FilterNot, + }, + { + filterStr: "(sn=Miller)", + expectedFilter: "(sn=Miller)", + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mill*)", + expectedFilter: `(sn=Mill\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*Mill)", + expectedFilter: `(sn=\2aMill)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*Mill*)", + expectedFilter: `(sn=\2aMill\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*i*le*)", + expectedFilter: `(sn=\2ai\2ale\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mi*l*r)", + expectedFilter: `(sn=Mi\2al\2ar)`, + expectedType: FilterEqualityMatch, + }, + // substring filters escape properly + { + filterStr: `(sn=Mi*함*r)`, + expectedFilter: `(sn=Mi\2a\ed\95\a8\2ar)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=Mi*le*)", + expectedFilter: `(sn=Mi\2ale\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn=*i*ler)", + expectedFilter: `(sn=\2ai\2aler)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn>=Miller)", + expectedFilter: "(sn>=Miller)", + expectedType: FilterGreaterOrEqual, + }, + { + filterStr: "(sn<=Miller)", + expectedFilter: "(sn<=Miller)", + expectedType: FilterLessOrEqual, + }, + { + filterStr: "(sn=*)", + expectedFilter: `(sn=\2a)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: "(sn~=Miller)", + expectedFilter: "(sn~=Miller)", + expectedType: FilterApproxMatch, + }, + { + filterStr: `(objectGUID=абвгдеёжзийклмнопрстуфхцчшщъыьэюя)`, + expectedFilter: `(objectGUID=\d0\b0\d0\b1\d0\b2\d0\b3\d0\b4\d0\b5\d1\91\d0\b6\d0\b7\d0\b8\d0\b9\d0\ba\d0\bb\d0\bc\d0\bd\d0\be\d0\bf\d1\80\d1\81\d1\82\d1\83\d1\84\d1\85\d1\86\d1\87\d1\88\d1\89\d1\8a\d1\8b\d1\8c\d1\8d\d1\8e\d1\8f)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: `(objectGUID=함수목록)`, + expectedFilter: `(objectGUID=\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d)`, + expectedType: FilterEqualityMatch, + }, + { + filterStr: `(objectGUID=`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `(objectGUID=함수목록`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `((cn=)`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + { + filterStr: `(&(objectclass=inetorgperson)(cn=中文))`, + expectedFilter: `(&(objectclass=inetorgperson)(cn=\e4\b8\ad\e6\96\87))`, + expectedType: 0, + }, + // attr extension + { + filterStr: `(memberOf:=foo)`, + expectedFilter: `(memberOf:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+named matching rule extension + { + filterStr: `(memberOf:test:=foo)`, + expectedFilter: `(memberOf:test:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+oid matching rule extension + { + filterStr: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedFilter: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn+oid matching rule extension + { + filterStr: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedFilter: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn extension + { + filterStr: `(o:dn:=Ace Industry)`, + expectedFilter: `(o:dn:=Ace Industry)`, + expectedType: FilterExtensibleMatch, + }, + // dn extension + { + filterStr: `(:dn:2.4.6.8.10:=Dino)`, + expectedFilter: `(:dn:2.4.6.8.10:=Dino)`, + expectedType: FilterExtensibleMatch, + }, + { + filterStr: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedFilter: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedType: FilterExtensibleMatch, + }, + } + // Test Compiler and Decompiler + for _, i := range testEscapeFilters { + filter, err := CompileEscapeFilter(i.filterStr) + switch { + case err != nil: + if i.expectedErr == "" || !strings.Contains(err.Error(), i.expectedErr) { + t.Errorf("Problem compiling '%s' - '%v' (expected error to contain '%v')", i.filterStr, err, i.expectedErr) + } + case filter.Tag != ber.Tag(i.expectedType): + t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[uint64(i.expectedType)], FilterMap[uint64(filter.Tag)]) + default: + o, err := DecompileFilter(filter) + if err != nil { + t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error()) + } else if i.expectedFilter != o { + t.Errorf("%q expected, got %q", i.expectedFilter, o) + } + } + } +} + func TestDecodeEscapedSymbols(t *testing.T) { for _, testInfo := range []struct { @@ -290,3 +469,65 @@ func BenchmarkFilterDecompile(b *testing.B) { DecompileFilter(filters[i%maxIdx]) } } + +func TestParseFilter(t *testing.T) { + for _, testInfo := range []struct { + src string + expecting string + err string + }{ + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\+,OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main\5c+,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example),OU=Groups,DC=example,DC=foo,DC=bar))`, + expecting: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main \28Example\29,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: unbalanced filter: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main (Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + err: `ldap: invalid filter string: (&(objectCategory=group)(objectClass=group)(memberOf=CN=Main )Example,OU=Groups,DC=example,DC=foo,DC=bar))`, + }, + { + src: `(sn=Mill*)`, + expecting: `(sn=Mill\2a)`, + }, + { + src: `(sn=Mi*함*r)`, + expecting: `(sn=Mi\2a\ed\95\a8\2ar)`, + }, + { + src: `(objectGUID=\a)`, + expecting: `(objectGUID=\5ca)`, + }, + { + src: `(objectGUID=\a\a)`, + expecting: `(objectGUID=\5ca\5ca)`, + }, + } { + + res, err := ParseFilter(testInfo.src) + if err != nil { + if err.Error() != testInfo.err { + t.Fatal(testInfo.src, "=> ", err, "!=", testInfo.err) + } + } else if testInfo.err != "" { + t.Fatal(testInfo.src, "=> ", err, "!=", testInfo.err) + } + if res != testInfo.expecting { + t.Fatal(testInfo.expecting, "=> ", "invalid result", res) + } + } +} diff --git a/v3/ldap_test.go b/v3/ldap_test.go index c3245b0e..83d70284 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -236,11 +236,31 @@ func TestMultiGoroutineSearch(t *testing.T) { } func TestEscapeFilter(t *testing.T) { - if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want { - t.Errorf("Got %s, expected %s", want, got) - } - if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want { - t.Errorf("Got %s, expected %s", want, got) + for _, testInfo := range []struct { + src string + expecting string + }{ + { + src: "a\x00b(c)d*e\\f", + expecting: `a\00b\28c\29d\2ae\5cf`, + }, + { + src: "Lučić", + expecting: `Lu\c4\8di\c4\87`, + }, + { + src: `\\some-server\code`, + expecting: `\5c\5csome-server\5ccode`, + }, + { + src: `Mi*함*r`, + expecting: `Mi\2a\ed\95\a8\2ar`, + }, + } { + got := EscapeFilter(testInfo.src) + if got != testInfo.expecting { + t.Errorf("Got %s, expected %s", got, testInfo.expecting) + } } }