Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: EMLX compiled mode #68

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

Conversation

polvalente
Copy link
Collaborator

@polvalente polvalente commented Dec 12, 2024

This PR uses Cocoa's nif_call library to bridge the missing
gap and implement a proper Nx.Defn compiler for EMLX.

closes #61

Benchmark of backend [EMLX, EXLA] against compiler [self, Nx.Defn.Evaluator]

Mix.install [{:emlx, path: __DIR__}, :benchee, :exla]

defmodule ToBench do
  def run(backend, compiler) do
    prev_backend = Nx.default_backend(backend)
    fun = fn x, y ->
      Enum.reduce(1..10, x, fn x, acc ->
        Nx.dot(x, acc)
        |> Nx.add(y)
      end)
    end


    Nx.Defn.jit_apply(fun, [Nx.iota({10, 10}, type: :f32, backend: backend), Nx.iota({10, 10}, type: :f32, backend: backend)], compiler: compiler)
    |> tap(fn _ -> Nx.default_backend(prev_backend) end)
  end
end

Benchee.run(%{
  "EMLX (evaluator)" => fn -> ToBench.run(EMLX.Backend, Nx.Defn.Evaluator) end,
  "EMLX" => fn -> ToBench.run(EMLX.Backend, EMLX) end,
  "EMLX gpu" => fn -> ToBench.run({EMLX.Backend, device: :gpu}, EMLX) end,
  "EXLA (evaluator)" => fn -> ToBench.run(EXLA.Backend, Nx.Defn.Evaluator) end,
  "EXLA" => fn -> ToBench.run(EXLA.Backend, EXLA) end
})
Name                       ips        average  deviation         median         99th %
EXLA                   37.97 K       26.34 μs    ±49.15%          22 μs       94.71 μs
EMLX                   16.64 K       60.08 μs    ±43.45%       52.88 μs      204.13 μs
EMLX gpu               16.11 K       62.07 μs    ±61.75%       53.29 μs      206.79 μs
EMLX (evaluator)        3.48 K      287.49 μs    ±30.55%      280.96 μs      421.05 μs
EXLA (evaluator)        3.17 K      315.41 μs    ±28.33%      296.25 μs      540.18 μs

Comparison: 
EXLA                   37.97 K
EMLX                   16.64 K - 2.28x slower +33.75 μs
EMLX gpu               16.11 K - 2.36x slower +35.73 μs
EMLX (evaluator)        3.48 K - 10.92x slower +261.16 μs
EXLA (evaluator)        3.17 K - 11.98x slower +289.08 μs

@polvalente polvalente self-assigned this Dec 12, 2024
return output_tensors;
};

emlx::function compiled_function_ptr = mlx::core::compile(fun);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni I got things working via the idea we discussed yesterday! ~5x speed up vs the non-compiled version as per the benchmark in the PR description.

I'm do wonder about what XLA is doing to get such a performance boost when compiled.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, so you figured out how to pass an elixir function into C/C++?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not exactly that, but that library I mentioned that's newer.

We're calling back into Elixir (in an asynchronous manner, mind you) and then executing the Nx AST with the tracer params provided by mlx::core::compile.

Not the ideal solution as this forces copying data between processes, but given that the majority of the data is pointers or references to things, works well enough.

@polvalente polvalente requested a review from cocoa-xu December 12, 2024 11:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use mlx::compile for the Nx.Defn compiler
2 participants