From 1a7ca55663ef19c95ba9d6bc2a2789414762273d Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Wed, 14 Feb 2024 21:12:10 +0000 Subject: [tosa_mlir_translator] tosa.fb name changes This patch implements attribute name changes in tosa serialization library to align with tosa spec. Signed-off-by: Tai Ly Change-Id: I575e4601f12aee67ec6e3114c68a428b8db9d232 --- src/TosaDeserialize.cpp | 18 +++++++++--------- src/TosaSerialize.cpp | 14 +++++++------- third_party/serialization_lib | 2 +- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 1f24f71..82c107e 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -356,9 +356,9 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { return ""; } -// this is a counter part to Type2PoolAccumDType -mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) { - // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; +// this is a counter part to Type2PoolAccDType +mlir::TypeAttr AccDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) { + // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>; if (dtype == DType_INT32) { return mlir::TypeAttr::get(op_builder->getI32Type()); } else if (dtype == DType_FP32) { @@ -366,7 +366,7 @@ mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) { } else if (dtype == DType_FP16) { return mlir::TypeAttr::get(op_builder->getF16Type()); } else { - // unknown accum type + // unknown acc type // for now, default to F32 return mlir::TypeAttr::get(op_builder->getF32Type()); } @@ -504,7 +504,7 @@ std::vector TosaMlirOperatorBuilder::build( mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); - auto acc_attr = AccumDType2TypeAttr(op_builder, attr->accum_dtype()); + auto acc_attr = AccDType2TypeAttr(op_builder, attr->acc_type()); int32_t input_zp = attr->input_zp(); int32_t output_zp = attr->output_zp(); @@ -1526,8 +1526,8 @@ std::vector TosaMlirOperatorBuilder::build( Attribute_CondIfAttribute); // double check attribute type TosaCondIfAttribute *attr = static_cast(op->GetAttribute()); - auto ser_then_region = GetTsh()->GetRegionByName(attr->then_branch()); - auto ser_else_region = GetTsh()->GetRegionByName(attr->else_branch()); + auto ser_then_region = GetTsh()->GetRegionByName(attr->then_graph()); + auto ser_else_region = GetTsh()->GetRegionByName(attr->else_graph()); if (!ser_then_region || !ser_else_region) { llvm::errs() << "ERROR: " << get_string(op) @@ -1591,8 +1591,8 @@ std::vector TosaMlirOperatorBuilder::build( Attribute_WhileLoopAttribute); // double check attribute type TosaWhileLoopAttribute *attr = static_cast(op->GetAttribute()); - auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_branch()); - auto ser_body_region = GetTsh()->GetRegionByName(attr->body_branch()); + auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_graph()); + auto ser_body_region = GetTsh()->GetRegionByName(attr->body_graph()); mlir::Operation *mlir_op = op_builder->create(loc, output_types, input_values); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 9a6cf84..5b0d2bd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -105,13 +105,13 @@ static DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -static DType Type2PoolAccumDType(mlir::Type element_type) { - // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; +static DType Type2PoolAccDType(mlir::Type element_type) { + // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>; if (element_type.isF32()) { return DType_FP32; } else if (element_type.isF16()) { return DType_FP16; - } else if (element_type.isInteger(32) || element_type.isSignedInteger(32)) { + } else if (element_type.isInteger(32)) { return DType_INT32; } return DType_UNKNOWN; @@ -461,11 +461,11 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, auto kernel = getDenseI64ArrayAttr(op.getAttr("kernel")); ASSERT_VECTOR_LENGTH(kernel, 2); - DType accum_dtype = DType_FP32; - // AvgPool has accum_dtype, MaxPool does not + DType acc_dtype = DType_FP32; + // AvgPool has acc_type, MaxPool does not if (op.hasAttr("acc_type")) { auto acc_type = op.getAttr("acc_type").cast().getValue(); - accum_dtype = Type2PoolAccumDType(acc_type); + acc_dtype = Type2PoolAccDType(acc_type); } std::string input_name = GetTensorName(op.getOperand(0)); @@ -482,7 +482,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, : 0; TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, - accum_dtype); + acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib index 61a8313..81db8ee 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit 61a8313f5a0cbcfc7c8ee8a44f05a5ca9b1015b9 +Subproject commit 81db8ee8f580d30ec0ca53067df32ef046e6f09e -- cgit v1.2.1