aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp42
1 files changed, 27 insertions, 15 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 9a2b08f0..759514d6 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1332,7 +1332,30 @@ bool ConvertPooling2d(const HalOperation& operation,
ActivationFn activation;
- if (operation.inputs.size() == 7)
+ auto inputSize = operation.inputs.size();
+
+ if (inputSize >= 10)
+ {
+ // one input, 9 parameters (padding l r t b, stridex, stridey, width, height, activation type)
+ if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, desc.m_PadLeft, model, data) ||
+ !GetInputScalar<HalPolicy>(operation, 2, HalOperandType::INT32, desc.m_PadRight, model, data) ||
+ !GetInputScalar<HalPolicy>(operation, 3, HalOperandType::INT32, desc.m_PadTop, model, data) ||
+ !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_PadBottom, model, data) ||
+ !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_StrideX, model, data) ||
+ !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, desc.m_StrideY, model, data) ||
+ !GetInputScalar<HalPolicy>(operation, 7, HalOperandType::INT32, desc.m_PoolWidth, model, data) ||
+ !GetInputScalar<HalPolicy>(operation, 8, HalOperandType::INT32, desc.m_PoolHeight, model, data) ||
+ !GetInputActivationFunction<HalPolicy>(operation, 9, activation, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs", operationName);
+ }
+
+ if (Is12Operand(*output))
+ {
+ desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 10, model, data);
+ }
+ }
+ else
{
// one input, 6 parameters (padding, stridex, stridey, width, height, activation type)
android::nn::PaddingScheme scheme;
@@ -1351,21 +1374,10 @@ bool ConvertPooling2d(const HalOperation& operation,
CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, scheme);
CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, scheme);
- }
- else
- {
- // one input, 9 parameters (padding l r t b, stridex, stridey, width, height, activation type)
- if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, desc.m_PadLeft, model, data) ||
- !GetInputScalar<HalPolicy>(operation, 2, HalOperandType::INT32, desc.m_PadRight, model, data) ||
- !GetInputScalar<HalPolicy>(operation, 3, HalOperandType::INT32, desc.m_PadTop, model, data) ||
- !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_PadBottom, model, data) ||
- !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_StrideX, model, data) ||
- !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, desc.m_StrideY, model, data) ||
- !GetInputScalar<HalPolicy>(operation, 7, HalOperandType::INT32, desc.m_PoolWidth, model, data) ||
- !GetInputScalar<HalPolicy>(operation, 8, HalOperandType::INT32, desc.m_PoolHeight, model, data) ||
- !GetInputActivationFunction<HalPolicy>(operation, 9, activation, model, data))
+
+ if (Is12Operand(*output))
{
- return Fail("%s: Operation has invalid inputs", operationName);
+ desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 7, model, data);
}
}