diff options
Diffstat (limited to 'tests/DeepSpeechV1InferenceTest.hpp')
-rwxr-xr-x | tests/DeepSpeechV1InferenceTest.hpp | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/tests/DeepSpeechV1InferenceTest.hpp b/tests/DeepSpeechV1InferenceTest.hpp index 633176219c..3195d2bb14 100755 --- a/tests/DeepSpeechV1InferenceTest.hpp +++ b/tests/DeepSpeechV1InferenceTest.hpp @@ -27,12 +27,12 @@ public: : InferenceModelTestCase<Model>(model, testCaseId, { testCaseData.m_InputData.m_InputSeq, - testCaseData.m_InputData.m_StateC, - testCaseData.m_InputData.m_StateH}, + testCaseData.m_InputData.m_StateH, + testCaseData.m_InputData.m_StateC}, { k_OutputSize1, k_OutputSize2, k_OutputSize3 }) , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f)) - , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateC, - testCaseData.m_ExpectedOutputData.m_StateH}) + , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH, + testCaseData.m_ExpectedOutputData.m_StateC}) {} TestCaseResult ProcessResult(const InferenceTestOptions& options) override @@ -59,9 +59,9 @@ public: for (unsigned int j = 0u; j < output2.size(); j++) { - if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateC[j])) + if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j])) { - BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() << + BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() << " is incorrect"; return TestCaseResult::Failed; } @@ -69,9 +69,9 @@ public: for (unsigned int j = 0u; j < output3.size(); j++) { - if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateH[j])) + if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j])) { - BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() << + BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() << " is incorrect"; return TestCaseResult::Failed; } @@ -106,20 +106,21 @@ public: ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(), "Path to directory containing test data for m_InputSeq"); options.add_options() - ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(), - "Path to directory containing test data for m_PrevStateC"); - options.add_options() ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(), "Path to directory containing test data for m_PrevStateH"); options.add_options() + ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(), + "Path to directory containing test data for m_PrevStateC"); + options.add_options() ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(), "Path to directory containing test data for m_Logits"); options.add_options() - ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(), - "Path to directory containing test data for m_NewStateC"); - options.add_options() ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(), "Path to directory containing test data for m_NewStateH"); + options.add_options() + ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(), + "Path to directory containing test data for m_NewStateC"); + Model::AddCommandLineOptions(options, m_ModelCommandLineOptions); } @@ -161,9 +162,9 @@ public: { return false; } - m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateCDir.c_str(), - m_PrevStateHDir.c_str(), m_LogitsDir.c_str(), - m_NewStateCDir.c_str(), m_NewStateHDir.c_str()); + m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(), + m_PrevStateCDir.c_str(), m_LogitsDir.c_str(), + m_NewStateHDir.c_str(), m_NewStateCDir.c_str()); if (!m_Database) { return false; |