Skip to content

Commit

Permalink
Support BM25 scoring for chunk matches
Browse files Browse the repository at this point in the history
  • Loading branch information
jtibshirani committed Jan 14, 2025
1 parent b51a233 commit 7c931a6
Show file tree
Hide file tree
Showing 7 changed files with 474 additions and 204 deletions.
20 changes: 19 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
stringHeaderBytes uint64 = 16
pointerSize uint64 = 8
interfaceBytes uint64 = 16
maxUInt16 = 0xffff
)

// FileMatch contains all the matches within a file.
Expand Down Expand Up @@ -135,6 +136,22 @@ func (m *FileMatch) sizeBytes() (sz uint64) {
return
}

// addScore increments the score of the FileMatch by the computed score. If
// debugScore is true, it also adds a debug string to the FileMatch. If raw is
// -1, it is ignored. Otherwise, it is added to the debug string.
func (m *FileMatch) addScore(what string, computed float64, raw float64, debugScore bool) {
if computed != 0 && debugScore {
var b strings.Builder
fmt.Fprintf(&b, "%s", what)
if raw != -1 {
fmt.Fprintf(&b, "(%s)", strconv.FormatFloat(raw, 'f', -1, 64))
}
fmt.Fprintf(&b, ":%.2f, ", computed)
m.Debug += b.String()
}
m.Score += computed
}

