diff --git a/nx/lib/nx/serving.ex b/nx/lib/nx/serving.ex index acb0523448..ab0a524bd7 100644 --- a/nx/lib/nx/serving.ex +++ b/nx/lib/nx/serving.ex @@ -1117,7 +1117,7 @@ defmodule Nx.Serving do end defp distributed_batched_run_with_retries!(name, input, retries) do - case :pg.get_members(Nx.Serving.PG, __MODULE__) do + case :pg.get_members(Nx.Serving.PG, name) do [] -> exit({:noproc, {__MODULE__, :distributed_batched_run, [name, input, [retries: retries]]}}) @@ -1332,7 +1332,7 @@ defmodule Nx.Serving do ) serving_weight = max(1, weight * partitions_count) - :pg.join(Nx.Serving.PG, __MODULE__, List.duplicate(self(), serving_weight)) + :pg.join(Nx.Serving.PG, name, List.duplicate(self(), serving_weight)) for batch_key <- batch_keys do stack_init(batch_key) diff --git a/nx/test/nx/serving_test.exs b/nx/test/nx/serving_test.exs index f6aa54368d..49d72573f0 100644 --- a/nx/test/nx/serving_test.exs +++ b/nx/test/nx/serving_test.exs @@ -1288,7 +1288,8 @@ defmodule Nx.ServingTest do ] Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])]) @@ -1327,14 +1328,16 @@ defmodule Nx.ServingTest do opts2 = Keyword.put(opts, :distribution_weight, 4) Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, pids} + assert_receive {_, :join, name, pids} assert length(pids) == 1 + assert name == config.test Node.spawn_link(:"tertiary@127.0.0.1", DistributedServings, :multiply, [parent, opts2]) - assert_receive {_, :join, Nx.Serving, pids} + assert_receive {_, :join, name, pids} assert length(pids) == 4 + assert name == config.test - members = :pg.get_members(Nx.Serving.PG, Nx.Serving) + members = :pg.get_members(Nx.Serving.PG, config.test) assert length(members) == 5 end @@ -1356,7 +1359,8 @@ defmodule Nx.ServingTest do args = [parent, opts] Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :add_five_round_about, args) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])]) @@ -1412,7 +1416,8 @@ defmodule Nx.ServingTest do ] Node.spawn_link(:"tertiary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])])