aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-02-14 21:12:10 +0000
committerTai Ly <tai.ly@arm.com>2024-02-27 19:12:35 +0000
commit1a7ca55663ef19c95ba9d6bc2a2789414762273d (patch)
tree401c9e16f87063d8a30cf04f73b6990e5ed00663
parent150bc9bcdea84f8c24a17d5f2fcb38128afe50ab (diff)
downloadtosa_mlir_translator-1a7ca55663ef19c95ba9d6bc2a2789414762273d.tar.gz
[tosa_mlir_translator] tosa.fb name changes
This patch implements attribute name changes in tosa serialization library to align with tosa spec. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I575e4601f12aee67ec6e3114c68a428b8db9d232
-rw-r--r--src/TosaDeserialize.cpp18
-rw-r--r--src/TosaSerialize.cpp14
m---------third_party/serialization_lib0
3 files changed, 16 insertions, 16 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 1f24f71..82c107e 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -356,9 +356,9 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) {
return "";
}
-// this is a counter part to Type2PoolAccumDType
-mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
- // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+// this is a counter part to Type2PoolAccDType
+mlir::TypeAttr AccDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>;
if (dtype == DType_INT32) {
return mlir::TypeAttr::get(op_builder->getI32Type());
} else if (dtype == DType_FP32) {
@@ -366,7 +366,7 @@ mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
} else if (dtype == DType_FP16) {
return mlir::TypeAttr::get(op_builder->getF16Type());
} else {
- // unknown accum type
+ // unknown acc type
// for now, default to F32
return mlir::TypeAttr::get(op_builder->getF32Type());
}
@@ -504,7 +504,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(
mlir::DenseI64ArrayAttr stride =
BuildDenseI64ArrayAttr(op_builder, attr->stride());
mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
- auto acc_attr = AccumDType2TypeAttr(op_builder, attr->accum_dtype());
+ auto acc_attr = AccDType2TypeAttr(op_builder, attr->acc_type());
int32_t input_zp = attr->input_zp();
int32_t output_zp = attr->output_zp();
@@ -1526,8 +1526,8 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>(
Attribute_CondIfAttribute); // double check attribute type
TosaCondIfAttribute *attr =
static_cast<TosaCondIfAttribute *>(op->GetAttribute());
- auto ser_then_region = GetTsh()->GetRegionByName(attr->then_branch());
- auto ser_else_region = GetTsh()->GetRegionByName(attr->else_branch());
+ auto ser_then_region = GetTsh()->GetRegionByName(attr->then_graph());
+ auto ser_else_region = GetTsh()->GetRegionByName(attr->else_graph());
if (!ser_then_region || !ser_else_region) {
llvm::errs() << "ERROR: " << get_string(op)
@@ -1591,8 +1591,8 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>(
Attribute_WhileLoopAttribute); // double check attribute type
TosaWhileLoopAttribute *attr =
static_cast<TosaWhileLoopAttribute *>(op->GetAttribute());
- auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_branch());
- auto ser_body_region = GetTsh()->GetRegionByName(attr->body_branch());
+ auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_graph());
+ auto ser_body_region = GetTsh()->GetRegionByName(attr->body_graph());
mlir::Operation *mlir_op =
op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values);
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 9a6cf84..5b0d2bd 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -105,13 +105,13 @@ static DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
-static DType Type2PoolAccumDType(mlir::Type element_type) {
- // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+static DType Type2PoolAccDType(mlir::Type element_type) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>;
if (element_type.isF32()) {
return DType_FP32;
} else if (element_type.isF16()) {
return DType_FP16;
- } else if (element_type.isInteger(32) || element_type.isSignedInteger(32)) {
+ } else if (element_type.isInteger(32)) {
return DType_INT32;
}
return DType_UNKNOWN;
@@ -461,11 +461,11 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
auto kernel = getDenseI64ArrayAttr<int>(op.getAttr("kernel"));
ASSERT_VECTOR_LENGTH(kernel, 2);
- DType accum_dtype = DType_FP32;
- // AvgPool has accum_dtype, MaxPool does not
+ DType acc_dtype = DType_FP32;
+ // AvgPool has acc_type, MaxPool does not
if (op.hasAttr("acc_type")) {
auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
- accum_dtype = Type2PoolAccumDType(acc_type);
+ acc_dtype = Type2PoolAccDType(acc_type);
}
std::string input_name = GetTensorName(op.getOperand(0));
@@ -482,7 +482,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
: 0;
TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp,
- accum_dtype);
+ acc_dtype);
TosaSerializationOperator *tyop =
new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute,
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 61a8313f5a0cbcfc7c8ee8a44f05a5ca9b1015b
+Subproject 81db8ee8f580d30ec0ca53067df32ef046e6f09