aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp12
1 files changed, 11 insertions, 1 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 335a997..a4b7eda 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -137,6 +137,16 @@ BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder,
}
template <class T>
+mlir::DenseI8ArrayAttr BuildDenseI8ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int8_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI8ArrayAttr(vec);
+}
+
+template <class T>
mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder,
const std::vector<T> &values) {
std::vector<int32_t> vec;
@@ -1051,7 +1061,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>(
auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp());
auto multiplier = BuildDenseI32ArrayAttr(op_builder, attr->multiplier());
- auto shift = BuildDenseI32ArrayAttr(op_builder, attr->shift());
+ auto shift = BuildDenseI8ArrayAttr(op_builder, attr->shift());
auto scale32 = op_builder->getBoolAttr(attr->scale32());
auto double_round = op_builder->getBoolAttr(attr->double_round());