aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp16
1 files changed, 14 insertions, 2 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 73c84e8..a3e21f9 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -244,6 +244,19 @@ static std::vector<T> getDenseI64ArrayAttr(mlir::Attribute attr) {
return vec;
}
+// Unpack 8-bit integer attribute element and pack into a std vector.
+template <class T>
+static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) {
+ auto array_ref = attr.cast<mlir::DenseI8ArrayAttr>().asArrayRef();
+
+ std::vector<T> vec;
+ for (auto val : array_ref) {
+ vec.push_back(val);
+ }
+
+ return vec;
+}
+
// Main template to catch unimplemented translation.
template <typename T>
TosaSerializationOperator *
@@ -1179,8 +1192,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
auto multiplier =
op.getAttr("multiplier").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef();
- auto shift =
- op.getAttr("shift").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef();
+ auto shift = getDenseI8ArrayAttr<int32_t>(op.getAttr("shift"));
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));