diff options
author | Jim Flynn <jim.flynn@arm.com> | 2019-04-15 14:34:17 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-04-16 10:58:52 +0000 |
commit | e571d33a4550ab3bea6f42dda3ec88d5924e9c00 (patch) | |
tree | b2dd99b11ca9c9eec79277d39674ada4c194f2f8 /tests/DeepSpeechV1Database.hpp | |
parent | aab6aff4aa282810cb535eeec65e59741f1f4f0e (diff) | |
download | armnn-e571d33a4550ab3bea6f42dda3ec88d5924e9c00.tar.gz |
IVGCVSW-2855 Create TfLite reference test for DeepSpeechV1
Change-Id: I4492a85c8337bf4ea0eb998c88b9cbfc932dc4e6
Signed-off-by: Ruomei Yan <ruomei.yan@arm.com>
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'tests/DeepSpeechV1Database.hpp')
-rw-r--r-- | tests/DeepSpeechV1Database.hpp | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/tests/DeepSpeechV1Database.hpp b/tests/DeepSpeechV1Database.hpp new file mode 100644 index 0000000000..4d2d591bed --- /dev/null +++ b/tests/DeepSpeechV1Database.hpp @@ -0,0 +1,203 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "LstmCommon.hpp" + +#include <memory> +#include <string> +#include <vector> + +#include <armnn/TypesUtils.hpp> +#include <backendsCommon/test/QuantizeHelper.hpp> + +#include <boost/log/trivial.hpp> +#include <boost/numeric/conversion/cast.hpp> + +#include <array> +#include <string> + +#include "InferenceTestImage.hpp" + +namespace +{ + +template<typename T, typename TParseElementFunc> +std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:") +{ + std::vector<T> result; + // Processes line-by-line. + std::string line; + while (std::getline(stream, line)) + { + std::vector<std::string> tokens; + try + { + // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call. + boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on); + } + catch (const std::exception& e) + { + BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what(); + continue; + } + for (const std::string& token : tokens) + { + if (!token.empty()) // See https://stackoverflow.com/questions/10437406/ + { + try + { + result.push_back(parseElementFunc(token)); + } + catch (const std::exception&) + { + BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored."; + } + } + } + } + + return result; +} + +template<armnn::DataType NonQuantizedType> +auto ParseDataArray(std::istream & stream); + +template<armnn::DataType QuantizedType> +auto ParseDataArray(std::istream& stream, + const float& quantizationScale, + const int32_t& quantizationOffset); + +template<> +auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream) +{ + return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); }); +} + +template<> +auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream) +{ + return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); }); +} + +template<> +auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream, + const float& quantizationScale, + const int32_t& quantizationOffset) +{ + return ParseArrayImpl<uint8_t>(stream, + [&quantizationScale, &quantizationOffset](const std::string & s) + { + return boost::numeric_cast<uint8_t>( + armnn::Quantize<u_int8_t>(std::stof(s), + quantizationScale, + quantizationOffset)); + }); +} + +struct DeepSpeechV1TestCaseData +{ + DeepSpeechV1TestCaseData( + const LstmInput& inputData, + const LstmInput& expectedOutputData) + : m_InputData(inputData) + , m_ExpectedOutputData(expectedOutputData) + {} + + LstmInput m_InputData; + LstmInput m_ExpectedOutputData; +}; + +class DeepSpeechV1Database +{ +public: + explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateCDir, + const std::string& prevStateHDir, const std::string& logitsDir, + const std::string& newStateCDir, const std::string& newStateHDir); + + std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId); + +private: + std::string m_InputSeqDir; + std::string m_PrevStateCDir; + std::string m_PrevStateHDir; + std::string m_LogitsDir; + std::string m_NewStateCDir; + std::string m_NewStateHDir; +}; + +DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateCDir, + const std::string& prevStateHDir, const std::string& logitsDir, + const std::string& newStateCDir, const std::string& newStateHDir) + : m_InputSeqDir(inputSeqDir) + , m_PrevStateCDir(prevStateCDir) + , m_PrevStateHDir(prevStateHDir) + , m_LogitsDir(logitsDir) + , m_NewStateCDir(newStateCDir) + , m_NewStateHDir(newStateHDir) +{} + +std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId) +{ + // Load test case input + const std::string inputSeqPath = m_InputSeqDir + "input_node_0_flat.txt"; + const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt"; + const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt"; + + std::vector<float> inputSeqData; + std::vector<float> prevStateCData; + std::vector<float> prevStateHData; + + std::ifstream inputSeqFile(inputSeqPath); + std::ifstream prevStateCTensorFile(prevStateCPath); + std::ifstream prevStateHTensorFile(prevStateHPath); + + try + { + inputSeqData = ParseDataArray<armnn::DataType::Float32>(inputSeqFile); + prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile); + prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile); + } + catch (const InferenceTestImageException& e) + { + BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what(); + return nullptr; + } + + // Prepare test case expected output + const std::string logitsPath = m_LogitsDir + "logits.txt"; + const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt"; + const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt"; + + std::vector<float> logitsData; + std::vector<float> expectedNewStateCData; + std::vector<float> expectedNewStateHData; + + std::ifstream logitsTensorFile(logitsPath); + std::ifstream newStateCTensorFile(newStateCPath); + std::ifstream newStateHTensorFile(newStateHPath); + + try + { + logitsData = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile); + expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile); + expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile); + } + catch (const InferenceTestImageException& e) + { + BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what(); + return nullptr; + } + + // use the struct for representing input and output data + LstmInput inputDataSingleTest(inputSeqData, prevStateCData, prevStateHData); + + LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateCData, expectedNewStateHData); + + return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest); +} + +} // anonymous namespace + |