aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-06-30 23:48:34 +0000
committerEric Kunze <eric.kunze@arm.com>2023-07-12 15:20:43 +0000
commit9a57b9fe6f9832fa0406daac367fd3fc09afa018 (patch)
treed1c16bd1c2b914bd1e2184d77da2ecba40fb94ee /src/TosaSerialize.cpp
parent0720dfa614ff041d1530e7bb2b7155d1047c71ff (diff)
downloadtosa_mlir_translator-9a57b9fe6f9832fa0406daac367fd3fc09afa018.tar.gz
[tosa_mlir_translator] Fix Rescale shift data type
Changed to support new data type of Rescale shift attr: DenseI8ArrayAttr (instead of DenseI32ArrayAttr) LLVM_REFSPEC: refs/changes/55/532955/1 TF_REFSPEC: refs/changes/50/700450/3 Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I8f176ab95e167a8c4a0d3da605384509cf083d5e
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp16
1 files changed, 14 insertions, 2 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 73c84e8..a3e21f9 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -244,6 +244,19 @@ static std::vector<T> getDenseI64ArrayAttr(mlir::Attribute attr) {
return vec;
}
+// Unpack 8-bit integer attribute element and pack into a std vector.
+template <class T>
+static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) {
+ auto array_ref = attr.cast<mlir::DenseI8ArrayAttr>().asArrayRef();
+
+ std::vector<T> vec;
+ for (auto val : array_ref) {
+ vec.push_back(val);
+ }
+
+ return vec;
+}
+
// Main template to catch unimplemented translation.
template <typename T>
TosaSerializationOperator *
@@ -1179,8 +1192,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
auto multiplier =
op.getAttr("multiplier").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef();
- auto shift =
- op.getAttr("shift").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef();
+ auto shift = getDenseI8ArrayAttr<int32_t>(op.getAttr("shift"));
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));