diff options
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 7256385..e407610 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -543,6 +543,53 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( template <> TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>( + mlir::Operation &op) const { + std::vector<int> pad, stride, dilation; + + auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue(); + for (auto &int_attr : pad_attr) { + pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); + } + ASSERT_VECTOR_LENGTH(pad, 6); + + auto stride_attr = + op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 3); + + auto dilation_attr = + op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue(); + for (auto &int_attr : dilation_attr) { + dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); + } + ASSERT_VECTOR_LENGTH(dilation, 3); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + auto quant_info = + op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); + + int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; + int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONV3D, Attribute_ConvAttribute, &attribute, + std::vector<std::string>{input0_name, input1_name, input2_name}, + std::vector<std::string>{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( mlir::Operation &op) const { std::vector<int> pad, stride, dilation; |