Skip to content

Commit

Permalink
prompt class #72
Browse files Browse the repository at this point in the history
  • Loading branch information
YY-SONG0718 committed Dec 14, 2023
1 parent 43e5bd4 commit f96a4f9
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions biochatter/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env Python3

from typing import Optional
import re


primary_messages = {
"message_entity": (
Expand Down Expand Up @@ -48,6 +50,74 @@
}


def extract_placeholders(text):

pattern = r'{(.*?)}'
matches = re.findall(pattern, text)
return matches


class Prompt:
def __init__(self,
text_template_set: dict,
elements: dict,
text_template: Optional[str] = None,):
self.text_template = text_template ## current active template
self.text_template_set = text_template_set ## all templates
self.elements = elements

# repr
def generate_prompt(self):
return self.text_template.format(**self.elements)


class SystemPrompt(Prompt):

def __init__(
self,
text_template_set: dict,
elements: dict,
prompt_for: str,
text_template: Optional[str] = None,
message_type: str = "system",


):
super().__init__(
text_template_set=text_template_set,
elements=elements
)

prompt_for_options = ["entity", "relationship", "property", "query"]

if prompt_for is not None:
if prompt_for not in prompt_for_options:
raise ValueError(
"prompt_for must be one of ['entity', 'relationship', 'property', 'query']"
)

self.prompt_for = prompt_for

def generate_system_message_entity(self):
self.prompt_for = "entity"
self.text_template = self.text_template_set["message_entity"]

# TODO: an option to add kg info at the beginning of this message
# TODO: other propts then primary

placeholders = extract_placeholders(self.text_template)

# check if all placeholder keys are in elements dictionary

if not all([ i in self.elements.keys() for i in placeholders ]):
raise ValueError("Not all placeholders in text template provided in elements")

return(self.generate_prompt())





class message(str):
def __init__(
self,
Expand Down

0 comments on commit f96a4f9

Please sign in to comment.