From f65ce51e3344313b744429c3763d1c85bf77a857 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Mon, 3 Jul 2023 22:46:25 +0000 Subject: Add deserialization of AccumType Added deserialize accum_type attribute for AvgPool2D Op also fixed serialization of accum_type attribute for AvgPool2D Op also updated third_party/serialization_lib hash LLVM_REFSPEC: refs/changes/23/532123/2 TF_REFSPEC: refs/changes/34/699334/5 Signed-off-by: Tai Ly Change-Id: I2084f33e60d1bf8f76958b320a96fc1f3a94d95c --- src/TosaDeserialize.cpp | 22 ++++++++++++++++++++-- src/TosaSerialize.cpp | 19 ++++++++++++++----- 2 files changed, 34 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 153f16f..45e1f18 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -175,6 +175,22 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { return ""; } +// this is a counter part to Type2PoolAccumDType +mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) { + // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; + if (dtype == DType_INT32) { + return mlir::TypeAttr::get(op_builder->getI32Type()); + } else if (dtype == DType_FP32) { + return mlir::TypeAttr::get(op_builder->getF32Type()); + } else if (dtype == DType_FP16) { + return mlir::TypeAttr::get(op_builder->getF16Type()); + } else { + // unknown accum type + // for now, default to F32 + return mlir::TypeAttr::get(op_builder->getF32Type()); + } +} + } // namespace class TosaMlirRegionBuilder; @@ -290,17 +306,19 @@ std::vector TosaMlirOperatorBuilder::build( mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + auto acc_attr = AccumDType2TypeAttr(op_builder, attr->accum_dtype()); + int32_t input_zp = attr->input_zp(); int32_t output_zp = attr->output_zp(); mlir::Operation *mlir_op; if (input_zp == 0 && output_zp == 0) { mlir_op = op_builder->create( - loc, output_type, input_val, kernel, stride, pad); + loc, output_type, input_val, kernel, stride, pad, acc_attr); } else { auto quant = op_builder->getAttr( input_zp, output_zp); mlir_op = op_builder->create( - loc, output_type, input_val, kernel, stride, pad, quant); + loc, output_type, input_val, kernel, stride, pad, acc_attr, quant); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 66d0a31..33d87d0 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -107,10 +107,12 @@ static DType Type2AccumDType(mlir::Type element_type) { } static DType Type2PoolAccumDType(mlir::Type element_type) { - if (element_type.isF64() || element_type.isF32() || element_type.isF16() || - element_type.isBF16()) { + // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; + if (element_type.isF32()) { return DType_FP32; - } else if (element_type.isInteger(8) || element_type.isInteger(16)) { + } else if (element_type.isF16()) { + return DType_FP16; + } else if (element_type.isInteger(32) || element_type.isSignedInteger(32)) { return DType_INT32; } return DType_UNKNOWN; @@ -289,6 +291,13 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, auto kernel = getDenseI64ArrayAttr(op.getAttr("kernel")); ASSERT_VECTOR_LENGTH(kernel, 2); + DType accum_dtype = DType_FP32; + // AvgPool has accum_dtype, MaxPool does not + if (op.hasAttr("acc_type")) { + auto acc_type = op.getAttr("acc_type").cast().getValue(); + accum_dtype = Type2PoolAccumDType(acc_type); + } + std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); @@ -300,8 +309,8 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); - DType type = Type2PoolAccumDType(tensor.getElementType()); - TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, type); + TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, + accum_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute, -- cgit v1.2.1