aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-02-29 00:17:29 -0800
committerTatWai Chong <tatwai.chong@arm.com>2024-03-01 14:01:59 -0800
commitfa591272c05d2a24412b4b5b4398ded17be0912e (patch)
tree53a904684d13e00e079b9a9dac57a917b988048d
parent7720f24131f5672a1137cc7b17edf017e66b6ae7 (diff)
downloadtosa_mlir_translator-fa591272c05d2a24412b4b5b4398ded17be0912e.tar.gz
Change the shift operand of mul to be available to all data types.
Change-Id: I436912af95e4aef1b67b140079070168d158ff49 Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
-rw-r--r--src/TosaDeserialize.cpp15
-rw-r--r--src/TosaSerialize.cpp19
2 files changed, 8 insertions, 26 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index bd9fc9d..301d0da 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -1229,23 +1229,16 @@ std::vector<mlir::Value>
TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator *op) const {
mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]);
+
mlir::RankedTensorType output_type =
tensor_type_map->at(op->GetOutputTensorNames()[0]);
assert(op->GetAttributeType() ==
Attribute_MulAttribute); // double check attribute type
- mlir::ValueRange operands;
- if (output_type.getElementType().isInteger(32)) {
- // Integer multiply carries shift argument.
- mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]);
- operands = {input0_val, input1_val, shift_val};
- } else {
- operands = {input0_val, input1_val};
- }
-
- mlir::Operation *mlir_op =
- op_builder->create<mlir::tosa::MulOp>(loc, output_type, operands);
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>(
+ loc, output_type, input0_val, input1_val, shift_val);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 5b0d2bd..0d5e044 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1286,22 +1286,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>(
std::string input0_name = GetTensorName(mul_op.getInput1());
std::string input1_name = GetTensorName(mul_op.getInput2());
std::string output_name = GetTensorName(mul_op.getOutput());
+ std::string shift_name = GetTensorName(mul_op.getShift());
- std::vector<std::string> operands;
- if (mul_op.getOutput()
- .getType()
- .cast<mlir::TensorType>()
- .getElementType()
- .isInteger(32)) {
- std::string shift_name = GetTensorName(mul_op.getShift());
- operands = {input0_name, input1_name, shift_name};
- } else {
- operands = {input0_name, input1_name};
- }
-
- TosaSerializationOperator *tyop =
- new TosaSerializationOperator(Op_MUL, Attribute_NONE, nullptr, operands,
- std::vector<std::string>{output_name});
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_MUL, Attribute_NONE, nullptr, {input0_name, input1_name, shift_name},
+ std::vector<std::string>{output_name});
return tyop;
}