From 04a8b05b25d3b752040a262a2725fa59753dd9b5 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 26 Apr 2019 13:48:57 +0100 Subject: IVGCVSW-3005 Correct the order of inputs and outputs of deepspeech v1 Change-Id: I36b3467e74508ad4e8f3140285f965bc63433d1d Signed-off-by: Narumol Prangnawarat --- tests/DeepSpeechV1Database.hpp | 24 ++++++++++++------------ tests/DeepSpeechV1InferenceTest.hpp | 35 ++++++++++++++++++----------------- tests/LstmCommon.hpp | 8 ++++---- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/tests/DeepSpeechV1Database.hpp b/tests/DeepSpeechV1Database.hpp index 274bf6e22f..037c810122 100644 --- a/tests/DeepSpeechV1Database.hpp +++ b/tests/DeepSpeechV1Database.hpp @@ -115,30 +115,30 @@ struct DeepSpeechV1TestCaseData class DeepSpeechV1Database { public: - explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateCDir, - const std::string& prevStateHDir, const std::string& logitsDir, - const std::string& newStateCDir, const std::string& newStateHDir); + explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir, + const std::string& prevStateCDir, const std::string& logitsDir, + const std::string& newStateHDir, const std::string& newStateCDir); std::unique_ptr GetTestCaseData(unsigned int testCaseId); private: std::string m_InputSeqDir; - std::string m_PrevStateCDir; std::string m_PrevStateHDir; + std::string m_PrevStateCDir; std::string m_LogitsDir; - std::string m_NewStateCDir; std::string m_NewStateHDir; + std::string m_NewStateCDir; }; -DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateCDir, - const std::string& prevStateHDir, const std::string& logitsDir, - const std::string& newStateCDir, const std::string& newStateHDir) +DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir, + const std::string& prevStateCDir, const std::string& logitsDir, + const std::string& newStateHDir, const std::string& newStateCDir) : m_InputSeqDir(inputSeqDir) - , m_PrevStateCDir(prevStateCDir) , m_PrevStateHDir(prevStateHDir) + , m_PrevStateCDir(prevStateCDir) , m_LogitsDir(logitsDir) - , m_NewStateCDir(newStateCDir) , m_NewStateHDir(newStateHDir) + , m_NewStateCDir(newStateCDir) {} std::unique_ptr DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId) @@ -194,9 +194,9 @@ std::unique_ptr DeepSpeechV1Database::GetTestCaseData( } // use the struct for representing input and output data - LstmInput inputDataSingleTest(inputSeqData, prevStateCData, prevStateHData); + LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData); - LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateCData, expectedNewStateHData); + LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData); return std::make_unique(inputDataSingleTest, expectedOutputsSingleTest); } diff --git a/tests/DeepSpeechV1InferenceTest.hpp b/tests/DeepSpeechV1InferenceTest.hpp index 633176219c..3195d2bb14 100755 --- a/tests/DeepSpeechV1InferenceTest.hpp +++ b/tests/DeepSpeechV1InferenceTest.hpp @@ -27,12 +27,12 @@ public: : InferenceModelTestCase(model, testCaseId, { testCaseData.m_InputData.m_InputSeq, - testCaseData.m_InputData.m_StateC, - testCaseData.m_InputData.m_StateH}, + testCaseData.m_InputData.m_StateH, + testCaseData.m_InputData.m_StateC}, { k_OutputSize1, k_OutputSize2, k_OutputSize3 }) , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f)) - , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateC, - testCaseData.m_ExpectedOutputData.m_StateH}) + , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH, + testCaseData.m_ExpectedOutputData.m_StateC}) {} TestCaseResult ProcessResult(const InferenceTestOptions& options) override @@ -59,9 +59,9 @@ public: for (unsigned int j = 0u; j < output2.size(); j++) { - if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateC[j])) + if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j])) { - BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() << + BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() << " is incorrect"; return TestCaseResult::Failed; } @@ -69,9 +69,9 @@ public: for (unsigned int j = 0u; j < output3.size(); j++) { - if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateH[j])) + if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j])) { - BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() << + BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() << " is incorrect"; return TestCaseResult::Failed; } @@ -105,21 +105,22 @@ public: options.add_options() ("input-seq-dir,s", po::value(&m_InputSeqDir)->required(), "Path to directory containing test data for m_InputSeq"); - options.add_options() - ("prev-state-c-dir,c", po::value(&m_PrevStateCDir)->required(), - "Path to directory containing test data for m_PrevStateC"); options.add_options() ("prev-state-h-dir,h", po::value(&m_PrevStateHDir)->required(), "Path to directory containing test data for m_PrevStateH"); + options.add_options() + ("prev-state-c-dir,c", po::value(&m_PrevStateCDir)->required(), + "Path to directory containing test data for m_PrevStateC"); options.add_options() ("logits-dir,l", po::value(&m_LogitsDir)->required(), "Path to directory containing test data for m_Logits"); - options.add_options() - ("new-state-c-dir,C", po::value(&m_NewStateCDir)->required(), - "Path to directory containing test data for m_NewStateC"); options.add_options() ("new-state-h-dir,H", po::value(&m_NewStateHDir)->required(), "Path to directory containing test data for m_NewStateH"); + options.add_options() + ("new-state-c-dir,C", po::value(&m_NewStateCDir)->required(), + "Path to directory containing test data for m_NewStateC"); + Model::AddCommandLineOptions(options, m_ModelCommandLineOptions); } @@ -161,9 +162,9 @@ public: { return false; } - m_Database = std::make_unique(m_InputSeqDir.c_str(), m_PrevStateCDir.c_str(), - m_PrevStateHDir.c_str(), m_LogitsDir.c_str(), - m_NewStateCDir.c_str(), m_NewStateHDir.c_str()); + m_Database = std::make_unique(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(), + m_PrevStateCDir.c_str(), m_LogitsDir.c_str(), + m_NewStateHDir.c_str(), m_NewStateCDir.c_str()); if (!m_Database) { return false; diff --git a/tests/LstmCommon.hpp b/tests/LstmCommon.hpp index 31c4d041c1..0876d263ed 100755 --- a/tests/LstmCommon.hpp +++ b/tests/LstmCommon.hpp @@ -13,16 +13,16 @@ namespace struct LstmInput { LstmInput(const std::vector& inputSeq, - const std::vector& stateC, - const std::vector& stateH) + const std::vector& stateH, + const std::vector& stateC) : m_InputSeq(inputSeq) - , m_StateC(stateC) , m_StateH(stateH) + , m_StateC(stateC) {} std::vector m_InputSeq; - std::vector m_StateC; std::vector m_StateH; + std::vector m_StateC; }; using LstmInputs = std::pair>; -- cgit v1.2.1