diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 87c363f..5704b04 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1234,12 +1234,18 @@ TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator *op) const { assert(op->GetAttributeType() == Attribute_MulAttribute); // double check attribute type - TosaMulAttribute *attr = static_cast<TosaMulAttribute *>(op->GetAttribute()); - auto shift = op_builder->getI8IntegerAttr(attr->shift()); + mlir::ValueRange operands; + if (output_type.getElementType().isInteger(32)) { + // Integer multiply carries shift argument. + mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); + operands = {input0_val, input1_val, shift_val}; + } else { + operands = {input0_val, input1_val}; + } - mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>( - loc, output_type, input0_val, input1_val, shift); + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::MulOp>(loc, output_type, operands); block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); } |