Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move generation of sampling params #1340

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union

from outlines.generate.generator import sequence_generator
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -353,11 +352,13 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
]

generated_sequences = [
self.format_sequence(
self.strip_stop_sequences(sequence, stop_sequences)
(
self.format_sequence(
self.strip_stop_sequences(sequence, stop_sequences)
)
if stop
else sequence
)
if stop
else sequence
for sequence, stop in zip(
generated_sequences, is_stop_at_reached
)
Expand Down Expand Up @@ -428,22 +429,7 @@ def __init__(self, model, logits_processor, sampler):
self.model = model
self.logits_processor = logits_processor

if isinstance(sampler, MultinomialSampler):
self.sampling_params = SamplingParameters(
"multinomial",
sampler.samples,
sampler.top_p,
sampler.top_k,
sampler.temperature,
)
elif isinstance(sampler, GreedySampler):
self.sampling_params = SamplingParameters(
"greedy", sampler.samples, None, None, 0.0
)
elif isinstance(sampler, BeamSearchSampler):
self.sampling_params = SamplingParameters(
"beam_search", sampler.samples, None, None, 1.0
)
self.sampling_params = sampler.sampling_params

def prepare_generation_parameters(
self,
Expand Down
30 changes: 30 additions & 0 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Protocol, Tuple

if TYPE_CHECKING:
Expand All @@ -17,6 +18,17 @@ def __call__(
...


@dataclass(frozen=True)
class SamplingParameters:
"""Sampling parameters available in Outlines."""

sampler: str
num_samples: int = 1
top_p: Optional[float] = None
top_k: Optional[int] = None
temperature: Optional[float] = None


class GreedySampler:
"""Greedy Sampling algorithm.

Expand Down Expand Up @@ -76,6 +88,10 @@ def __call__(

return next_token_ids, ancestors, weights

@property
def sampling_params(self):
return SamplingParameters("greedy", self.samples, None, None, 0.0)


greedy = GreedySampler

Expand Down Expand Up @@ -161,6 +177,16 @@ def __call__(

return next_token_ids, ancestors, weights

@property
def sampling_params(self):
return SamplingParameters(
"multinomial",
self.samples,
self.top_p,
self.top_k,
self.temperature,
)


multinomial = MultinomialSampler

Expand Down Expand Up @@ -320,5 +346,9 @@ def __call__(

return next_token_ids, ancestors, weights

@property
def sampling_params(self):
return SamplingParameters("beam_search", self.samples, None, None, 1.0)


beam_search = BeamSearchSampler
23 changes: 23 additions & 0 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def test_greedy():
assert ancestors.equal(torch.tensor([0, 1]))
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))

params = sampler.sampling_params
assert params.sampler == "greedy"
assert params.num_samples == 1
assert params.top_p is None
assert params.top_k is None
assert params.temperature == 0.0


def test_multinomial():
rng = torch.Generator()
Expand All @@ -72,6 +79,14 @@ def test_multinomial():
assert ancestors.equal(torch.tensor([0, 1]))
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))

sampler = MultinomialSampler(samples=5, top_k=10, top_p=0.9, temperature=0.8)
params = sampler.sampling_params
assert params.sampler == "multinomial"
assert params.num_samples == 5
assert params.top_p == 0.9
assert params.top_k == 10
assert params.temperature == 0.8


def test_multinomial_init():
sampler = MultinomialSampler()
Expand Down Expand Up @@ -252,3 +267,11 @@ def test_beam_search():
]
)
)

sampler = BeamSearchSampler(beams=3)
params = sampler.sampling_params
assert params.sampler == "beam_search"
assert params.num_samples == 3
assert params.top_p is None
assert params.top_k is None
assert params.temperature == 1.0
Loading