aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-01-12 11:30:33 +0000
committerLuke Hutton <luke.hutton@arm.com>2023-04-22 12:56:16 +0100
commit67e9fc539dd014745f8e2559b967489b8479a8f8 (patch)
tree6ca5e5eb6552f606dfc32e8a3cd8a87a7b9b4ad8
parente37ee8680fdc4b633888b31c2f7e76b65dc2c479 (diff)
downloadtosa_mlir_translator-67e9fc539dd014745f8e2559b967489b8479a8f8.tar.gz
Support translation of FFT2d
Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: I4aae94438380d394b9c13015aa69ac52f9b73f74
-rw-r--r--include/operator.def1
-rw-r--r--src/TosaDeserialize.cpp36
-rw-r--r--src/TosaSerialize.cpp22
3 files changed, 52 insertions, 7 deletions
diff --git a/include/operator.def b/include/operator.def
index 51eaf82..3c766c4 100644
--- a/include/operator.def
+++ b/include/operator.def
@@ -27,6 +27,7 @@ DEF_OPERATOR(AvgPool2d)
DEF_OPERATOR(Conv2D)
DEF_OPERATOR(Conv3D)
DEF_OPERATOR(DepthwiseConv2D)
+DEF_OPERATOR(FFT2d)
DEF_OPERATOR(FullyConnected)
DEF_OPERATOR(MatMul)
DEF_OPERATOR(MaxPool2d)
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index f6f78fc..d30a6f2 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -1057,6 +1057,28 @@ TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const {
{mlir_op->getResult(0), mlir_op->getResult(1)});
}
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output0_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ mlir::RankedTensorType output1_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[1]);
+
+ assert(op->GetAttributeType() == Attribute_FFTAttribute);
+ TosaFFTAttribute *attr =
+ static_cast<TosaFFTAttribute *>(op->GetAttribute());
+ auto inverse = op_builder->getBoolAttr(attr->inverse());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::FFT2dOp>(
+ loc, output0_type, output1_type, input0_val, input1_val, inverse);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(
+ {mlir_op->getResult(0), mlir_op->getResult(1)});
+}
+
class TosaMlirRegionBuilder {
public:
TosaMlirRegionBuilder(TosaSerializationRegion* _ser_region,
@@ -1125,7 +1147,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
tensor_type_map[ts_name] = type;
tensor_built[ts_name] = false;
}
-
+
for (auto op : ser_block->GetOperators()) {
operator_built[op] = false;
for (auto ts_name : op->GetInputTensorNames()) {
@@ -1200,7 +1222,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()] << "\n";
return mlir::failure();
}
-
+
// Sanity check if number of built mlir::Value is expected
if (op->GetOutputTensorNames().size() != output_values.size()) {
llvm::errs() << "ERROR: number of built mlir::Value is not matching number of operator output tensor\n";
@@ -1262,7 +1284,7 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion(
mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
mlir::MLIRContext& context,
tosa::TosaSerializationHandler& tsh) {
-
+
mlir::Region* main_region = func.getCallableRegion();
if (!main_region) {
llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n";
@@ -1274,7 +1296,7 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
"must contain exactly one region\n";
return mlir::failure();
}
-
+
TosaSerializationRegion* ser_main_region = tsh.GetRegions().front();
auto loc = func.getLoc();
@@ -1328,7 +1350,7 @@ public:
auto function = getOperation();
auto& context = getContext();
-
+
if (buildTosaMlir(function, context, tsh).failed()) {
llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n";
return signalPassFailure();
@@ -1341,7 +1363,7 @@ class TosaDeserializeJSON
public:
void runOnOperation() final {
TosaSerializationHandler tsh;
-
+
// must load tosa schema before loading json file
if (loadTosaSchema(tsh).failed()) {
return signalPassFailure();
@@ -1354,7 +1376,7 @@ public:
auto function = getOperation();
auto& context = getContext();
-
+
if (buildTosaMlir(function, context, tsh).failed()) {
llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n";
return signalPassFailure();
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index e69fcba..8a95b68 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1367,6 +1367,28 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RFFT2dOp>(
return tyop;
}
+template<>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::FFT2dOp>(
+ mlir::Operation &op) const {
+
+ bool inverse = op.getAttr("inverse").dyn_cast<mlir::BoolAttr>().getValue();
+
+ std::string input_real_name = GetTensorName(op.getOperand(0));
+ std::string input_imag_name = GetTensorName(op.getOperand(1));
+ std::string output_real_name = GetTensorName(op.getResult(0));
+ std::string output_imag_name = GetTensorName(op.getResult(1));
+
+ TosaFFTAttribute attribute(inverse);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_FFT2D, Attribute_FFTAttribute, &attribute,
+ std::vector<std::string>{input_real_name, input_imag_name},
+ std::vector<std::string>{output_real_name, output_imag_name});
+
+ return tyop;
+}
+
/* End translating TOSA operator */
mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values) {
std::string region_name = ser_region->GetName();