diff options
author | Ferran Balaguer <ferran.balaguer@arm.com> | 2018-10-26 16:41:17 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-11-06 17:02:26 +0000 |
commit | d2966a96822772683f4b4a8a368edd3ee8adfdad (patch) | |
tree | fec9a51b65545173c30911a379ab9d5b00de5d8f /ConversionUtils.hpp | |
parent | ae622b7553d3d3f49447160655f2feb7aa3b0e17 (diff) | |
download | android-nn-driver-branches/nhwc-preview.tar.gz |
MLCE-65 Early access release with NHWC for Androidbranches/nhwc-preview
Change-Id: I612efc41b5ea460f4893bb05b8d358d21ee393bb
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 35 |
1 files changed, 16 insertions, 19 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 783f7cec..9b56a9aa 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -939,10 +939,8 @@ bool ConvertPooling2d(const HalOperation& operation, const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); - const armnn::TensorInfo swizzledInputInfo = armnnUtils::Permuted(inputInfo, NHWCToArmNN); - const armnn::TensorInfo swizzledOutputInfo = armnnUtils::Permuted(outputInfo, NHWCToArmNN); - armnn::Pooling2dDescriptor desc; + desc.m_DataLayout = armnn::DataLayout::NHWC; desc.m_PoolType = poolType; desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor; @@ -962,8 +960,8 @@ bool ConvertPooling2d(const HalOperation& operation, return Fail("%s: Operation has invalid inputs", operationName); } - const unsigned int inputWidth = swizzledInputInfo.GetShape()[3]; - const unsigned int inputHeight = swizzledInputInfo.GetShape()[2]; + const unsigned int inputWidth = inputInfo.GetShape()[2]; + const unsigned int inputHeight = inputInfo.GetShape()[1]; CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, scheme); CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, scheme); @@ -985,32 +983,31 @@ bool ConvertPooling2d(const HalOperation& operation, } } - armnn::IConnectableLayer* startLayer = nullptr; - if (!IsLayerSupported(__func__, armnn::IsPooling2dSupported, data.m_Compute, - swizzledInputInfo, - swizzledOutputInfo, + inputInfo, + outputInfo, desc)) { return false; } - startLayer = data.m_Network->AddPooling2dLayer(desc); - - armnn::IConnectableLayer* endLayer = ProcessActivation(swizzledOutputInfo, activation, startLayer, data); - - if (endLayer != nullptr) + armnn::IConnectableLayer* startLayer = data.m_Network->AddPooling2dLayer(desc); + if (!startLayer) { - armnn::IConnectableLayer& outSwizzleLayer = - SwizzleInDeswizzleOut(*data.m_Network, input, *startLayer, *endLayer); - return SetupAndTrackLayerOutputSlot(operation, 0, outSwizzleLayer, model, data); + return Fail("%s: AddPooling2dLayer failed", __func__); } - else + + armnn::IConnectableLayer* endLayer = ProcessActivation(outputInfo, activation, startLayer, data); + if (!endLayer) { - return Fail("%s: ProcessActivation failed", operationName); + return Fail("%s: ProcessActivation failed", __func__); } + + input.Connect(startLayer->GetInputSlot(0)); + + return SetupAndTrackLayerOutputSlot(operation, 0, *endLayer, model, data); } } // namespace armnn_driver |