From 218419bb16167170fb5ec38474d51e497b66d6e8 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Mon, 25 Jul 2022 09:17:53 -0700 Subject: Add conv3d support Change-Id: If77423cbdb354a213677139d6cf4641db2bd6fcd --- src/TosaSerialize.cpp | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 7256385..e407610 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -541,6 +541,53 @@ TosaSerializationOperatorBuilder::build( return tyop; } +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector pad, stride, dilation; + + auto pad_attr = op.getAttr("pad").dyn_cast().getValue(); + for (auto &int_attr : pad_attr) { + pad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(pad, 6); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 3); + + auto dilation_attr = + op.getAttr("dilation").dyn_cast().getValue(); + for (auto &int_attr : dilation_attr) { + dilation.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(dilation, 3); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + auto quant_info = + op.getAttrOfType("quantization_info"); + + int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; + int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONV3D, Attribute_ConvAttribute, &attribute, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + return tyop; +} + template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( -- cgit v1.2.1