diff options
author | narpra01 <narumol.prangnawarat@arm.com> | 2018-10-22 14:52:32 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-25 10:09:44 +0100 |
commit | 2fb804a9460239dad5949ae92b1b98fd0fb01c61 (patch) | |
tree | ba8f09d569a1a4df93b7a17e572dd1ccd8538e30 | |
parent | fe463150228156c29a415f45d2172a43df6ce6c3 (diff) | |
download | android-nn-driver-2fb804a9460239dad5949ae92b1b98fd0fb01c61.tar.gz |
IVGCVSW-2065 - Modify HAL Policy for Normalization to use NHWC data layout
Change-Id: Ic60c4dacb55bcf2514f011a8e844e7b8f7b13560
-rw-r--r-- | 1.0/HalPolicy.cpp | 18 |
1 files changed, 7 insertions, 11 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp index 174bbb3f..4c5fba3f 100644 --- a/1.0/HalPolicy.cpp +++ b/1.0/HalPolicy.cpp @@ -712,16 +712,14 @@ bool HalPolicy::ConvertLocalResponseNormalization(const Operation& operation, return Fail("%s: Could not read output 0", __func__); } - const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); + const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); - const armnn::TensorInfo swizzledInputInfo = armnnUtils::Permuted(inputInfo, NHWCToArmNN); - const armnn::TensorInfo swizzledOutputInfo = armnnUtils::Permuted(outputInfo, NHWCToArmNN); - armnn::NormalizationDescriptor descriptor; + descriptor.m_DataLayout = armnn::DataLayout::NHWC; descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across; - descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness; + descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness; if (!input.IsValid() || !GetInputScalar(operation, 1, OperandType::INT32, descriptor.m_NormSize, model, data) || @@ -739,8 +737,8 @@ bool HalPolicy::ConvertLocalResponseNormalization(const Operation& operation, if (!IsLayerSupported(__func__, armnn::IsNormalizationSupported, data.m_Compute, - swizzledInputInfo, - swizzledOutputInfo, + inputInfo, + outputInfo, descriptor)) { return false; @@ -749,11 +747,9 @@ bool HalPolicy::ConvertLocalResponseNormalization(const Operation& operation, armnn::IConnectableLayer* layer = data.m_Network->AddNormalizationLayer(descriptor); assert(layer != nullptr); - layer->GetOutputSlot(0).SetTensorInfo(swizzledOutputInfo); - - armnn::IConnectableLayer& outSwizzleLayer = SwizzleInDeswizzleOut(*data.m_Network, input, *layer); + input.Connect(layer->GetInputSlot(0)); - return SetupAndTrackLayerOutputSlot(operation, 0, outSwizzleLayer, model, data); + return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data); } bool HalPolicy::ConvertLogistic(const Operation& operation, const Model& model, ConversionData& data) |