From 9a57b9fe6f9832fa0406daac367fd3fc09afa018 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 30 Jun 2023 23:48:34 +0000 Subject: [tosa_mlir_translator] Fix Rescale shift data type Changed to support new data type of Rescale shift attr: DenseI8ArrayAttr (instead of DenseI32ArrayAttr) LLVM_REFSPEC: refs/changes/55/532955/1 TF_REFSPEC: refs/changes/50/700450/3 Signed-off-by: Tai Ly Change-Id: I8f176ab95e167a8c4a0d3da605384509cf083d5e --- src/TosaDeserialize.cpp | 12 +++++++++++- src/TosaSerialize.cpp | 16 ++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 335a997..a4b7eda 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -136,6 +136,16 @@ BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder, return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); } +template +mlir::DenseI8ArrayAttr BuildDenseI8ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector &values) { + std::vector vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI8ArrayAttr(vec); +} + template mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder, const std::vector &values) { @@ -1051,7 +1061,7 @@ std::vector TosaMlirOperatorBuilder::build( auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp()); auto multiplier = BuildDenseI32ArrayAttr(op_builder, attr->multiplier()); - auto shift = BuildDenseI32ArrayAttr(op_builder, attr->shift()); + auto shift = BuildDenseI8ArrayAttr(op_builder, attr->shift()); auto scale32 = op_builder->getBoolAttr(attr->scale32()); auto double_round = op_builder->getBoolAttr(attr->double_round()); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 73c84e8..a3e21f9 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -244,6 +244,19 @@ static std::vector getDenseI64ArrayAttr(mlir::Attribute attr) { return vec; } +// Unpack 8-bit integer attribute element and pack into a std vector. +template +static std::vector getDenseI8ArrayAttr(mlir::Attribute attr) { + auto array_ref = attr.cast().asArrayRef(); + + std::vector vec; + for (auto val : array_ref) { + vec.push_back(val); + } + + return vec; +} + // Main template to catch unimplemented translation. template TosaSerializationOperator * @@ -1179,8 +1192,7 @@ TosaSerializationOperatorBuilder::build( auto multiplier = op.getAttr("multiplier").dyn_cast().asArrayRef(); - auto shift = - op.getAttr("shift").dyn_cast().asArrayRef(); + auto shift = getDenseI8ArrayAttr(op.getAttr("shift")); std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); -- cgit v1.2.1