diff options
Diffstat (limited to 'src/armnnUtils/ParserPrototxtFixture.hpp')
-rw-r--r-- | src/armnnUtils/ParserPrototxtFixture.hpp | 88 |
1 files changed, 81 insertions, 7 deletions
diff --git a/src/armnnUtils/ParserPrototxtFixture.hpp b/src/armnnUtils/ParserPrototxtFixture.hpp index acb8f82c4d..154f6bec2a 100644 --- a/src/armnnUtils/ParserPrototxtFixture.hpp +++ b/src/armnnUtils/ParserPrototxtFixture.hpp @@ -16,6 +16,7 @@ #include <boost/format.hpp> +#include <iomanip> #include <string> namespace armnnUtils @@ -37,6 +38,10 @@ struct ParserPrototxtFixture void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape, const std::string& inputName, const std::string& outputName); + void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape, + const armnn::TensorShape& outputTensorShape, + const std::string& inputName, + const std::string& outputName); void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes, const std::vector<std::string>& requestedOutputs); void Setup(); @@ -56,6 +61,9 @@ struct ParserPrototxtFixture void RunTest(const std::map<std::string, std::vector<float>>& inputData, const std::map<std::string, std::vector<float>>& expectedOutputData); + /// Converts an int value into the Protobuf octal representation + std::string ConvertInt32ToOctalString(int value); + std::string m_Prototext; std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser; armnn::IRuntimePtr m_Runtime; @@ -67,6 +75,10 @@ struct ParserPrototxtFixture std::string m_SingleInputName; std::string m_SingleOutputName; /// @} + + /// This will store the output shape so it don't need to be passed to the single-input-single-output overload + /// of RunTest(). + armnn::TensorShape m_SingleOutputShape; }; template<typename TParser> @@ -91,6 +103,20 @@ void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::T } template<typename TParser> +void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape, + const armnn::TensorShape& outputTensorShape, + const std::string& inputName, + const std::string& outputName) +{ + // Stores the input name, the output name and the output tensor shape + // so they don't need to be passed to the single-input-single-output RunTest(). + m_SingleInputName = inputName; + m_SingleOutputName = outputName; + m_SingleOutputShape = outputTensorShape; + Setup({ { inputName, inputTensorShape } }, { outputName }); +} + +template<typename TParser> void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes, const std::vector<std::string>& requestedOutputs) { @@ -181,17 +207,65 @@ void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::ve if (bindingInfo.second.GetNumElements() != it.second.size()) { throw armnn::Exception( - boost::str( - boost::format("Output tensor %1% is expected to have %2% elements. " - "%3% elements supplied. %4%") % - it.first % - bindingInfo.second.GetNumElements() % - it.second.size() % - CHECK_LOCATION().AsString())); + boost::str(boost::format("Output tensor %1% is expected to have %2% elements. " + "%3% elements supplied. %4%") % + it.first % + bindingInfo.second.GetNumElements() % + it.second.size() % + CHECK_LOCATION().AsString())); } + + // If the expected output shape is set, the output tensor checks will be carried out. + if (m_SingleOutputShape.GetNumDimensions() != 0) + { + + if (bindingInfo.second.GetShape().GetNumDimensions() == NumOutputDimensions && + bindingInfo.second.GetShape().GetNumDimensions() == m_SingleOutputShape.GetNumDimensions()) + { + for (unsigned int i = 0; i < m_SingleOutputShape.GetNumDimensions(); ++i) + { + if (m_SingleOutputShape[i] != bindingInfo.second.GetShape()[i]) + { + throw armnn::Exception( + boost::str(boost::format("Output tensor %1% is expected to have %2% shape. " + "%3% shape supplied. %4%") % + it.first % + bindingInfo.second.GetShape() % + m_SingleOutputShape % + CHECK_LOCATION().AsString())); + } + } + } + else + { + throw armnn::Exception( + boost::str(boost::format("Output tensor %1% is expected to have %2% dimensions. " + "%3% dimensions supplied. %4%") % + it.first % + bindingInfo.second.GetShape().GetNumDimensions() % + NumOutputDimensions % + CHECK_LOCATION().AsString())); + } + } + auto outputExpected = MakeTensor<float, NumOutputDimensions>(bindingInfo.second, it.second); BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first])); } } +template<typename TParser> +std::string ParserPrototxtFixture<TParser>::ConvertInt32ToOctalString(int value) +{ + std::stringstream ss; + std::string returnString; + for (int i = 0; i < 4; ++i) + { + ss << "\\"; + ss << std::setw(3) << std::setfill('0') << std::oct << ((value >> (i * 8)) & 0xFF); + } + + ss >> returnString; + return returnString; +} + } // namespace armnnUtils |