aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp14
1 files changed, 10 insertions, 4 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 87c363f..5704b04 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -1234,12 +1234,18 @@ TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator *op) const {
assert(op->GetAttributeType() ==
Attribute_MulAttribute); // double check attribute type
- TosaMulAttribute *attr = static_cast<TosaMulAttribute *>(op->GetAttribute());
- auto shift = op_builder->getI8IntegerAttr(attr->shift());
+ mlir::ValueRange operands;
+ if (output_type.getElementType().isInteger(32)) {
+ // Integer multiply carries shift argument.
+ mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ operands = {input0_val, input1_val, shift_val};
+ } else {
+ operands = {input0_val, input1_val};
+ }
- mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>(
- loc, output_type, input0_val, input1_val, shift);
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::MulOp>(loc, output_type, operands);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}