-
Notifications
You must be signed in to change notification settings - Fork 549
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(audio): integrate audio transfromers
- Loading branch information
Showing
11 changed files
with
490 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union | ||
|
||
from outlines.generate.api import GenerationParameters, SamplingParameters | ||
from outlines.models import Transformers | ||
|
||
if TYPE_CHECKING: | ||
from outlines.processors import OutlinesLogitsProcessor | ||
|
||
|
||
class TransformersAudio(Transformers): | ||
def __init__(self, model, tokenizer, processor): | ||
super().__init__(model, tokenizer) | ||
self.processor = processor | ||
|
||
def generate( # type: ignore | ||
self, | ||
prompts: Union[str, List[str]], | ||
media: Union[List[Any], List[List[Any]]], | ||
generation_parameters: GenerationParameters, | ||
logits_processor: Optional["OutlinesLogitsProcessor"], | ||
sampling_parameters: SamplingParameters, | ||
) -> Union[str, List[str], List[List[str]]]: | ||
"""Generate text using `transformers`. | ||
Arguments | ||
--------- | ||
prompts | ||
A prompt or list of prompts. | ||
media | ||
A List[numpy.ndarray] or List[List[numpy.ndarray]] | ||
generation_parameters | ||
An instance of `GenerationParameters` that contains the prompt, | ||
the maximum number of tokens, stop sequences and seed. All the | ||
arguments to `SequenceGeneratorAdapter`'s `__cal__` method. | ||
logits_processor | ||
The logits processor to use when generating text. | ||
sampling_parameters | ||
An instance of `SamplingParameters`, a dataclass that contains | ||
the name of the sampler to use and related parameters as available | ||
in Outlines. | ||
Returns | ||
------- | ||
The generated text | ||
""" | ||
inputs = self.processor( | ||
text=prompts, audios=media, padding=True, return_tensors="pt" | ||
).to(self.model.device) | ||
|
||
generation_kwargs = self._get_generation_kwargs( | ||
prompts, | ||
generation_parameters, | ||
logits_processor, | ||
sampling_parameters, | ||
) | ||
generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) | ||
|
||
# if single str input and single sample per input, convert to a 1D output | ||
if isinstance(prompts, str): | ||
# Should always be true until NotImplementedError above is fixed | ||
generated_ids = generated_ids.squeeze(0) | ||
|
||
return self._decode_generation(generated_ids) | ||
|
||
def stream( # type: ignore | ||
self, | ||
prompts: Union[str, List[str]], | ||
media: Union[Any, List[Any]], # TODO: docstring | ||
generation_parameters: GenerationParameters, | ||
logits_processor: Optional["OutlinesLogitsProcessor"], | ||
sampling_parameters: SamplingParameters, | ||
) -> Iterator[Union[str, List[str]]]: | ||
raise NotImplementedError | ||
|
||
|
||
def transformers_audio( | ||
model_name: str, | ||
model_class, | ||
device: Optional[str] = None, | ||
model_kwargs: dict = {}, | ||
processor_kwargs: dict = {}, | ||
tokenizer_class=None, | ||
processor_class=None, | ||
): | ||
"""Instantiate a model from the `transformers` library and its tokenizer. | ||
Parameters | ||
---------- | ||
model_name | ||
The name of the model as listed on Hugging Face's model page. | ||
model_class | ||
The `PreTrainedModel` class from transformers to use in initializing the vision model from `model_name`. | ||
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel | ||
device | ||
The device(s) on which the model should be loaded. This overrides | ||
the `device_map` entry in `model_kwargs` when provided. | ||
model_kwargs | ||
A dictionary that contains the keyword arguments to pass to the | ||
`from_pretrained` method when loading the model. | ||
processor_kwargs | ||
A dictionary that contains the keyword arguments to pass to the | ||
`from_pretrained` method when loading the processor. | ||
Returns | ||
------- | ||
A `TransformersModel` model instance. | ||
""" | ||
if processor_class is None or tokenizer_class is None: | ||
try: | ||
from transformers import AutoProcessor, AutoTokenizer | ||
except ImportError: | ||
raise ImportError( | ||
"The `transformers` library needs to be installed in order to use `transformers` models." | ||
) | ||
if processor_class is None: | ||
processor_class = AutoProcessor | ||
|
||
if device is not None: | ||
model_kwargs["device_map"] = device | ||
|
||
model = model_class.from_pretrained(model_name, **model_kwargs) | ||
|
||
processor_kwargs.setdefault("padding_side", "left") | ||
processor_kwargs.setdefault("pad_token", "[PAD]") | ||
processor = processor_class.from_pretrained(model_name, **processor_kwargs) | ||
|
||
if tokenizer_class is None: | ||
if getattr(processor, "tokenizer", None): | ||
tokenizer = processor.tokenizer | ||
else: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, **processor_kwargs) | ||
else: | ||
tokenizer = tokenizer_class.from_pretrained(model_name, **processor_kwargs) | ||
|
||
return TransformersAudio(model, tokenizer, processor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.