diff options
Diffstat (limited to 'src/armnnUtils')
-rw-r--r-- | src/armnnUtils/ParserPrototxtFixture.hpp | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/src/armnnUtils/ParserPrototxtFixture.hpp b/src/armnnUtils/ParserPrototxtFixture.hpp index be35e460cf..7ae0742b8e 100644 --- a/src/armnnUtils/ParserPrototxtFixture.hpp +++ b/src/armnnUtils/ParserPrototxtFixture.hpp @@ -53,11 +53,17 @@ struct ParserPrototxtFixture template <std::size_t NumOutputDimensions> void RunTest(const std::vector<float>& inputData, const std::vector<float>& expectedOutputData); + /// Executes the network with the given input tensor and checks the result against the given output tensor. + /// Calls RunTest with output type of uint8_t for checking comparison operators. + template <std::size_t NumOutputDimensions> + void RunComparisonTest(const std::map<std::string, std::vector<float>>& inputData, + const std::map<std::string, std::vector<uint8_t>>& expectedOutputData); + /// Executes the network with the given input tensors and checks the results against the given output tensors. /// This overload supports multiple inputs and multiple outputs, identified by name. - template <std::size_t NumOutputDimensions> + template <std::size_t NumOutputDimensions, typename T = float> void RunTest(const std::map<std::string, std::vector<float>>& inputData, - const std::map<std::string, std::vector<float>>& expectedOutputData); + const std::map<std::string, std::vector<T>>& expectedOutputData); std::string m_Prototext; std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser; @@ -162,15 +168,24 @@ armnn::IOptimizedNetworkPtr ParserPrototxtFixture<TParser>::SetupOptimizedNetwor template<typename TParser> template <std::size_t NumOutputDimensions> void ParserPrototxtFixture<TParser>::RunTest(const std::vector<float>& inputData, - const std::vector<float>& expectedOutputData) + const std::vector<float>& expectedOutputData) { RunTest<NumOutputDimensions>({ { m_SingleInputName, inputData } }, { { m_SingleOutputName, expectedOutputData } }); } template<typename TParser> template <std::size_t NumOutputDimensions> +void ParserPrototxtFixture<TParser>::RunComparisonTest(const std::map<std::string, std::vector<float>>& inputData, + const std::map<std::string, std::vector<uint8_t>>& + expectedOutputData) +{ + RunTest<NumOutputDimensions, uint8_t>(inputData, expectedOutputData); +} + +template<typename TParser> +template <std::size_t NumOutputDimensions, typename T> void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::vector<float>>& inputData, - const std::map<std::string, std::vector<float>>& expectedOutputData) + const std::map<std::string, std::vector<T>>& expectedOutputData) { using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>; @@ -183,12 +198,12 @@ void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::ve } // Allocates storage for the output tensors to be written to and sets up the armnn output tensors. - std::map<std::string, boost::multi_array<float, NumOutputDimensions>> outputStorage; + std::map<std::string, boost::multi_array<T, NumOutputDimensions>> outputStorage; armnn::OutputTensors outputTensors; for (auto&& it : expectedOutputData) { BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first); - outputStorage.emplace(it.first, MakeTensor<float, NumOutputDimensions>(bindingInfo.second)); + outputStorage.emplace(it.first, MakeTensor<T, NumOutputDimensions>(bindingInfo.second)); outputTensors.push_back( { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) }); } @@ -243,8 +258,15 @@ void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::ve } } - auto outputExpected = MakeTensor<float, NumOutputDimensions>(bindingInfo.second, it.second); - BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first])); + auto outputExpected = MakeTensor<T, NumOutputDimensions>(bindingInfo.second, it.second); + if (std::is_same<T, uint8_t>::value) + { + BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first], true)); + } + else + { + BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first])); + } } } |