From 793a6c5088d1cb59753eccc158173fbab8e44190 Mon Sep 17 00:00:00 2001 From: Suraj Sudhir Date: Thu, 14 Mar 2024 23:43:29 +0000 Subject: [tosa-mlir-translator] Remove TRANSPOSE_CONV2D out_shape argument Change-Id: I0a96662eed0466cc2220f88522fb97d3ad221559 Signed-off-by: Suraj Sudhir --- src/TosaDeserialize.cpp | 4 +--- src/TosaSerialize.cpp | 7 ++----- third_party/serialization_lib | 2 +- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index fdbd892..215d760 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -842,8 +842,6 @@ std::vector TosaMlirOperatorBuilder::build( BuildDenseI64ArrayAttr(op_builder, attr->out_pad()); mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); - mlir::DenseI64ArrayAttr output_shape = - BuildDenseI64ArrayAttr(op_builder, attr->output_shape()); auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); @@ -860,7 +858,7 @@ std::vector TosaMlirOperatorBuilder::build( mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, - output_shape, acc_type, input_zp_attr, weight_zp_attr, local_bound); + acc_type, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 6553944..c3a9878 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -888,9 +888,6 @@ TosaSerializationOperatorBuilder::build( auto stride = getDenseI64ArrayAttr(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); - auto out_shape = getDenseI64ArrayAttr(op.getAttr("out_shape")); - ASSERT_VECTOR_LENGTH(out_shape, 4); - std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); @@ -916,8 +913,8 @@ TosaSerializationOperatorBuilder::build( auto acc_type = op.getAttr("acc_type").cast().getValue(); auto acc_dtype = Type2AccDType(acc_type); - TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, - weight_zp, local_bound, acc_dtype); + TosaTransposeConvAttribute attribute(out_pad, stride, input_zp, weight_zp, + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib index 57d7818..50256e1 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit 57d781883142db8a45fe98ac1a1dfacc49cba78a +Subproject commit 50256e168c3e759f03445bb872d0a43da1a6ba30 -- cgit v1.2.1