Skip to content

Commit

Permalink
Use serving name as pg group name (#1566)
Browse files Browse the repository at this point in the history
Co-authored-by: Steffen Deusch <[email protected]>
  • Loading branch information
SteffenDE and Steffen Deusch authored Dec 12, 2024
1 parent b2fdb9a commit f0b3f10
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
4 changes: 2 additions & 2 deletions nx/lib/nx/serving.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ defmodule Nx.Serving do
end

defp distributed_batched_run_with_retries!(name, input, retries) do
case :pg.get_members(Nx.Serving.PG, __MODULE__) do
case :pg.get_members(Nx.Serving.PG, name) do
[] ->
exit({:noproc, {__MODULE__, :distributed_batched_run, [name, input, [retries: retries]]}})

Expand Down Expand Up @@ -1332,7 +1332,7 @@ defmodule Nx.Serving do
)

serving_weight = max(1, weight * partitions_count)
:pg.join(Nx.Serving.PG, __MODULE__, List.duplicate(self(), serving_weight))
:pg.join(Nx.Serving.PG, name, List.duplicate(self(), serving_weight))

for batch_key <- batch_keys do
stack_init(batch_key)
Expand Down
17 changes: 11 additions & 6 deletions nx/test/nx/serving_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,8 @@ defmodule Nx.ServingTest do
]

Node.spawn_link(:"[email protected]", DistributedServings, :multiply, [parent, opts])
assert_receive {_, :join, Nx.Serving, _}
assert_receive {_, :join, name, _}
assert name == config.test

batch = Nx.Batch.concatenate([Nx.tensor([1, 2])])

Expand Down Expand Up @@ -1327,14 +1328,16 @@ defmodule Nx.ServingTest do
opts2 = Keyword.put(opts, :distribution_weight, 4)

Node.spawn_link(:"[email protected]", DistributedServings, :multiply, [parent, opts])
assert_receive {_, :join, Nx.Serving, pids}
assert_receive {_, :join, name, pids}
assert length(pids) == 1
assert name == config.test

Node.spawn_link(:"[email protected]", DistributedServings, :multiply, [parent, opts2])
assert_receive {_, :join, Nx.Serving, pids}
assert_receive {_, :join, name, pids}
assert length(pids) == 4
assert name == config.test

members = :pg.get_members(Nx.Serving.PG, Nx.Serving)
members = :pg.get_members(Nx.Serving.PG, config.test)
assert length(members) == 5
end

Expand All @@ -1356,7 +1359,8 @@ defmodule Nx.ServingTest do

args = [parent, opts]
Node.spawn_link(:"[email protected]", DistributedServings, :add_five_round_about, args)
assert_receive {_, :join, Nx.Serving, _}
assert_receive {_, :join, name, _}
assert name == config.test

batch = Nx.Batch.concatenate([Nx.tensor([1, 2])])

Expand Down Expand Up @@ -1412,7 +1416,8 @@ defmodule Nx.ServingTest do
]

Node.spawn_link(:"[email protected]", DistributedServings, :multiply, [parent, opts])
assert_receive {_, :join, Nx.Serving, _}
assert_receive {_, :join, name, _}
assert name == config.test

batch = Nx.Batch.concatenate([Nx.tensor([1, 2])])

Expand Down

0 comments on commit f0b3f10

Please sign in to comment.