aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/TosaDeserialize.cpp12
-rw-r--r--src/TosaSerialize.cpp16
2 files changed, 25 insertions, 3 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 335a997..a4b7eda 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -137,6 +137,16 @@ BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder,
}
template <class T>
+mlir::DenseI8ArrayAttr BuildDenseI8ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int8_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI8ArrayAttr(vec);
+}
+
+template <class T>
mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder,
const std::vector<T> &values) {
std::vector<int32_t> vec;
@@ -1051,7 +1061,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>(
auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp());
auto multiplier = BuildDenseI32ArrayAttr(op_builder, attr->multiplier());
- auto shift = BuildDenseI32ArrayAttr(op_builder, attr->shift());
+ auto shift = BuildDenseI8ArrayAttr(op_builder, attr->shift());
auto scale32 = op_builder->getBoolAttr(attr->scale32());
auto double_round = op_builder->getBoolAttr(attr->double_round());
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 73c84e8..a3e21f9 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -244,6 +244,19 @@ static std::vector<T> getDenseI64ArrayAttr(mlir::Attribute attr) {
return vec;
}
+// Unpack 8-bit integer attribute element and pack into a std vector.
+template <class T>
+static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) {
+ auto array_ref = attr.cast<mlir::DenseI8ArrayAttr>().asArrayRef();
+
+ std::vector<T> vec;
+ for (auto val : array_ref) {
+ vec.push_back(val);
+ }
+
+ return vec;
+}
+
// Main template to catch unimplemented translation.
template <typename T>
TosaSerializationOperator *
@@ -1179,8 +1192,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
auto multiplier =
op.getAttr("multiplier").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef();
- auto shift =
- op.getAttr("shift").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef();
+ auto shift = getDenseI8ArrayAttr<int32_t>(op.getAttr("shift"));
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));