aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ConversionUtils.hpp37
1 files changed, 35 insertions, 2 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 3ad15d30..808fc1ca 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -773,6 +773,37 @@ bool GetOptionalInputActivation(const HalOperation& operation,
return true;
}
+template <typename HalOperand,
+ typename HalOperandType,
+ typename HalOperation,
+ typename HalModel,
+ typename ConvolutionDescriptor>
+bool GetOptionalConvolutionDilationParams(const HalOperation& operation,
+ uint32_t dilationXIndex,
+ ConvolutionDescriptor& descriptor,
+ const HalModel& model,
+ const ConversionData& data)
+{
+ bool success = true;
+ if (operation.inputs.size() >= dilationXIndex + 2)
+ {
+ success &= GetInputScalar<HalOperand, HalOperandType>(operation,
+ dilationXIndex,
+ HalOperandType::INT32,
+ descriptor.m_DilationX,
+ model,
+ data);
+ success &= GetInputScalar<HalOperand, HalOperandType>(operation,
+ dilationXIndex + 1,
+ HalOperandType::INT32,
+ descriptor.m_DilationY,
+ model,
+ data);
+ }
+
+ return success;
+}
+
template<typename HalOperand, typename HalOperandType, typename HalModel>
bool GetTensorInt32Values(const HalOperand& operand,
std::vector<int32_t>& outValues,
@@ -1306,7 +1337,8 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model
data)
|| !GetInputScalar<HalOperand, HalOperandType>(operation, 8, HalOperandType::INT32, desc.m_StrideY, model,
data)
- || !GetInputActivationFunction<HalOperand, HalOperandType>(operation, 10, activation, model, data))
+ || !GetInputActivationFunction<HalOperand, HalOperandType>(operation, 10, activation, model, data)
+ || !GetOptionalConvolutionDilationParams<HalOperand, HalOperandType>(operation, 12, desc, model, data))
{
return Fail("%s: Operation has invalid inputs", __func__);
}
@@ -1319,7 +1351,8 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model
data)
|| !GetInputScalar<HalOperand, HalOperandType>(operation, 5, HalOperandType::INT32, desc.m_StrideY, model,
data)
- || !GetInputActivationFunction<HalOperand, HalOperandType>(operation, 7, activation, model, data))
+ || !GetInputActivationFunction<HalOperand, HalOperandType>(operation, 7, activation, model, data)
+ || !GetOptionalConvolutionDilationParams<HalOperand, HalOperandType>(operation, 9, desc, model, data))
{
return Fail("%s: Operation has invalid inputs", __func__);
}