aboutsummaryrefslogtreecommitdiff
path: root/1.0
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2018-10-22 14:52:32 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-25 10:09:44 +0100
commit2fb804a9460239dad5949ae92b1b98fd0fb01c61 (patch)
treeba8f09d569a1a4df93b7a17e572dd1ccd8538e30 /1.0
parentfe463150228156c29a415f45d2172a43df6ce6c3 (diff)
downloadandroid-nn-driver-2fb804a9460239dad5949ae92b1b98fd0fb01c61.tar.gz
IVGCVSW-2065 - Modify HAL Policy for Normalization to use NHWC data layout
Change-Id: Ic60c4dacb55bcf2514f011a8e844e7b8f7b13560
Diffstat (limited to '1.0')
-rw-r--r--1.0/HalPolicy.cpp18
1 files changed, 7 insertions, 11 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index 174bbb3f..4c5fba3f 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -712,16 +712,14 @@ bool HalPolicy::ConvertLocalResponseNormalization(const Operation& operation,
return Fail("%s: Could not read output 0", __func__);
}
- const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
+ const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
- const armnn::TensorInfo swizzledInputInfo = armnnUtils::Permuted(inputInfo, NHWCToArmNN);
- const armnn::TensorInfo swizzledOutputInfo = armnnUtils::Permuted(outputInfo, NHWCToArmNN);
-
armnn::NormalizationDescriptor descriptor;
+ descriptor.m_DataLayout = armnn::DataLayout::NHWC;
descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across;
- descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness;
+ descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness;
if (!input.IsValid() ||
!GetInputScalar(operation, 1, OperandType::INT32, descriptor.m_NormSize, model, data) ||
@@ -739,8 +737,8 @@ bool HalPolicy::ConvertLocalResponseNormalization(const Operation& operation,
if (!IsLayerSupported(__func__,
armnn::IsNormalizationSupported,
data.m_Compute,
- swizzledInputInfo,
- swizzledOutputInfo,
+ inputInfo,
+ outputInfo,
descriptor))
{
return false;
@@ -749,11 +747,9 @@ bool HalPolicy::ConvertLocalResponseNormalization(const Operation& operation,
armnn::IConnectableLayer* layer = data.m_Network->AddNormalizationLayer(descriptor);
assert(layer != nullptr);
- layer->GetOutputSlot(0).SetTensorInfo(swizzledOutputInfo);
-
- armnn::IConnectableLayer& outSwizzleLayer = SwizzleInDeswizzleOut(*data.m_Network, input, *layer);
+ input.Connect(layer->GetInputSlot(0));
- return SetupAndTrackLayerOutputSlot(operation, 0, outSwizzleLayer, model, data);
+ return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data);
}
bool HalPolicy::ConvertLogistic(const Operation& operation, const Model& model, ConversionData& data)