aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-13 19:19:53 +0000
committerTai Ly <tai.ly@arm.com>2024-03-19 15:39:48 -0700
commit5eddcd35c1776784baeeb39e92bad81da826e065 (patch)
tree03f066e45067220ab940d53488efc905e6036dd8
parent909d4d159ee12c6bc8113974d76f46249b6fd7fb (diff)
downloadtosa_mlir_translator-5eddcd35c1776784baeeb39e92bad81da826e065.tar.gz
[tosa_mlir_translator] Add acc_type to conv ops
Add serializing/deserializing acc_type to/from ConvAttribute Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I20780056f467952eb8baf6f5e80d242df6f0b33f
-rw-r--r--src/TosaDeserialize.cpp25
-rw-r--r--src/TosaSerialize.cpp28
m---------third_party/serialization_lib0
3 files changed, 36 insertions, 17 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index b12d22a..6fa691e 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -356,19 +356,21 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) {
return "";
}
-// this is a counter part to Type2PoolAccDType
-mlir::TypeAttr AccDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
- // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>;
+// this is a counter part to Type2AccDType
+mlir::Type AccDType2Type(mlir::OpBuilder *op_builder, DType dtype) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
if (dtype == DType_INT32) {
- return mlir::TypeAttr::get(op_builder->getI32Type());
+ return op_builder->getI32Type();
+ } else if (dtype == DType_INT48) {
+ return op_builder->getIntegerType(48);
} else if (dtype == DType_FP32) {
- return mlir::TypeAttr::get(op_builder->getF32Type());
+ return op_builder->getF32Type();
} else if (dtype == DType_FP16) {
- return mlir::TypeAttr::get(op_builder->getF16Type());
+ return op_builder->getF16Type();
} else {
// unknown acc type
// for now, default to F32
- return mlir::TypeAttr::get(op_builder->getF32Type());
+ return op_builder->getF32Type();
}
}
@@ -504,7 +506,8 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(
mlir::DenseI64ArrayAttr stride =
BuildDenseI64ArrayAttr(op_builder, attr->stride());
mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
- auto acc_attr = AccDType2TypeAttr(op_builder, attr->acc_type());
+ auto acc_attr =
+ mlir::TypeAttr::get(AccDType2Type(op_builder, attr->acc_type()));
int32_t input_zp = attr->input_zp();
int32_t output_zp = attr->output_zp();
@@ -776,6 +779,7 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const {
auto input_zp = attr->input_zp();
auto weight_zp = attr->weight_zp();
bool local_bound = attr->local_bound();
+ auto acc_type = AccDType2Type(op_builder, attr->acc_type());
// input_zp/weight_zp is not allowed for float type
mlir::Operation *mlir_op;
@@ -787,7 +791,7 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const {
auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp);
mlir_op = op_builder->create<MLIR_OP>(
loc, output_type, input0_val, input1_val, input2_val, pad, stride,
- dilation, input_zp_attr, weight_zp_attr, local_bound);
+ dilation, acc_type, input_zp_attr, weight_zp_attr, local_bound);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
@@ -827,6 +831,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>(
auto input_zp = attr->input_zp();
auto weight_zp = attr->weight_zp();
bool local_bound = attr->local_bound();
+ auto acc_type = AccDType2Type(op_builder, attr->acc_type());
// input_zp/weight_zp is not allowed for float type
mlir::Operation *mlir_op;
@@ -839,7 +844,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>(
mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>(
loc, output_type, input0_val, input1_val, input2_val, out_pad, stride,
- output_shape, input_zp_attr, weight_zp_attr, local_bound);
+ output_shape, acc_type, input_zp_attr, weight_zp_attr, local_bound);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 55a11fd..875303e 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -105,14 +105,16 @@ static DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
-static DType Type2PoolAccDType(mlir::Type element_type) {
- // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>;
+static DType Type2AccDType(mlir::Type element_type) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
if (element_type.isF32()) {
return DType_FP32;
} else if (element_type.isF16()) {
return DType_FP16;
} else if (element_type.isInteger(32)) {
return DType_INT32;
+ } else if (element_type.isInteger(48)) {
+ return DType_INT48;
}
return DType_UNKNOWN;
}
@@ -465,7 +467,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
// AvgPool has acc_type, MaxPool does not
if (op.hasAttr("acc_type")) {
auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
- acc_dtype = Type2PoolAccDType(acc_type);
+ acc_dtype = Type2AccDType(acc_type);
}
std::string input_name = GetTensorName(op.getOperand(0));
@@ -735,8 +737,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
: false;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
- local_bound);
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -778,8 +783,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>(
? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
: false;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
- local_bound);
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV3D, Attribute_ConvAttribute, &attribute,
@@ -821,8 +829,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
: false;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
- local_bound);
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -867,8 +878,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>(
? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
: false;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp,
- weight_zp, local_bound);
+ weight_zp, local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute,
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 0b6d7c271af1e6593e6a2cf14b32acea765f4b6
+Subproject ad78daaf0fa1e41742cbed314459c3dbbb483c2