Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Commit

Permalink
Feature/update sciphi local logic rebased (#120)
Browse files Browse the repository at this point in the history
* tweak local logic for sciphi

* fix
  • Loading branch information
emrgnt-cmplxty authored Oct 31, 2023
1 parent e9d6a8e commit 045d5ca
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 18 deletions.
35 changes: 30 additions & 5 deletions sciphi/interface/llm/sciphi_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ class SciPhiLLMInterface(LLMInterface):

def __init__(
self,
rag_interface: RAGInterface,
config: SciPhiConfig = SciPhiConfig(),
rag_interface: Optional[RAGInterface] = None,
*args,
**kwargs,
) -> None:
Expand All @@ -98,11 +98,22 @@ def get_chat_completion(
if not added_system_prompt:
prompt = f"### System:\n{SciPhiLLMInterface.ALPACA_CHAT_SYSTEM_PROMPT}.\n\n{prompt}"

context = self.rag_interface.get_contexts([last_user_message])[0]
prompt += f"### Response:\n{SciPhiFormatter.RETRIEVAL_TOKEN} {SciPhiFormatter.INIT_PARAGRAPH_TOKEN}{context}{SciPhiFormatter.END_PARAGRAPH_TOKEN}"
# TODO - Cleanup RAG logic checks across this script.

if not generation_config.model_name:
raise ValueError("No model name provided")
if "RAG" in generation_config.model_name:
if not self.rag_interface:
raise ValueError(
"RAG generation requested but no RAG interface provided"
)
context = self.rag_interface.get_contexts([last_user_message])[0]
prompt += f"### Response:\n{SciPhiFormatter.RETRIEVAL_TOKEN} {SciPhiFormatter.INIT_PARAGRAPH_TOKEN}{context}{SciPhiFormatter.END_PARAGRAPH_TOKEN}"
else:
prompt += f"### Response:\n"
latest_completion = self.model.get_instruct_completion(
prompt, generation_config
).strip()
)

return SciPhiFormatter.remove_cruft(latest_completion)

Expand All @@ -114,6 +125,20 @@ def get_completion(
logger.debug(
f"Requesting completion from local vLLM with model={generation_config.model_name} and prompt={prompt}"
)

# TODO - Cleanup and consolidate RAG logic checks across this script.

if not generation_config.model_name:
raise ValueError("No model name provided")

if "RAG" not in generation_config.model_name:
return self.model.get_instruct_completion(
prompt, generation_config
).strip()
if not self.rag_interface:
raise ValueError(
"RAG model requested, but no RAG interface provided"
)
completion = ""
while True:
prompt_with_context = (
Expand All @@ -133,7 +158,7 @@ def get_completion(
)
context = self.rag_interface.get_contexts([context_query])[0]
completion += f"{SciPhiFormatter.INIT_PARAGRAPH_TOKEN}{context}{SciPhiFormatter.END_PARAGRAPH_TOKEN}"
return SciPhiFormatter.remove_cruft(completion)
return SciPhiFormatter.remove_cruft(completion).strip()

def get_batch_completion(
self, prompts: List[str], generation_config: GenerationConfig
Expand Down
20 changes: 17 additions & 3 deletions sciphi/llm/models/sciphi_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class SciPhiConfig(LLMConfig):
sub_provider_name: LLMProviderName = LLMProviderName.VLLM

# SciPhi Extras...
mode = SciPhiProviderMode.REMOTE
mode: SciPhiProviderMode = SciPhiProviderMode.REMOTE
## Local
model_name: Optional[str] = None
## Remote
api_base: Optional[str] = "https://api.sciphi.ai/v1"
api_key: Optional[str] = None

Expand All @@ -47,14 +50,15 @@ def __init__(
SciPhiProviderMode.REMOTE,
SciPhiProviderMode.LOCAL_VLLM,
]:
# Remote and local vLLM are both powered by vLLM
assert self.config.sub_provider_name == LLMProviderName.VLLM
from sciphi.llm.models.vllm_llm import (
vLLM,
vLLMConfig,
vLLMProviderMode,
)

# Remote and local vLLM are both powered by vLLM
assert self.config.sub_provider_name == LLMProviderName.VLLM

if self.config.mode == SciPhiProviderMode.REMOTE:
api_key = config.api_key or os.getenv("SCIPHI_API_KEY")
if not api_key:
Expand All @@ -64,11 +68,21 @@ def __init__(
self.model = vLLM(
vLLMConfig(
provider_name=config.provider_name,
model_name=config.model_name,
api_base=config.api_base,
api_key=api_key,
mode=vLLMProviderMode.REMOTE,
),
)
elif self.config.mode == SciPhiProviderMode.LOCAL_VLLM:
self.model = vLLM(
vLLMConfig(
provider_name=config.provider_name,
model_name=config.model_name,
mode=vLLMProviderMode.LOCAL,
),
)

elif self.config.mode == SciPhiProviderMode.LOCAL_HF:
from sciphi.llm.models import hugging_face_llm # noqa F401

Expand Down
32 changes: 22 additions & 10 deletions sciphi/scripts/sciphi_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
SciPhiLLMInterface,
SciPhiWikiRAGInterface,
)
from sciphi.llm import GenerationConfig
from sciphi.llm import GenerationConfig, SciPhiConfig


def filter_relevant_args(dataclass_type, args_dict):
Expand All @@ -20,10 +20,22 @@ def filter_relevant_args(dataclass_type, args_dict):

def main(
query: str = "Who is the president of the United States?",
mode="remote",
llm_model_name="SciPhi/SciPhi-Self-RAG-Mistral-7B-32k",
):
rag_interface = SciPhiWikiRAGInterface()
llm_interface = SciPhiLLMInterface(rag_interface)

config = None
if mode == "local":
from sciphi.llm.models.sciphi_llm import SciPhiProviderMode

config = SciPhiConfig(
mode=SciPhiProviderMode.LOCAL_VLLM, model_name=llm_model_name
)
else:
config = SciPhiConfig()

llm_interface = SciPhiLLMInterface(config, rag_interface)

generation_config = GenerationConfig(
model_name=llm_model_name,
Expand All @@ -36,14 +48,14 @@ def main(
"role": "system",
"content": "You are a helpful and informative professor. You give long, accurate, and detailed explanations to student questions. You answer EVERY question that is given to you. You retrieve data multiple times if necessary.",
},
{
"role": "user",
"content": "Who is the president of the United States?",
},
{
"role": "assistant",
"content": "Joe Biden is the current president of the United States.",
},
# {
# "role": "user",
# "content": "Who is the president of the United States?",
# },
# {
# "role": "assistant",
# "content": "Joe Biden is the current president of the United States.",
# },
{
"role": "user",
"content": query,
Expand Down

0 comments on commit 045d5ca

Please sign in to comment.