aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-11-29 11:46:50 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-12-02 08:52:43 +0000
commitbf19d2a95462fffdb7288a2289056dd443ca4275 (patch)
treeb81f54cdde3bcae54d2f08d2dd7e053c80b832a2
parentb9cb84484b29ca588661b542bf8f93a8fb14edc1 (diff)
downloadandroid-nn-driver-bf19d2a95462fffdb7288a2289056dd443ca4275.tar.gz
IVGCVSW-4209 Remove the Half.hpp header usage from the driver
* Removed the inclusion of the Half.hpp header from the Android NN Driver, as it's a private header not part of the now public armnnUtils API * Refactored the code not to use that header Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: I0feeb86ccb31e016395e623029974f599a174149
-rw-r--r--Utils.cpp21
1 files changed, 6 insertions, 15 deletions
diff --git a/Utils.cpp b/Utils.cpp
index ee46aea0..a5a6ef0e 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -9,8 +9,6 @@
#include <armnnUtils/Permute.hpp>
-#include <Half.hpp>
-
#include <cassert>
#include <cinttypes>
@@ -25,14 +23,12 @@ const armnn::PermutationVector g_DontPermute{};
namespace
{
-template <typename T>
void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input,
- void* output, const armnn::PermutationVector& mappings)
+ void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings)
{
- const auto inputData = static_cast<const T*>(input);
- const auto outputData = static_cast<T*>(output);
+ assert(inTensorShape.GetNumDimensions() == 4U);
- armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, inputData, outputData, sizeof(T));
+ armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize);
}
} // anonymous namespace
@@ -42,19 +38,14 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void
{
assert(tensor.GetNumDimensions() == 4U);
- switch(tensor.GetDataType())
+ armnn::DataType dataType = tensor.GetDataType();
+ switch (dataType)
{
case armnn::DataType::Float16:
- SwizzleAndroidNn4dTensorToArmNn<armnn::Half>(tensor.GetShape(), input, output, mappings);
- break;
case armnn::DataType::Float32:
- SwizzleAndroidNn4dTensorToArmNn<float>(tensor.GetShape(), input, output, mappings);
- break;
case armnn::DataType::QuantisedAsymm8:
- SwizzleAndroidNn4dTensorToArmNn<uint8_t>(tensor.GetShape(), input, output, mappings);
- break;
case armnn::DataType::QuantizedSymm8PerAxis:
- SwizzleAndroidNn4dTensorToArmNn<int8_t>(tensor.GetShape(), input, output, mappings);
+ SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings);
break;
default:
ALOGW("Unknown armnn::DataType for swizzling");