diff options
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 73c84e8..a3e21f9 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -244,6 +244,19 @@ static std::vector<T> getDenseI64ArrayAttr(mlir::Attribute attr) { return vec; } +// Unpack 8-bit integer attribute element and pack into a std vector. +template <class T> +static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) { + auto array_ref = attr.cast<mlir::DenseI8ArrayAttr>().asArrayRef(); + + std::vector<T> vec; + for (auto val : array_ref) { + vec.push_back(val); + } + + return vec; +} + // Main template to catch unimplemented translation. template <typename T> TosaSerializationOperator * @@ -1179,8 +1192,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>( auto multiplier = op.getAttr("multiplier").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef(); - auto shift = - op.getAttr("shift").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef(); + auto shift = getDenseI8ArrayAttr<int32_t>(op.getAttr("shift")); std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); |