From fa591272c05d2a24412b4b5b4398ded17be0912e Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Thu, 29 Feb 2024 00:17:29 -0800 Subject: Change the shift operand of mul to be available to all data types. Change-Id: I436912af95e4aef1b67b140079070168d158ff49 Signed-off-by: TatWai Chong --- src/TosaDeserialize.cpp | 15 ++++----------- src/TosaSerialize.cpp | 19 ++++--------------- 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index bd9fc9d..301d0da 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1229,23 +1229,16 @@ std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_MulAttribute); // double check attribute type - mlir::ValueRange operands; - if (output_type.getElementType().isInteger(32)) { - // Integer multiply carries shift argument. - mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); - operands = {input0_val, input1_val, shift_val}; - } else { - operands = {input0_val, input1_val}; - } - - mlir::Operation *mlir_op = - op_builder->create(loc, output_type, operands); + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, shift_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 5b0d2bd..0d5e044 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1286,22 +1286,11 @@ TosaSerializationOperatorBuilder::build( std::string input0_name = GetTensorName(mul_op.getInput1()); std::string input1_name = GetTensorName(mul_op.getInput2()); std::string output_name = GetTensorName(mul_op.getOutput()); + std::string shift_name = GetTensorName(mul_op.getShift()); - std::vector operands; - if (mul_op.getOutput() - .getType() - .cast() - .getElementType() - .isInteger(32)) { - std::string shift_name = GetTensorName(mul_op.getShift()); - operands = {input0_name, input1_name, shift_name}; - } else { - operands = {input0_name, input1_name}; - } - - TosaSerializationOperator *tyop = - new TosaSerializationOperator(Op_MUL, Attribute_NONE, nullptr, operands, - std::vector{output_name}); + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_MUL, Attribute_NONE, nullptr, {input0_name, input1_name, shift_name}, + std::vector{output_name}); return tyop; } -- cgit v1.2.1