Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reenable OpenAI Tokenizer #3062

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def get_session_with_tenant(
engine = get_sqlalchemy_engine()

# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA

if tenant_id is None:
tenant_id = previous_tenant_id
Expand Down
76 changes: 51 additions & 25 deletions backend/danswer/natural_language_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,31 @@ def decode(self, tokens: list[int]) -> str:
class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {}

def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
if encoding_name not in cls._instances:
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[encoding_name]
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
if model_name not in cls._instances:
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[model_name]

def __init__(self, encoding_name: str = "cl100k_base"):
def __init__(self, model_name: str):
if not hasattr(self, "encoder"):
import tiktoken

self.encoder = tiktoken.get_encoding(encoding_name)
self.encoder = tiktoken.encoding_for_model(model_name)

def encode(self, string: str) -> list[int]:
# this returns no special tokens
# this ignores special tokens that the model is trained on, see encode_ordinary for details
return self.encoder.encode_ordinary(string)

def tokenize(self, string: str) -> list[str]:
return [self.encoder.decode([token]) for token in self.encode(string)]
encoded = self.encode(string)
decoded = [self.encoder.decode([token]) for token in encoded]

if len(decoded) != len(encoded):
logger.warning(
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
)

return decoded

def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
Expand All @@ -74,22 +82,35 @@ def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)


_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}


def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE

if tokenizer_name not in _TOKENIZER_CACHE:
if tokenizer_name == "openai":
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
return _TOKENIZER_CACHE[tokenizer_name]
id_tuple = (model_provider, model_name)

if id_tuple not in _TOKENIZER_CACHE:
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)

_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]

try:
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
if model_name is None:
model_name = DOCUMENT_ENCODER_MODEL

logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
)
logger.warning(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
Expand All @@ -98,18 +119,18 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
logger.error(
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
f"Failed to initialize tokenizer for {model_name} and fallback model"
) from fallback_error

return _TOKENIZER_CACHE[tokenizer_name]
return _TOKENIZER_CACHE[id_tuple]


_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
Expand All @@ -118,11 +139,16 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
# Currently all of the viable models use the same sentencepiece tokenizer
# OpenAI uses a different one but currently it's not supported due to quality issues
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
# LLM tokenizers are specified by strings
global _DEFAULT_TOKENIZER
if provider_type is not None:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
return _DEFAULT_TOKENIZER


Expand Down
18 changes: 9 additions & 9 deletions web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
LLM_PROVIDERS_ADMIN_URL,
} from "../../configuration/llm/constants";
import { mutate } from "swr";
import { testEmbedding } from "../pages/utils";

export function ChangeCredentialsModal({
provider,
Expand Down Expand Up @@ -112,16 +113,15 @@ export function ChangeCredentialsModal({
const normalizedProviderType = provider.provider_type
.toLowerCase()
.split(" ")[0];

try {
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: normalizedProviderType,
api_key: apiKey,
api_url: apiUrl,
model_name: modelName,
}),
const testResponse = await testEmbedding({
provider_type: normalizedProviderType,
modelName,
apiKey,
apiUrl,
apiVersion: null,
deploymentName: null,
});

if (!testResponse.ok) {
Expand Down
23 changes: 15 additions & 8 deletions web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,27 @@ export function ProviderCreationModal({
setErrorMsg("");
try {
const customConfig = Object.fromEntries(values.custom_config);
const providerType = values.provider_type.toLowerCase().split(" ")[0];
const isOpenAI = providerType === "openai";

const testModelName =
isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name;

const testEmbeddingPayload = {
provider_type: providerType,
api_key: values.api_key,
api_url: values.api_url,
model_name: testModelName,
api_version: values.api_version,
deployment_name: values.deployment_name,
};

const initialResponse = await fetch(
"/api/admin/embedding/test-embedding",
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key,
api_url: values.api_url,
model_name: values.model_name,
api_version: values.api_version,
deployment_name: values.deployment_name,
}),
body: JSON.stringify(testEmbeddingPayload),
}
);

Expand Down
34 changes: 34 additions & 0 deletions web/src/app/admin/embeddings/pages/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,37 @@ export const deleteSearchSettings = async (search_settings_id: number) => {
});
return response;
};

export const testEmbedding = async ({
provider_type,
modelName,
apiKey,
apiUrl,
apiVersion,
deploymentName,
}: {
provider_type: string;
modelName: string;
apiKey: string | null;
apiUrl: string | null;
apiVersion: string | null;
deploymentName: string | null;
}) => {
const testModelName =
provider_type === "openai" ? "text-embedding-3-small" : modelName;

const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: provider_type,
api_key: apiKey,
api_url: apiUrl,
model_name: testModelName,
api_version: apiVersion,
deployment_name: deploymentName,
}),
});

return testResponse;
};
Loading