aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-10-26 16:39:28 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2018-11-12 11:08:37 +0000
commit39fc547629a1f787bfa32398c6954b298f4f9d9a (patch)
treee8b1bd31f3bd3a5da7bdd496eeeeddf0b7a7dd17
parent46e9d1e8aa2ca5a05021a9185d1caea9d08b5499 (diff)
downloadandroid-nn-driver-39fc547629a1f787bfa32398c6954b298f4f9d9a.tar.gz
IVGCVSW-1981: Edit HAL Policy for NHWC Pooling2D
* Removes permutes of tensors for Pooling2D, as NHWC is now supported by Arm NN. Change-Id: I48417c91f387b6f73bc071e473828f2ee5949332
-rw-r--r--ConversionUtils.hpp35
1 files changed, 16 insertions, 19 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 597edc47..68ce09d8 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -940,12 +940,10 @@ bool ConvertPooling2d(const HalOperation& operation,
const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
- const armnn::TensorInfo swizzledInputInfo = armnnUtils::Permuted(inputInfo, NHWCToArmNN);
- const armnn::TensorInfo swizzledOutputInfo = armnnUtils::Permuted(outputInfo, NHWCToArmNN);
-
armnn::Pooling2dDescriptor desc;
desc.m_PoolType = poolType;
desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
+ desc.m_DataLayout = armnn::DataLayout::NHWC;
ActivationFn activation;
@@ -963,8 +961,8 @@ bool ConvertPooling2d(const HalOperation& operation,
return Fail("%s: Operation has invalid inputs", operationName);
}
- const unsigned int inputWidth = swizzledInputInfo.GetShape()[3];
- const unsigned int inputHeight = swizzledInputInfo.GetShape()[2];
+ const unsigned int inputWidth = inputInfo.GetShape()[2];
+ const unsigned int inputHeight = inputInfo.GetShape()[1];
CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, scheme);
CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, scheme);
@@ -986,32 +984,31 @@ bool ConvertPooling2d(const HalOperation& operation,
}
}
- armnn::IConnectableLayer* startLayer = nullptr;
-
if (!IsLayerSupported(__func__,
armnn::IsPooling2dSupported,
data.m_Compute,
- swizzledInputInfo,
- swizzledOutputInfo,
+ inputInfo,
+ outputInfo,
desc))
{
return false;
}
- startLayer = data.m_Network->AddPooling2dLayer(desc);
-
- armnn::IConnectableLayer* endLayer = ProcessActivation(swizzledOutputInfo, activation, startLayer, data);
-
- if (endLayer != nullptr)
+ armnn::IConnectableLayer* pooling2dLayer = data.m_Network->AddPooling2dLayer(desc);
+ if (!pooling2dLayer)
{
- armnn::IConnectableLayer& outSwizzleLayer =
- SwizzleInDeswizzleOut(*data.m_Network, input, *startLayer, *endLayer);
- return SetupAndTrackLayerOutputSlot(operation, 0, outSwizzleLayer, model, data);
+ return Fail("%s: AddPooling2dLayer failed", __func__);
}
- else
+
+ armnn::IConnectableLayer* endLayer = ProcessActivation(outputInfo, activation, pooling2dLayer, data);
+ if (!endLayer)
{
- return Fail("%s: ProcessActivation failed", operationName);
+ return Fail("%s: ProcessActivation failed", __func__);
}
+
+ input.Connect(pooling2dLayer->GetInputSlot(0));
+
+ return SetupAndTrackLayerOutputSlot(operation, 0, *endLayer, model, data);
}
} // namespace armnn_driver