aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinnWilliamsArm <Finn.Williams@Arm.com>2019-11-25 16:02:07 +0000
committerKevin May <kevin.may@arm.com>2019-11-27 13:37:28 +0000
commit54c59758d49e0932c83a9ff27eea6f93c044eed6 (patch)
treef676985fbfe03c8fda712a7fedc596fa6df5ec40
parent7a13acc8747a829fe3c912fcad028cbffdae4e49 (diff)
downloadandroid-nn-driver-54c59758d49e0932c83a9ff27eea6f93c044eed6.tar.gz
MLCE-144 Fix 2d pooling convert function
Signed-off-by: FinnWilliamsArm <Finn.Williams@Arm.com> Change-Id: I999d9091bc4046861433d4eb3109fe972611bd82
-rw-r--r--ConversionUtils.hpp15
1 files changed, 8 insertions, 7 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 6f1f100d..fabf1896 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1388,7 +1388,7 @@ bool ConvertPooling2d(const HalOperation& operation,
LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
if (!input.IsValid())
{
- return Fail("%s: Could not read input 0", operationName);
+ return Fail("%s: Operation Could not read input 0", operationName);
}
const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
@@ -1449,16 +1449,17 @@ bool ConvertPooling2d(const HalOperation& operation,
return Fail("%s: Operation has invalid inputs", operationName);
}
- const unsigned int inputWidth = inputInfo.GetShape()[2];
- const unsigned int inputHeight = inputInfo.GetShape()[1];
-
- CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, scheme);
- CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, scheme);
-
if (Is12Operand(*output))
{
desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 7, model, data);
}
+
+ const armnnUtils::DataLayoutIndexed dataLayout(desc.m_DataLayout);
+ const unsigned int inputWidth = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
+ const unsigned int inputHeight = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
+
+ CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, scheme);
+ CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, scheme);
}
bool isSupported = false;