diff options
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 42 |
1 files changed, 27 insertions, 15 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 9a2b08f0..759514d6 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -1332,7 +1332,30 @@ bool ConvertPooling2d(const HalOperation& operation, ActivationFn activation; - if (operation.inputs.size() == 7) + auto inputSize = operation.inputs.size(); + + if (inputSize >= 10) + { + // one input, 9 parameters (padding l r t b, stridex, stridey, width, height, activation type) + if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, desc.m_PadLeft, model, data) || + !GetInputScalar<HalPolicy>(operation, 2, HalOperandType::INT32, desc.m_PadRight, model, data) || + !GetInputScalar<HalPolicy>(operation, 3, HalOperandType::INT32, desc.m_PadTop, model, data) || + !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_PadBottom, model, data) || + !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_StrideX, model, data) || + !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, desc.m_StrideY, model, data) || + !GetInputScalar<HalPolicy>(operation, 7, HalOperandType::INT32, desc.m_PoolWidth, model, data) || + !GetInputScalar<HalPolicy>(operation, 8, HalOperandType::INT32, desc.m_PoolHeight, model, data) || + !GetInputActivationFunction<HalPolicy>(operation, 9, activation, model, data)) + { + return Fail("%s: Operation has invalid inputs", operationName); + } + + if (Is12Operand(*output)) + { + desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 10, model, data); + } + } + else { // one input, 6 parameters (padding, stridex, stridey, width, height, activation type) android::nn::PaddingScheme scheme; @@ -1351,21 +1374,10 @@ bool ConvertPooling2d(const HalOperation& operation, CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, scheme); CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, scheme); - } - else - { - // one input, 9 parameters (padding l r t b, stridex, stridey, width, height, activation type) - if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, desc.m_PadLeft, model, data) || - !GetInputScalar<HalPolicy>(operation, 2, HalOperandType::INT32, desc.m_PadRight, model, data) || - !GetInputScalar<HalPolicy>(operation, 3, HalOperandType::INT32, desc.m_PadTop, model, data) || - !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_PadBottom, model, data) || - !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_StrideX, model, data) || - !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, desc.m_StrideY, model, data) || - !GetInputScalar<HalPolicy>(operation, 7, HalOperandType::INT32, desc.m_PoolWidth, model, data) || - !GetInputScalar<HalPolicy>(operation, 8, HalOperandType::INT32, desc.m_PoolHeight, model, data) || - !GetInputActivationFunction<HalPolicy>(operation, 9, activation, model, data)) + + if (Is12Operand(*output)) { - return Fail("%s: Operation has invalid inputs", operationName); + desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 7, model, data); } } |