Skip to content

Commit

Permalink
return indices
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieugomez committed Dec 12, 2019
1 parent c5f5df5 commit 538c379
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 33 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,19 @@ compare(::AbstractString, ::AbstractString, ::PreMetric = TokenMax(Levenshtein()
```

## Find
`find_best` returns the element of an iterator with the highest similarity score
`find_best` returns the index of the element with the highest similarity score.
It returns nothing if all elements have a similarity score below `min_score` (default to 0.0)
```julia
find_best("New York", ["NewYork", "Newark", "San Francisco"], Levenshtein())
#> "NewYork"
#> 1
```

`find_all` returns all the elements of an iterator with a similarity score higher than a minimum value (default to 0.8)
`find_all` returns the indices of the elements with a similarity score higher than a minimum value (default to 0.8)

```julia
find_all("New York", ["NewYork", "Newark", "San Francisco"], Levenshtein(); min_score = 0.8)
#> 1-element Array{String,1}:
#> "NewYork"
#> [1]
```

While these functions are defined for any distance, they are particularly optimized for `Levenshtein` and `DamerauLevenshtein` distances (as well as their modifications via `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/.sublime2Terminal.jl
Original file line number Diff line number Diff line change
@@ -1 +1 @@
@time find_best(x[1], y, Levenshtein())
@time find_all(x[1], y, TokenMax(DamerauLevenshtein()))
45 changes: 21 additions & 24 deletions src/find.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,41 @@
"""
find_best(s1::AbstractString, iter, dist::PreMetric; min_score = 0.0)
find_best(s::AbstractString, iter::AbstractVector, dist::PreMetric; min_score = 0.0)
`find_best` returns the element of the iterator `iter` that has the highest similarity score with `s1` according to the distance `dist`. Return nothing if all elements have a similarity score below `min_score`.
`find_best` returns the index of the element of `iter` that has the highest similarity score with `s` according to the distance `dist`.
It returns nothing if all elements have a similarity score below `min_score` (default to 0.0)
The function is optimized for `Levenshtein` and `DamerauLevenshtein` distances (potentially modified by `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`)
"""
function find_best(s1::AbstractString, iter_s2, dist::PreMetric; min_score = 0.0)
function find_best(s::AbstractString, iter::AbstractVector, dist::PreMetric; min_score = 0.0)
min_score >= 0 || throw("min_score should be positive")
best_s2s = AbstractString["" for _ in 1:Threads.nthreads()]
best_scores = [-1.0 for _ in 1:Threads.nthreads()]
is = [0 for _ in 1:Threads.nthreads()]
scores = [-1.0 for _ in 1:Threads.nthreads()]
min_score_atomic = Threads.Atomic{typeof(min_score)}(min_score)
Threads.@threads for s2 in iter_s2
score = compare(s1, s2, dist; min_score = min_score_atomic[])
Threads.@threads for i in 1:length(iter)
score = compare(s, iter[i], dist; min_score = min_score_atomic[])
min_score_atomic_old = Threads.atomic_max!(min_score_atomic, score)
if score >= min_score_atomic_old
best_s2s[Threads.threadid()] = s2
best_scores[Threads.threadid()] = score
score == 1.0 && return s2
score == 1.0 && return i
is[Threads.threadid()] = i
scores[Threads.threadid()] = score
end
end
i = argmax(best_scores)
if best_scores[i] < 0
return nothing
else
return best_s2s[i]
end
i = argmax(scores)
is[i] == 0 ? nothing : is[i]
end


"""
find_all(s1::AbstractString, iter, dist::PreMetric; min_score = 0.8)
`find_all` returns the vector with all the elements of `iter` that have a similarity score higher or equal than `min_score` according to the distance `dist`.
find_all(s::AbstractString, iter::AbstractVector, dist::PreMetric; min_score = 0.8)
`find_all` returns the vector of indices for elements of `iter` that have a similarity score higher or equal than `min_score` according to the distance `dist`.
The function is optimized for `Levenshtein` and `DamerauLevenshtein` distances (potentially modified by `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`)
"""
function find_all(s1::AbstractString, iter_s2, dist::PreMetric; min_score = 0.8)
best_s2s = [eltype(iter_s2)[] for _ in 1:Threads.nthreads()]
Threads.@threads for s2 in iter_s2
score = compare(s1, s2, dist; min_score = min_score)
function find_all(s::AbstractString, iter::AbstractVector, dist::PreMetric; min_score = 0.8)
out = [Int[] for _ in 1:Threads.nthreads()]
Threads.@threads for i in 1:length(iter)
score = compare(s, iter[i], dist; min_score = min_score)
if score >= min_score
push!(best_s2s[Threads.threadid()], s2)
push!(out[Threads.threadid()], i)
end
end
vcat(best_s2s...)
vcat(out...)
end
10 changes: 6 additions & 4 deletions test/modifiers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ using StringDistances, Test
end

# check find_best and find_all
@test find_best("New York", ["NewYork", "Newark", "San Francisco"], Levenshtein()) == "NewYork"
@test find_best("New York", ["NewYork", "Newark", "San Francisco"], Jaro()) == "NewYork"
@test find_all("New York", ["NewYork", "Newark", "San Francisco"], Levenshtein()) == ["NewYork"]
@test find_all("New York", ["NewYork", "Newark", "San Francisco"], Jaro()) == ["NewYork", "Newark"]
@test find_best("New York", ["NewYork", "Newark", "San Francisco"], Levenshtein()) == 1
@test find_best("New York", ["NewYork", "Newark", "San Francisco"], Levenshtein(); min_score = 0.99) == nothing
@test find_best("New York", ["NewYork", "Newark", "San Francisco"], Jaro()) == 1
@test find_all("New York", ["NewYork", "Newark", "San Francisco"], Levenshtein()) == [1]
@test find_all("New York", ["NewYork", "Newark", "San Francisco"], Jaro()) == [1, 2]
@test find_all("New York", ["NewYork", "Newark", "San Francisco"], Jaro(); min_score = 0.99) == Int[]

end

Expand Down

0 comments on commit 538c379

Please sign in to comment.