aboutsummaryrefslogtreecommitdiff
path: root/tests/DeepSpeechV1InferenceTest.hpp
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2019-04-15 14:34:17 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-04-16 10:58:52 +0000
commite571d33a4550ab3bea6f42dda3ec88d5924e9c00 (patch)
treeb2dd99b11ca9c9eec79277d39674ada4c194f2f8 /tests/DeepSpeechV1InferenceTest.hpp
parentaab6aff4aa282810cb535eeec65e59741f1f4f0e (diff)
downloadarmnn-e571d33a4550ab3bea6f42dda3ec88d5924e9c00.tar.gz
IVGCVSW-2855 Create TfLite reference test for DeepSpeechV1
Change-Id: I4492a85c8337bf4ea0eb998c88b9cbfc932dc4e6 Signed-off-by: Ruomei Yan <ruomei.yan@arm.com> Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'tests/DeepSpeechV1InferenceTest.hpp')
-rwxr-xr-xtests/DeepSpeechV1InferenceTest.hpp201
1 files changed, 201 insertions, 0 deletions
diff --git a/tests/DeepSpeechV1InferenceTest.hpp b/tests/DeepSpeechV1InferenceTest.hpp
new file mode 100755
index 0000000000..24e7dac567
--- /dev/null
+++ b/tests/DeepSpeechV1InferenceTest.hpp
@@ -0,0 +1,201 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "InferenceTest.hpp"
+#include "DeepSpeechV1Database.hpp"
+
+#include <boost/assert.hpp>
+#include <boost/log/trivial.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+#include <boost/test/tools/floating_point_comparison.hpp>
+
+#include <vector>
+
+namespace
+{
+
+template<typename Model>
+class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
+{
+public:
+ DeepSpeechV1TestCase(Model& model,
+ unsigned int testCaseId,
+ const DeepSpeechV1TestCaseData& testCaseData)
+ : InferenceModelTestCase<Model>(model,
+ testCaseId,
+ { testCaseData.m_InputData.m_InputSeq,
+ testCaseData.m_InputData.m_StateC,
+ testCaseData.m_InputData.m_StateH},
+ { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
+ , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
+ , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateC,
+ testCaseData.m_ExpectedOutputData.m_StateH})
+ {}
+
+ TestCaseResult ProcessResult(const InferenceTestOptions& options) override
+ {
+ const std::vector<float>& output1 = boost::get<std::vector<float>>(this->GetOutputs()[0]); // logits
+ BOOST_ASSERT(output1.size() == k_OutputSize1);
+
+ const std::vector<float>& output2 = boost::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
+ BOOST_ASSERT(output2.size() == k_OutputSize2);
+
+ const std::vector<float>& output3 = boost::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
+ BOOST_ASSERT(output3.size() == k_OutputSize3);
+
+ // Check each output to see whether it is the expected value
+ for (unsigned int j = 0u; j < output1.size(); j++)
+ {
+ if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
+ {
+ BOOST_LOG_TRIVIAL(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
+ " is incorrect at" << j;
+ return TestCaseResult::Failed;
+ }
+ }
+
+ for (unsigned int j = 0u; j < output2.size(); j++)
+ {
+ if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateC[j]))
+ {
+ BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() <<
+ " is incorrect";
+ return TestCaseResult::Failed;
+ }
+ }
+
+ for (unsigned int j = 0u; j < output3.size(); j++)
+ {
+ if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateH[j]))
+ {
+ BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() <<
+ " is incorrect";
+ return TestCaseResult::Failed;
+ }
+ }
+ return TestCaseResult::Ok;
+ }
+
+private:
+
+ static constexpr unsigned int k_OutputSize1 = 464u;
+ static constexpr unsigned int k_OutputSize2 = 2048u;
+ static constexpr unsigned int k_OutputSize3 = 2048u;
+
+ boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
+ LstmInput m_ExpectedOutputs;
+};
+
+template <typename Model>
+class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
+{
+public:
+ template <typename TConstructModelCallable>
+ explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
+ : m_ConstructModel(constructModel)
+ {}
+
+ virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
+ {
+ namespace po = boost::program_options;
+
+ options.add_options()
+ ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(),
+ "Path to directory containing test data for m_InputSeq");
+ options.add_options()
+ ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(),
+ "Path to directory containing test data for m_PrevStateC");
+ options.add_options()
+ ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(),
+ "Path to directory containing test data for m_PrevStateH");
+ options.add_options()
+ ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(),
+ "Path to directory containing test data for m_Logits");
+ options.add_options()
+ ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(),
+ "Path to directory containing test data for m_NewStateC");
+ options.add_options()
+ ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(),
+ "Path to directory containing test data for m_NewStateH");
+
+ Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
+ }
+
+ virtual bool ProcessCommandLineOptions() override
+ {
+ if (!ValidateDirectory(m_InputSeqDir))
+ {
+ return false;
+ }
+
+ if (!ValidateDirectory(m_PrevStateCDir))
+ {
+ return false;
+ }
+
+ if (!ValidateDirectory(m_PrevStateHDir))
+ {
+ return false;
+ }
+
+ if (!ValidateDirectory(m_LogitsDir))
+ {
+ return false;
+ }
+
+ if (!ValidateDirectory(m_NewStateCDir))
+ {
+ return false;
+ }
+
+ if (!ValidateDirectory(m_NewStateHDir))
+ {
+ return false;
+ }
+
+ m_Model = m_ConstructModel(m_ModelCommandLineOptions);
+ if (!m_Model)
+ {
+ return false;
+ }
+ m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateCDir.c_str(),
+ m_PrevStateHDir.c_str(), m_LogitsDir.c_str(),
+ m_NewStateCDir.c_str(), m_NewStateHDir.c_str());
+ if (!m_Database)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
+ {
+ std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
+ if (!testCaseData)
+ {
+ return nullptr;
+ }
+
+ return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
+ }
+
+private:
+ typename Model::CommandLineOptions m_ModelCommandLineOptions;
+ std::function<std::unique_ptr<Model>(typename Model::CommandLineOptions)> m_ConstructModel;
+ std::unique_ptr<Model> m_Model;
+
+ std::string m_InputSeqDir;
+ std::string m_PrevStateCDir;
+ std::string m_PrevStateHDir;
+ std::string m_LogitsDir;
+ std::string m_NewStateCDir;
+ std::string m_NewStateHDir;
+
+ std::unique_ptr<DeepSpeechV1Database> m_Database;
+};
+
+} // anonymous namespace