From f96a4f9c959f110e39b905ae3301d6b26197bdb4 Mon Sep 17 00:00:00 2001 From: Yuyao Song <48168412+YY-SONG0718@users.noreply.github.com> Date: Thu, 14 Dec 2023 18:38:45 +0100 Subject: [PATCH] prompt class #72 --- biochatter/messages.py | 70 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/biochatter/messages.py b/biochatter/messages.py index d1d8471a..4cf96395 100644 --- a/biochatter/messages.py +++ b/biochatter/messages.py @@ -1,6 +1,8 @@ #!/usr/bin/env Python3 from typing import Optional +import re + primary_messages = { "message_entity": ( @@ -48,6 +50,74 @@ } +def extract_placeholders(text): + + pattern = r'{(.*?)}' + matches = re.findall(pattern, text) + return matches + + +class Prompt: + def __init__(self, + text_template_set: dict, + elements: dict, + text_template: Optional[str] = None,): + self.text_template = text_template ## current active template + self.text_template_set = text_template_set ## all templates + self.elements = elements + + # repr + def generate_prompt(self): + return self.text_template.format(**self.elements) + + +class SystemPrompt(Prompt): + + def __init__( + self, + text_template_set: dict, + elements: dict, + prompt_for: str, + text_template: Optional[str] = None, + message_type: str = "system", + + + ): + super().__init__( + text_template_set=text_template_set, + elements=elements + ) + + prompt_for_options = ["entity", "relationship", "property", "query"] + + if prompt_for is not None: + if prompt_for not in prompt_for_options: + raise ValueError( + "prompt_for must be one of ['entity', 'relationship', 'property', 'query']" + ) + + self.prompt_for = prompt_for + + def generate_system_message_entity(self): + self.prompt_for = "entity" + self.text_template = self.text_template_set["message_entity"] + + # TODO: an option to add kg info at the beginning of this message + # TODO: other propts then primary + + placeholders = extract_placeholders(self.text_template) + + # check if all placeholder keys are in elements dictionary + + if not all([ i in self.elements.keys() for i in placeholders ]): + raise ValueError("Not all placeholders in text template provided in elements") + + return(self.generate_prompt()) + + + + + class message(str): def __init__( self,