Skip to content

Commit

Permalink
Merge pull request #733
Browse files Browse the repository at this point in the history
feat: Add support for embeddings only with search threshold
  • Loading branch information
drazvan authored Sep 6, 2024
2 parents 937c9d4 + 5e64f3a commit c468d64
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 15 deletions.
47 changes: 39 additions & 8 deletions nemoguardrails/actions/v2_x/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

"""A set of actions for generating various types of completions using an LLMs."""

import logging
import re
import textwrap
Expand Down Expand Up @@ -146,21 +147,45 @@ async def _init_flows_index(self) -> None:

async def _collect_user_intent_and_examples(
self, state: State, user_action: str, max_example_flows: int
) -> Tuple[List[str], str]:
) -> Tuple[List[str], str, bool]:
# We search for the most relevant similar user intents
examples = ""
potential_user_intents = []
embedding_only = False

if self.user_message_index:
threshold = None

if self.config.rails.dialog.user_messages:
threshold = (
self.config.rails.dialog.user_messages.embeddings_only_similarity_threshold
)

results = await self.user_message_index.search(
text=user_action, max_results=max_example_flows
text=user_action, max_results=max_example_flows, threshold=threshold
)

# We add these in reverse order so the most relevant is towards the end.
for result in reversed(results):
examples += f"user action: user said \"{result.text}\"\nuser intent: {result.meta['intent']}\n\n"
if result.meta["intent"] not in potential_user_intents:
potential_user_intents.append(result.meta["intent"])
if results and self.config.rails.dialog.user_messages.embeddings_only:
intent = results[0].meta["intent"]
potential_user_intents.append(intent)
embedding_only = True

elif (
self.config.rails.dialog.user_messages.embeddings_only
and self.config.rails.dialog.user_messages.embeddings_only_fallback_intent
):
intent = (
self.config.rails.dialog.user_messages.embeddings_only_fallback_intent
)
potential_user_intents.append(intent)
embedding_only = True

else:
# We add these in reverse order so the most relevant is towards the end.
for result in reversed(results):
examples += f"user action: user said \"{result.text}\"\nuser intent: {result.meta['intent']}\n\n"
if result.meta["intent"] not in potential_user_intents:
potential_user_intents.append(result.meta["intent"])

# We add all currently active user intents (heads on match statements)
heads = find_all_active_event_matchers(state)
Expand Down Expand Up @@ -196,8 +221,10 @@ async def _collect_user_intent_and_examples(
elif flow_id not in potential_user_intents:
examples += f"user intent: {flow_id}\n\n"
potential_user_intents.append(flow_id)

examples = examples.strip("\n")
return (potential_user_intents, examples)

return potential_user_intents, examples, embedding_only

@action(name="GetLastUserMessageAction", is_system_action=True)
async def get_last_user_message(
Expand Down Expand Up @@ -225,9 +252,12 @@ async def generate_user_intent(
(
potential_user_intents,
examples,
embedding_only,
) = await self._collect_user_intent_and_examples(
state, user_action, max_example_flows
)
if embedding_only:
return f"{potential_user_intents[0]}"

prompt = self.llm_task_manager.render_task_prompt(
task=Task.GENERATE_USER_INTENT_FROM_USER_ACTION,
Expand Down Expand Up @@ -294,6 +324,7 @@ async def generate_user_intent_and_bot_action(
(
potential_user_intents,
examples,
embedding_only,
) = await self._collect_user_intent_and_examples(
state, user_action, max_example_flows
)
Expand Down
43 changes: 36 additions & 7 deletions tests/test_embeddings_only_user_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,67 @@ def config():
)


@pytest.fixture
def colang_2_config():
return RailsConfig.from_content(
"""
import core
import llm
flow main
activate greeting
activate llm continuation
flow user expressed greeting
user said "hi"
flow bot express greeting
bot say "Hello!"
flow greeting
user expressed greeting
bot express greeting
""",
"""
colang_version: 2.x
rails:
dialog:
user_messages:
embeddings_only: True
embeddings_only_similarity_threshold: 0.8
embeddings_only_fallback_intent: "user expressed greeting"
""",
)


@pytest.mark.parametrize("config", ["config", "colang_2_config"], indirect=True)
def test_greeting(config):
"""Test that the bot responds with 'Hello!' when the user says 'hello'."""

chat = TestChat(
config,
llm_completions=[],
)

chat >> "hello"
chat << "Hello!"


@pytest.mark.parametrize("config", ["config", "colang_2_config"], indirect=True)
def test_error_when_embeddings_only_is_false(config):
"""Test that an error is raised when the 'embeddings_only' option is False."""

# Check that if we deactivate the embeddings_only option we get an error
config.rails.dialog.user_messages.embeddings_only = False
chat = TestChat(
config,
llm_completions=[],
)

with pytest.raises(LLMCallException):
chat >> "hello"
chat << "Hello!"


@pytest.mark.parametrize("config", ["config", "colang_2_config"], indirect=True)
def test_fallback_intent(config):
"""Test that the bot uses the fallback intent when it doesn't recognize the user's message."""

rails = LLMRails(config)
res = rails.generate(messages=[{"role": "user", "content": "lets use fallback"}])
assert res["content"] == "Hello!"
Expand All @@ -83,5 +114,3 @@ def test_fallback_intent(config):
rails = LLMRails(config)
with pytest.raises(LLMCallException):
rails.generate(messages=[{"role": "user", "content": "lets use fallback"}])
#
# Check the bot's response

0 comments on commit c468d64

Please sign in to comment.