aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-22 19:30:11 +0000
committerTai Ly <tai.ly@arm.com>2023-10-16 17:11:30 +0000
commit28026df342e035d71907c5e6cf1ee29f3974afea (patch)
treea29c6bb28380ead925405dbbeaa61f1d1bb74260
parenta7d41ccfc92f0a00ee6b4e96f217eb8f27956b00 (diff)
downloadtosa_mlir_translator-28026df342e035d71907c5e6cf1ee29f3974afea.tar.gz
Implement deserialization of stateful ops
This patch implements deserialization of variable, variable.read and variable.write ops. The variable ops are deserialized before the function. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I029bdf087576e97cacab469386863e6d1baf855c
-rw-r--r--src/TosaDeserialize.cpp346
-rw-r--r--src/TosaSerialize.cpp17
m---------third_party/serialization_lib0
3 files changed, 259 insertions, 104 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 79f0c78..031c57f 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -54,6 +54,51 @@ const std::string kMainFunctionName = "main";
namespace {
+// a global map from flatbuffer variable names to serialized tensors
+std::unordered_map<std::string, TosaSerializationTensor *> variable_tensor_map;
+
+void RegisterVariableTensor(TosaSerializationTensor *ts) {
+ assert(ts->GetVariable());
+ // insert variable tensor ts only if not already present
+ variable_tensor_map.insert({ts->GetName(), ts});
+}
+
+bool IsVariableTensor(const std::string flatbuffer_tensor_name) {
+ return variable_tensor_map.count(flatbuffer_tensor_name);
+}
+
+// return the variable name corresponding to flatbuffer_tensor_name
+const std::string GetVariableTensorName(TosaSerializationTensor *ts) {
+ assert(ts->GetVariable());
+ const auto name = ts->GetVariableName();
+ if (name == "") {
+ // for legacy flatbuffers which may not have variable_name fields
+ return ts->GetName();
+ }
+ return name;
+}
+
+// return the variable name corresponding to flatbuffer_tensor_name
+const std::string
+GetVariableTensorName(const std::string flatbuffer_tensor_name) {
+ if (!IsVariableTensor(flatbuffer_tensor_name)) {
+ llvm::errs() << "ERROR: Variable tensor " << flatbuffer_tensor_name
+ << " is not found in variable_tensor_map";
+ return "";
+ }
+ return GetVariableTensorName(variable_tensor_map[flatbuffer_tensor_name]);
+}
+
+bool IsVariableReadOp(TosaSerializationOperator *op) {
+ return (op->GetOp() == tosa::Op::Op_IDENTITY) &&
+ IsVariableTensor(op->GetInputTensorNames()[0]);
+}
+
+bool IsVariableWriteOp(TosaSerializationOperator *op) {
+ return (op->GetOp() == tosa::Op::Op_IDENTITY) &&
+ IsVariableTensor(op->GetOutputTensorNames()[0]);
+}
+
// construct tensor type from dtype and shape of TosaSerializationTensor
mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
TosaSerializationTensor *ts,
@@ -108,6 +153,107 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
return mlir::success();
}
+mlir::DenseElementsAttr
+ConstructConstAttr(const mlir::RankedTensorType &output_type,
+ TosaSerializationTensor *ts, const std::string &op_name) {
+ const auto &data = ts->GetData();
+ auto &shape = ts->GetShape();
+ // compute output data size
+ uint32_t out_size = 1;
+ for (const auto dim : shape) {
+ out_size *= dim;
+ }
+ mlir::DenseElementsAttr value_attr;
+ switch (ts->GetDtype()) {
+ case DType_FP32: {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data));
+ break;
+ }
+ case DType_INT4: {
+ std::vector<int8_t> int4_data;
+ TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data));
+ break;
+ }
+ case DType_INT8: {
+ std::vector<int8_t> int8_data;
+ TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data));
+ break;
+ }
+ case DType_INT16: {
+ std::vector<int16_t> int16_data;
+ TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data));
+ break;
+ }
+ case DType_INT32: {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data));
+ break;
+ }
+ case DType_INT48: {
+ std::vector<int64_t> int48_data;
+ TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data);
+ std::vector<mlir::APInt> apint_data;
+ for (const auto v : int48_data) {
+ mlir::APInt apint_value(48, static_cast<uint64_t>(v),
+ /* isSigned = */ false);
+ apint_data.push_back(apint_value);
+ }
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data));
+ break;
+ }
+ case DType_BOOL: {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data);
+ llvm::SmallVector<bool> bool_values(bool_data.begin(), bool_data.end());
+ value_attr = mlir::DenseElementsAttr::get(output_type, bool_values);
+ break;
+ }
+ default: {
+ llvm::errs() << "ERROR: " << op_name
+ << " contains unsupported element type\n";
+ return nullptr;
+ }
+ }
+
+ return value_attr;
+}
+
+mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) {
+ if (variable_tensor_map.empty()) {
+ return mlir::success();
+ }
+ auto loc = module.getLoc();
+ auto op_builder = mlir::OpBuilder(module.getBodyRegion());
+ for (auto [flatbuffer_name, ts] : variable_tensor_map) {
+ auto name = GetVariableTensorName(ts);
+ mlir::RankedTensorType type;
+ if (BuildTensorType(&op_builder, ts, type).failed()) {
+ return mlir::failure();
+ }
+
+ mlir::Attribute value_attr = nullptr;
+ if (!ts->GetData().empty()) {
+ value_attr = ConstructConstAttr(type, ts, name);
+ }
+ op_builder.create<mlir::tosa::VariableOp>(loc, llvm::StringRef(name), type,
+ value_attr);
+ }
+
+ return mlir::success();
+}
+
template <class T>
mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder,
const std::vector<T> &values) {
@@ -221,6 +367,13 @@ public:
template <Op OPCODE>
std::vector<mlir::Value> build(TosaSerializationOperator *op) const;
+ std::vector<mlir::Value> BuildVariableOp(TosaSerializationOperator *op) const;
+
+ std::vector<mlir::Value>
+ BuildVariableReadOp(TosaSerializationOperator *op) const;
+
+ void BuildVariableWriteOp(TosaSerializationOperator *op) const;
+
std::string get_string(TosaSerializationOperator *op) const {
std::string op_string;
op_string += "operator opcode=";
@@ -334,6 +487,37 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildVariableReadOp(
+ TosaSerializationOperator *op) const {
+ auto input_tensor_name = op->GetInputTensorNames()[0];
+ auto output_tensor_name = op->GetOutputTensorNames()[0];
+
+ assert(IsVariableTensor(input_tensor_name));
+
+ auto variable_name = GetVariableTensorName(input_tensor_name);
+ mlir::RankedTensorType output_type = tensor_type_map->at(output_tensor_name);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableReadOp>(
+ loc, output_type, llvm::StringRef(variable_name));
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+void TosaMlirOperatorBuilder::BuildVariableWriteOp(
+ TosaSerializationOperator *op) const {
+ auto input_tensor_name = op->GetInputTensorNames()[0];
+ auto output_tensor_name = op->GetOutputTensorNames()[0];
+
+ assert(IsVariableTensor(output_tensor_name));
+ auto variable_name = GetVariableTensorName(output_tensor_name);
+ mlir::Value input_val = tensor_map->at(input_tensor_name);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableWriteOp>(
+ loc, llvm::StringRef(variable_name), input_val);
+ block->push_back(mlir_op);
+}
+
template <class MLIR_OP>
std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseUnaryOp(
TosaSerializationOperator *op) const {
@@ -460,73 +644,8 @@ TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializationOperator *op) const {
const auto &output_name = op->GetOutputTensorNames()[0];
mlir::RankedTensorType output_type = tensor_type_map->at(output_name);
TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name);
- const auto &data = ts->GetData();
- auto &shape = ts->GetShape();
- // compute output data size
- uint32_t out_size = 1;
- for (const auto dim : shape) {
- out_size *= dim;
- }
- mlir::DenseElementsAttr value_attr;
- switch (ts->GetDtype()) {
- case DType_FP32: {
- std::vector<float> float_data;
- TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data));
- break;
- }
- case DType_INT4: {
- std::vector<int8_t> int4_data;
- TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data));
- break;
- }
- case DType_INT8: {
- std::vector<int8_t> int8_data;
- TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data));
- break;
- }
- case DType_INT16: {
- std::vector<int16_t> int16_data;
- TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data));
- break;
- }
- case DType_INT32: {
- std::vector<int32_t> int32_data;
- TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data));
- break;
- }
- case DType_INT48: {
- std::vector<int64_t> int48_data;
- TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data);
- std::vector<mlir::APInt> apint_data;
- for (const auto v : int48_data) {
- mlir::APInt apint_value(48, static_cast<uint64_t>(v),
- /* isSigned = */ false);
- apint_data.push_back(apint_value);
- }
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data));
- break;
- }
- case DType_BOOL: {
- std::vector<bool> bool_data;
- TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data);
- llvm::SmallVector<bool> bool_values(bool_data.begin(), bool_data.end());
- value_attr = mlir::DenseElementsAttr::get(output_type, bool_values);
- break;
- }
- default:
- llvm::errs() << "ERROR: " << get_string(op)
- << " contains unsupported element type\n";
+ auto value_attr = ConstructConstAttr(output_type, ts, get_string(op));
+ if (!value_attr) {
return {};
}
mlir::Operation *mlir_op =
@@ -1420,6 +1539,9 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
this, &tensor_map, &tensor_type_map);
for (auto ts : ser_block->GetTensors()) {
+ if (ts->GetVariable()) {
+ RegisterVariableTensor(ts);
+ }
mlir::RankedTensorType type;
if (BuildTensorType(op_builder, ts, type).failed()) {
return mlir::failure();
@@ -1432,27 +1554,6 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
}
}
- // Update operator_queue with operators whose inputs are all built
- auto queue_ready_operators = [&]() {
- // note: it is important to queue in order of original operators
- // because an operator input may be defined in upper value scopes
- for (auto consumer_op : ser_block->GetOperators()) {
- if (operator_built.count(consumer_op)) {
- continue;
- }
- bool all_inputs_ready = true;
- for (const auto &input_name : consumer_op->GetInputTensorNames()) {
- if (!tensor_map.count(input_name)) {
- all_inputs_ready = false;
- break;
- }
- }
- if (all_inputs_ready) {
- operator_queue.push(consumer_op);
- }
- }
- };
-
// Initialize tensor_map/operator_queue based on block input arguments
for (const std::string &block_input_name : ser_block->GetInputs()) {
mlir::Type type = tensor_type_map[block_input_name];
@@ -1470,11 +1571,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
tensor_map[block_input_name] = input_value;
}
- queue_ready_operators();
-
- while (!operator_queue.empty()) {
- TosaSerializationOperator *op = operator_queue.front();
- operator_queue.pop();
+ for (auto op : ser_block->GetOperators()) {
// skip if operator has been built
if (operator_built.count(op)) {
@@ -1485,7 +1582,10 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
operator_built.insert(op);
std::vector<mlir::Value> output_values;
- if (false) {
+ if (IsVariableReadOp(op)) {
+ output_values = tosa_op_builder.BuildVariableReadOp(op);
+ } else if (IsVariableWriteOp(op)) {
+ tosa_op_builder.BuildVariableWriteOp(op);
}
#define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \
else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \
@@ -1499,6 +1599,12 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
return mlir::failure();
}
+ if (IsVariableWriteOp(op)) {
+ // the sanity checking below does not apply for variable write op because
+ // it has no output tensors whereas the original identity op has
+ continue;
+ }
+
// Sanity check if number of built mlir::Value is expected
if (op->GetOutputTensorNames().size() != output_values.size()) {
llvm::errs() << "ERROR: number of built mlir::Value is not matching "
@@ -1516,8 +1622,6 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
}
tensor_map[op_output_name] = output_values[i];
}
- // look for any more ready consumers
- queue_ready_operators();
}
// Construct return values
@@ -1533,8 +1637,15 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
return_operands.push_back(output_value);
return_values.push_back(output_value);
}
- auto terminator_op =
- op_builder->create<mlir::func::ReturnOp>(loc, return_operands);
+ mlir::Operation *terminator_op;
+ auto parent_op = block->getParentOp();
+ if (mlir::isa<mlir::func::FuncOp>(parent_op)) {
+ terminator_op =
+ op_builder->create<mlir::func::ReturnOp>(loc, return_operands);
+ } else {
+ terminator_op =
+ op_builder->create<mlir::tosa::YieldOp>(loc, return_operands);
+ }
block->push_back(terminator_op);
// need topological sorting?
@@ -1624,7 +1735,20 @@ mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder,
builder.getStringAttr(names));
}
-// erase function attrs and empty function region'd body
+// erase all ops in block except for FuncOp
+void ClearNonFuncOps(mlir::Block *block) {
+ std::vector<mlir::Operation *> to_delete;
+ for (auto &op : block->getOperations()) {
+ if (!mlir::isa<mlir::func::FuncOp>(op)) {
+ to_delete.push_back(&op);
+ }
+ }
+ for (mlir::Operation *op : to_delete) {
+ op->erase();
+ }
+}
+
+// erase function attrs and empty function region's body
void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) {
function->setAttrs(mlir::DictionaryAttr::get(&context, {}));
mlir::Region *main_region = function.getCallableRegion();
@@ -1637,6 +1761,9 @@ mlir::LogicalResult CloneIntoModuleAndFunction(
mlir::MLIRContext &context, mlir::func::FuncOp &to_function,
mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function,
mlir::ModuleOp &from_module) {
+ auto from_block = from_function.getOperation()->getBlock();
+ auto to_block = to_function.getOperation()->getBlock();
+ ClearNonFuncOps(to_block);
// copy all attrs from new_module to module
to_module->setAttrs(from_module->getAttrDictionary());
// erase attrs and body of function
@@ -1644,6 +1771,22 @@ mlir::LogicalResult CloneIntoModuleAndFunction(
// clone new_func attrs and region into function
mlir::IRMapping mapping;
from_function.cloneInto(to_function, mapping);
+
+ // copy variable ops in from_block to to_block
+ // collect variable ops in from_block in reverse order
+ std::vector<mlir::Operation *> variable_ops;
+ for (mlir::Operation &op : *from_block) {
+ if (mlir::isa<mlir::tosa::VariableOp>(op)) {
+ variable_ops.push_back(&op);
+ }
+ }
+ auto cloneOptions =
+ mlir::Operation::CloneOptions::all().cloneRegions(false).cloneOperands(
+ false);
+ for (auto iter = variable_ops.rbegin(); iter != variable_ops.rend(); iter++) {
+ auto op = *iter;
+ to_block->push_front(op->clone(mapping, cloneOptions));
+ }
return mlir::success();
}
@@ -1745,6 +1888,11 @@ BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context,
mlir::ArrayAttr::get(
context, {mlir::StringAttr::get(context, kDefaultExportedName)}));
+ // deserialize variable ops in the new module just before adding func op
+ if (ConstructVariableOps(module).failed()) {
+ return nullptr;
+ }
+
// add func to module
module.push_back(std::move(func));
return mlir::OwningOpRef<mlir::ModuleOp>(module);
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 45a1b37..8aef8fd 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -210,7 +210,8 @@ private:
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op);
TosaSerializationTensor *
BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type,
- const std::string &name);
+ const std::string &name,
+ const std::string &variable_mlir_name);
TosaSerializationTensor *
BuildTosaSerializationTensor(mlir::Value val, const std::string &name);
@@ -305,7 +306,6 @@ static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) {
std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
mlir::Operation *op) const {
- mlir::Attribute variable_op_name_attr = op->getAttr("name");
std::string variable_tensor_mlir_name =
op->getAttr("name").cast<mlir::StringAttr>().getValue().str();
@@ -1766,8 +1766,13 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
.cast<mlir::TypeAttr>()
.getValue()
.cast<mlir::RankedTensorType>();
+
+ std::string variable_mlir_name =
+ op->getAttr("name").cast<mlir::StringAttr>().getValue().str();
+
ser_tensor = BuildTosaSerializationVariableTensor(
- tensor_type /* tensor_type */, pair.first /* flatbuffer name */);
+ tensor_type /* tensor_type */, pair.first /* flatbuffer name */,
+ variable_mlir_name);
if (!ser_tensor) {
llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
return mlir::failure();
@@ -1897,7 +1902,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
TosaSerializationTensor *
TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor(
- mlir::RankedTensorType tensor_type, const std::string &name) {
+ mlir::RankedTensorType tensor_type, const std::string &name,
+ const std::string &variable_mlir_name) {
// If tensor already created before, use that tensor directly, create a new
// one otherwise
TosaSerializationTensor *ts = ser_block->GetTensorByName(name);
@@ -1912,7 +1918,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor(
ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
/* is_variable = */ true,
- /* is_unranked = */ false);
+ /* is_unranked = */ false,
+ variable_mlir_name);
return ts;
}
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject c0a60300951a59c33d2afaea0f6ca0889cabf34
+Subproject 5917fc7a9392da8fd1e8c68b2d00b89709a3158