Skip to content

Commit

Permalink
allow for new map-like functions (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcMush authored Feb 8, 2024
1 parent 6575743 commit befd9f8
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 76 deletions.
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ end
It possible to disable the progress meter when the use is optional.

```julia
x,n = 1,10
x, n = 1, 10
p = Progress(n; enabled = false)
for iter in 1:10
x *= 2
Expand All @@ -431,7 +431,25 @@ In cases where the output is text output such as CI or in an HPC scheduler, the
```julia
is_logging(io) = isa(io, Base.TTY) == false || (get(ENV, "CI", nothing) == "true")
p = Progress(n; output = stderr, enabled = !is_logging(stderr))
````
```

### Adding support for more map-like functions

To add support for other functions, `ProgressMeter.ncalls` must be defined,
where `ncalls_map`, `ncalls_broadcast`, `ncalls_broadcast!` or `ncalls_reduce` can help

For example, with `tmap` from [`ThreadTools.jl`](https://github.com/baggepinnen/ThreadTools.jl):

```julia
using ThreadTools, ProgressMeter

ProgressMeter.ncalls(::typeof(tmap), ::Function, args...) = ProgressMeter.ncalls_map(args...)
ProgressMeter.ncalls(::typeof(tmap), ::Function, ::Int, args...) = ProgressMeter.ncalls_map(args...)

@showprogress tmap(abs2, 1:10^5)
@showprogress tmap(abs2, 4, 1:10^5)
```


## Development/debugging tips

Expand Down
91 changes: 55 additions & 36 deletions src/ProgressMeter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ end
"""
Equivalent of @showprogress for a distributed for loop.
```
result = @showprogress dt "Computing..." @distributed (+) for i = 1:50
result = @showprogress @distributed (+) for i = 1:50
sleep(0.1)
i^2
end
Expand Down Expand Up @@ -863,9 +863,14 @@ displays progress in performing a computation. You may optionally
supply a custom message to be printed that specifies the computation
being performed or other options.
`@showprogress` works for loops, comprehensions, `asyncmap`,
`broadcast`, `broadcast!`, `foreach`, `map`, `mapfoldl`,
`mapfoldr`, `mapreduce`, `pmap` and `reduce`.
`@showprogress` works for loops, comprehensions, and `map`-like
functions. These `map`-like functions rely on `ncalls` being defined
and can be checked with `methods(ProgressMeter.ncalls)`. New ones can
be added by defining `ProgressMeter.ncalls(::typeof(mapfun), args...) = ...`.
`@showprogress` is thread-safe and will work with `@distributed` loops
as well as threaded or distributed functions like `pmap` and `asyncmap`.
"""
macro showprogress(args...)
showprogress(args...)
Expand All @@ -889,8 +894,6 @@ function showprogress(args...)
return expr
end
metersym = gensym("meter")
mapfuns = (:asyncmap, :broadcast, :broadcast!, :foreach, :map,
:mapfoldl, :mapfoldr, :mapreduce, :pmap, :reduce)
kind = :invalid # :invalid, :loop, or :map

if isa(expr, Expr)
Expand All @@ -906,18 +909,18 @@ function showprogress(args...)
outerassignidx = lastindex(expr.args)
loopbodyidx = 2
kind = :loop
elseif expr.head == :call && expr.args[1] in mapfuns
elseif expr.head == :call
kind = :map
elseif expr.head == :do
call = expr.args[1]
if call.head == :call && call.args[1] in mapfuns
if call.head == :call
kind = :map
end
end
end

if kind == :invalid
throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, map, reduce, or pmap; got $expr"))
throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr"))
elseif kind == :loop
# As of julia 0.5, a comprehension's "loop" is actually one level deeper in the syntax tree.
if expr.head !== :for
Expand Down Expand Up @@ -995,7 +998,7 @@ function showprogress(args...)
return isa(a, Symbol) || isa(a, Number) || !(a.head in (:kw, :parameters))
end)
if expr.head == :do
insert!(mapargs, 1, :nothing) # to make args for ncalls line up
insert!(mapargs, 1, identity) # to make args for ncalls line up
end

# change call to progress_map
Expand All @@ -1011,7 +1014,7 @@ function showprogress(args...)
end

# create appropriate Progress expression
lenex = :(ncalls($(esc(mapfun)), ($([esc(a) for a in mapargs]...),)))
lenex = :(ncalls($(esc(mapfun)), $(esc.(mapargs)...)))
progex = :(Progress($lenex, $(showprogress_process_args(progressargs)...)))

# insert progress and mapfun kwargs
Expand All @@ -1028,10 +1031,12 @@ end
Run a `map`-like function while displaying progress.
`mapfun` can be any function, but it is only tested with `map`, `reduce` and `pmap`.
`ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...)` must be defined to
specify the number of calls to `f`.
"""
function progress_map(args...; mapfun=map,
progress=Progress(ncalls(mapfun, args)),
channel_bufflen=min(1000, ncalls(mapfun, args)),
progress=Progress(ncalls(mapfun, args...)),
channel_bufflen=min(1000, ncalls(mapfun, args...)),
kwargs...)
isempty(args) && return mapfun(; kwargs...)
f = first(args)
Expand Down Expand Up @@ -1066,36 +1071,50 @@ Run `pmap` while displaying progress.
progress_pmap(args...; kwargs...) = progress_map(args...; mapfun=pmap, kwargs...)

"""
Infer the number of calls to the mapped function (i.e. the length of the returned array) given the input arguments to map, reduce or pmap.
ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...)
Infer the number of calls to the mapped function (often the length of the returned array)
to define the length of the `Progress` in `@showprogress` and `progress_map`.
Internally uses one of `ncalls_map`, `ncalls_broadcast(!)` or `ncalls_reduce` depending
on the type of `mapfun`.
Support for additional functions can be added by defining
`ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...)`.
"""
function ncalls(::typeof(broadcast), map_args)
length(map_args) < 2 && return 1
return prod(length, Broadcast.combine_axes(map_args[2:end]...))
end
ncalls(::typeof(map), ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(map!), ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(foreach), ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(asyncmap), ::Function, args...) = ncalls_map(args...)

function ncalls(::typeof(broadcast!), map_args)
length(map_args) < 2 && return 1
return length(map_args[2])
end
ncalls(::typeof(pmap), ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(pmap), ::Function, ::AbstractWorkerPool, args...) = ncalls_map(args...)

function ncalls(::Union{typeof(mapreduce),typeof(mapfoldl),typeof(mapfoldr)}, map_args)
length(map_args) < 3 && return 1
return minimum(length, map_args[3:end])
ncalls(::typeof(mapfoldl), ::Function, ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(mapfoldr), ::Function, ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(mapreduce), ::Function, ::Function, args...) = ncalls_map(args...)

ncalls(::typeof(broadcast), ::Function, args...) = ncalls_broadcast(args...)
ncalls(::typeof(broadcast!), ::Function, args...) = ncalls_broadcast!(args...)

ncalls(::typeof(foldl), ::Function, arg) = ncalls_reduce(arg)
ncalls(::typeof(foldr), ::Function, arg) = ncalls_reduce(arg)
ncalls(::typeof(reduce), ::Function, arg) = ncalls_reduce(arg)

ncalls_reduce(arg) = length(arg) - 1

function ncalls_broadcast(args...)
length(args) < 1 && return 1
return prod(length, Broadcast.combine_axes(args...))
end

function ncalls(::typeof(pmap), map_args)
if length(map_args) 2 && map_args[2] isa AbstractWorkerPool
length(map_args) < 3 && return 1
return minimum(length, map_args[3:end])
else
length(map_args) < 2 && return 1
return minimum(length, map_args[2:end])
end
function ncalls_broadcast!(args...)
length(args) < 1 && return 1
return length(args[1])
end

function ncalls(mapfun::Function, map_args)
length(map_args) < 2 && return 1
return minimum(length, map_args[2:end])
function ncalls_map(args...)
length(args) < 1 && return 1
return minimum(length, args)
end

include("deprecated.jl")
Expand Down
90 changes: 52 additions & 38 deletions test/test_map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,52 @@ wp = WorkerPool(procs)
println()

# test ncalls
@test ncalls(map, (+, 1:10)) == 10
@test ncalls(pmap, (+, 1:10, 1:100)) == 10
@test ncalls(pmap, (+, wp, 1:10)) == 10
@test ncalls(reduce, (+, 1:10)) == 10
@test ncalls(mapreduce, (+, +, 1:10, (1:10)')) == 10
@test ncalls(mapfoldl, (+, +, 1:10, (1:10)')) == 10
@test ncalls(mapfoldr, (+, +, 1:10, (1:10)')) == 10
@test ncalls(foreach, (+, 1:10)) == 10
@test ncalls(broadcast, (+, 1:10, 1:10)) == 10
@test ncalls(broadcast, (+, 1:8, (1:7)', 1)) == 8*7
@test ncalls(broadcast, (+, 1:3, (1:5)', ones(1,1,2))) == 3*5*2
@test ncalls(broadcast!, (+, zeros(10,8))) == 80
@test ncalls(broadcast!, (+, zeros(10,8,7), 1:10)) == 10*8*7

@test ncalls(map, (time,)) == 1
@test ncalls(foreach, (time,)) == 1
@test ncalls(broadcast, (time,)) == 1
@test ncalls(broadcast!, (time, [1])) == 1
@test ncalls(mapreduce, (time, +)) == 1

@test_throws DimensionMismatch ncalls(broadcast, (+, 1:10, 1:100))
@test_throws DimensionMismatch ncalls(broadcast, (+, 1:100, 1:10))
@test ncalls(map, +, 1:10) == 10
@test ncalls(pmap, +, 1:10, 1:100) == 10
@test ncalls(pmap, +, wp, 1:10) == 10
@test ncalls(foldr, +, 1:10) == 9
@test ncalls(foldl, +, 1:10) == 9
@test ncalls(reduce, +, 1:10) == 9
@test ncalls(mapreduce, +, +, 1:10, (1:10)') == 10
@test ncalls(mapfoldl, +, +, 1:10, (1:10)') == 10
@test ncalls(mapfoldr, +, +, 1:10, (1:10)') == 10
@test ncalls(foreach, +, 1:10) == 10
@test ncalls(broadcast, +, 1:10, 1:10) == 10
@test ncalls(broadcast, +, 1:8, (1:7)', 1) == 8*7
@test ncalls(broadcast, +, 1:3, (1:5)', ones(1,1,2)) == 3*5*2
@test ncalls(broadcast!, +, zeros(10,8,7), 1:10) == 10*8*7

# functions with no args
# map(f) and foreach(f) were removed (#291)
@test ncalls(broadcast, time) == 1
@test ncalls(broadcast!, time, [1]) == 1
@test ncalls(broadcast!, time, zeros(10,8)) == 80
@test ncalls(mapreduce, time, +) == 1

@test_throws DimensionMismatch ncalls(broadcast, +, 1:10, 1:100)
@test_throws DimensionMismatch ncalls(broadcast, +, 1:100, 1:10)

@test_throws MethodError ncalls(map, 1:10, 1:10)
@test_throws MethodError @showprogress map(1:10, 1:10)

# test custom mapfun
mymap(f, x) = map(f, [x ; x])
@test_throws MethodError ncalls(mymap, +, 1:10)
@test_throws MethodError @showprogress mymap(+, 1:10)

ProgressMeter.ncalls(::typeof(mymap), ::Function, args...) = 2*ProgressMeter.ncalls_map(args...)
@test ncalls(mymap, +, 1:10) == 20

println("Testing custom map")
vals = @showprogress mymap(1:10) do x
sleep(0.1)
return x^2
end
@test vals == map(x->x^2, [1:10; 1:10])

println("Testing custom map with kwarg (color red)")
vals = @showprogress color=:red mymap(x->(sleep(0.1); x^2), 1:10)
@test vals == map(x->x^2, [1:10; 1:10])

# @showprogress
vals = @showprogress map(1:10) do x
Expand Down Expand Up @@ -137,9 +161,7 @@ wp = WorkerPool(procs)
return x
end
@test A == repeat(1:10, 1, 8)




# function passed by name
function testfun(x)
return x^2
Expand Down Expand Up @@ -172,7 +194,6 @@ wp = WorkerPool(procs)
@test broadcast(constfun) == @showprogress broadcast(constfun)
#@test mapreduce(constfun, error) == @showprogress mapreduce(constfun, error) # julia 1.2+


# #136: make sure mid progress shows up even without sleep
println("Verify that intermediate progress is displayed:")
@showprogress map(1:100) do i
Expand All @@ -184,41 +205,34 @@ wp = WorkerPool(procs)
vals = @showprogress pmap((x,y)->x*y, 1:10, 2:11)
@test vals == map((x,y)->x*y, 1:10, 2:11)







# Progress args
vals = @showprogress dt=0.1 desc="Computing" pmap(testfun, 1:10)
@test vals == map(testfun, 1:10)



# named vector arg
a = collect(1:10)
vals = @showprogress pmap(x->x^2, a)
@test vals == map(x->x^2, a)



# global variable in do
C = 10
vals = @showprogress pmap(1:10) do x
return C*x
end
@test vals == map(x->C*x, 1:10)



# keyword arguments
vals = @showprogress pmap(x->x^2, 1:100, batch_size=10)
@test vals == map(x->x^2, 1:100)
# with semicolon
vals = @showprogress pmap(x->x^2, 1:100; batch_size=10)
@test vals == map(x->x^2, 1:100)

A = rand(0:999, 7, 11, 13)
vals = @showprogress mapreduce(abs2, +, A; dims=1, init=0)
@test vals == mapreduce(abs2, +, A; dims=1, init=0)
vals = @showprogress mapfoldl(abs2, -, A; init=1)
@test vals == mapfoldl(abs2, -, A; init=1)

# pipes after map
@showprogress map(testfun, 1:10) |> sum |> length
Expand Down

0 comments on commit befd9f8

Please sign in to comment.