diff options
author | Jim Flynn <jim.flynn@arm.com> | 2019-04-15 14:34:17 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-04-16 10:58:52 +0000 |
commit | e571d33a4550ab3bea6f42dda3ec88d5924e9c00 (patch) | |
tree | b2dd99b11ca9c9eec79277d39674ada4c194f2f8 /tests/DeepSpeechV1InferenceTest.hpp | |
parent | aab6aff4aa282810cb535eeec65e59741f1f4f0e (diff) | |
download | armnn-e571d33a4550ab3bea6f42dda3ec88d5924e9c00.tar.gz |
IVGCVSW-2855 Create TfLite reference test for DeepSpeechV1
Change-Id: I4492a85c8337bf4ea0eb998c88b9cbfc932dc4e6
Signed-off-by: Ruomei Yan <ruomei.yan@arm.com>
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'tests/DeepSpeechV1InferenceTest.hpp')
-rwxr-xr-x | tests/DeepSpeechV1InferenceTest.hpp | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/tests/DeepSpeechV1InferenceTest.hpp b/tests/DeepSpeechV1InferenceTest.hpp new file mode 100755 index 0000000000..24e7dac567 --- /dev/null +++ b/tests/DeepSpeechV1InferenceTest.hpp @@ -0,0 +1,201 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "InferenceTest.hpp" +#include "DeepSpeechV1Database.hpp" + +#include <boost/assert.hpp> +#include <boost/log/trivial.hpp> +#include <boost/numeric/conversion/cast.hpp> +#include <boost/test/tools/floating_point_comparison.hpp> + +#include <vector> + +namespace +{ + +template<typename Model> +class DeepSpeechV1TestCase : public InferenceModelTestCase<Model> +{ +public: + DeepSpeechV1TestCase(Model& model, + unsigned int testCaseId, + const DeepSpeechV1TestCaseData& testCaseData) + : InferenceModelTestCase<Model>(model, + testCaseId, + { testCaseData.m_InputData.m_InputSeq, + testCaseData.m_InputData.m_StateC, + testCaseData.m_InputData.m_StateH}, + { k_OutputSize1, k_OutputSize2, k_OutputSize3 }) + , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f)) + , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateC, + testCaseData.m_ExpectedOutputData.m_StateH}) + {} + + TestCaseResult ProcessResult(const InferenceTestOptions& options) override + { + const std::vector<float>& output1 = boost::get<std::vector<float>>(this->GetOutputs()[0]); // logits + BOOST_ASSERT(output1.size() == k_OutputSize1); + + const std::vector<float>& output2 = boost::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c + BOOST_ASSERT(output2.size() == k_OutputSize2); + + const std::vector<float>& output3 = boost::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h + BOOST_ASSERT(output3.size() == k_OutputSize3); + + // Check each output to see whether it is the expected value + for (unsigned int j = 0u; j < output1.size(); j++) + { + if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j])) + { + BOOST_LOG_TRIVIAL(error) << "InputSeq for Lstm " << this->GetTestCaseId() << + " is incorrect at" << j; + return TestCaseResult::Failed; + } + } + + for (unsigned int j = 0u; j < output2.size(); j++) + { + if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateC[j])) + { + BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() << + " is incorrect"; + return TestCaseResult::Failed; + } + } + + for (unsigned int j = 0u; j < output3.size(); j++) + { + if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateH[j])) + { + BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() << + " is incorrect"; + return TestCaseResult::Failed; + } + } + return TestCaseResult::Ok; + } + +private: + + static constexpr unsigned int k_OutputSize1 = 464u; + static constexpr unsigned int k_OutputSize2 = 2048u; + static constexpr unsigned int k_OutputSize3 = 2048u; + + boost::math::fpc::close_at_tolerance<float> m_FloatComparer; + LstmInput m_ExpectedOutputs; +}; + +template <typename Model> +class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider +{ +public: + template <typename TConstructModelCallable> + explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel) + : m_ConstructModel(constructModel) + {} + + virtual void AddCommandLineOptions(boost::program_options::options_description& options) override + { + namespace po = boost::program_options; + + options.add_options() + ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(), + "Path to directory containing test data for m_InputSeq"); + options.add_options() + ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(), + "Path to directory containing test data for m_PrevStateC"); + options.add_options() + ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(), + "Path to directory containing test data for m_PrevStateH"); + options.add_options() + ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(), + "Path to directory containing test data for m_Logits"); + options.add_options() + ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(), + "Path to directory containing test data for m_NewStateC"); + options.add_options() + ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(), + "Path to directory containing test data for m_NewStateH"); + + Model::AddCommandLineOptions(options, m_ModelCommandLineOptions); + } + + virtual bool ProcessCommandLineOptions() override + { + if (!ValidateDirectory(m_InputSeqDir)) + { + return false; + } + + if (!ValidateDirectory(m_PrevStateCDir)) + { + return false; + } + + if (!ValidateDirectory(m_PrevStateHDir)) + { + return false; + } + + if (!ValidateDirectory(m_LogitsDir)) + { + return false; + } + + if (!ValidateDirectory(m_NewStateCDir)) + { + return false; + } + + if (!ValidateDirectory(m_NewStateHDir)) + { + return false; + } + + m_Model = m_ConstructModel(m_ModelCommandLineOptions); + if (!m_Model) + { + return false; + } + m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateCDir.c_str(), + m_PrevStateHDir.c_str(), m_LogitsDir.c_str(), + m_NewStateCDir.c_str(), m_NewStateHDir.c_str()); + if (!m_Database) + { + return false; + } + + return true; + } + + std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override + { + std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId); + if (!testCaseData) + { + return nullptr; + } + + return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData); + } + +private: + typename Model::CommandLineOptions m_ModelCommandLineOptions; + std::function<std::unique_ptr<Model>(typename Model::CommandLineOptions)> m_ConstructModel; + std::unique_ptr<Model> m_Model; + + std::string m_InputSeqDir; + std::string m_PrevStateCDir; + std::string m_PrevStateHDir; + std::string m_LogitsDir; + std::string m_NewStateCDir; + std::string m_NewStateHDir; + + std::unique_ptr<DeepSpeechV1Database> m_Database; +}; + +} // anonymous namespace |