diff --git a/data/my_squad_answer.txt b/data/my_squad_answer.txt new file mode 100644 index 0000000..7e6ac22 --- /dev/null +++ b/data/my_squad_answer.txt @@ -0,0 +1 @@ +copenhagen telephone exchange \ No newline at end of file diff --git a/data/my_squad_question.txt b/data/my_squad_question.txt new file mode 100644 index 0000000..e19758c --- /dev/null +++ b/data/my_squad_question.txt @@ -0,0 +1 @@ +Where did Erlang work? \ No newline at end of file diff --git a/data/my_squad_source.txt b/data/my_squad_source.txt new file mode 100644 index 0000000..e468f07 --- /dev/null +++ b/data/my_squad_source.txt @@ -0,0 +1 @@ +Erlang worked for the Copenhagen Telephone Exchange and wanted to analyze and optimize its operations \ No newline at end of file diff --git a/modules/7_nlp/include/model.hpp b/modules/7_nlp/include/model.hpp index 52a2f3e..0bf2c9b 100644 --- a/modules/7_nlp/include/model.hpp +++ b/modules/7_nlp/include/model.hpp @@ -14,4 +14,5 @@ class SQuADModel { private: Tokenizer tokenizer; InferenceEngine::InferRequest req; + std::string outputName; }; diff --git a/modules/7_nlp/include/tokenizer.hpp b/modules/7_nlp/include/tokenizer.hpp index 0efd486..db0db60 100644 --- a/modules/7_nlp/include/tokenizer.hpp +++ b/modules/7_nlp/include/tokenizer.hpp @@ -27,6 +27,6 @@ class Tokenizer { std::vector tokensToIndices(const std::vector& tokens, int maxNumTokens=128); private: - std::vector vocab; std::map vocabMap; + std::vector vocab; }; diff --git a/modules/7_nlp/src/model.cpp b/modules/7_nlp/src/model.cpp index 254e127..d49788d 100644 --- a/modules/7_nlp/src/model.cpp +++ b/modules/7_nlp/src/model.cpp @@ -11,16 +11,23 @@ using namespace InferenceEngine; using namespace cv; using namespace cv::utils::fs; +Blob::Ptr wrapVecToBlob(const std::vector& v) { + std::vector dims = {1, v.size()}; + return make_shared_blob(TensorDesc(Precision::I32, dims, Layout::NC), (int*)v.data()); +} + SQuADModel::SQuADModel() : tokenizer(join(DATA_FOLDER, "bert-large-uncased-vocab.txt")) { Core ie; // Load deep learning network into memory CNNNetwork net = ie.ReadNetwork(join(DATA_FOLDER, "distilbert.xml"), join(DATA_FOLDER, "distilbert.bin")); - + InputInfo::Ptr inputInfo = net.getInputsInfo()["input.1"]; + inputInfo->setLayout(Layout::NC); + inputInfo->setPrecision(Precision::I32); + outputName = net.getOutputsInfo().begin()->first; // Initialize runnable object on CPU device ExecutableNetwork execNet = ie.LoadNetwork(net, "CPU"); - // Create a single processing thread req = execNet.CreateInferRequest(); } @@ -39,8 +46,39 @@ std::string SQuADModel::getAnswer(const std::string& question, const std::string tokens.push_back("[SEP]"); std::vector indices = tokenizer.tokensToIndices(tokens); + Blob::Ptr input = wrapVecToBlob(indices); + req.SetBlob("input.1", input); + req.Infer(); + float* output1 = req.GetBlob("Squeeze_437")->buffer().as(); + float* output2 = req.GetBlob("Squeeze_438")->buffer().as(); + float max1 = output1[0], max2 = output2[0]; + int indMax1 = 0, indMax2 = 0; + for (int i = 0; i < 128; i++) { + if (output1[i] > max1) { + max1 = output1[i]; + indMax1 = i; + } - // TODO: forward indices through the network and return an answer - - return ""; + if (output2[i] > max2) { + max2 = output2[i]; + indMax2 = i; + } + } + + std::cout << indMax1 << " " << indMax2 << std::endl; + std::string result = ""; + CV_CheckLE(indMax1, indMax2, "indMax1 > indMax2"); + for (int i = indMax1; i < indMax2 + 1; i++) { + std::string word = tokens[i]; + if (word[0] == '#') { + result.pop_back(); + result += word.substr(2, word.length() - 2); + result += (char)32; + } + else { + result += word + (char)32; + } + } + result.pop_back(); + return result; } diff --git a/modules/7_nlp/src/tokenizer.cpp b/modules/7_nlp/src/tokenizer.cpp index 1ac56b1..e17390a 100644 --- a/modules/7_nlp/src/tokenizer.cpp +++ b/modules/7_nlp/src/tokenizer.cpp @@ -6,7 +6,26 @@ #include std::vector basicTokenize(const std::string& text) { - CV_Error(cv::Error::StsNotImplemented, "basicTokenize"); + std::vector basicTokens; + std::string currToken = ""; + for (auto ch : text) { + if (isspace(ch)) { + if (!currToken.empty()) + basicTokens.push_back(currToken); + currToken = ""; + } else if (ispunct(ch)) { + if (!currToken.empty()) + basicTokens.push_back(currToken); + currToken = ""; currToken += ch; + basicTokens.push_back(currToken); + currToken = ""; + } else { + currToken += tolower(ch); + } + } + if (!currToken.empty()) + basicTokens.push_back(currToken); + return basicTokens; } std::vector wordTokenize(const std::string& word,