From 4bc57740f03704179d2611f5d41572612bc42e9a Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Thu, 8 Feb 2024 13:54:21 -0800 Subject: Change the shift of mul to the tensor type Note that the shift value only apply to i32 data tensor. Change-Id: I4368146d9e0cdc2243ff822f59489ef5b78148ec Signed-off-by: TatWai Chong --- src/TosaDeserialize.cpp | 14 ++++++++++---- src/TosaSerialize.cpp | 27 ++++++++++++++++++--------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 87c363f..5704b04 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1234,12 +1234,18 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { assert(op->GetAttributeType() == Attribute_MulAttribute); // double check attribute type - TosaMulAttribute *attr = static_cast(op->GetAttribute()); - auto shift = op_builder->getI8IntegerAttr(attr->shift()); + mlir::ValueRange operands; + if (output_type.getElementType().isInteger(32)) { + // Integer multiply carries shift argument. + mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); + operands = {input0_val, input1_val, shift_val}; + } else { + operands = {input0_val, input1_val}; + } - mlir::Operation *mlir_op = op_builder->create( - loc, output_type, input0_val, input1_val, shift); + mlir::Operation *mlir_op = + op_builder->create(loc, output_type, operands); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index fc6655b..05c7812 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1272,18 +1272,27 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { - std::string input0_name = GetTensorName(op.getOperand(0)); - std::string input1_name = GetTensorName(op.getOperand(1)); - std::string output_name = GetTensorName(op.getResult(0)); - int32_t shift = op.getAttr("shift").dyn_cast().getInt(); + mlir::tosa::MulOp mul_op = mlir::cast(op); + std::string input0_name = GetTensorName(mul_op.getInput1()); + std::string input1_name = GetTensorName(mul_op.getInput2()); + std::string output_name = GetTensorName(mul_op.getOutput()); - TosaMulAttribute attribute(shift); + std::vector operands; + if (mul_op.getOutput() + .getType() + .cast() + .getElementType() + .isInteger(32)) { + std::string shift_name = GetTensorName(mul_op.getShift()); + operands = {input0_name, input1_name, shift_name}; + } else { + operands = {input0_name, input1_name}; + } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_MUL, Attribute_MulAttribute, &attribute, - std::vector{input0_name, input1_name}, - std::vector{output_name}); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_MUL, Attribute_NONE, nullptr, operands, + std::vector{output_name}); return tyop; } -- cgit v1.2.1