aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp1735
1 files changed, 1147 insertions, 588 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index a6ea5ff..c3a9878 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
// (the "License"); you may not use this file except in compliance with
@@ -21,9 +21,13 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
-#include "mlir/Pass/Pass.h" // from @llvm-project
-#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tosa_serialization_handler.h"
+#include <algorithm>
#include <functional>
#include <map>
#include <unordered_map>
@@ -61,7 +65,7 @@ template <> struct equal_to<mlir::Value> {
} // namespace std
-ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
+static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
if (mode_str == "NEAREST_NEIGHBOR")
return ResizeMode_NEAREST;
else if (mode_str == "BILINEAR")
@@ -70,14 +74,21 @@ ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
return ResizeMode_UNKNOWN;
}
-DType Type2DType(mlir::Type element_type) {
- if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
- element_type.isBF16()) {
- return DType_FLOAT;
+static DType Type2DType(mlir::Type element_type) {
+ if (element_type.isF64() || element_type.isF32()) {
+ return DType_FP32;
+ } else if (element_type.isFloat8E5M2()) {
+ return DType_FP8E5M2;
+ } else if (element_type.isFloat8E4M3FN()) {
+ return DType_FP8E4M3;
+ } else if (element_type.isF16()) {
+ return DType_FP16;
+ } else if (element_type.isBF16()) {
+ return DType_BF16;
} else if (element_type.isUnsignedInteger(8)) {
return DType_UINT8;
} else if (element_type.isInteger(4)) {
- return DType_INT8;
+ return DType_INT4;
} else if (element_type.isInteger(8)) {
return DType_INT8;
} else if (element_type.isInteger(16)) {
@@ -94,18 +105,79 @@ DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
+static DType Type2AccDType(mlir::Type element_type) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
+ if (element_type.isF32()) {
+ return DType_FP32;
+ } else if (element_type.isF16()) {
+ return DType_FP16;
+ } else if (element_type.isInteger(32)) {
+ return DType_INT32;
+ } else if (element_type.isInteger(48)) {
+ return DType_INT48;
+ }
+ return DType_UNKNOWN;
+}
class TosaSerializationBlockBuilder;
+class TosaSerializationRegionBuilder;
+
+std::unordered_map<std::string, mlir::Operation *> variable_tensor_op_map;
+std::unordered_map<std::string, std::string>
+ variable_tensor_flatbuffer_name_map;
+static int variable_tensor_index = 0;
+
+namespace {
+
+// for now, this is a global map of variables
+void RegisterVariableOp(mlir::Operation &op) {
+ std::string variable_tensor_flatbuffer_name =
+ "Variable_" + std::to_string(variable_tensor_index++);
+ std::string variable_tensor_mlir_name =
+ op.getAttr("name").cast<mlir::StringAttr>().getValue().str();
+ variable_tensor_op_map[variable_tensor_flatbuffer_name] = &op;
+ variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name] =
+ variable_tensor_flatbuffer_name;
+}
+
+} // namespace
class TosaSerializationOperatorBuilder {
public:
TosaSerializationOperatorBuilder(
TosaSerializationBlockBuilder *_block_builder)
: block_builder(_block_builder) {}
+
template <typename T>
TosaSerializationOperator *build(mlir::Operation &op) const;
+ TosaSerializationHandler *GetTsh() const;
+ TosaSerializationRegionBuilder *GetRegionBuilder() const;
+ mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op,
+ mlir::Attribute &attr,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
+ // populate u8_data with either int64_value or float_value depending on
+ // element_type
+ mlir::LogicalResult
+ GetU8DataFromIntOrFloatValue(int64_t int64_value, float fp_value,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
+ // populate u8_data with int_value depending on non-float element_type
+ mlir::LogicalResult
+ GetU8DataFromIntValues(const std::vector<int64_t> &int_values,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
+ // populate u8_data with fp_value depending on float element_type
+ mlir::LogicalResult
+ GetU8DataFromFloatValues(const std::vector<float> &fp_values,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
private:
std::string GetTensorName(mlir::Value val) const;
+ std::string GetVariableTensorName(mlir::Operation *op) const;
TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op,
Op opcode) const;
TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op,
@@ -120,37 +192,275 @@ private:
// This builder assumes each region only has only one block
class TosaSerializationBlockBuilder {
public:
- friend class TosaSerializationOperatorBuilder;
- TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_block,
- TosaSerializationHandler *_tsh,
- mlir::Region *_region)
- : block(_block), tsh(_tsh), region(_region) {}
+ // constructor
+ TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_ser_block,
+ TosaSerializationRegionBuilder *_region_builder,
+ mlir::Block *_block)
+ : ser_block(_ser_block), region_builder(_region_builder), block(_block) {}
mlir::LogicalResult
- BuildAllOpsInRegion(std::vector<mlir::Value> &return_values);
- TosaSerializationBasicBlock *GetBlock() { return block; }
- TosaSerializationHandler *GetTsh() { return tsh; }
+ BuildAllOpsInBlock(std::vector<mlir::Value> &return_values);
+ TosaSerializationBasicBlock *GetBlock() const { return ser_block; }
+ TosaSerializationRegionBuilder *GetRegionBuilder() const {
+ return region_builder;
+ }
+ TosaSerializationHandler *GetTsh() const;
+ std::unordered_map<mlir::Value, std::string> &GetTensorMap() {
+ return tensor_map;
+ }
private:
TosaSerializationOperator *BuildTosaSerializationOperator(
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op);
TosaSerializationTensor *
+ BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type,
+ const std::string &name,
+ const std::string &variable_mlir_name);
+ TosaSerializationTensor *
BuildTosaSerializationTensor(mlir::Value val, const std::string &name);
- TosaSerializationBasicBlock *block;
- TosaSerializationHandler *tsh;
- mlir::Region *region;
+ TosaSerializationBasicBlock *ser_block;
+ TosaSerializationRegionBuilder *region_builder;
+ mlir::Block *block;
std::unordered_map<mlir::Value, std::string> tensor_map;
std::unordered_map<mlir::Value, std::string> input_tensor_map;
};
+class TosaSerializationRegionBuilder {
+public:
+ // Constructor
+ TosaSerializationRegionBuilder(
+ TosaSerializationRegion *_ser_region, mlir::Region *_region,
+ TosaSerializationRegionBuilder *_parent_value_scope,
+ TosaSerializationHandler *_tsh)
+ : ser_region(_ser_region), region(_region),
+ parent_value_scope(_parent_value_scope), tsh(_tsh) {}
+ TosaSerializationHandler *GetTsh() const { return tsh; }
+ mlir::LogicalResult
+ BuildAllBlocksInRegion(bool is_top, std::vector<mlir::Value> &return_values);
+ TosaSerializationRegionBuilder *GetParentValueScope() const {
+ return parent_value_scope;
+ }
+ std::vector<TosaSerializationBlockBuilder *> &GetBlockBuilders() {
+ return block_builders;
+ }
+
+private:
+ TosaSerializationRegion *ser_region;
+ mlir::Region *region;
+ TosaSerializationRegionBuilder *parent_value_scope;
+ TosaSerializationHandler *tsh;
+ std::vector<TosaSerializationBlockBuilder *> block_builders;
+};
+
+TosaSerializationHandler *TosaSerializationOperatorBuilder::GetTsh() const {
+ return block_builder->GetTsh();
+}
+TosaSerializationHandler *TosaSerializationBlockBuilder::GetTsh() const {
+ return region_builder->GetTsh();
+}
+TosaSerializationRegionBuilder *
+TosaSerializationOperatorBuilder::GetRegionBuilder() const {
+ return block_builder->GetRegionBuilder();
+}
+
std::string
TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const {
- if (block_builder->tensor_map.find(val) == block_builder->tensor_map.end()) {
- llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map";
+ auto value_scope = GetRegionBuilder();
+ while (value_scope) {
+ // Traverse through each block builder in the region
+ for (auto curr_block_builder : value_scope->GetBlockBuilders()) {
+ const auto &tensor_map = curr_block_builder->GetTensorMap();
+ if (tensor_map.count(val)) {
+ return tensor_map.at(val);
+ }
+ }
+ value_scope = value_scope->GetParentValueScope();
+ }
+ // Didn't find anything
+ llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n";
+ assert(0);
+}
+
+// Unpack 64-bit integer attribute element and pack into a std vector.
+template <class T>
+static std::vector<T> getDenseI64ArrayAttr(mlir::Attribute attr) {
+ auto array_ref = attr.cast<mlir::DenseI64ArrayAttr>().asArrayRef();
+
+ std::vector<T> vec;
+ for (auto val : array_ref) {
+ vec.push_back(val);
+ }
+
+ return vec;
+}
+
+// Unpack 8-bit integer attribute element and pack into a std vector.
+template <class T>
+static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) {
+ auto array_ref = attr.cast<mlir::DenseI8ArrayAttr>().asArrayRef();
+
+ std::vector<T> vec;
+ for (auto val : array_ref) {
+ vec.push_back(val);
+ }
+
+ return vec;
+}
+
+std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
+ mlir::Operation *op) const {
+ std::string variable_tensor_mlir_name =
+ op->getAttr("name").cast<mlir::StringAttr>().getValue().str();
+
+ if (variable_tensor_flatbuffer_name_map.find(variable_tensor_mlir_name) ==
+ variable_tensor_flatbuffer_name_map.end()) {
+ llvm::errs() << "ERROR: Failed to find key " << variable_tensor_mlir_name
+ << " from variable_tensor_flatbuffer_name_map\n";
assert(0);
}
- return block_builder->tensor_map[val];
+ return variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name];
+}
+
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
+ mlir::Operation &op, mlir::Attribute &attr, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ if (!element_type.isIntOrFloat()) {
+ return mlir::failure();
+ }
+ auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
+
+ // handle float types
+ if (element_type.isa<mlir::FloatType>()) {
+ std::vector<float> fp_data;
+ auto val_attr = attr.dyn_cast<mlir::FloatAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ fp_data.push_back(val.convertToFloat());
+ }
+ } else if (val_attr) {
+ fp_data.push_back((float)val_attr.getValueAsDouble());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return mlir::failure();
+ }
+
+ return GetU8DataFromFloatValues(fp_data, element_type, u8_data);
+ }
+
+ // element_type is integer type
+
+ bool isInt48 = element_type.isInteger(48);
+ std::vector<int64_t> i64_data;
+
+ auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
+ if (dense_attr) {
+ for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
+ int64_t val = isInt48 ? static_cast<int64_t>(valueIt.getLimitedValue())
+ : valueIt.getSExtValue();
+ i64_data.push_back(val);
+ }
+ } else if (val_attr) {
+ i64_data.push_back(val_attr.getInt());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return mlir::failure();
+ }
+
+ return GetU8DataFromIntValues(i64_data, element_type, u8_data);
+}
+
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromIntValues(
+ const std::vector<int64_t> &int64_values, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ switch (element_type.getIntOrFloatBitWidth()) {
+ case 1: {
+ // bool use bool vec
+ std::vector<bool> bool_values;
+ for (auto v : int64_values) {
+ bool bool_value = v == 0 ? false : true;
+ bool_values.push_back(bool_value);
+ }
+ TosaSerializationHandler::ConvertBooltoU8(bool_values, u8_data);
+ break;
+ }
+ case 4:
+ case 8: {
+ // I4 and I8 use int8_t vec
+ std::vector<int8_t> i8_values;
+ for (auto v : int64_values) {
+ i8_values.push_back(static_cast<int8_t>(v));
+ }
+ if (element_type.isInteger(4)) {
+ TosaSerializationHandler::ConvertI4toU8(i8_values, u8_data);
+ } else {
+ TosaSerializationHandler::ConvertI8toU8(i8_values, u8_data);
+ }
+ break;
+ }
+ case 16: {
+ // I16 use int16_t vec
+ std::vector<int16_t> i16_values;
+ for (auto v : int64_values) {
+ i16_values.push_back(static_cast<int16_t>(v));
+ }
+ TosaSerializationHandler::ConvertI16toU8(i16_values, u8_data);
+ break;
+ }
+ case 32: {
+ // I32 use int32_t vec
+ std::vector<int32_t> i32_values;
+ for (auto v : int64_values) {
+ i32_values.push_back(static_cast<int32_t>(v));
+ }
+ TosaSerializationHandler::ConvertI32toU8(i32_values, u8_data);
+ break;
+ }
+ case 48: {
+ // I48 use int64_t vec
+ TosaSerializationHandler::ConvertI48toU8(int64_values, u8_data);
+ break;
+ }
+ default: {
+ // unsupported bit widths
+ return mlir::failure();
+ }
+ }
+ return mlir::success();
+}
+
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromFloatValues(
+ const std::vector<float> &fp_values, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ assert(
+ element_type
+ .isa<mlir::FloatType>()); // this should only be called for float type
+ if (element_type.isF16()) {
+ TosaSerializationHandler::ConvertF16toU8(fp_values, u8_data);
+ } else if (element_type.isBF16()) {
+ TosaSerializationHandler::ConvertBF16toU8(fp_values, u8_data);
+ } else if (element_type.isFloat8E4M3FN()) {
+ TosaSerializationHandler::ConvertFP8E4M3toU8(fp_values, u8_data);
+ } else if (element_type.isFloat8E5M2()) {
+ TosaSerializationHandler::ConvertFP8E5M2toU8(fp_values, u8_data);
+ } else if (element_type.isF32()) {
+ TosaSerializationHandler::ConvertF32toU8(fp_values, u8_data);
+ } else {
+ return mlir::failure();
+ }
+ return mlir::success();
+}
+
+mlir::LogicalResult
+TosaSerializationOperatorBuilder::GetU8DataFromIntOrFloatValue(
+ int64_t int64_value, float fp_value, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ if (element_type.isa<mlir::FloatType>()) {
+ return GetU8DataFromFloatValues({fp_value}, element_type, u8_data);
+ } else {
+ return GetU8DataFromIntValues({int64_value}, element_type, u8_data);
+ }
}
// Main template to catch unimplemented translation.
@@ -176,43 +486,42 @@ TosaSerializationOperatorBuilder::build(mlir::Operation &op) const {
TosaSerializationOperator *
TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
Op opcode) const {
- std::vector<int> pad, stride, kernel;
-
- auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : pad_attr) {
- pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad"));
ASSERT_VECTOR_LENGTH(pad, 4);
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
ASSERT_VECTOR_LENGTH(stride, 2);
- auto kernel_attr =
- op.getAttr("kernel").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : kernel_attr) {
- kernel.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto kernel = getDenseI64ArrayAttr<int>(op.getAttr("kernel"));
ASSERT_VECTOR_LENGTH(kernel, 2);
+ DType acc_dtype = DType_FP32;
+ // AvgPool has acc_type, MaxPool does not
+ if (op.hasAttr("acc_type")) {
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ acc_dtype = Type2AccDType(acc_type);
+ }
+
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>(
- "quantization_info");
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? input_zp = op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t output_zp =
+ op.hasAttr("output_zp")
+ ? output_zp =
+ op.getAttr("output_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t output_zp = quant_info ? quant_info.output_zp().getInt() : 0;
+ TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp,
+ acc_dtype);
- TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp);
-
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- opcode, Attribute_PoolAttribute, &attribute,
- std::vector<std::string>{input_name},
- std::vector<std::string>{output_name});
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
return tyop;
}
@@ -239,8 +548,7 @@ TosaSerializationOperatorBuilder::BuildEwiseUnaryOpFromMlirOp(
std::string output_name = GetTensorName(op.getResult(0));
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- opcode, Attribute_NONE, nullptr,
- std::vector<std::string>{input_name},
+ opcode, Attribute_NONE, nullptr, std::vector<std::string>{input_name},
std::vector<std::string>{output_name});
return tyop;
@@ -255,10 +563,10 @@ TosaSerializationOperatorBuilder::BuildReductionOpFromMlirOp(
int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
TosaAxisAttribute attribute(axis);
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- opcode, Attribute_AxisAttribute, &attribute,
- std::vector<std::string>{input_name},
- std::vector<std::string>{output_name});
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(opcode, Attribute_AxisAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
return tyop;
}
@@ -302,7 +610,7 @@ BUILD_OP_ELEMENTWISE_BINARY(Add, ADD)
BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND)
BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR)
BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR)
-BUILD_OP_ELEMENTWISE_BINARY(Div, INTDIV)
+BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV)
BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND)
BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
@@ -317,12 +625,14 @@ BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS)
BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT)
BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL)
BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ)
+BUILD_OP_ELEMENTWISE_UNARY(Cos, COS)
BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP)
BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR)
BUILD_OP_ELEMENTWISE_UNARY(Log, LOG)
BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT)
BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL)
BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT)
+BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN)
BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY)
BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL)
@@ -335,14 +645,20 @@ BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL)
BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER)
BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL)
+BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF)
BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID)
BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH)
BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY)
BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST)
+BUILD_OP_ELEMENTWISE_BINARY(AddShape, ADD_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY(SubShape, SUB_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY(MulShape, MUL_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY(DivShape, DIV_SHAPE)
+
template <>
TosaSerializationOperator *
-TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
+TosaSerializationOperatorBuilder::build<mlir::tosa::ConstShapeOp>(
mlir::Operation &op) const {
std::string output_name = GetTensorName(op.getResult(0));
TosaSerializationTensor *ts =
@@ -353,142 +669,73 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
-#if 0
- // Gracefully handle constants of "constant unit" type which have no value
- // by creating a numpy value of 0.
- auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::UnitAttr>();
-
- if (unit_val)
- {
- std::vector<float> data = { 0.0 };
- type = DType_FLOAT;
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- }
-#endif
-
// Update tensor.data array with Const value attribute
+ mlir::Attribute value_attr = op.getAttr("value");
+ if (!value_attr) {
+ op.emitOpError("ERROR: tosa.const_shape doesn't have value");
+ return nullptr;
+ }
+ assert(ts->GetDtype() == DType::DType_SHAPE);
std::vector<uint8_t> u8_data;
- DType type = ts->GetDtype();
- if (type == DType_FLOAT) {
- std::vector<float> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
- if (dense_attr) {
- for (auto val : dense_attr.getValues<float>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- } else if (type == DType_INT8) {
- std::vector<int8_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+ std::vector<int64_t> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseIntElementsAttr>();
+ if (!dense_attr) {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int8_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- } else if (type == DType_INT16) {
- std::vector<int16_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+ for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
+ int64_t val = valueIt.getSExtValue();
+ data.push_back(val);
+ }
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int16_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI16toU8(data, u8_data);
- } else if (type == DType_INT32) {
- std::vector<int32_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+ TosaSerializationHandler::ConvertI64toU8(data, u8_data);
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int32_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI32toU8(data, u8_data);
- } else if (type == DType_INT48) {
- std::vector<int64_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+ ts->SetData(u8_data);
- if (dense_attr) {
- for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
- uint64_t val = valueIt.getLimitedValue();
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI48toU8(data, u8_data);
- } else if (type == DType_BOOL) {
- std::vector<bool> data;
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONST_SHAPE, Attribute_NONE, nullptr, std::vector<std::string>{},
+ std::vector<std::string>{output_name});
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::BoolAttr>();
+ return tyop;
+}
- if (dense_attr) {
- for (auto val : dense_attr.getValues<bool>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getValue());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
+ mlir::Operation &op) const {
+ std::string output_name = GetTensorName(op.getResult(0));
+ TosaSerializationTensor *ts =
+ block_builder->GetBlock()->GetTensorByName(output_name);
+ if (!ts) {
+ op.emitOpError(
+ "ERROR: serialization tensor must be built before building operator");
+ return nullptr;
+ }
- TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
- } else {
- op.emitOpError("Unknown element type of const attribute");
+ // Update tensor.data array with Const value attribute
+ mlir::Attribute value_attr = op.getAttr("value");
+ if (!value_attr) {
+ op.emitOpError("ERROR: tosa.const doesn't have value");
+ return nullptr;
+ }
+ std::vector<uint8_t> u8_data;
+ mlir::Attribute attr = op.getAttr(llvm::StringRef("value"));
+ mlir::Type element_type =
+ llvm::cast<mlir::ShapedType>(op.getResult(0).getType()).getElementType();
+
+ if (GetDataFromAttribute(op, attr, element_type, u8_data).failed()) {
+ op.emitOpError("ERROR: GetDataFromAttribute() fails when building value of "
+ "const tensor");
return nullptr;
}
- ts->SetData(u8_data);
+ ts->SetData(u8_data);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_CONST, Attribute_NONE, nullptr,
- std::vector<std::string>{}, std::vector<std::string>{output_name});
+ Op_CONST, Attribute_NONE, nullptr, std::vector<std::string>{},
+ std::vector<std::string>{output_name});
return tyop;
}
@@ -497,26 +744,13 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
mlir::Operation &op) const {
- std::vector<int> pad, stride, dilation;
-
- auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : pad_attr) {
- pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad"));
ASSERT_VECTOR_LENGTH(pad, 4);
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
ASSERT_VECTOR_LENGTH(stride, 2);
- auto dilation_attr =
- op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : dilation_attr) {
- dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation"));
ASSERT_VECTOR_LENGTH(dilation, 2);
std::string input0_name = GetTensorName(op.getOperand(0));
@@ -524,14 +758,25 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
std::string input2_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- auto quant_info =
- op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -543,28 +788,61 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
template <>
TosaSerializationOperator *
-TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
+TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>(
mlir::Operation &op) const {
- std::vector<int> pad, stride, dilation;
+ auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad"));
+ ASSERT_VECTOR_LENGTH(pad, 6);
- auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : pad_attr) {
- pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
+ ASSERT_VECTOR_LENGTH(stride, 3);
+
+ auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation"));
+ ASSERT_VECTOR_LENGTH(dilation, 3);
+
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
+
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
+ local_bound, acc_dtype);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONV3D, Attribute_ConvAttribute, &attribute,
+ std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
+ mlir::Operation &op) const {
+ auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad"));
ASSERT_VECTOR_LENGTH(pad, 4);
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
ASSERT_VECTOR_LENGTH(stride, 2);
- auto dilation_attr =
- op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : dilation_attr) {
- dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation"));
ASSERT_VECTOR_LENGTH(dilation, 2);
std::string input0_name = GetTensorName(op.getOperand(0));
@@ -572,13 +850,25 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
std::string input2_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info =
- op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -592,41 +882,39 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>(
mlir::Operation &op) const {
- std::vector<int> outpad, stride, dilation, output_shape;
+ auto out_pad = getDenseI64ArrayAttr<int>(op.getAttr("out_pad"));
+ ASSERT_VECTOR_LENGTH(out_pad, 4);
- auto outpad_attr =
- op.getAttr("out_pad").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : outpad_attr) {
- outpad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(outpad, 4);
-
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
ASSERT_VECTOR_LENGTH(stride, 2);
- auto output_shape_attr =
- op.getAttr("out_shape").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : output_shape_attr) {
- output_shape.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(output_shape, 4);
-
std::string input0_name = GetTensorName(op.getOperand(0));
std::string input1_name = GetTensorName(op.getOperand(1));
std::string input2_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info =
- op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp);
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
+
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
+ TosaTransposeConvAttribute attribute(out_pad, stride, input_zp, weight_zp,
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute,
@@ -645,11 +933,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>(
std::string input2_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info =
- op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
TosaFullyConnectedAttribute attribute(input_zp, weight_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
@@ -668,11 +960,12 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>(
std::string input1_name = GetTensorName(op.getOperand(1));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info = op.getAttrOfType<mlir::tosa::MatMulOpQuantizationAttr>(
- "quantization_info");
-
- int32_t A_zp = quant_info ? quant_info.a_zp().getInt() : 0;
- int32_t B_zp = quant_info ? quant_info.b_zp().getInt() : 0;
+ int32_t A_zp = op.hasAttr("a_zp")
+ ? op.getAttr("a_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t B_zp = op.hasAttr("b_zp")
+ ? op.getAttr("b_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
TosaMatMulAttribute attribute(A_zp, B_zp);
@@ -705,23 +998,49 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
mlir::Operation &op) const {
- int32_t min_int =
- op.getAttr("min_int").dyn_cast<mlir::IntegerAttr>().getInt();
- int32_t max_int =
- op.getAttr("max_int").dyn_cast<mlir::IntegerAttr>().getInt();
- float min_fp = op.getAttr("min_fp")
- .dyn_cast<mlir::FloatAttr>()
- .getValue()
- .convertToFloat();
- float max_fp = op.getAttr("max_fp")
- .dyn_cast<mlir::FloatAttr>()
- .getValue()
- .convertToFloat();
+ auto min_val_attr = op.getAttr("min_val");
+ auto max_val_attr = op.getAttr("max_val");
+
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
+ if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
+ input_element_type)) {
+ input_element_type = quantType.getStorageType();
+ }
+
+ std::vector<uint8_t> min_val, max_val;
+ float min_fp, max_fp;
+ int64_t min_int, max_int;
+
+ if (input_element_type.isa<mlir::FloatType>()) {
+ min_fp =
+ mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat();
+ max_fp =
+ mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat();
+ min_int = max_int = 0;
+ } else {
+ assert(input_element_type.isa<mlir::IntegerType>());
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ min_fp = max_fp = 0.f;
+ }
+
+ if (GetU8DataFromIntOrFloatValue(min_int, min_fp, input_element_type, min_val)
+ .failed()) {
+ op.emitOpError("Failed to serialize min value");
+ return nullptr;
+ }
+
+ if (GetU8DataFromIntOrFloatValue(max_int, max_fp, input_element_type, max_val)
+ .failed()) {
+ op.emitOpError("Failed to serialize max value");
+ return nullptr;
+ }
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp);
+ TosaClampAttribute attribute(min_val, max_val);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CLAMP, Attribute_ClampAttribute, &attribute,
@@ -767,8 +1086,27 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConcatOp>(
TosaAxisAttribute attribute(axis);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_CONCAT, Attribute_AxisAttribute, &attribute,
- inputs, std::vector<std::string>{output_name});
+ Op_CONCAT, Attribute_AxisAttribute, &attribute, inputs,
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ConcatShapeOp>(
+ mlir::Operation &op) const {
+ std::vector<std::string> inputs;
+ for (uint32_t i = 0; i < op.getNumOperands(); i++) {
+ std::string input_name = GetTensorName(op.getOperand(i));
+ inputs.push_back(input_name);
+ }
+
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONCAT_SHAPE, Attribute_NONE, nullptr, inputs,
+ std::vector<std::string>{output_name});
return tyop;
}
@@ -780,13 +1118,16 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::NegateOp>(
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>(
- "quantization_info");
-
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t output_zp = quant_info ? quant_info.output_zp().getInt() : 0;
+ int32_t input1_zp =
+ op.hasAttr("input1_zp")
+ ? op.getAttr("input1_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t output_zp =
+ op.hasAttr("output_zp")
+ ? op.getAttr("output_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- TosaNegateAttribute attribute(input_zp, output_zp);
+ TosaNegateAttribute attribute(input1_zp, output_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_NEGATE, Attribute_NegateAttribute, &attribute,
@@ -800,20 +1141,12 @@ TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::ReshapeOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string shape_name = GetTensorName(op.getOperand(1));
std::string output_name = GetTensorName(op.getResult(0));
- std::vector<int> shape;
- auto shape_attr =
- op.getAttr("new_shape").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : shape_attr) {
- shape.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
-
- TosaReshapeAttribute attribute(shape);
-
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_RESHAPE, Attribute_ReshapeAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_RESHAPE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name, shape_name},
std::vector<std::string>{output_name});
return tyop;
@@ -824,38 +1157,80 @@ TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string padding_name = GetTensorName(op.getOperand(1));
std::string output_name = GetTensorName(op.getResult(0));
+ auto pad_op = llvm::cast<mlir::tosa::PadOp>(op);
- // Match padding tensor as compile-time constant attribute
- // TODO: fix when MLIR dialect changes
- mlir::ElementsAttr paddings_elems;
- if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems)))
- return nullptr;
+ auto input_zp_attr = pad_op.getInputZpAttr();
+ // pad_const includes the zero point if the tensor uses a zero point.
+ int32_t pad_const_int = input_zp_attr ? input_zp_attr.getInt() : 0;
+ float pad_const_fp = 0.f;
+
+ if (auto tensor = pad_op.getPadConst()) {
+ // Match pad_const tensor as compile-time constant attribute if present.
+ mlir::DenseElementsAttr attr;
+ if (!matchPattern(tensor, m_Constant(&attr)))
+ return nullptr;
+
+ assert(attr.getNumElements() == 1);
+ auto elementTy = attr.getElementType();
+
+ if (elementTy.isa<mlir::IntegerType>()) {
+ pad_const_int = (attr.getValues<mlir::APInt>()[0]).getSExtValue();
+ } else if (elementTy.isa<mlir::FloatType>()) {
+ pad_const_fp = (attr.getValues<mlir::APFloat>()[0]).convertToFloat();
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ }
+
+ std::vector<uint8_t> pad_const;
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
- std::vector<int> paddings;
- for (int32_t val : paddings_elems.getValues<int32_t>()) {
- paddings.push_back(val);
+ if (GetU8DataFromIntOrFloatValue(pad_const_int, pad_const_fp,
+ input_element_type, pad_const)
+ .failed()) {
+ op.emitOpError("Failed to serialize pad_const value");
+ return nullptr;
}
- TosaPadAttribute attribute(paddings, 0 /* pad_const_int */,
- 0.0f /* pad_const_fp */);
+ TosaPadAttribute attribute(pad_const);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_PAD, Attribute_PadAttribute, &attribute,
- std::vector<std::string>{input_name},
+ std::vector<std::string>{input_name, padding_name},
std::vector<std::string>{output_name});
return tyop;
}
template <>
TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::DimOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
+ TosaAxisAttribute attribute(axis);
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_DIM, Attribute_AxisAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
// Match perm tensor as compile-time constant attribute
- // TODO: fix when MLIR dialect changes
mlir::ElementsAttr perm_elems;
if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems)))
return nullptr;
@@ -879,26 +1254,14 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::SliceOp>(
mlir::Operation &op) const {
- std::vector<int> start, size;
- auto begin_attr = op.getAttr("start").dyn_cast<mlir::ArrayAttr>().getValue();
- auto size_attr = op.getAttr("size").dyn_cast<mlir::ArrayAttr>().getValue();
-
- for (auto &int_attr : begin_attr) {
- start.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
-
- for (auto &int_attr : size_attr) {
- size.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
-
- TosaSliceAttribute attribute(start, size);
-
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string start_name = GetTensorName(op.getOperand(1));
+ std::string size_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_SLICE, Attribute_SliceAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_SLICE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name, start_name, size_name},
std::vector<std::string>{output_name});
return tyop;
@@ -908,21 +1271,13 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TileOp>(
mlir::Operation &op) const {
- std::string input_name = GetTensorName(op.getOperand(0));
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
std::string output_name = GetTensorName(op.getResult(0));
- std::vector<int> multiples;
- auto multiples_attr =
- op.getAttr("multiples").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : multiples_attr) {
- multiples.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
-
- TosaTileAttribute attribute(multiples);
-
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_TILE, Attribute_TileAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_TILE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input0_name, input1_name},
std::vector<std::string>{output_name});
return tyop;
@@ -966,60 +1321,21 @@ TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::ResizeOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string scale_name = GetTensorName(op.getOperand(1));
+ std::string offset_name = GetTensorName(op.getOperand(2));
+ std::string border_name = GetTensorName(op.getOperand(3));
std::string output_name = GetTensorName(op.getResult(0));
- std::vector<int> output_size;
- auto output_size_attr =
- op.getAttr("output_size").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : output_size_attr) {
- output_size.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(output_size, 2);
-
- std::vector<int> stride;
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(stride, 2);
-
- std::vector<int> offset;
- auto offset_attr =
- op.getAttr("offset").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : offset_attr) {
- offset.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(offset, 2);
-
- int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt();
-
- std::vector<float> stride_fp;
- auto stride_fp_attr =
- op.getAttr("stride_fp").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &fp_attr : stride_fp_attr) {
- stride_fp.push_back(fp_attr.dyn_cast<mlir::FloatAttr>().getValueAsDouble());
- }
- ASSERT_VECTOR_LENGTH(stride_fp, 2);
-
- std::vector<float> offset_fp;
- auto offset_fp_attr =
- op.getAttr("offset_fp").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &fp_attr : offset_fp_attr) {
- offset_fp.push_back(fp_attr.dyn_cast<mlir::FloatAttr>().getValueAsDouble());
- }
- ASSERT_VECTOR_LENGTH(offset_fp, 2);
-
auto mode_str =
op.getAttr("mode").dyn_cast<mlir::StringAttr>().getValue().str();
ResizeMode mode = ResizeModeStr2Enum(mode_str);
- TosaResizeAttribute attribute(output_size, stride, offset, shift, stride_fp,
- offset_fp, mode);
+ TosaResizeAttribute attribute({}, {}, {}, mode);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_RESIZE, Attribute_ResizeAttribute, &attribute,
- std::vector<std::string>{input_name},
+ std::vector<std::string>{input_name, scale_name, offset_name,
+ border_name},
std::vector<std::string>{output_name});
return tyop;
@@ -1048,17 +1364,15 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>(
mlir::Operation &op) const {
- std::string input0_name = GetTensorName(op.getOperand(0));
- std::string input1_name = GetTensorName(op.getOperand(1));
- std::string output_name = GetTensorName(op.getResult(0));
-
- int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt();
- TosaMulAttribute attribute(shift);
+ mlir::tosa::MulOp mul_op = mlir::cast<mlir::tosa::MulOp>(op);
+ std::string input0_name = GetTensorName(mul_op.getInput1());
+ std::string input1_name = GetTensorName(mul_op.getInput2());
+ std::string output_name = GetTensorName(mul_op.getOutput());
+ std::string shift_name = GetTensorName(mul_op.getShift());
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_MUL, Attribute_MulAttribute, &attribute,
- std::vector<std::string>{input0_name, input1_name},
+ Op_MUL, Attribute_NONE, nullptr, {input0_name, input1_name, shift_name},
std::vector<std::string>{output_name});
return tyop;
@@ -1078,8 +1392,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ArithmeticRightShiftOp>(
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_ARITHMETIC_RIGHT_SHIFT, Attribute_ArithmeticRightShiftAttribute,
- &attribute,
- std::vector<std::string>{input0_name, input1_name},
+ &attribute, std::vector<std::string>{input0_name, input1_name},
std::vector<std::string>{output_name});
return tyop;
@@ -1093,7 +1406,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>(
std::string output_name = GetTensorName(op.getResult(0));
// Match table tensor as compile-time constant attribute
- // TODO: fix when MLIR dialect changes
mlir::ElementsAttr table_elems;
if (!matchPattern(op.getOperand(1), m_Constant(&table_elems)))
return nullptr;
@@ -1127,28 +1439,27 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
bool per_channel =
op.getAttr("per_channel").dyn_cast<mlir::BoolAttr>().getValue();
- std::vector<int> multiplier, shift;
- auto multiplier_attr =
- op.getAttr("multiplier").dyn_cast<mlir::ArrayAttr>().getValue();
- auto shift_attr = op.getAttr("shift").dyn_cast<mlir::ArrayAttr>().getValue();
+ bool input_unsigned =
+ op.getAttr("input_unsigned").dyn_cast<mlir::BoolAttr>().getValue();
+ bool output_unsigned =
+ op.getAttr("output_unsigned").dyn_cast<mlir::BoolAttr>().getValue();
- for (auto &int_attr : multiplier_attr) {
- multiplier.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto input = op.getOperand(0);
+ auto input_ty = input.getType().cast<mlir::RankedTensorType>();
+ auto output = op.getResult(0);
+ auto output_ty = output.getType().cast<mlir::RankedTensorType>();
- for (auto &int_attr : shift_attr) {
- shift.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ std::string input_name = GetTensorName(input);
+ std::string multiplier_name = GetTensorName(op.getOperand(1));
+ std::string shift_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(output);
- std::string input_name = GetTensorName(op.getOperand(0));
- std::string output_name = GetTensorName(op.getResult(0));
-
- TosaRescaleAttribute attribute(input_zp, output_zp, multiplier, shift,
- scale32, double_round, per_channel);
+ TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round,
+ per_channel, input_unsigned, output_unsigned);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_RESCALE, Attribute_RescaleAttribute, &attribute,
- std::vector<std::string>{input_name},
+ std::vector<std::string>{input_name, multiplier_name, shift_name},
std::vector<std::string>{output_name});
return tyop;
@@ -1161,39 +1472,77 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::CustomOp>(
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
+ const std::string implementation_attrs = op.getAttr("implementation_attrs")
+ .cast<mlir::StringAttr>()
+ .getValue()
+ .str();
+
+ std::vector<uint8_t> attrs_data(implementation_attrs.size());
+ memcpy(attrs_data.data(), implementation_attrs.data(), attrs_data.size());
+ TosaCustomAttribute attribute(
+ op.getAttr("operator_name").cast<mlir::StringAttr>().getValue().str(),
+ op.getAttr("domain_name").cast<mlir::StringAttr>().getValue().str(),
+ attrs_data);
+
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_CUSTOM, Attribute_NONE, nullptr,
+ Op_CUSTOM, Attribute_CustomAttribute, &attribute,
std::vector<std::string>{input_name},
std::vector<std::string>{output_name});
return tyop;
}
+namespace {
+
+// serialize a region and all its blocks, and return region's return values
+TosaSerializationRegion *
+BuildRegion(mlir::Region &region, const std::string region_name,
+ const bool isolated_from_above,
+ TosaSerializationRegionBuilder *curr_region_builder,
+ TosaSerializationHandler *tsh,
+ std::vector<mlir::Value> &return_values, bool is_top = false) {
+ TosaSerializationRegion *ser_region =
+ new TosaSerializationRegion(region_name, {});
+ assert(ser_region);
+ tsh->GetRegions().push_back(ser_region);
+
+ TosaSerializationRegionBuilder *parent_value_scope =
+ isolated_from_above ? nullptr : curr_region_builder;
+
+ TosaSerializationRegionBuilder region_builder(ser_region, &region,
+ parent_value_scope, tsh);
+ if (region_builder.BuildAllBlocksInRegion(is_top, return_values).failed()) {
+ return nullptr;
+ }
+ return ser_region;
+}
+
+static int input_tensor_index = 0;
+static int intermediate_tensor_index = 0;
+static int output_tensor_index = 0;
+
+} // namespace
+
template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
mlir::Operation &op) const {
+ const std::string op_name = op.getName().getStringRef().str();
+ const bool isolated_from_above =
+ op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ auto curr_region_builder = GetRegionBuilder();
std::vector<std::string> input_names, output_names;
+ std::vector<mlir::Value> then_yields, else_yields;
+ auto tsh = GetTsh();
mlir::Region &then_region = op.getRegion(0);
mlir::Region &else_region = op.getRegion(1);
- std::vector<mlir::Value> then_yields, else_yields;
- TosaSerializationBasicBlock *then_block = nullptr;
- TosaSerializationBasicBlock *else_block = nullptr;
-
- // Building then branch block
- std::string then_block_name =
- "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
- then_block = new TosaSerializationBasicBlock(
- then_block_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(then_block);
- block_builder->GetTsh()->GetBlocks().push_back(then_block);
-
- TosaSerializationBlockBuilder then_block_builder(
- then_block, block_builder->GetTsh(), &then_region);
- if (then_block_builder.BuildAllOpsInRegion(then_yields).failed()) {
+
+ const std::string then_region_name = op_name + "_then_region";
+ TosaSerializationRegion *ser_then_region =
+ BuildRegion(then_region, then_region_name, isolated_from_above,
+ curr_region_builder, tsh, then_yields);
+ if (!ser_then_region) {
return nullptr;
}
if (then_yields.size() != op.getNumResults()) {
@@ -1202,19 +1551,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
return nullptr;
}
- // Building else branch block
- std::string else_block_name =
- "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
- else_block = new TosaSerializationBasicBlock(
- else_block_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(else_block);
- block_builder->GetTsh()->GetBlocks().push_back(else_block);
-
- TosaSerializationBlockBuilder else_block_builder(
- else_block, block_builder->GetTsh(), &else_region);
- if (else_block_builder.BuildAllOpsInRegion(else_yields).failed()) {
+ const std::string else_region_name = op_name + "_else_region";
+ TosaSerializationRegion *ser_else_region =
+ BuildRegion(else_region, else_region_name, isolated_from_above,
+ curr_region_builder, tsh, else_yields);
+ if (!ser_else_region) {
return nullptr;
}
if (else_yields.size() != op.getNumResults()) {
@@ -1223,7 +1564,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
return nullptr;
}
- TosaCondIfAttribute attribute(then_block->GetName(), else_block->GetName());
+ TosaCondIfAttribute attribute(then_region_name, else_region_name);
for (size_t i = 0; i < op.getNumOperands(); i++) {
std::string input_name = GetTensorName(op.getOperand(i));
@@ -1235,9 +1576,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
output_names.push_back(output_name);
}
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_COND_IF, Attribute_CondIfAttribute, &attribute,
- input_names, output_names);
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_COND_IF, Attribute_CondIfAttribute,
+ &attribute, input_names, output_names);
return tyop;
}
@@ -1246,27 +1587,22 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
mlir::Operation &op) const {
+ const std::string op_name = op.getName().getStringRef().str();
+ const bool isolated_from_above =
+ op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ auto curr_region_builder = GetRegionBuilder();
std::vector<std::string> input_names, output_names;
+ auto tsh = GetTsh();
mlir::Region &cond_region = op.getRegion(0);
mlir::Region &body_region = op.getRegion(1);
std::vector<mlir::Value> cond_yields, body_yields;
- TosaSerializationBasicBlock *cond_block = nullptr;
- TosaSerializationBasicBlock *body_block = nullptr;
-
- // Building cond branch block
- std::string cond_block_name =
- "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
- cond_block = new TosaSerializationBasicBlock(
- cond_block_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(cond_block);
- block_builder->GetTsh()->GetBlocks().push_back(cond_block);
-
- TosaSerializationBlockBuilder cond_block_builder(
- cond_block, block_builder->GetTsh(), &cond_region);
- if (cond_block_builder.BuildAllOpsInRegion(cond_yields).failed()) {
+
+ const std::string cond_region_name = op_name + "_cond_region";
+ TosaSerializationRegion *ser_cond_region =
+ BuildRegion(cond_region, cond_region_name, isolated_from_above,
+ curr_region_builder, tsh, cond_yields);
+ if (!ser_cond_region) {
return nullptr;
}
if (cond_yields.size() != 1) {
@@ -1274,19 +1610,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
return nullptr;
}
- // Building body branch block
- std::string body_block_name =
- "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
- body_block = new TosaSerializationBasicBlock(
- body_block_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(body_block);
- block_builder->GetTsh()->GetBlocks().push_back(body_block);
-
- TosaSerializationBlockBuilder body_block_builder(
- body_block, block_builder->GetTsh(), &body_region);
- if (body_block_builder.BuildAllOpsInRegion(body_yields).failed()) {
+ const std::string body_region_name = op_name + "_body_region";
+ TosaSerializationRegion *ser_body_region =
+ BuildRegion(body_region, body_region_name, isolated_from_above,
+ curr_region_builder, tsh, body_yields);
+ if (!ser_body_region) {
return nullptr;
}
if (body_yields.size() != op.getNumResults()) {
@@ -1295,8 +1623,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
return nullptr;
}
- TosaWhileLoopAttribute attribute(cond_block->GetName(),
- body_block->GetName());
+ TosaWhileLoopAttribute attribute(cond_region_name, body_region_name);
for (size_t i = 0; i < op.getNumOperands(); i++) {
std::string input_name = GetTensorName(op.getOperand(i));
@@ -1308,135 +1635,286 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
output_names.push_back(output_name);
}
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_WHILE_LOOP, Attribute_WhileLoopAttribute,
+ &attribute, input_names, output_names);
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::RFFT2dOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_real_name = GetTensorName(op.getResult(0));
+ std::string output_imag_name = GetTensorName(op.getResult(1));
+
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
+
+ TosaRFFTAttribute attribute(local_bound);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_RFFT2D, Attribute_RFFTAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_real_name, output_imag_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::FFT2dOp>(
+ mlir::Operation &op) const {
+
+ bool inverse = op.getAttr("inverse").dyn_cast<mlir::BoolAttr>().getValue();
+
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
+
+ std::string input_real_name = GetTensorName(op.getOperand(0));
+ std::string input_imag_name = GetTensorName(op.getOperand(1));
+ std::string output_real_name = GetTensorName(op.getResult(0));
+ std::string output_imag_name = GetTensorName(op.getResult(1));
+
+ TosaFFTAttribute attribute(inverse, local_bound);
+
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute,
- input_names, output_names);
+ Op_FFT2D, Attribute_FFTAttribute, &attribute,
+ std::vector<std::string>{input_real_name, input_imag_name},
+ std::vector<std::string>{output_real_name, output_imag_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::VariableReadOp>(
+ mlir::Operation &op) const {
+
+ std::string input_name = GetVariableTensorName(&op);
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::VariableWriteOp>(
+ mlir::Operation &op) const {
+
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetVariableTensorName(&op);
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
return tyop;
}
/* End translating TOSA operator */
+mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(
+ bool is_top, std::vector<mlir::Value> &return_values) {
+ std::string region_name = ser_region->GetName();
+ int block_index = 0;
+ for (auto &block : this->region->getBlocks()) {
+ // must name first block of top region "main"
+ const std::string block_name =
+ (is_top && block_index == 0)
+ ? "main"
+ : (region_name + "_bb" + std::to_string(block_index++));
+ TosaSerializationBasicBlock *ser_block = new TosaSerializationBasicBlock(
+ block_name, region_name, std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
+
+ // build the block
+ TosaSerializationBlockBuilder block_builder(ser_block, this, &block);
+ // Region Builders need access to block builders
+ block_builders.push_back(&block_builder);
+
+ if (block_builder.BuildAllOpsInBlock(return_values).failed()) {
+ return mlir::failure();
+ }
+
+ if (return_values.empty()) {
+ llvm::errs() << "BWarning: graph doesn't have return values\n";
+ }
-mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion(
+ // Add serialized block to serialized region
+ ser_region->GetBlocks().push_back(ser_block);
+ }
+
+ return mlir::success();
+}
+
+mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
std::vector<mlir::Value> &return_values) {
TosaSerializationOperator *ser_operator = nullptr;
TosaSerializationTensor *ser_tensor = nullptr;
size_t num_blocks_in_region = 0;
- static int input_tensor_index = 0;
- static int intermediate_tensor_index = 0;
- static int output_tensor_index = 0;
TosaSerializationOperatorBuilder op_builder(this);
- for (auto &bb : region->getBlocks()) {
- num_blocks_in_region++;
-
- if (num_blocks_in_region > 1) {
- llvm::errs() << "Invalid MLIR: multiple blocks in a region\n";
- return mlir::failure();
- }
-
- // We always have one block for each region right now
- assert(bb.isEntryBlock());
-
- // Specify block input tensor name
- for (auto args : bb.getArguments()) {
- std::string block_input_name =
- "TosaInput_" + std::to_string(input_tensor_index++);
- block->GetInputs().push_back(block_input_name);
- tensor_map[args] = block_input_name;
- input_tensor_map[args] = block_input_name;
- }
+ // Specify block input tensor name
+ for (auto args : block->getArguments()) {
+ std::string block_input_name =
+ "TosaInput_" + std::to_string(input_tensor_index++);
+ ser_block->GetInputs().push_back(block_input_name);
+ tensor_map[args] = block_input_name;
+ input_tensor_map[args] = block_input_name;
+ }
+
+ // Build tensor_map
+ for (auto &op : block->getOperations()) {
+ if (llvm::isa<mlir::tosa::VariableOp>(op)) {
+ RegisterVariableOp(op);
+ } else if (!(llvm::isa<mlir::tosa::YieldOp>(op) ||
+ llvm::isa<mlir::func::ReturnOp>(op) ||
+ llvm::isa<mlir::tensor::CastOp>(op))) {
+ for (uint32_t i = 0; i < op.getNumResults(); i++) {
+ std::string intermediate_tensor_name =
+ "layer_" + std::to_string(intermediate_tensor_index++);
+ tensor_map[op.getResult(i)] = intermediate_tensor_name;
+ }
+ } else {
+ if (llvm::isa<mlir::tensor::CastOp>(op))
+ continue;
+ // Override return tensor name
+ for (auto val : op.getOperands()) {
+ // Workaround to skip mlir::tensor::CastOp before return
+ mlir::Operation *val_defining_op = val.getDefiningOp();
+ if (val_defining_op) {
+ if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op))
+ val = val_defining_op->getOperand(0);
+ }
- // Build tensor_map
- for (auto &op : bb) {
- if (!(llvm::isa<mlir::tosa::YieldOp>(op) ||
- llvm::isa<mlir::func::ReturnOp>(op) ||
- llvm::isa<mlir::tensor::CastOp>(op))) {
- for (uint32_t i = 0; i < op.getNumResults(); i++) {
- std::string intermediate_tensor_name =
- "layer_" + std::to_string(intermediate_tensor_index++);
- tensor_map[op.getResult(i)] = intermediate_tensor_name;
+ // Sanity check. This mlir::Value should be built in map since graph
+ // is DAG
+ if (tensor_map.find(val) == tensor_map.end()) {
+ llvm::errs() << "ERROR: Can't find built mlir::Value key.\n";
+ return mlir::failure();
}
- } else {
- if (llvm::isa<mlir::tensor::CastOp>(op))
- continue;
- // Override return tensor name
- for (auto val : op.getOperands()) {
- // Workaround to skip mlir::tensor::CastOp before return
- mlir::Operation *val_defining_op = val.getDefiningOp();
- if (val_defining_op) {
- if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op))
- val = val_defining_op->getOperand(0);
- }
-
- // Sanity check. This mlir::Value should be built in map since graph
- // is DAG
- if (tensor_map.find(val) == tensor_map.end()) {
- llvm::errs() << "ERROR: Can't find built mlir::Value key.\n";
- return mlir::failure();
- }
-
- // If returned value is block input, short-circuit the tensor name
- // Otherwise, build a new output name and override the origin tensor
- // name
- if (input_tensor_map.find(val) != input_tensor_map.end()) {
- block->GetOutputs().push_back(input_tensor_map[val]);
- return_values.push_back(val);
- } else {
- std::string output_name =
- "TosaOutput_" + std::to_string(output_tensor_index++);
- tensor_map[val] = output_name;
- block->GetOutputs().push_back(output_name);
- return_values.push_back(val);
- }
+
+ // If returned value is block input, short-circuit the tensor name
+ // Otherwise, build a new output name and override the origin tensor
+ // name
+ if (input_tensor_map.find(val) != input_tensor_map.end()) {
+ ser_block->GetOutputs().push_back(input_tensor_map[val]);
+ return_values.push_back(val);
+ } else {
+ std::string output_name =
+ "TosaOutput_" + std::to_string(output_tensor_index++);
+ tensor_map[val] = output_name;
+ ser_block->GetOutputs().push_back(output_name);
+ return_values.push_back(val);
}
}
}
+ }
- // Build tensor
+ // Build variable tensor
+ for (auto pair : variable_tensor_op_map) {
+ mlir::Operation *op = pair.second;
+ mlir::Value val = op->getResult(0);
+ mlir::RankedTensorType tensor_type = op->getAttr("type")
+ .cast<mlir::TypeAttr>()
+ .getValue()
+ .cast<mlir::RankedTensorType>();
- // The tensor_map is sorted by hashed mlir::Value types.
- // For serialization, sort tensors alphabetically by name for a
- // deterministic and human-friendly ordering.
- std::map<std::string, mlir::Value> tensor_name_sort;
- for (auto pair : tensor_map)
- tensor_name_sort[pair.second] = pair.first;
+ std::string variable_mlir_name =
+ op->getAttr("name").cast<mlir::StringAttr>().getValue().str();
- for (auto pair : tensor_name_sort) {
- ser_tensor = BuildTosaSerializationTensor(pair.second /* val */,
- pair.first /* name */);
- if (!ser_tensor) {
- llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
- return mlir::failure();
- }
- block->GetTensors().push_back(ser_tensor);
+ ser_tensor = BuildTosaSerializationVariableTensor(
+ tensor_type /* tensor_type */, pair.first /* flatbuffer name */,
+ variable_mlir_name);
+ if (!ser_tensor) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
+ return mlir::failure();
}
-
- // Build operator
- for (auto &op : bb) {
- if (llvm::isa<mlir::tosa::YieldOp>(op) ||
- llvm::isa<mlir::func::ReturnOp>(op) ||
- llvm::isa<mlir::tensor::CastOp>(op))
- continue;
- ser_operator = BuildTosaSerializationOperator(op_builder, op);
- if (!ser_operator) {
- llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n";
+ // Initialize if "initial_value" attribute exists. If not, set data to all
+ // zeros
+ mlir::Attribute initial_value = op->getAttr("initial_value");
+ std::vector<uint8_t> u8_data;
+ if (initial_value) {
+ if (initial_value.isa<mlir::DenseElementsAttr>()) {
+ if (op_builder
+ .GetDataFromAttribute(*op, initial_value,
+ tensor_type.getElementType(), u8_data)
+ .failed()) {
+ llvm::errs() << "ERROR: GetDataFromAttribute() fails when building "
+ "initial_value of variable tensor\n";
+ return mlir::failure();
+ }
+ } else {
+ llvm::errs() << "ERROR: Unknown initial_value attribute type\n";
return mlir::failure();
}
- block->GetOperators().push_back(ser_operator);
+ } else {
+ TosaSerializationHandler::ForceAlignTensorData(u8_data);
}
+ ser_tensor->SetData(u8_data);
+ ser_block->GetTensors().push_back(ser_tensor);
+ }
+
+ // Build tensor
+
+ // The tensor_map is sorted by hashed mlir::Value types.
+ // For serialization, sort tensors alphabetically by name for a
+ // deterministic and human-friendly ordering.
+ std::map<std::string, mlir::Value> tensor_name_sort;
+ for (auto pair : tensor_map)
+ tensor_name_sort[pair.second] = pair.first;
+
+ for (auto pair : tensor_name_sort) {
+ mlir::RankedTensorType tensor_type =
+ pair.second.getType().cast<mlir::RankedTensorType>();
+ ser_tensor = BuildTosaSerializationTensor(pair.second /* val */,
+ pair.first /* name */);
+ if (!ser_tensor) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
+ return mlir::failure();
+ }
+ ser_block->GetTensors().push_back(ser_tensor);
+ }
+
+ // Build operator
+ for (auto &op : block->getOperations()) {
+ if (llvm::isa<mlir::tosa::YieldOp>(op) ||
+ llvm::isa<mlir::func::ReturnOp>(op) ||
+ llvm::isa<mlir::tensor::CastOp>(op) ||
+ llvm::isa<mlir::tosa::VariableOp>(op))
+ continue;
+ ser_operator = BuildTosaSerializationOperator(op_builder, op);
+ if (!ser_operator) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n";
+ return mlir::failure();
+ }
+ ser_block->GetOperators().push_back(ser_operator);
}
-
return mlir::success();
}
TosaSerializationOperator *
TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) {
- std::string full_op_name = op.getName().getStringRef().str();
TosaSerializationOperator *target_operator = nullptr;
- if (false) {
+ if (llvm::isa<mlir::tosa::VariableReadOp>(op)) {
+ target_operator = op_builder.build<mlir::tosa::VariableReadOp>(op);
+ } else if (llvm::isa<mlir::tosa::VariableWriteOp>(op)) {
+ target_operator = op_builder.build<mlir::tosa::VariableWriteOp>(op);
}
#define DEF_OPERATOR(MLIR_OP) \
else if (llvm::isa<mlir::tosa::MLIR_OP##Op>(op)) { \
@@ -1455,17 +1933,22 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
return nullptr;
}
+ if (llvm::isa<mlir::tosa::VariableReadOp>(op) ||
+ llvm::isa<mlir::tosa::VariableWriteOp>(op)) {
+ return target_operator;
+ }
+
// Sanity check the number of inputs/outputs of TOSA dialect matches the
// number of TOSA flatbuffer
if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) {
- llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands()
+ llvm::errs() << "WARNING: MLIR operator has " << op.getNumOperands()
<< " input tensors != Flatbuffer "
"operator has "
<< target_operator->GetInputTensorNames().size()
<< " input tensors\n";
}
if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) {
- llvm::errs() << "WARNING. MLIR operator has " << op.getNumResults()
+ llvm::errs() << "WARNING: MLIR operator has " << op.getNumResults()
<< " output tensors != Flatbuffer "
"operator has "
<< target_operator->GetOutputTensorNames().size()
@@ -1476,30 +1959,83 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
}
TosaSerializationTensor *
+TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor(
+ mlir::RankedTensorType tensor_type, const std::string &name,
+ const std::string &variable_mlir_name) {
+ // If tensor already created before, use that tensor directly, create a new
+ // one otherwise
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(name);
+ if (ts) {
+ return nullptr;
+ }
+
+ std::vector<int32_t> shape(tensor_type.getShape().begin(),
+ tensor_type.getShape().end());
+
+ DType type = Type2DType(tensor_type.getElementType());
+
+ ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
+ /* is_variable = */ true,
+ /* is_unranked = */ false,
+ variable_mlir_name);
+
+ return ts;
+}
+
+TosaSerializationTensor *
TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
mlir::Value val, const std::string &name) {
// If tensor already created before, use that tensor directly, create a new
// one otherwise
- TosaSerializationTensor *ts = block->GetTensorByName(name);
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(name);
if (ts) {
return nullptr;
}
- mlir::RankedTensorType tensor =
- val.getType().dyn_cast<mlir::RankedTensorType>();
- std::vector<int32_t> shape(tensor.getShape().begin(),
- tensor.getShape().end());
- DType type = Type2DType(tensor.getElementType());
+ // handling of tosa.shape values
+ if (auto shape_ty = val.getType().dyn_cast<mlir::tosa::shapeType>()) {
+ auto rank = shape_ty.getRank();
+ std::vector<int32_t> shape;
+ if (rank > 0) {
+ shape.push_back(rank);
+ }
+ ts = new TosaSerializationTensor(name,
+ /* shape = */ shape,
+ /* type = */ DType::DType_SHAPE,
+ /* data = */ std::vector<uint8_t>());
+ return ts;
+ }
- ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>());
+ auto ttype = val.getType().dyn_cast<mlir::TensorType>();
+ if (!ttype) {
+ llvm::errs() << "TOSA serialization, supplied value is not of TensorType\n";
+ return nullptr;
+ }
+
+ const bool is_unranked = !ttype.hasRank();
+ std::vector<int32_t> shape;
+ if (!is_unranked) {
+ auto shaped = val.getType().dyn_cast<mlir::ShapedType>();
+ assert(shaped);
+ for (int idx = 0; idx < ttype.getRank(); idx++) {
+ if (shaped.isDynamicDim(idx)) {
+ shape.push_back(0); // size of 0 represents dynamic dimension
+ } else {
+ auto dim = shaped.getDimSize(idx);
+ shape.push_back(dim);
+ }
+ }
+ }
+
+ DType type = Type2DType(ttype.getElementType());
+ ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
+ /* variable = */ false, is_unranked);
return ts;
}
mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func,
TosaSerializationHandler &tsh) {
- TosaSerializationBasicBlock *main_block;
-
mlir::Region *main_region = func.getCallableRegion();
std::vector<mlir::Value> main_returns;
@@ -1508,21 +2044,22 @@ mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func,
return mlir::failure();
}
- if (!tsh.GetBlocks().empty()) {
- llvm::errs() << "Internal Error: TosaSerializationHandler's block list "
+ if (!tsh.GetRegions().empty()) {
+ llvm::errs() << "Internal Error: TosaSerializationHandler's region list "
"must be empty\n";
return mlir::failure();
}
- main_block = new TosaSerializationBasicBlock(
- std::string("main"), std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(main_block);
- tsh.GetBlocks().push_back(main_block);
+ // reset static counters
+ input_tensor_index = 0;
+ intermediate_tensor_index = 0;
+ output_tensor_index = 0;
- TosaSerializationBlockBuilder block_builder(main_block, &tsh, main_region);
- if (block_builder.BuildAllOpsInRegion(main_returns).failed()) {
+ TosaSerializationRegion *ser_main_region =
+ BuildRegion(*main_region, "main", /* isolated_from_above = */ true,
+ /* parent_value_scope = */ nullptr, &tsh, main_returns,
+ /* is_top = */ true);
+ if (!ser_main_region) {
return mlir::failure();
}
@@ -1580,20 +2117,42 @@ mlir::LogicalResult dumpTosaJSON(mlir::func::FuncOp &func) {
return mlir::success();
}
-namespace mlir {
+#define GEN_PASS_DEF_TOSASERIALIZATIONPASS
+namespace mlir {
namespace tosa {
-
namespace {
class TosaSerialize : public TosaSerializationPassBase<TosaSerialize> {
public:
void runOnOperation() final {
- auto function = getOperation();
+ auto moduleOp = getOperation();
- if (dumpTosaFlatbuffer(function).failed()) {
- llvm::errs() << "Failed to generate TOSA flatbuffer...\n";
- return signalPassFailure();
+ // iterate through each op in the moduleOp, call dumpTosaFlatbuffer if
+ // that's a func.funcOp
+
+ auto regions = moduleOp->getRegions();
+ auto region_size = regions.size();
+
+ auto region_0 = regions.begin();
+ auto block_size = region_0->getBlocks().size();
+
+ auto block_0 = region_0->getBlocks().begin();
+
+ auto op_size = block_0->getOperations().size();
+
+ for (auto it = block_0->getOperations().begin();
+ it != block_0->getOperations().end(); ++it) {
+ // read variableOps that are declared outside of functionOps
+ if (llvm::isa<mlir::tosa::VariableOp>(*it)) {
+ RegisterVariableOp(*it);
+ } else if (llvm::isa<mlir::func::FuncOp>(*it)) {
+ auto funcOp = dyn_cast<mlir::func::FuncOp>((*it));
+ if (dumpTosaFlatbuffer(funcOp).failed()) {
+ llvm::errs() << "Failed to generate TOSA flatbuffer...\n";
+ return signalPassFailure();
+ }
+ }
}
}
};
@@ -1614,7 +2173,7 @@ public:
} // anonymous namespace
// Creates an instance of the TOSA flatbuffer generation pass
-std::unique_ptr<Pass> createTosaSerializePass() {
+std::unique_ptr<OperationPass<ModuleOp>> createTosaSerializePass() {
return std::make_unique<TosaSerialize>();
}