diff options
author | Tai Ly <tai.ly@arm.com> | 2024-02-14 21:12:10 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-02-27 19:12:35 +0000 |
commit | 1a7ca55663ef19c95ba9d6bc2a2789414762273d (patch) | |
tree | 401c9e16f87063d8a30cf04f73b6990e5ed00663 /src/TosaSerialize.cpp | |
parent | 150bc9bcdea84f8c24a17d5f2fcb38128afe50ab (diff) | |
download | tosa_mlir_translator-1a7ca55663ef19c95ba9d6bc2a2789414762273d.tar.gz |
[tosa_mlir_translator] tosa.fb name changes
This patch implements attribute name changes in tosa
serialization library to align with tosa spec.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I575e4601f12aee67ec6e3114c68a428b8db9d232
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 9a6cf84..5b0d2bd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -105,13 +105,13 @@ static DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -static DType Type2PoolAccumDType(mlir::Type element_type) { - // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; +static DType Type2PoolAccDType(mlir::Type element_type) { + // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>; if (element_type.isF32()) { return DType_FP32; } else if (element_type.isF16()) { return DType_FP16; - } else if (element_type.isInteger(32) || element_type.isSignedInteger(32)) { + } else if (element_type.isInteger(32)) { return DType_INT32; } return DType_UNKNOWN; @@ -461,11 +461,11 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, auto kernel = getDenseI64ArrayAttr<int>(op.getAttr("kernel")); ASSERT_VECTOR_LENGTH(kernel, 2); - DType accum_dtype = DType_FP32; - // AvgPool has accum_dtype, MaxPool does not + DType acc_dtype = DType_FP32; + // AvgPool has acc_type, MaxPool does not if (op.hasAttr("acc_type")) { auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue(); - accum_dtype = Type2PoolAccumDType(acc_type); + acc_dtype = Type2PoolAccDType(acc_type); } std::string input_name = GetTensorName(op.getOperand(0)); @@ -482,7 +482,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, : 0; TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, - accum_dtype); + acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute, |