From 627b35f8acfa62663222485a105cfa76bf4bdfc4 Mon Sep 17 00:00:00 2001 From: "K. R" Date: Wed, 3 Apr 2024 14:45:47 -0400 Subject: [PATCH] created echo agent and quantize character functions --- Assets/CommandTerminal/CommandBuffer.cs | 9 ++ Assets/Scripts/Agents/EchoAgent.cs | 69 ++++++++++++++++ Assets/Scripts/Agents/EchoAgent.cs.meta | 2 + Assets/Scripts/Agents/Lexer.cs | 29 ++++++- Tests/LexerTests.cs | 105 +++++++++++++++++++++++- 5 files changed, 209 insertions(+), 5 deletions(-) create mode 100644 Assets/Scripts/Agents/EchoAgent.cs create mode 100644 Assets/Scripts/Agents/EchoAgent.cs.meta diff --git a/Assets/CommandTerminal/CommandBuffer.cs b/Assets/CommandTerminal/CommandBuffer.cs index 03d8436..fd63261 100644 --- a/Assets/CommandTerminal/CommandBuffer.cs +++ b/Assets/CommandTerminal/CommandBuffer.cs @@ -109,6 +109,15 @@ private void ProcessLogLine(LogItem item, int maxChars, ref List lines, } } + public string GetLastLog() + { + if (_Logs.Count == 0) + { + return string.Empty; + } + return _Logs[_Logs.Count - 1].Message; + } + public void Reset() { _Logs.Clear(); diff --git a/Assets/Scripts/Agents/EchoAgent.cs b/Assets/Scripts/Agents/EchoAgent.cs new file mode 100644 index 0000000..a644ae7 --- /dev/null +++ b/Assets/Scripts/Agents/EchoAgent.cs @@ -0,0 +1,69 @@ +using Unity.MLAgents; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Actuators; +using System.Collections.Generic; +using CommandTerminal; + +namespace DialogosEngine +{ + public class EchoAgent : Agent + { + char _GuessedChar; + + public override void OnEpisodeBegin() + { + ClearConsole(); + } + + public void FixedUpdate() + { + char expectedChar = GetExpectedChar(); + float reward = CalculateReward(expectedChar, _GuessedChar); + SetReward(reward); + } + + public override void CollectObservations(VectorSensor sensor) + { + string buffer = GetConsoleBuffer(); + float[] vectorizedBuffer = Lexer.VectorizeUTF8(buffer); + foreach (var obs in vectorizedBuffer) + { + sensor.AddObservation(obs); + } + } + + public override void OnActionReceived(ActionBuffers actions) + { + float[] actionArray = new float[1] { actions.ContinuousActions[0] }; + _GuessedChar = Lexer.QuantizeUTF8(actionArray)[0]; + HandleGuessedCharacter(_GuessedChar); + } + + private void ClearConsole() + { + Terminal.Instance.Buffer.Reset(); + } + + private float CalculateReward(char expectedChar, char guessedChar) + { + // Implementation to calculate the reward based on the guessed character + return 0; + } + + private string GetConsoleBuffer() + { + return Terminal.Instance.Buffer.GetLastLog(); + } + + private void HandleGuessedCharacter(char guessedChar) + { + // Implementation to handle the guessed character + } + + private char GetExpectedChar() + { + // Implementation to get the expected character for the current step + return new char(); + } + } +} diff --git a/Assets/Scripts/Agents/EchoAgent.cs.meta b/Assets/Scripts/Agents/EchoAgent.cs.meta new file mode 100644 index 0000000..87ecd66 --- /dev/null +++ b/Assets/Scripts/Agents/EchoAgent.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: 91356ebb54747f140a42a1e87c3f5965 \ No newline at end of file diff --git a/Assets/Scripts/Agents/Lexer.cs b/Assets/Scripts/Agents/Lexer.cs index fdbd652..d3c415e 100644 --- a/Assets/Scripts/Agents/Lexer.cs +++ b/Assets/Scripts/Agents/Lexer.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Text; using UnityEditor; @@ -82,6 +81,34 @@ public static float[] VectorizeUTF8(string line) return vector; } + public static string QuantizeUTF8(float[] vector) + { + if (vector == null) + { + throw new ArgumentNullException(nameof(vector), "Input vector cannot be null."); + } + + byte[] utf8Bytes = new byte[vector.Length]; + int index = 0; + while (index < vector.Length) + { + utf8Bytes[index] = (byte)(vector[index] / MultiplierUTF8); + index++; + } + + Decoder utf8Decoder = Encoding.UTF8.GetDecoder(); + int charCount = utf8Decoder.GetCharCount(utf8Bytes, 0, utf8Bytes.Length); + if (charCount > k_MaxChars) + { + throw new LexerException($"Output exceeds the maximum length of {k_MaxChars} characters."); + } + + char[] chars = new char[charCount]; + utf8Decoder.GetChars(utf8Bytes, 0, utf8Bytes.Length, chars, 0); + + return new string(chars); + } + public static float CalculateWhitespace(string[] text) { int _totalWhitespace = text.Sum(line => line.Count(char.IsWhiteSpace)); diff --git a/Tests/LexerTests.cs b/Tests/LexerTests.cs index d6d0653..1a4a606 100644 --- a/Tests/LexerTests.cs +++ b/Tests/LexerTests.cs @@ -1,6 +1,4 @@ -๏ปฟusing System.Diagnostics; - -namespace DialogosEngine.Tests +๏ปฟnamespace DialogosEngine.Tests { [TestFixture] public static class LexerTests @@ -94,7 +92,6 @@ public static void Vectorize_GivenStringWithSpecialChars_ConvertsToAsciiFloatArr TestContext.WriteLine($"Test passed: Input string '{input}' converts to expected packed float array."); } - [Test] public static void VectorizeUTF8_GivenString_ConvertsToUtf8FloatArray() { @@ -306,5 +303,105 @@ public static void VectorizeUTF8_ChineseHeading_ConvertsToUtf8FloatArray() TestContext.WriteLine($"Test passed: Input string '{input}' converts to expected UTF-8 float array."); } + [Test] + public static void QuantizeUTF8_GivenFloatArray_ConvertsToUtf8String() + { + // Arrange + string expected = "Test"; // The expected UTF-8 string output + float[] input = Lexer.VectorizeUTF8(expected); // Use VectorizeUTF8 to get the correct float array + + TestContext.WriteLine($"Testing with input float array: '{Utility.FormatFloatArray(input)}'."); + + // Act + string result = Lexer.QuantizeUTF8(input); + TestContext.WriteLine($"Resulting UTF-8 string: '{result}'"); + + // Assert + Assert.That(result, Is.Not.Null); + Assert.That(result, Is.EqualTo(expected), "The resulting string should match the expected UTF-8 string."); + TestContext.WriteLine($"Test passed: Input float array converts to expected UTF-8 string '{expected}'."); + } + + [Test] + public static void QuantizeUTF8_ComplexStringWithEmojis_ConvertsToUtf8String() + { + // Arrange + string expected = "The quick brown fox jumps over the lazy dog ๐Ÿš€๐ŸฆŠ๐Ÿถ"; + float[] input = Lexer.VectorizeUTF8(expected); // Use VectorizeUTF8 to get the correct float array + + TestContext.WriteLine($"Testing with input float array: '{Utility.FormatFloatArray(input)}'."); + + // Act + string result = Lexer.QuantizeUTF8(input); + TestContext.WriteLine($"Resulting UTF-8 string: '{result}'"); + + // Assert + Assert.That(result, Is.Not.Null); + Assert.That(result, Is.EqualTo(expected), "The resulting string should match the expected complex UTF-8 string with emojis."); + TestContext.WriteLine($"Test passed: Input float array converts to expected UTF-8 string with emojis '{expected}'."); + } + + [Test] + public static void QuantizeUTF8_JapaneseText_ConvertsToUtf8String() + { + // Arrange + string expected = "ใ“ใ‚Œใฏๆ—ฅๆœฌ่ชžใฎๆฎต่ฝใƒ†ใ‚นใƒˆใงใ™ใ€‚ใ“ใฎใƒ†ใ‚นใƒˆใฏใ€UTF-8้–ขๆ•ฐใŒๆ—ฅๆœฌ่ชžใฎๆ–‡ๅญ—ใ‚’ๅซใ‚€ใƒ†ใ‚ญใ‚นใƒˆใ‚’ๆญฃใ—ใๅ‡ฆ็†ใงใใ‚‹ใ“ใจใ‚’็ขบ่ชใ™ใ‚‹ใŸใ‚ใฎใ‚‚ใฎใงใ™ใ€‚"; // A sample Japanese paragraph + // Ensure the string is not longer than 1000 characters + expected = expected.Substring(0, Math.Min(1000, expected.Length)); + float[] input = Lexer.VectorizeUTF8(expected); // Use VectorizeUTF8 to get the correct float array + + TestContext.WriteLine($"Testing with input float array: '{Utility.FormatFloatArray(input)}'."); + + // Act + string result = Lexer.QuantizeUTF8(input); + TestContext.WriteLine($"Resulting UTF-8 string: '{result}'"); + + // Assert + Assert.That(result, Is.Not.Null); + Assert.That(result, Is.EqualTo(expected), "The resulting string should match the expected Japanese UTF-8 string."); + TestContext.WriteLine($"Test passed: Input float array converts to expected UTF-8 Japanese text '{expected}'."); + } + + [Test] + public static void QuantizeUTF8_ChineseText_ConvertsToUtf8String() + { + // Arrange + string expected = "่ฟ™ๆ˜ฏไธ€ไธชไธญๆ–‡ๆฎต่ฝๆต‹่ฏ•ใ€‚่ฟ™ไธชๆต‹่ฏ•ๅฐ†้ชŒ่ฏUTF-8ๅ‡ฝๆ•ฐๆ˜ฏๅฆ่ƒฝๅคŸๆญฃ็กฎๅค„็†ๅŒ…ๅซไธญๆ–‡ๅญ—็ฌฆ็š„ๆ–‡ๆœฌใ€‚"; // A sample Chinese paragraph + // Ensure the string is not longer than 1000 characters + expected = expected.Substring(0, Math.Min(1000, expected.Length)); + float[] input = Lexer.VectorizeUTF8(expected); // Use VectorizeUTF8 to get the correct float array + + TestContext.WriteLine($"Testing with input float array: '{Utility.FormatFloatArray(input)}'."); + + // Act + string result = Lexer.QuantizeUTF8(input); + TestContext.WriteLine($"Resulting UTF-8 string: '{result}'"); + + // Assert + Assert.That(result, Is.Not.Null); + Assert.That(result, Is.EqualTo(expected), "The resulting string should match the expected Chinese UTF-8 string."); + TestContext.WriteLine($"Test passed: Input float array converts to expected UTF-8 Chinese text '{expected}'."); + } + + [Test] + public static void QuantizeUTF8_EmojiString_ConvertsToUtf8String() + { + // Arrange + string expected = "๐Ÿ˜€๐Ÿ˜ƒ๐Ÿ˜„๐Ÿ˜๐Ÿ˜†๐Ÿ˜…๐Ÿ˜‚๐Ÿคฃ๐Ÿ˜Š๐Ÿ˜‡๐Ÿ™‚๐Ÿ™ƒ๐Ÿ˜‰๐Ÿ˜Œ๐Ÿ˜๐Ÿฅฐ๐Ÿ˜˜๐Ÿ˜—๐Ÿ˜™๐Ÿ˜š๐Ÿ˜‹๐Ÿ˜›๐Ÿ˜๐Ÿ˜œ๐Ÿคช๐Ÿคจ๐Ÿง๐Ÿค“๐Ÿ˜Ž๐Ÿคฉ๐Ÿฅณ๐Ÿ˜๐Ÿ˜’๐Ÿ˜ž๐Ÿ˜”๐Ÿ˜Ÿ๐Ÿ˜•๐Ÿ™โ˜น๏ธ๐Ÿ˜ฃ๐Ÿ˜–๐Ÿ˜ซ๐Ÿ˜ฉ๐Ÿฅบ๐Ÿ˜ข๐Ÿ˜ญ๐Ÿ˜ค๐Ÿ˜ ๐Ÿ˜ก๐Ÿคฌ๐Ÿ˜ฑ๐Ÿ˜จ๐Ÿ˜ฐ๐Ÿ˜ฅ๐Ÿ˜“๐Ÿค—๐Ÿค”๐Ÿคญ๐Ÿคซ๐Ÿคฅ๐Ÿ˜ถ๐Ÿ˜๐Ÿ˜‘๐Ÿ˜ฌ๐Ÿ™„๐Ÿ˜ฏ๐Ÿ˜ฆ๐Ÿ˜ง๐Ÿ˜ฎ๐Ÿ˜ฒ๐Ÿฅฑ๐Ÿ˜ด๐Ÿคค๐Ÿ˜ช๐Ÿ˜ต๐Ÿค๐Ÿฅด๐Ÿคข๐Ÿคฎ๐Ÿคง๐Ÿ˜ท๐Ÿค’๐Ÿค•๐Ÿ˜ˆ๐Ÿ‘ฟ๐Ÿ‘น๐Ÿ‘บ๐Ÿ’€โ˜ ๏ธ๐Ÿ‘ป๐Ÿ‘ฝ๐Ÿ‘พ๐Ÿค–๐Ÿ’ฉ๐Ÿ˜บ๐Ÿ˜ธ๐Ÿ˜น๐Ÿ˜ป๐Ÿ˜ผ๐Ÿ˜ฝ๐Ÿ™€๐Ÿ˜ฟ๐Ÿ˜พ"; + // Ensure the string is not longer than 1000 characters + expected = expected.Substring(0, Math.Min(1000, expected.Length)); + float[] input = Lexer.VectorizeUTF8(expected); // Use VectorizeUTF8 to get the correct float array + + TestContext.WriteLine($"Testing with input float array: '{Utility.FormatFloatArray(input)}'."); + + // Act + string result = Lexer.QuantizeUTF8(input); + TestContext.WriteLine($"Resulting UTF-8 string: '{result}'"); + + // Assert + Assert.That(result, Is.Not.Null); + Assert.That(result, Is.EqualTo(expected), "The resulting string should match the expected emoji-only UTF-8 string."); + TestContext.WriteLine($"Test passed: Input float array converts to expected UTF-8 emoji-only string '{expected}'."); + } } }