aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2022-07-25 09:17:53 -0700
committerTatWai Chong <tatwai.chong@arm.com>2022-08-29 10:22:14 -0700
commit218419bb16167170fb5ec38474d51e497b66d6e8 (patch)
tree4f2d4b958d9d63e3d707dd275c0254c1800ca7db
parentd13f5e6e8b964e2b10b5f4133043e59e45f11aaa (diff)
downloadtosa_mlir_translator-218419bb16167170fb5ec38474d51e497b66d6e8.tar.gz
Add conv3d support
Change-Id: If77423cbdb354a213677139d6cf4641db2bd6fcd
-rw-r--r--src/TosaSerialize.cpp47
1 files changed, 47 insertions, 0 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 7256385..e407610 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -543,6 +543,53 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
template <>
TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>(
+ mlir::Operation &op) const {
+ std::vector<int> pad, stride, dilation;
+
+ auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : pad_attr) {
+ pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(pad, 6);
+
+ auto stride_attr =
+ op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : stride_attr) {
+ stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(stride, 3);
+
+ auto dilation_attr =
+ op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : dilation_attr) {
+ dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(dilation, 3);
+
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ auto quant_info =
+ op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+
+ int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
+ int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
+
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONV3D, Attribute_ConvAttribute, &attribute,
+ std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
mlir::Operation &op) const {
std::vector<int> pad, stride, dilation;