// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "InferenceTest.hpp" #include "DeepSpeechV1Database.hpp" #include #include #include #include namespace { template class DeepSpeechV1TestCase : public InferenceModelTestCase { public: DeepSpeechV1TestCase(Model& model, unsigned int testCaseId, const DeepSpeechV1TestCaseData& testCaseData) : InferenceModelTestCase(model, testCaseId, { testCaseData.m_InputData.m_InputSeq, testCaseData.m_InputData.m_StateH, testCaseData.m_InputData.m_StateC}, { k_OutputSize1, k_OutputSize2, k_OutputSize3 }) , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH, testCaseData.m_ExpectedOutputData.m_StateC}) {} TestCaseResult ProcessResult(const InferenceTestOptions& options) override { armnn::IgnoreUnused(options); const std::vector& output1 = mapbox::util::get>(this->GetOutputs()[0]); // logits ARMNN_ASSERT(output1.size() == k_OutputSize1); const std::vector& output2 = mapbox::util::get>(this->GetOutputs()[1]); // new_state_c ARMNN_ASSERT(output2.size() == k_OutputSize2); const std::vector& output3 = mapbox::util::get>(this->GetOutputs()[2]); // new_state_h ARMNN_ASSERT(output3.size() == k_OutputSize3); // Check each output to see whether it is the expected value for (unsigned int j = 0u; j < output1.size(); j++) { if(!armnnUtils::within_percentage_tolerance(output1[j], m_ExpectedOutputs.m_InputSeq[j])) { ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() << " is incorrect at" << j; return TestCaseResult::Failed; } } for (unsigned int j = 0u; j < output2.size(); j++) { if(!armnnUtils::within_percentage_tolerance(output2[j], m_ExpectedOutputs.m_StateH[j])) { ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() << " is incorrect"; return TestCaseResult::Failed; } } for (unsigned int j = 0u; j < output3.size(); j++) { if(!armnnUtils::within_percentage_tolerance(output3[j], m_ExpectedOutputs.m_StateC[j])) { ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() << " is incorrect"; return TestCaseResult::Failed; } } return TestCaseResult::Ok; } private: static constexpr unsigned int k_OutputSize1 = 464u; static constexpr unsigned int k_OutputSize2 = 2048u; static constexpr unsigned int k_OutputSize3 = 2048u; LstmInput m_ExpectedOutputs; }; template class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider { public: template explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel) : m_ConstructModel(constructModel) {} virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector& required) override { options .allow_unrecognised_options() .add_options() ("s,input-seq-dir", "Path to directory containing test data for m_InputSeq", cxxopts::value(m_InputSeqDir)) ("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH", cxxopts::value(m_PrevStateHDir)) ("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC", cxxopts::value(m_PrevStateCDir)) ("l,logits-dir", "Path to directory containing test data for m_Logits", cxxopts::value(m_LogitsDir)) ("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH", cxxopts::value(m_NewStateHDir)) ("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC", cxxopts::value(m_NewStateCDir)); required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir", "new-state-h-dir", "new-state-c-dir"}); Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required); } virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override { if (!ValidateDirectory(m_InputSeqDir)) { return false; } if (!ValidateDirectory(m_PrevStateCDir)) { return false; } if (!ValidateDirectory(m_PrevStateHDir)) { return false; } if (!ValidateDirectory(m_LogitsDir)) { return false; } if (!ValidateDirectory(m_NewStateCDir)) { return false; } if (!ValidateDirectory(m_NewStateHDir)) { return false; } m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); if (!m_Model) { return false; } m_Database = std::make_unique(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(), m_PrevStateCDir.c_str(), m_LogitsDir.c_str(), m_NewStateHDir.c_str(), m_NewStateCDir.c_str()); if (!m_Database) { return false; } return true; } std::unique_ptr GetTestCase(unsigned int testCaseId) override { std::unique_ptr testCaseData = m_Database->GetTestCaseData(testCaseId); if (!testCaseData) { return nullptr; } return std::make_unique>(*m_Model, testCaseId, *testCaseData); } private: typename Model::CommandLineOptions m_ModelCommandLineOptions; std::function(const InferenceTestOptions&, typename Model::CommandLineOptions)> m_ConstructModel; std::unique_ptr m_Model; std::string m_InputSeqDir; std::string m_PrevStateCDir; std::string m_PrevStateHDir; std::string m_LogitsDir; std::string m_NewStateCDir; std::string m_NewStateHDir; std::unique_ptr m_Database; }; } // anonymous namespace