aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-07-03 22:46:25 +0000
committerTai Ly <tai.ly@arm.com>2023-07-18 23:02:10 +0000
commitf65ce51e3344313b744429c3763d1c85bf77a857 (patch)
treecc07319a88e04e6a5ea2e933fb436448a9276966
parent1acb3672107eeb94c7c23d13c24df3d7671dbcc6 (diff)
downloadtosa_mlir_translator-f65ce51e3344313b744429c3763d1c85bf77a857.tar.gz
Add deserialization of AccumType
Added deserialize accum_type attribute for AvgPool2D Op also fixed serialization of accum_type attribute for AvgPool2D Op also updated third_party/serialization_lib hash LLVM_REFSPEC: refs/changes/23/532123/2 TF_REFSPEC: refs/changes/34/699334/5 Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I2084f33e60d1bf8f76958b320a96fc1f3a94d95c
-rw-r--r--src/TosaDeserialize.cpp22
-rw-r--r--src/TosaSerialize.cpp19
m---------third_party/serialization_lib0
3 files changed, 34 insertions, 7 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 153f16f..45e1f18 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -175,6 +175,22 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) {
return "";
}
+// this is a counter part to Type2PoolAccumDType
+mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+ if (dtype == DType_INT32) {
+ return mlir::TypeAttr::get(op_builder->getI32Type());
+ } else if (dtype == DType_FP32) {
+ return mlir::TypeAttr::get(op_builder->getF32Type());
+ } else if (dtype == DType_FP16) {
+ return mlir::TypeAttr::get(op_builder->getF16Type());
+ } else {
+ // unknown accum type
+ // for now, default to F32
+ return mlir::TypeAttr::get(op_builder->getF32Type());
+ }
+}
+
} // namespace
class TosaMlirRegionBuilder;
@@ -290,17 +306,19 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(
mlir::DenseI64ArrayAttr stride =
BuildDenseI64ArrayAttr(op_builder, attr->stride());
mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ auto acc_attr = AccumDType2TypeAttr(op_builder, attr->accum_dtype());
+
int32_t input_zp = attr->input_zp();
int32_t output_zp = attr->output_zp();
mlir::Operation *mlir_op;
if (input_zp == 0 && output_zp == 0) {
mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(
- loc, output_type, input_val, kernel, stride, pad);
+ loc, output_type, input_val, kernel, stride, pad, acc_attr);
} else {
auto quant = op_builder->getAttr<mlir::tosa::UnaryOpQuantizationAttr>(
input_zp, output_zp);
mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(
- loc, output_type, input_val, kernel, stride, pad, quant);
+ loc, output_type, input_val, kernel, stride, pad, acc_attr, quant);
}
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 66d0a31..33d87d0 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -107,10 +107,12 @@ static DType Type2AccumDType(mlir::Type element_type) {
}
static DType Type2PoolAccumDType(mlir::Type element_type) {
- if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
- element_type.isBF16()) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+ if (element_type.isF32()) {
return DType_FP32;
- } else if (element_type.isInteger(8) || element_type.isInteger(16)) {
+ } else if (element_type.isF16()) {
+ return DType_FP16;
+ } else if (element_type.isInteger(32) || element_type.isSignedInteger(32)) {
return DType_INT32;
}
return DType_UNKNOWN;
@@ -289,6 +291,13 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
auto kernel = getDenseI64ArrayAttr<int>(op.getAttr("kernel"));
ASSERT_VECTOR_LENGTH(kernel, 2);
+ DType accum_dtype = DType_FP32;
+ // AvgPool has accum_dtype, MaxPool does not
+ if (op.hasAttr("acc_type")) {
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ accum_dtype = Type2PoolAccumDType(acc_type);
+ }
+
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
@@ -300,8 +309,8 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
mlir::RankedTensorType tensor =
op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- DType type = Type2PoolAccumDType(tensor.getElementType());
- TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, type);
+ TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp,
+ accum_dtype);
TosaSerializationOperator *tyop =
new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute,
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 8a2704330c18532d79a65ad2733458a80bf9c5b
+Subproject 89963aa8fad822ab7a6e1ff92f6b7b4ee0b9350