From 67e9fc539dd014745f8e2559b967489b8479a8f8 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 12 Jan 2023 11:30:33 +0000 Subject: Support translation of FFT2d Signed-off-by: Luke Hutton Change-Id: I4aae94438380d394b9c13015aa69ac52f9b73f74 --- include/operator.def | 1 + src/TosaDeserialize.cpp | 36 +++++++++++++++++++++++++++++------- src/TosaSerialize.cpp | 22 ++++++++++++++++++++++ 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/include/operator.def b/include/operator.def index 51eaf82..3c766c4 100644 --- a/include/operator.def +++ b/include/operator.def @@ -27,6 +27,7 @@ DEF_OPERATOR(AvgPool2d) DEF_OPERATOR(Conv2D) DEF_OPERATOR(Conv3D) DEF_OPERATOR(DepthwiseConv2D) +DEF_OPERATOR(FFT2d) DEF_OPERATOR(FullyConnected) DEF_OPERATOR(MatMul) DEF_OPERATOR(MaxPool2d) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index f6f78fc..d30a6f2 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1057,6 +1057,28 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { {mlir_op->getResult(0), mlir_op->getResult(1)}); } +template <> +std::vector +TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output0_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + mlir::RankedTensorType output1_type = + tensor_type_map->at(op->GetOutputTensorNames()[1]); + + assert(op->GetAttributeType() == Attribute_FFTAttribute); + TosaFFTAttribute *attr = + static_cast(op->GetAttribute()); + auto inverse = op_builder->getBoolAttr(attr->inverse()); + + mlir::Operation *mlir_op = op_builder->create( + loc, output0_type, output1_type, input0_val, input1_val, inverse); + block->push_back(mlir_op); + return std::vector( + {mlir_op->getResult(0), mlir_op->getResult(1)}); +} + class TosaMlirRegionBuilder { public: TosaMlirRegionBuilder(TosaSerializationRegion* _ser_region, @@ -1125,7 +1147,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( tensor_type_map[ts_name] = type; tensor_built[ts_name] = false; } - + for (auto op : ser_block->GetOperators()) { operator_built[op] = false; for (auto ts_name : op->GetInputTensorNames()) { @@ -1200,7 +1222,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()] << "\n"; return mlir::failure(); } - + // Sanity check if number of built mlir::Value is expected if (op->GetOutputTensorNames().size() != output_values.size()) { llvm::errs() << "ERROR: number of built mlir::Value is not matching number of operator output tensor\n"; @@ -1262,7 +1284,7 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion( mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, mlir::MLIRContext& context, tosa::TosaSerializationHandler& tsh) { - + mlir::Region* main_region = func.getCallableRegion(); if (!main_region) { llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n"; @@ -1274,7 +1296,7 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, "must contain exactly one region\n"; return mlir::failure(); } - + TosaSerializationRegion* ser_main_region = tsh.GetRegions().front(); auto loc = func.getLoc(); @@ -1328,7 +1350,7 @@ public: auto function = getOperation(); auto& context = getContext(); - + if (buildTosaMlir(function, context, tsh).failed()) { llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n"; return signalPassFailure(); @@ -1341,7 +1363,7 @@ class TosaDeserializeJSON public: void runOnOperation() final { TosaSerializationHandler tsh; - + // must load tosa schema before loading json file if (loadTosaSchema(tsh).failed()) { return signalPassFailure(); @@ -1354,7 +1376,7 @@ public: auto function = getOperation(); auto& context = getContext(); - + if (buildTosaMlir(function, context, tsh).failed()) { llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n"; return signalPassFailure(); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index e69fcba..8a95b68 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1367,6 +1367,28 @@ TosaSerializationOperatorBuilder::build( return tyop; } +template<> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + + bool inverse = op.getAttr("inverse").dyn_cast().getValue(); + + std::string input_real_name = GetTensorName(op.getOperand(0)); + std::string input_imag_name = GetTensorName(op.getOperand(1)); + std::string output_real_name = GetTensorName(op.getResult(0)); + std::string output_imag_name = GetTensorName(op.getResult(1)); + + TosaFFTAttribute attribute(inverse); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_FFT2D, Attribute_FFTAttribute, &attribute, + std::vector{input_real_name, input_imag_name}, + std::vector{output_real_name, output_imag_name}); + + return tyop; +} + /* End translating TOSA operator */ mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::vector& return_values) { std::string region_name = ser_region->GetName(); -- cgit v1.2.1