// ChunkMatch is a set of non-overlapping matches within a contiguous range of
// lines in the file.
type ChunkMatch struct {
Expand Down Expand Up @@ -976,7 +993,8 @@ type SearchOptions struct {

// EXPERIMENTAL. If true, use text-search style scoring instead of the default
// scoring formula. The scoring algorithm treats each match in a file as a term
// and computes an approximation to BM25.
// and computes an approximation to BM25. When enabled, BM25 scoring is used for
// the overall FileMatch score, as well as individual LineMatch and ChunkMatch scores.
//
// The calculation of IDF assumes that Zoekt visits all documents containing any
// of the query terms during evaluation. This is true, for example, if all query
Expand Down
17 changes: 15 additions & 2 deletions build/scoring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,21 @@ func TestBM25(t *testing.T) {
language: "Java",
// bm25-score: 1.81 <- sum-termFrequencyScore: 116.00, length-ratio: 1.00
wantScore: 1.81,
// line 3: public class InnerClasses {
wantBestLineMatch: 3,
// line 54: private static <A, B> B runInnerInterface(InnerInterface<A, B> fn, A a) {
wantBestLineMatch: 54,
}, {
// Another content-only match
fileName: "example.java",
query: &query.And{Children: []query.Q{
&query.Substring{Pattern: "system"},
&query.Substring{Pattern: "time"},
}},
content: exampleJava,
language: "Java",
// bm25-score: 0.96 <- sum-termFrequencies: 12, length-ratio: 1.00
wantScore: 0.96,
// line 59: if (System.nanoTime() > System.currentTimeMillis()) {
wantBestLineMatch: 59,
},
{
// Matches only on filename
Expand Down
184 changes: 22 additions & 162 deletions contentprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ package zoekt

import (
"bytes"
"fmt"
"log"
"path"
"slices"
"sort"
"strings"
"unicode"
"unicode/utf8"

Expand Down Expand Up @@ -145,7 +143,7 @@ func (p *contentProvider) findOffset(filename bool, r uint32) uint32 {
//
// Note: the byte slices may be backed by mmapped data, so before being
// returned by the API it needs to be copied.
func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []LineMatch {
func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []LineMatch {
var filenameMatches []*candidateMatch
contentMatches := make([]*candidateMatch, 0, len(ms))

Expand All @@ -160,16 +158,16 @@ func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int,
// If there are any content matches, we only return these and skip filename matches.
if len(contentMatches) > 0 {
contentMatches = breakMatchesOnNewlines(contentMatches, p.data(false))
return p.fillContentMatches(contentMatches, numContextLines, language, debug)
return p.fillContentMatches(contentMatches, numContextLines, language, opts)
}

// Otherwise, we return a single line containing the filematch match.
bestMatch, _ := p.candidateMatchScore(filenameMatches, language, debug)
lineScore, _ := p.scoreLine(filenameMatches, language, -1 /* must pass -1 for filenames */, opts)
res := LineMatch{
Line: p.id.fileName(p.idx),
FileName: true,
Score: bestMatch.score,
DebugScore: bestMatch.debugScore,
Score: lineScore.score,
DebugScore: lineScore.debugScore,
}

for _, m := range ms {
Expand All @@ -192,7 +190,7 @@ func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int,
//
// Note: the byte slices may be backed by mmapped data, so before being
// returned by the API it needs to be copied.
func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []ChunkMatch {
func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []ChunkMatch {
var filenameMatches []*candidateMatch
contentMatches := make([]*candidateMatch, 0, len(ms))

Expand All @@ -206,11 +204,11 @@ func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines

// If there are any content matches, we only return these and skip filename matches.
if len(contentMatches) > 0 {
return p.fillContentChunkMatches(contentMatches, numContextLines, language, debug)
return p.fillContentChunkMatches(contentMatches, numContextLines, language, opts)
}

// Otherwise, we return a single chunk representing the filename match.
bestMatch, _ := p.candidateMatchScore(filenameMatches, language, debug)
lineScore, _ := p.scoreLine(filenameMatches, language, -1 /* must pass -1 for filenames */, opts)
fileName := p.id.fileName(p.idx)
ranges := make([]Range, 0, len(ms))
for _, m := range ms {
Expand All @@ -233,12 +231,12 @@ func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines
ContentStart: Location{ByteOffset: 0, LineNumber: 1, Column: 1},
Ranges: ranges,
FileName: true,
Score: bestMatch.score,
DebugScore: bestMatch.debugScore,
Score: lineScore.score,
DebugScore: lineScore.debugScore,
}}
}

func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []LineMatch {
func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []LineMatch {
var result []LineMatch
for len(ms) > 0 {
m := ms[0]
Expand Down Expand Up @@ -296,16 +294,17 @@ func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLin
finalMatch.After = p.newlines().getLines(data, num+1, num+1+numContextLines)
}

bestMatch, symbolInfo := p.candidateMatchScore(lineCands, language, debug)
finalMatch.Score = bestMatch.score
finalMatch.DebugScore = bestMatch.debugScore
lineScore, symbolInfo := p.scoreLine(lineCands, language, num, opts)
finalMatch.Score = lineScore.score
finalMatch.DebugScore = lineScore.debugScore

for i, m := range lineCands {
fragment := LineFragmentMatch{
Offset: m.byteOffset,
LineOffset: int(m.byteOffset) - lineStart,
MatchLength: int(m.byteMatchSz),
}

if i < len(symbolInfo) && symbolInfo[i] != nil {
fragment.SymbolInfo = symbolInfo[i]
}
Expand All @@ -317,8 +316,7 @@ func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLin
return result
}

func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []ChunkMatch {
newlines := p.newlines()
func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []ChunkMatch {
data := p.data(false)

// columnHelper prevents O(len(ms) * len(data)) lookups for all columns.
Expand All @@ -332,11 +330,10 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte
sort.Sort((sortByOffsetSlice)(ms))
}

newlines := p.newlines()
chunks := chunkCandidates(ms, newlines, numContextLines)
chunkMatches := make([]ChunkMatch, 0, len(chunks))
for _, chunk := range chunks {
bestMatch, symbolInfo := p.candidateMatchScore(chunk.candidates, language, debug)

ranges := make([]Range, 0, len(chunk.candidates))
for _, cm := range chunk.candidates {
startOffset := cm.byteOffset
Expand All @@ -363,14 +360,7 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte
}
firstLineStart := newlines.lineStart(firstLineNumber)

bestLineMatch := 0
if bestMatch.match != nil {
bestLineMatch = newlines.atOffset(bestMatch.match.byteOffset)
if debug {
bestMatch.debugScore = fmt.Sprintf("%s, (line: %d)", bestMatch.debugScore, bestLineMatch)
}
}

chunkScore, symbolInfo := p.scoreChunk(chunk.candidates, language, opts)
chunkMatches = append(chunkMatches, ChunkMatch{
Content: newlines.getLines(data, firstLineNumber, int(chunk.lastLine)+numContextLines+1),
ContentStart: Location{
Expand All @@ -381,9 +371,9 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte
FileName: false,
Ranges: ranges,
SymbolInfo: symbolInfo,
BestLineMatch: uint32(bestLineMatch),
Score: bestMatch.score,
DebugScore: bestMatch.debugScore,
BestLineMatch: uint32(chunkScore.bestLine),
Score: chunkScore.score,
DebugScore: chunkScore.debugScore,
})
}
return chunkMatches
Expand All @@ -405,6 +395,7 @@ type candidateChunk struct {
// output invariants: if you flatten candidates the input invariant is retained.
func chunkCandidates(ms []*candidateMatch, newlines newlines, numContextLines int) []candidateChunk {
var chunks []candidateChunk

for _, m := range ms {
startOffset := m.byteOffset
endOffset := m.byteOffset + m.byteMatchSz
Expand Down Expand Up @@ -536,10 +527,6 @@ const (
scoreKindMatch = 100.0
scoreFactorAtomMatch = 400.0

// File-only scoring signals. For now these are also bounded ~9000 to give them
// equal weight with the query-dependent signals.
scoreFileRankFactor = 9000.0

// Used for ordering line and chunk matches within a file.
scoreLineOrderFactor = 1.0

Expand Down Expand Up @@ -643,133 +630,6 @@ func (p *contentProvider) findSymbol(cm *candidateMatch) (DocumentSection, *Symb
return sec, si, true
}

// calculateTermFrequency computes the term frequency for the file match.
// Notes:
// * Filename matches count more than content matches. This mimics a common text search strategy to 'boost' matches on document titles.
// * Symbol matches also count more than content matches, to reward matches on symbol definitions.
func (p *contentProvider) calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int {
// Treat each candidate match as a term and compute the frequencies. For now, ignore case
// sensitivity and treat filenames and symbols the same as content.
termFreqs := map[string]int{}
for _, m := range cands {
term := string(m.substrLowered)
if m.fileName || p.matchesSymbol(m) {
termFreqs[term] += 5
} else {
termFreqs[term]++
}
}

for term := range termFreqs {
df[term] += 1
}
return termFreqs
}

// scoredMatch holds the score information for a candidate match.
type scoredMatch struct {
score float64
debugScore string
match *candidateMatch
}

// candidateMatchScore scores all candidate matches and returns the best-scoring match plus its score information.
// Invariant: there should be at least one input candidate, len(ms) > 0.
func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language string, debug bool) (scoredMatch, []*Symbol) {
score := 0.0
what := ""

addScore := func(w string, s float64) {
if s != 0 && debug {
what += fmt.Sprintf("%s:%.2f, ", w, s)
}
score += s
}

filename := p.data(true)
var symbolInfo []*Symbol

var bestMatch scoredMatch
for i, m := range ms {
data := p.data(m.fileName)

endOffset := m.byteOffset + m.byteMatchSz
startBoundary := m.byteOffset < uint32(len(data)) && (m.byteOffset == 0 || byteClass(data[m.byteOffset-1]) != byteClass(data[m.byteOffset]))
endBoundary := endOffset > 0 && (endOffset == uint32(len(data)) || byteClass(data[endOffset-1]) != byteClass(data[endOffset]))

score = 0
what = ""

if startBoundary && endBoundary {
addScore("WordMatch", scoreWordMatch)
} else if startBoundary || endBoundary {
addScore("PartialWordMatch", scorePartialWordMatch)
}

if m.fileName {
sep := bytes.LastIndexByte(data, '/')
startMatch := int(m.byteOffset) == sep+1
endMatch := endOffset == uint32(len(data))
if startMatch && endMatch {
addScore("Base", scoreBase)
} else if startMatch || endMatch {
addScore("EdgeBase", (scoreBase+scorePartialBase)/2)
} else if sep < int(m.byteOffset) {
addScore("InnerBase", scorePartialBase)
}
} else if sec, si, ok := p.findSymbol(m); ok {
startMatch := sec.Start == m.byteOffset
endMatch := sec.End == endOffset
if startMatch && endMatch {
addScore("Symbol", scoreSymbol)
} else if startMatch || endMatch {
addScore("EdgeSymbol", (scoreSymbol+scorePartialSymbol)/2)
} else {
addScore("OverlapSymbol", scorePartialSymbol)
}

// Score based on symbol data
if si != nil {
symbolKind := ctags.ParseSymbolKind(si.Kind)
sym := sectionSlice(data, sec)

addScore(fmt.Sprintf("kind:%s:%s", language, si.Kind), scoreSymbolKind(language, filename, sym, symbolKind))

// This is from a symbol tree, so we need to store the symbol
// information.
if m.symbol {
if symbolInfo == nil {
symbolInfo = make([]*Symbol, len(ms))
}
// findSymbols does not hydrate in Sym. So we need to store it.
si.Sym = string(sym)
symbolInfo[i] = si
}
}
}

// scoreWeight != 1 means it affects score
if !epsilonEqualsOne(m.scoreWeight) {
score = score * m.scoreWeight
if debug {
what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight)
}
}

if score > bestMatch.score {
bestMatch.score = score
bestMatch.debugScore = what
bestMatch.match = m
}
}

if debug {
bestMatch.debugScore = fmt.Sprintf("score:%.2f <- %s", bestMatch.score, strings.TrimSuffix(bestMatch.debugScore, ", "))
}

return bestMatch, symbolInfo
}

// sectionSlice will return data[sec.Start:sec.End] but will clip Start and
// End such that it won't be out of range.
func sectionSlice(data []byte, sec DocumentSection) []byte {
Expand Down
4 changes: 2 additions & 2 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ nextFileMatch:
finalCands := d.gatherMatches(nextDoc, mt, known, shouldMergeMatches)

if opts.ChunkMatches {
fileMatch.ChunkMatches = cp.fillChunkMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts.DebugScore)
fileMatch.ChunkMatches = cp.fillChunkMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts)
} else {
fileMatch.LineMatches = cp.fillMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts.DebugScore)
fileMatch.LineMatches = cp.fillMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts)
}

var tf map[string]int
Expand Down
Loading

0 comments on commit 7c931a6

Please sign in to comment.