Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GuideLogitsProcessor for MPS device #1306

Merged
merged 3 commits into from
Jan 2, 2025

Conversation

hoesler
Copy link
Contributor

@hoesler hoesler commented Dec 2, 2024

While debugging #1282, I found the the issue to be caused by non_blocking=True on a Mac with an MPS device.

The usage of non_blocking=True is only safe for CPU->GPU, as far as I understand from this guide.
For other directions, especially CPU->MPS in my case, this results in a all-zero vector instead of the real token ids which generates wrong tokens and results in errors like described in #1282.
I also tried to use torch.mps.synchronize(), but it doesn't help.

I'm haven't benchmarked the difference in speed, but I suspect it to be neglectable because the created vector is accessed directly afterwards.

Fixes #1282

@cpfiffer
Copy link
Contributor

cpfiffer commented Dec 2, 2024

I think we'd want to see some benchmarks + tests on this one for sure -- I'm worried that this might have some unanticipated consequences.

I'm curious if it's also possible to specify this by a keyword argument somewhere.

@hoesler
Copy link
Contributor Author

hoesler commented Dec 3, 2024

I think we'd want to see some benchmarks + tests on this one for sure -- I'm worried that this might have some unanticipated consequences.

As I cant't do a benchmark comparison on my MPS machine, would you mind doing that? Which tests in addition to the existing ones do you imagine?
I think the unwanted consequences are what we see now by enabling that feature for potential performance benefits. Reading the linked article, I guess blocking mode is the default for a reason.
Maybe @lapp0 can comment on that, as he introduced this code in #1192.

I'm curious if it's also possible to specify this by a keyword argument somewhere.

Do you suggest here to run different code depending on the device in use or by a user setting?

@lapp0
Copy link
Contributor

lapp0 commented Dec 4, 2024

I believe this was introduced before MLX was integrated. Good find. Could you see what this does to the benchmarks on your Apple Silicon machine?

@dylanjcastillo
Copy link

This also fixes #1316

@hoesler
Copy link
Contributor Author

hoesler commented Dec 4, 2024

@lapp0 I obviously can't compare blocking vs non_blocking but with the fix, this is what I get:

asv run --bench bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation main..fix-guide-logits-processor-mps
Couldn't load asv.plugins._mamba_helpers because
No module named 'libmambapy'
· Creating environments
· Discovering benchmarks
· Running 1 total benchmarks (1 commits * 1 environments * 1 benchmarks)
[ 0.00%] · For Outlines commit aa333be9 <fix-guide-logits-processor-mps>:
[ 0.00%] ·· Benchmarking virtualenv-py3.11
[50.00%] ··· Running (bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation--).
[100.00%] ··· bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation                                                                                                                        ok
[100.00%] ··· ======================== ============ ============
              --                                 param2
              ------------------------ -------------------------
               array_library, pattern     [^Z]*          Z*
              ======================== ============ ============
                       torch             16.5±1ms     106±10μs
                       numpy             17.5±1ms     120±10μs
                        mlx             19.1±0.5ms    185±10μs
                     torch_mps          18.6±0.3ms   1.73±0.1ms
              ======================== ============ ============

I don't know what to make of the overall significant difference between patterns and especially the huge slowdown for torch_mps for pattern Z*. The latter could be related to MPS performing worse if the number of dims is small (see pytorch/pytorch#77799).

@hoesler hoesler changed the title Fix outlines for MPS device Fix GuideLogitsProcessor for MPS device Dec 4, 2024
@hoesler
Copy link
Contributor Author

hoesler commented Dec 4, 2024

Another idea: Wouldn't it be even better to remove the necessity of device synchronization by creating the allowed_tokens tensors directly on the device?

        with torch.device(mask.device):
            for i, guide_state in enumerate(sequence_states):
                allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
                allowed_tokens_batch.append(allowed_tokens)
                batch_indices.append(
                    torch.full_like(allowed_tokens, i)
                )  # Store batch index for each allowed token

This seems to be the recommended approach.

@hoesler
Copy link
Contributor Author

hoesler commented Dec 5, 2024

The following seems to perform even better. Concat the indexing tensors in cpu and move only these to target device:

        allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
        ....
        allowed_tokens_concat = torch.cat(allowed_tokens_batch).to(logits.device)
        batch_indices_concat = torch.cat(batch_indices).to(logits.device)
              ------------------------ -----------------------
               array_library, pattern     [^Z]*         Z*
              ======================== ============ ==========
                       torch             17.4±2ms    90.4±6μs
                       numpy            16.6±0.9ms   89.4±5μs
                        mlx             18.4±0.7ms   139±9μs
                     torch_mps          19.0±0.9ms   775±30μs
              ======================== ============ ==========

This also removes the need to use non-blocking syncs. Will push that.

@rlouf rlouf merged commit 6a8612b into dottxt-ai:main Jan 2, 2025
6 checks passed
@rlouf
Copy link
Member

rlouf commented Jan 2, 2025

Thank you for the fix ⭐

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cookbook "Receipt Data Extraction with VLMs" is not reproducible
5 participants