Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Add full SciPhi interface #106

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ write_to = "sciphi/_version.py"

[tool.poetry]
name = "sciphi"
version = "0.1.4"
version = "0.1.5"
description = "SciPhi: A Framework for LLM Powered Data."
authors = ["Owen Colegrove <[email protected]>"]
license = "Apache-2.0"
Expand Down
2 changes: 2 additions & 0 deletions sciphi/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sciphi.interface.llm.llama_index_interface import LlamaIndexInterface
from sciphi.interface.llm.llamacpp_interface import LlamaCPPInterface
from sciphi.interface.llm.openai_interface import OpenAILLMInterface
from sciphi.interface.llm.sciphi_interface import SciPhiInterface
from sciphi.interface.llm.vllm_interface import vLLMInterface
from sciphi.interface.llm_interface_manager import LLMInterfaceManager
from sciphi.interface.rag.sciphi_wiki import (
Expand All @@ -31,6 +32,7 @@
"vLLMInterface",
"LiteLLMInterface",
"LlamaCPPInterface",
"SciPhiInterface",
# RAG
"RAGInterfaceManager",
"RAGProviderConfig",
Expand Down
2 changes: 1 addition & 1 deletion sciphi/interface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class RAGProviderConfig(ABC):

rag_provider_name: RAGProviderName
base: str
token: str
api_key: str
max_context: int = 2_048


Expand Down
42 changes: 42 additions & 0 deletions sciphi/interface/llm/sciphi_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""A module for interfacing with local vLLM models"""
import logging
from typing import List

from sciphi.interface.base import LLMInterface, LLMProviderName
from sciphi.interface.llm_interface_manager import llm_interface
from sciphi.llm import SciPhiConfig, SciPhiLLM

logger = logging.getLogger(__name__)


@llm_interface
class SciPhiInterface(LLMInterface):
"""A class to interface with local vLLM models."""

llm_provider_name = LLMProviderName.SCIPHI

def __init__(
self,
config: SciPhiConfig = SciPhiConfig(),
) -> None:
self._model = SciPhiLLM(config)

def get_completion(self, prompt: str) -> str:
"""Get a completion from the local vLLM provider."""

logger.debug(
f"Requesting completion from local vLLM with model={self._model.config.model_name} and prompt={prompt}"
)
return self.model.get_instruct_completion(prompt)

def get_batch_completion(self, prompts: List[str]) -> List[str]:
"""Get a completion from the local vLLM provider."""

logger.debug(
f"Requesting completion from local vLLM with model={self._model.config.model_name} and prompts={prompts}"
)
return self.model.get_batch_instruct_completion(prompts)

@property
def model(self) -> SciPhiLLM:
return self._model
4 changes: 2 additions & 2 deletions sciphi/interface/rag/sciphi_wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_contexts(self, prompts: list[str]) -> list[str]:
raw_contexts = wiki_search_api(
prompts,
self.config.base,
self.config.token,
self.config.api_key,
self.config.top_k,
)

Expand Down Expand Up @@ -63,7 +63,7 @@ def wiki_search_api(
"""
# Make the GET request with basic authentication and the query parameter
response = requests.get(
rag_api_base,
f"{rag_api_base}/search",
params={"queries": queries, "top_k": top_k},
headers={"Authorization": f"Bearer {rag_api_key}"},
)
Expand Down
7 changes: 5 additions & 2 deletions sciphi/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sciphi.llm.models.llama_index_llm import LLamaIndexConfig, LlamaIndexLLM
from sciphi.llm.models.llamacpp_llm import LlamaCPP, LLamaCPPConfig
from sciphi.llm.models.openai_llm import OpenAIConfig, OpenAILLM
from sciphi.llm.models.sciphi_llm import SciPhiConfig, SciPhiLLM
from sciphi.llm.models.vllm_llm import vLLM, vLLMConfig

__all__ = [
Expand All @@ -31,12 +32,14 @@
"LlamaIndexLLM",
"OpenAIConfig",
"OpenAILLM",
"vLLMConfig",
"vLLM",
"LiteLLMConfig",
"LiteLLM",
"LLamaCPPConfig",
"LlamaCPP",
"SciPhiConfig",
"SciPhiLLM",
"vLLMConfig",
"vLLM",
# Embedding Helpers
"process_documents",
"sectionize_documents",
Expand Down
89 changes: 64 additions & 25 deletions sciphi/llm/models/sciphi_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""A module for managing local vLLM models."""

import logging
import re
from dataclasses import dataclass
from typing import Optional

from sciphi.core import LLMProviderName, RAGProviderName
from sciphi.interface.rag_interface_manager import RAGInterfaceManager
from sciphi.llm.base import LLM, LLMConfig
from sciphi.llm.config_manager import model_config
from sciphi.llm.models.vllm_llm import vLLM, vLLMConfig

logging.basicConfig(level=logging.INFO)

Expand All @@ -21,74 +21,89 @@ class SciPhiFormatter:
END_PARAGRAPH_TOKEN = "</paragraph>"

RETRIEVAL_TOKEN = "[Retrieval]"
FULLY_SUPPORTED = "[Fully supported]"
NO_RETRIEVAL_TOKEN = "[No Retrieval]"
EVIDENCE_TOKEN = "[Continue to Use Evidence]"
UTILITY_TOKEN = "[Utility:5]"
RELEVANT_TOKEN = "[Relevant]"
PARTIALLY_SUPPORTED_TOKEN = "[Partially supported]"
SUFFIX_CRUFT = "[Utility:5]</s>"
NO_SUPPORT_TOKEN = "[No support / Contradictory]"
END_TOKEN = "</s>"

@staticmethod
def format_prompt(input: str) -> str:
"""Format the prompt for the model."""
return f"{SciPhiFormatter.INSTRUCTION_PREFIX}\n{input}\n\n{SciPhiFormatter.INSTRUCTION_SUFFIX}"

@staticmethod
def extract_post_prompt(completion: str) -> str:
if SciPhiFormatter.INSTRUCTION_SUFFIX not in completion:
raise ValueError(
f"Full Completion does not contain {SciPhiFormatter.INSTRUCTION_SUFFIX}"
)

return completion.split(SciPhiFormatter.INSTRUCTION_SUFFIX)[1]

@staticmethod
def remove_cruft(result: str) -> str:
pattern = f"{re.escape(SciPhiFormatter.INIT_PARAGRAPH_TOKEN)}.*?{re.escape(SciPhiFormatter.END_PARAGRAPH_TOKEN)}"
# Remove <paragraph>{arbitrary text...}</paragraph>
result = re.sub(pattern, "", result, flags=re.DOTALL)

return (
result.replace(SciPhiFormatter.RETRIEVAL_TOKEN, " ")
result.replace(SciPhiFormatter.RETRIEVAL_TOKEN, "")
.replace(SciPhiFormatter.NO_RETRIEVAL_TOKEN, "")
.replace(SciPhiFormatter.EVIDENCE_TOKEN, " ")
.replace(SciPhiFormatter.SUFFIX_CRUFT, "")
.replace(SciPhiFormatter.UTILITY_TOKEN, "")
.replace(SciPhiFormatter.RELEVANT_TOKEN, "")
.replace(SciPhiFormatter.PARTIALLY_SUPPORTED_TOKEN, "")
.replace(SciPhiFormatter.FULLY_SUPPORTED, "")
.replace(SciPhiFormatter.END_TOKEN, "")
.replace(SciPhiFormatter.NO_SUPPORT_TOKEN, "")
)


@model_config
@dataclass
class SciPhiConfig(vLLMConfig):
class SciPhiConfig(LLMConfig):
"""Configuration for local vLLM models."""

# Base
provider_name: LLMProviderName = LLMProviderName.SCIPHI
model_name: str = "selfrag/selfrag_llama2_7b"
llm_provider_name: LLMProviderName = LLMProviderName.SCIPHI
model_name: str = "SciPhi/SciPhi-Self-RAG-Mistral-7B-32k"
temperature: float = 0.1
top_p: float = 1.0
top_k: int = 100
max_tokens_to_sample: int = 256

# SciPhi Extras...
max_tokens_to_sample: int = 1_024
server_base: Optional[str] = None
api_key: Optional[str] = None

# RAG Parameters
rag_provider_name: RAGProviderName = RAGProviderName.SCIPHI_WIKI
rag_provider_base: Optional[str] = None
rag_provider_token: Optional[str] = None
rag_server_base: Optional[str] = None
rag_api_key: Optional[str] = None
rag_top_k: int = 100


class SciPhiLLM(vLLM):
class SciPhiLLM(LLM):
"""Configuration for local vLLM models."""

def __init__(
self,
config: SciPhiConfig,
) -> None:
super().__init__(config)
from vllm import SamplingParams

# Hack to avoid typing errors
self.config: SciPhiConfig = config
self.sampling_params = SamplingParams(
temperature=config.temperature,
top_p=config.top_p,
top_k=config.top_k,
max_tokens=config.max_tokens_to_sample,
skip_special_tokens=False, # RAG Fine Tune includes special tokens
stop=SciPhiFormatter.INIT_PARAGRAPH_TOKEN, # Stops on Retrieval
)

from sciphi.interface.rag_interface_manager import RAGInterfaceManager

self.rag_provider = RAGInterfaceManager.get_interface_from_args(
provider_name=config.rag_provider_name,
base=config.rag_provider_base or "http://localhost:8000",
token=config.rag_provider_token or "",
base=config.rag_server_base or config.server_base,
api_key=config.rag_api_key or config.api_key,
top_k=config.rag_top_k,
)

Expand All @@ -98,13 +113,14 @@ def get_chat_completion(self, messages: list[dict[str, str]]) -> str:
"Chat completion not yet implemented for SciPhi."
)

def get_instruct_completion(self, prompt: str) -> str:
def _get_instruct_completion(self, prompt: str) -> str:
"""Get an instruction completion from local SciPhi API."""
import openai

openai.api_base = self.config.server_base or ""
return openai.Completion.create(
model=self.config.model_name,
api_key=self.config.api_key,
temperature=self.config.temperature,
top_p=self.config.top_p,
top_k=self.config.top_k,
Expand All @@ -114,6 +130,29 @@ def get_instruct_completion(self, prompt: str) -> str:
stop=SciPhiFormatter.INIT_PARAGRAPH_TOKEN,
)

def get_instruct_completion(self, prompt: str) -> str:
"""Get an instruction completion from local SciPhi API."""
completion = ""
while True:
prompt_with_context = (
SciPhiFormatter.format_prompt(prompt) + completion
)
latest_completion = self._get_instruct_completion(
prompt_with_context
)["choices"][0]["text"].strip()
completion += latest_completion

if not completion.endswith(SciPhiFormatter.RETRIEVAL_TOKEN):
break
context_query = (
prompt
if completion == SciPhiFormatter.RETRIEVAL_TOKEN
else f"{SciPhiFormatter.remove_cruft(completion)}"
)
context = self.rag_provider.get_contexts([context_query])[0]
completion += f"{SciPhiFormatter.INIT_PARAGRAPH_TOKEN}{context}{SciPhiFormatter.END_PARAGRAPH_TOKEN}"
return SciPhiFormatter.remove_cruft(completion)

def get_batch_instruct_completion(self, prompts: list[str]) -> list[str]:
"""Get batch instruction completion from local vLLM."""
raise NotImplementedError(
Expand Down
Loading