aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils
diff options
context:
space:
mode:
authorkevmay01 <kevin.may@arm.com>2019-01-24 14:05:09 +0000
committerkevmay01 <kevin.may@arm.com>2019-01-24 14:05:09 +0000
commit2b4d88e34ac1f965417fd236fd4786f26bae2042 (patch)
tree4518b52c6a22e33c4b467588a2843c9d5f1a9ee6 /src/armnnUtils
parent94412aff782472be54dce4328e2ecee0225b3e97 (diff)
downloadarmnn-2b4d88e34ac1f965417fd236fd4786f26bae2042.tar.gz
IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater
* Remove Equal and Greater from RefElementwiseWorkload * Create RefComparisonWorkload and add Equal and Greater * Update ElementwiseFunction for different input/output types * Update TfParser to create Equal/Greater with Boolean output * Update relevant tests to check for Boolean comparison Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32
Diffstat (limited to 'src/armnnUtils')
-rw-r--r--src/armnnUtils/ParserPrototxtFixture.hpp38
1 files changed, 30 insertions, 8 deletions
diff --git a/src/armnnUtils/ParserPrototxtFixture.hpp b/src/armnnUtils/ParserPrototxtFixture.hpp
index be35e460cf..7ae0742b8e 100644
--- a/src/armnnUtils/ParserPrototxtFixture.hpp
+++ b/src/armnnUtils/ParserPrototxtFixture.hpp
@@ -53,11 +53,17 @@ struct ParserPrototxtFixture
template <std::size_t NumOutputDimensions>
void RunTest(const std::vector<float>& inputData, const std::vector<float>& expectedOutputData);
+ /// Executes the network with the given input tensor and checks the result against the given output tensor.
+ /// Calls RunTest with output type of uint8_t for checking comparison operators.
+ template <std::size_t NumOutputDimensions>
+ void RunComparisonTest(const std::map<std::string, std::vector<float>>& inputData,
+ const std::map<std::string, std::vector<uint8_t>>& expectedOutputData);
+
/// Executes the network with the given input tensors and checks the results against the given output tensors.
/// This overload supports multiple inputs and multiple outputs, identified by name.
- template <std::size_t NumOutputDimensions>
+ template <std::size_t NumOutputDimensions, typename T = float>
void RunTest(const std::map<std::string, std::vector<float>>& inputData,
- const std::map<std::string, std::vector<float>>& expectedOutputData);
+ const std::map<std::string, std::vector<T>>& expectedOutputData);
std::string m_Prototext;
std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser;
@@ -162,15 +168,24 @@ armnn::IOptimizedNetworkPtr ParserPrototxtFixture<TParser>::SetupOptimizedNetwor
template<typename TParser>
template <std::size_t NumOutputDimensions>
void ParserPrototxtFixture<TParser>::RunTest(const std::vector<float>& inputData,
- const std::vector<float>& expectedOutputData)
+ const std::vector<float>& expectedOutputData)
{
RunTest<NumOutputDimensions>({ { m_SingleInputName, inputData } }, { { m_SingleOutputName, expectedOutputData } });
}
template<typename TParser>
template <std::size_t NumOutputDimensions>
+void ParserPrototxtFixture<TParser>::RunComparisonTest(const std::map<std::string, std::vector<float>>& inputData,
+ const std::map<std::string, std::vector<uint8_t>>&
+ expectedOutputData)
+{
+ RunTest<NumOutputDimensions, uint8_t>(inputData, expectedOutputData);
+}
+
+template<typename TParser>
+template <std::size_t NumOutputDimensions, typename T>
void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::vector<float>>& inputData,
- const std::map<std::string, std::vector<float>>& expectedOutputData)
+ const std::map<std::string, std::vector<T>>& expectedOutputData)
{
using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
@@ -183,12 +198,12 @@ void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::ve
}
// Allocates storage for the output tensors to be written to and sets up the armnn output tensors.
- std::map<std::string, boost::multi_array<float, NumOutputDimensions>> outputStorage;
+ std::map<std::string, boost::multi_array<T, NumOutputDimensions>> outputStorage;
armnn::OutputTensors outputTensors;
for (auto&& it : expectedOutputData)
{
BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
- outputStorage.emplace(it.first, MakeTensor<float, NumOutputDimensions>(bindingInfo.second));
+ outputStorage.emplace(it.first, MakeTensor<T, NumOutputDimensions>(bindingInfo.second));
outputTensors.push_back(
{ bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
}
@@ -243,8 +258,15 @@ void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::ve
}
}
- auto outputExpected = MakeTensor<float, NumOutputDimensions>(bindingInfo.second, it.second);
- BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
+ auto outputExpected = MakeTensor<T, NumOutputDimensions>(bindingInfo.second, it.second);
+ if (std::is_same<T, uint8_t>::value)
+ {
+ BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first], true));
+ }
+ else
+ {
+ BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
+ }
}
}