aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-02-14 21:12:10 +0000
committerTai Ly <tai.ly@arm.com>2024-02-27 19:12:35 +0000
commit1a7ca55663ef19c95ba9d6bc2a2789414762273d (patch)
tree401c9e16f87063d8a30cf04f73b6990e5ed00663 /src/TosaSerialize.cpp
parent150bc9bcdea84f8c24a17d5f2fcb38128afe50ab (diff)
downloadtosa_mlir_translator-1a7ca55663ef19c95ba9d6bc2a2789414762273d.tar.gz
[tosa_mlir_translator] tosa.fb name changes
This patch implements attribute name changes in tosa serialization library to align with tosa spec. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I575e4601f12aee67ec6e3114c68a428b8db9d232
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 9a6cf84..5b0d2bd 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -105,13 +105,13 @@ static DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
-static DType Type2PoolAccumDType(mlir::Type element_type) {
- // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+static DType Type2PoolAccDType(mlir::Type element_type) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>;
if (element_type.isF32()) {
return DType_FP32;
} else if (element_type.isF16()) {
return DType_FP16;
- } else if (element_type.isInteger(32) || element_type.isSignedInteger(32)) {
+ } else if (element_type.isInteger(32)) {
return DType_INT32;
}
return DType_UNKNOWN;
@@ -461,11 +461,11 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
auto kernel = getDenseI64ArrayAttr<int>(op.getAttr("kernel"));
ASSERT_VECTOR_LENGTH(kernel, 2);
- DType accum_dtype = DType_FP32;
- // AvgPool has accum_dtype, MaxPool does not
+ DType acc_dtype = DType_FP32;
+ // AvgPool has acc_type, MaxPool does not
if (op.hasAttr("acc_type")) {
auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
- accum_dtype = Type2PoolAccumDType(acc_type);
+ acc_dtype = Type2PoolAccDType(acc_type);
}
std::string input_name = GetTensorName(op.getOperand(0));
@@ -482,7 +482,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
: 0;
TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp,
- accum_dtype);
+ acc_dtype);
TosaSerializationOperator *tyop =
new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute,