From e571d33a4550ab3bea6f42dda3ec88d5924e9c00 Mon Sep 17 00:00:00 2001 From: Jim Flynn Date: Mon, 15 Apr 2019 14:34:17 +0100 Subject: IVGCVSW-2855 Create TfLite reference test for DeepSpeechV1 Change-Id: I4492a85c8337bf4ea0eb998c88b9cbfc932dc4e6 Signed-off-by: Ruomei Yan Signed-off-by: Jim Flynn --- tests/DeepSpeechV1Database.hpp | 203 ++++++++++++++++++++++++++++++++++++ tests/DeepSpeechV1InferenceTest.hpp | 201 +++++++++++++++++++++++++++++++++++ tests/LstmCommon.hpp | 30 ++++++ 3 files changed, 434 insertions(+) create mode 100644 tests/DeepSpeechV1Database.hpp create mode 100755 tests/DeepSpeechV1InferenceTest.hpp create mode 100755 tests/LstmCommon.hpp (limited to 'tests') diff --git a/tests/DeepSpeechV1Database.hpp b/tests/DeepSpeechV1Database.hpp new file mode 100644 index 0000000000..4d2d591bed --- /dev/null +++ b/tests/DeepSpeechV1Database.hpp @@ -0,0 +1,203 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "LstmCommon.hpp" + +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + +#include "InferenceTestImage.hpp" + +namespace +{ + +template +std::vector ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:") +{ + std::vector result; + // Processes line-by-line. + std::string line; + while (std::getline(stream, line)) + { + std::vector tokens; + try + { + // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call. + boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on); + } + catch (const std::exception& e) + { + BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what(); + continue; + } + for (const std::string& token : tokens) + { + if (!token.empty()) // See https://stackoverflow.com/questions/10437406/ + { + try + { + result.push_back(parseElementFunc(token)); + } + catch (const std::exception&) + { + BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored."; + } + } + } + } + + return result; +} + +template +auto ParseDataArray(std::istream & stream); + +template +auto ParseDataArray(std::istream& stream, + const float& quantizationScale, + const int32_t& quantizationOffset); + +template<> +auto ParseDataArray(std::istream & stream) +{ + return ParseArrayImpl(stream, [](const std::string& s) { return std::stof(s); }); +} + +template<> +auto ParseDataArray(std::istream & stream) +{ + return ParseArrayImpl(stream, [](const std::string & s) { return std::stoi(s); }); +} + +template<> +auto ParseDataArray(std::istream& stream, + const float& quantizationScale, + const int32_t& quantizationOffset) +{ + return ParseArrayImpl(stream, + [&quantizationScale, &quantizationOffset](const std::string & s) + { + return boost::numeric_cast( + armnn::Quantize(std::stof(s), + quantizationScale, + quantizationOffset)); + }); +} + +struct DeepSpeechV1TestCaseData +{ + DeepSpeechV1TestCaseData( + const LstmInput& inputData, + const LstmInput& expectedOutputData) + : m_InputData(inputData) + , m_ExpectedOutputData(expectedOutputData) + {} + + LstmInput m_InputData; + LstmInput m_ExpectedOutputData; +}; + +class DeepSpeechV1Database +{ +public: + explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateCDir, + const std::string& prevStateHDir, const std::string& logitsDir, + const std::string& newStateCDir, const std::string& newStateHDir); + + std::unique_ptr GetTestCaseData(unsigned int testCaseId); + +private: + std::string m_InputSeqDir; + std::string m_PrevStateCDir; + std::string m_PrevStateHDir; + std::string m_LogitsDir; + std::string m_NewStateCDir; + std::string m_NewStateHDir; +}; + +DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateCDir, + const std::string& prevStateHDir, const std::string& logitsDir, + const std::string& newStateCDir, const std::string& newStateHDir) + : m_InputSeqDir(inputSeqDir) + , m_PrevStateCDir(prevStateCDir) + , m_PrevStateHDir(prevStateHDir) + , m_LogitsDir(logitsDir) + , m_NewStateCDir(newStateCDir) + , m_NewStateHDir(newStateHDir) +{} + +std::unique_ptr DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId) +{ + // Load test case input + const std::string inputSeqPath = m_InputSeqDir + "input_node_0_flat.txt"; + const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt"; + const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt"; + + std::vector inputSeqData; + std::vector prevStateCData; + std::vector prevStateHData; + + std::ifstream inputSeqFile(inputSeqPath); + std::ifstream prevStateCTensorFile(prevStateCPath); + std::ifstream prevStateHTensorFile(prevStateHPath); + + try + { + inputSeqData = ParseDataArray(inputSeqFile); + prevStateCData = ParseDataArray(prevStateCTensorFile); + prevStateHData = ParseDataArray(prevStateHTensorFile); + } + catch (const InferenceTestImageException& e) + { + BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what(); + return nullptr; + } + + // Prepare test case expected output + const std::string logitsPath = m_LogitsDir + "logits.txt"; + const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt"; + const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt"; + + std::vector logitsData; + std::vector expectedNewStateCData; + std::vector expectedNewStateHData; + + std::ifstream logitsTensorFile(logitsPath); + std::ifstream newStateCTensorFile(newStateCPath); + std::ifstream newStateHTensorFile(newStateHPath); + + try + { + logitsData = ParseDataArray(logitsTensorFile); + expectedNewStateCData = ParseDataArray(newStateCTensorFile); + expectedNewStateHData = ParseDataArray(newStateHTensorFile); + } + catch (const InferenceTestImageException& e) + { + BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what(); + return nullptr; + } + + // use the struct for representing input and output data + LstmInput inputDataSingleTest(inputSeqData, prevStateCData, prevStateHData); + + LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateCData, expectedNewStateHData); + + return std::make_unique(inputDataSingleTest, expectedOutputsSingleTest); +} + +} // anonymous namespace + diff --git a/tests/DeepSpeechV1InferenceTest.hpp b/tests/DeepSpeechV1InferenceTest.hpp new file mode 100755 index 0000000000..24e7dac567 --- /dev/null +++ b/tests/DeepSpeechV1InferenceTest.hpp @@ -0,0 +1,201 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "InferenceTest.hpp" +#include "DeepSpeechV1Database.hpp" + +#include +#include +#include +#include + +#include + +namespace +{ + +template +class DeepSpeechV1TestCase : public InferenceModelTestCase +{ +public: + DeepSpeechV1TestCase(Model& model, + unsigned int testCaseId, + const DeepSpeechV1TestCaseData& testCaseData) + : InferenceModelTestCase(model, + testCaseId, + { testCaseData.m_InputData.m_InputSeq, + testCaseData.m_InputData.m_StateC, + testCaseData.m_InputData.m_StateH}, + { k_OutputSize1, k_OutputSize2, k_OutputSize3 }) + , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f)) + , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateC, + testCaseData.m_ExpectedOutputData.m_StateH}) + {} + + TestCaseResult ProcessResult(const InferenceTestOptions& options) override + { + const std::vector& output1 = boost::get>(this->GetOutputs()[0]); // logits + BOOST_ASSERT(output1.size() == k_OutputSize1); + + const std::vector& output2 = boost::get>(this->GetOutputs()[1]); // new_state_c + BOOST_ASSERT(output2.size() == k_OutputSize2); + + const std::vector& output3 = boost::get>(this->GetOutputs()[2]); // new_state_h + BOOST_ASSERT(output3.size() == k_OutputSize3); + + // Check each output to see whether it is the expected value + for (unsigned int j = 0u; j < output1.size(); j++) + { + if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j])) + { + BOOST_LOG_TRIVIAL(error) << "InputSeq for Lstm " << this->GetTestCaseId() << + " is incorrect at" << j; + return TestCaseResult::Failed; + } + } + + for (unsigned int j = 0u; j < output2.size(); j++) + { + if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateC[j])) + { + BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() << + " is incorrect"; + return TestCaseResult::Failed; + } + } + + for (unsigned int j = 0u; j < output3.size(); j++) + { + if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateH[j])) + { + BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() << + " is incorrect"; + return TestCaseResult::Failed; + } + } + return TestCaseResult::Ok; + } + +private: + + static constexpr unsigned int k_OutputSize1 = 464u; + static constexpr unsigned int k_OutputSize2 = 2048u; + static constexpr unsigned int k_OutputSize3 = 2048u; + + boost::math::fpc::close_at_tolerance m_FloatComparer; + LstmInput m_ExpectedOutputs; +}; + +template +class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider +{ +public: + template + explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel) + : m_ConstructModel(constructModel) + {} + + virtual void AddCommandLineOptions(boost::program_options::options_description& options) override + { + namespace po = boost::program_options; + + options.add_options() + ("input-seq-dir,s", po::value(&m_InputSeqDir)->required(), + "Path to directory containing test data for m_InputSeq"); + options.add_options() + ("prev-state-c-dir,c", po::value(&m_PrevStateCDir)->required(), + "Path to directory containing test data for m_PrevStateC"); + options.add_options() + ("prev-state-h-dir,h", po::value(&m_PrevStateHDir)->required(), + "Path to directory containing test data for m_PrevStateH"); + options.add_options() + ("logits-dir,l", po::value(&m_LogitsDir)->required(), + "Path to directory containing test data for m_Logits"); + options.add_options() + ("new-state-c-dir,C", po::value(&m_NewStateCDir)->required(), + "Path to directory containing test data for m_NewStateC"); + options.add_options() + ("new-state-h-dir,H", po::value(&m_NewStateHDir)->required(), + "Path to directory containing test data for m_NewStateH"); + + Model::AddCommandLineOptions(options, m_ModelCommandLineOptions); + } + + virtual bool ProcessCommandLineOptions() override + { + if (!ValidateDirectory(m_InputSeqDir)) + { + return false; + } + + if (!ValidateDirectory(m_PrevStateCDir)) + { + return false; + } + + if (!ValidateDirectory(m_PrevStateHDir)) + { + return false; + } + + if (!ValidateDirectory(m_LogitsDir)) + { + return false; + } + + if (!ValidateDirectory(m_NewStateCDir)) + { + return false; + } + + if (!ValidateDirectory(m_NewStateHDir)) + { + return false; + } + + m_Model = m_ConstructModel(m_ModelCommandLineOptions); + if (!m_Model) + { + return false; + } + m_Database = std::make_unique(m_InputSeqDir.c_str(), m_PrevStateCDir.c_str(), + m_PrevStateHDir.c_str(), m_LogitsDir.c_str(), + m_NewStateCDir.c_str(), m_NewStateHDir.c_str()); + if (!m_Database) + { + return false; + } + + return true; + } + + std::unique_ptr GetTestCase(unsigned int testCaseId) override + { + std::unique_ptr testCaseData = m_Database->GetTestCaseData(testCaseId); + if (!testCaseData) + { + return nullptr; + } + + return std::make_unique>(*m_Model, testCaseId, *testCaseData); + } + +private: + typename Model::CommandLineOptions m_ModelCommandLineOptions; + std::function(typename Model::CommandLineOptions)> m_ConstructModel; + std::unique_ptr m_Model; + + std::string m_InputSeqDir; + std::string m_PrevStateCDir; + std::string m_PrevStateHDir; + std::string m_LogitsDir; + std::string m_NewStateCDir; + std::string m_NewStateHDir; + + std::unique_ptr m_Database; +}; + +} // anonymous namespace diff --git a/tests/LstmCommon.hpp b/tests/LstmCommon.hpp new file mode 100755 index 0000000000..31c4d041c1 --- /dev/null +++ b/tests/LstmCommon.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include +#include + +namespace +{ + +struct LstmInput +{ + LstmInput(const std::vector& inputSeq, + const std::vector& stateC, + const std::vector& stateH) + : m_InputSeq(inputSeq) + , m_StateC(stateC) + , m_StateH(stateH) + {} + + std::vector m_InputSeq; + std::vector m_StateC; + std::vector m_StateH; +}; + +using LstmInputs = std::pair>; + +} // anonymous namespace \ No newline at end of file -- cgit v1.2.1