// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "InferenceTest.hpp" #include "DeepSpeechV1Database.hpp" #include #include #include #include #include namespace { template class DeepSpeechV1TestCase : public InferenceModelTestCase { public: DeepSpeechV1TestCase(Model& model, unsigned int testCaseId, const DeepSpeechV1TestCaseData& testCaseData) : InferenceModelTestCase(model, testCaseId, { testCaseData.m_InputData.m_InputSeq, testCaseData.m_InputData.m_StateH, testCaseData.m_InputData.m_StateC}, { k_OutputSize1, k_OutputSize2, k_OutputSize3 }) , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f)) , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH, testCaseData.m_ExpectedOutputData.m_StateC}) {} TestCaseResult ProcessResult(const InferenceTestOptions& options) override { const std::vector& output1 = boost::get>(this->GetOutputs()[0]); // logits BOOST_ASSERT(output1.size() == k_OutputSize1); const std::vector& output2 = boost::get>(this->GetOutputs()[1]); // new_state_c BOOST_ASSERT(output2.size() == k_OutputSize2); const std::vector& output3 = boost::get>(this->GetOutputs()[2]); // new_state_h BOOST_ASSERT(output3.size() == k_OutputSize3); // Check each output to see whether it is the expected value for (unsigned int j = 0u; j < output1.size(); j++) { if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j])) { BOOST_LOG_TRIVIAL(error) << "InputSeq for Lstm " << this->GetTestCaseId() << " is incorrect at" << j; return TestCaseResult::Failed; } } for (unsigned int j = 0u; j < output2.size(); j++) { if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j])) { BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() << " is incorrect"; return TestCaseResult::Failed; } } for (unsigned int j = 0u; j < output3.size(); j++) { if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j])) { BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() << " is incorrect"; return TestCaseResult::Failed; } } return TestCaseResult::Ok; } private: static constexpr unsigned int k_OutputSize1 = 464u; static constexpr unsigned int k_OutputSize2 = 2048u; static constexpr unsigned int k_OutputSize3 = 2048u; boost::math::fpc::close_at_tolerance m_FloatComparer; LstmInput m_ExpectedOutputs; }; template class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider { public: template explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel) : m_ConstructModel(constructModel) {} virtual void AddCommandLineOptions(boost::program_options::options_description& options) override { namespace po = boost::program_options; options.add_options() ("input-seq-dir,s", po::value(&m_InputSeqDir)->required(), "Path to directory containing test data for m_InputSeq"); options.add_options() ("prev-state-h-dir,h", po::value(&m_PrevStateHDir)->required(), "Path to directory containing test data for m_PrevStateH"); options.add_options() ("prev-state-c-dir,c", po::value(&m_PrevStateCDir)->required(), "Path to directory containing test data for m_PrevStateC"); options.add_options() ("logits-dir,l", po::value(&m_LogitsDir)->required(), "Path to directory containing test data for m_Logits"); options.add_options() ("new-state-h-dir,H", po::value(&m_NewStateHDir)->required(), "Path to directory containing test data for m_NewStateH"); options.add_options() ("new-state-c-dir,C", po::value(&m_NewStateCDir)->required(), "Path to directory containing test data for m_NewStateC"); Model::AddCommandLineOptions(options, m_ModelCommandLineOptions); } virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override { if (!ValidateDirectory(m_InputSeqDir)) { return false; } if (!ValidateDirectory(m_PrevStateCDir)) { return false; } if (!ValidateDirectory(m_PrevStateHDir)) { return false; } if (!ValidateDirectory(m_LogitsDir)) { return false; } if (!ValidateDirectory(m_NewStateCDir)) { return false; } if (!ValidateDirectory(m_NewStateHDir)) { return false; } m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); if (!m_Model) { return false; } m_Database = std::make_unique(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(), m_PrevStateCDir.c_str(), m_LogitsDir.c_str(), m_NewStateHDir.c_str(), m_NewStateCDir.c_str()); if (!m_Database) { return false; } return true; } std::unique_ptr GetTestCase(unsigned int testCaseId) override { std::unique_ptr testCaseData = m_Database->GetTestCaseData(testCaseId); if (!testCaseData) { return nullptr; } return std::make_unique>(*m_Model, testCaseId, *testCaseData); } private: typename Model::CommandLineOptions m_ModelCommandLineOptions; std::function(const InferenceTestOptions&, typename Model::CommandLineOptions)> m_ConstructModel; std::unique_ptr m_Model; std::string m_InputSeqDir; std::string m_PrevStateCDir; std::string m_PrevStateHDir; std::string m_LogitsDir; std::string m_NewStateCDir; std::string m_NewStateHDir; std::unique_ptr m_Database; }; } // anonymous namespace