diff options
-rw-r--r-- | ConversionUtils.hpp | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 3ad15d30..808fc1ca 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -773,6 +773,37 @@ bool GetOptionalInputActivation(const HalOperation& operation, return true; } +template <typename HalOperand, + typename HalOperandType, + typename HalOperation, + typename HalModel, + typename ConvolutionDescriptor> +bool GetOptionalConvolutionDilationParams(const HalOperation& operation, + uint32_t dilationXIndex, + ConvolutionDescriptor& descriptor, + const HalModel& model, + const ConversionData& data) +{ + bool success = true; + if (operation.inputs.size() >= dilationXIndex + 2) + { + success &= GetInputScalar<HalOperand, HalOperandType>(operation, + dilationXIndex, + HalOperandType::INT32, + descriptor.m_DilationX, + model, + data); + success &= GetInputScalar<HalOperand, HalOperandType>(operation, + dilationXIndex + 1, + HalOperandType::INT32, + descriptor.m_DilationY, + model, + data); + } + + return success; +} + template<typename HalOperand, typename HalOperandType, typename HalModel> bool GetTensorInt32Values(const HalOperand& operand, std::vector<int32_t>& outValues, @@ -1306,7 +1337,8 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model data) || !GetInputScalar<HalOperand, HalOperandType>(operation, 8, HalOperandType::INT32, desc.m_StrideY, model, data) - || !GetInputActivationFunction<HalOperand, HalOperandType>(operation, 10, activation, model, data)) + || !GetInputActivationFunction<HalOperand, HalOperandType>(operation, 10, activation, model, data) + || !GetOptionalConvolutionDilationParams<HalOperand, HalOperandType>(operation, 12, desc, model, data)) { return Fail("%s: Operation has invalid inputs", __func__); } @@ -1319,7 +1351,8 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model data) || !GetInputScalar<HalOperand, HalOperandType>(operation, 5, HalOperandType::INT32, desc.m_StrideY, model, data) - || !GetInputActivationFunction<HalOperand, HalOperandType>(operation, 7, activation, model, data)) + || !GetInputActivationFunction<HalOperand, HalOperandType>(operation, 7, activation, model, data) + || !GetOptionalConvolutionDilationParams<HalOperand, HalOperandType>(operation, 9, desc, model, data)) { return Fail("%s: Operation has invalid inputs", __func__); } |