diff options
23 files changed, 398 insertions, 140 deletions
diff --git a/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa b/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa Binary files differindex c240a91..03050de 100644 --- a/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa +++ b/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa diff --git a/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa b/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa Binary files differindex c240a91..03050de 100644 --- a/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa +++ b/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa diff --git a/examples/test_add_1x4x4x4_f32/model.pb b/examples/test_add_1x4x4x4_f32/model.pb index 479979c..f76ec12 100644 --- a/examples/test_add_1x4x4x4_f32/model.pb +++ b/examples/test_add_1x4x4x4_f32/model.pb @@ -92,5 +92,5 @@ node { } } versions { - producer: 1247 + producer: 1286 } diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa Binary files differindex c4224d3..70d269f 100644 --- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa +++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa Binary files differindex 98480be..3ddbae8 100644 --- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa +++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb index e95ba45..3cddba7 100644 --- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb +++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb @@ -137,5 +137,5 @@ node { } } versions { - producer: 1247 + producer: 1286 } diff --git a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa Binary files differindex 5570543..ab8bb78 100644 --- a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa +++ b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc index 5f85932..1781e40 100644 --- a/reference_model/src/graph_node.cc +++ b/reference_model/src/graph_node.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const clearOnNextNodeList(); setRequiredOperands(-1, -1); setRequiredRank(-1); + inMainBlock = false; } GraphNode::~GraphNode() diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index 874f1d8..b227d17 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -163,7 +163,6 @@ public: int addInputTensor(Tensor* tens); int addOutputTensor(Tensor* tens); - // Validate that the input tensors match properly // in their types, attributes, rank, etc well enough to be // processed. @@ -252,6 +251,22 @@ public: return nodeType; } + SubgraphTraverser* getParentSGT() + { + return parent_sgt; + } + + int setInMainBlock(bool isInMainBlock) + { + inMainBlock = isInMainBlock; + return 0; + } + + bool getInMainBlock() + { + return inMainBlock; + } + // Helper functions. int idiv_check(int input1, int input2, int& result); @@ -328,6 +343,8 @@ protected: // -1 means n/a int requiredRankMin; int requiredRankMax; + + bool inMainBlock; }; }; // namespace TosaReference diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 0941f2b..0375a48 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -83,7 +83,7 @@ int main(int argc, char** argv) FATAL_ERROR("Unable to load graph"); } - SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh); + SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr); if (main_gt.initializeGraph()) { diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc index 1109dd6..fa39c75 100644 --- a/reference_model/src/model_runner_impl.cc +++ b/reference_model/src/model_runner_impl.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2022, ARM Limited. +// Copyright (c) 2022-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug) GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler) { validateTosaVersion(serialization_handler); - return initialize(serialization_handler.GetMainBlock(), &serialization_handler); + return initialize(serialization_handler.GetMainRegion()->GetBlocks()[0], &serialization_handler); } GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock& bb) @@ -284,7 +284,7 @@ GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb, // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. _main_gt = nullptr; - _main_gt = new SubgraphTraverser(bb, serialization_handler); + _main_gt = new SubgraphTraverser(bb, serialization_handler, nullptr); if (_main_gt == nullptr) { diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index dfad9b8..af348ca 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2022, ARM Limited. +// Copyright (c) 2022-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -94,7 +94,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("argmax", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("argmax", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -137,7 +137,7 @@ extern "C" &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("avg_pool2d", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("avg_pool2d", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -185,7 +185,7 @@ extern "C" { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("conv2d", { op }, { input, weight, bias, output }, + tosa::TosaSerializationBasicBlock block("conv2d", "main", { op }, { input, weight, bias, output }, { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); @@ -236,7 +236,7 @@ extern "C" { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("conv3d", { op }, { input, weight, bias, output }, + tosa::TosaSerializationBasicBlock block("conv3d", "main", { op }, { input, weight, bias, output }, { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); @@ -287,7 +287,7 @@ extern "C" { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("depthwise_conv2d", { op }, { input, weight, bias, output }, + tosa::TosaSerializationBasicBlock block("depthwise_conv2d", "main", { op }, { input, weight, bias, output }, { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); @@ -328,8 +328,8 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("fully_connected", { op }, { input, output }, { input->GetName() }, - { output->GetName() }); + tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, output }, + { input->GetName() }, { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner; @@ -367,8 +367,8 @@ extern "C" &attr, { a->GetName(), b->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("matmul", { op }, { a, b, output }, { a->GetName(), b->GetName() }, - { output->GetName() }); + tosa::TosaSerializationBasicBlock block("matmul", "main", { op }, { a, b, output }, + { a->GetName(), b->GetName() }, { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner; @@ -411,7 +411,7 @@ extern "C" &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("max_pool2d", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("max_pool2d", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -463,7 +463,7 @@ extern "C" { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("transpose_conv2d", { op }, { input, weight, bias, output }, + tosa::TosaSerializationBasicBlock block("transpose_conv2d", "main", { op }, { input, weight, bias, output }, { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); @@ -506,7 +506,7 @@ extern "C" &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("clamp", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("clamp", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -537,7 +537,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("sigmoid", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("sigmoid", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -568,7 +568,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("tanh", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("tanh", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -600,7 +600,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("add", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("add", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -638,7 +638,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("arithmetic_right_shift", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("arithmetic_right_shift", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -672,7 +672,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("bitwise_and", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("bitwise_and", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -706,7 +706,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("bitwise_or", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("bitwise_or", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -740,7 +740,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("bitwise_xor", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("bitwise_xor", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -773,7 +773,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("intdiv", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("intdiv", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -807,7 +807,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("logical_and", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("logical_and", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -843,7 +843,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("logical_left_shift", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("logical_left_shift", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -879,7 +879,7 @@ extern "C" &attr, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("logical_right_shift", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("logical_right_shift", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -913,7 +913,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("logical_or", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("logical_or", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -947,7 +947,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("logical_xor", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("logical_xor", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -981,7 +981,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("maximum", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("maximum", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -1015,7 +1015,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("minimum", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("minimum", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -1052,7 +1052,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("mul", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("mul", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -1085,7 +1085,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("pow", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("pow", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -1118,7 +1118,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("sub", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("sub", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -1154,7 +1154,7 @@ extern "C" &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("table", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("table", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -1185,7 +1185,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("abs", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("abs", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1216,8 +1216,8 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("bitwise_not", { op }, { input1, output }, { input1->GetName() }, - { output->GetName() }); + tosa::TosaSerializationBasicBlock block("bitwise_not", "main", { op }, { input1, output }, + { input1->GetName() }, { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner; @@ -1247,7 +1247,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("ceil", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("ceil", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1278,7 +1278,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("clz", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("clz", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1309,7 +1309,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("exp", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("exp", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1340,7 +1340,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("floor", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("floor", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1371,7 +1371,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("log", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("log", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1402,8 +1402,8 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("logical_not", { op }, { input1, output }, { input1->GetName() }, - { output->GetName() }); + tosa::TosaSerializationBasicBlock block("logical_not", "main", { op }, { input1, output }, + { input1->GetName() }, { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner; @@ -1438,7 +1438,7 @@ extern "C" &attr, { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("negate", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("negate", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1469,7 +1469,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reciprocal", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("reciprocal", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1500,7 +1500,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("rsqrt", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("rsqrt", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1537,7 +1537,7 @@ extern "C" { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("select", { op }, { input1, input2, input3, output }, + tosa::TosaSerializationBasicBlock block("select", "main", { op }, { input1, input2, input3, output }, { input1->GetName(), input2->GetName(), input3->GetName() }, { output->GetName() }); @@ -1572,7 +1572,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("equal", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("equal", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -1606,7 +1606,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("greater", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("greater", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -1641,7 +1641,7 @@ extern "C" { input1->GetName(), input2->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("greater_equal", { op }, { input1, input2, output }, + tosa::TosaSerializationBasicBlock block("greater_equal", "main", { op }, { input1, input2, output }, { input1->GetName(), input2->GetName() }, { output->GetName() }); // Setup model @@ -1674,7 +1674,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reduce_all", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("reduce_all", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -1706,7 +1706,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reduce_any", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("reduce_any", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -1738,7 +1738,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reduce_max", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("reduce_max", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -1770,7 +1770,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reduce_min", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("reduce_min", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -1802,8 +1802,8 @@ extern "C" &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reduce_product", { op }, { input, output }, { input->GetName() }, - { output->GetName() }); + tosa::TosaSerializationBasicBlock block("reduce_product", "main", { op }, { input, output }, + { input->GetName() }, { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner; @@ -1834,7 +1834,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reduce_sum", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("reduce_sum", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -1865,7 +1865,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("concat", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("concat", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1904,7 +1904,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("pad", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("pad", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1939,7 +1939,7 @@ extern "C" &attr, { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reshape", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -1970,7 +1970,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reverse", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("reverse", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -2008,7 +2008,7 @@ extern "C" &attr, { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("slice", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("slice", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -2045,7 +2045,7 @@ extern "C" &attr, { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("tile", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("tile", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -2081,7 +2081,7 @@ extern "C" &attr, { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("transpose", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("transpose", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model @@ -2114,7 +2114,7 @@ extern "C" { values->GetName(), indices->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("gather", { op }, { values, indices, output }, + tosa::TosaSerializationBasicBlock block("gather", "main", { op }, { values, indices, output }, { values->GetName(), indices->GetName() }, { output->GetName() }); // Setup model @@ -2152,7 +2152,7 @@ extern "C" { values_out->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("scatter", { op }, { values_in, indices, input, values_out }, + tosa::TosaSerializationBasicBlock block("scatter", "main", { op }, { values_in, indices, input, values_out }, { values_in->GetName(), indices->GetName(), input->GetName() }, { values_out->GetName() }); @@ -2195,7 +2195,7 @@ extern "C" &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("resize", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("resize", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -2226,7 +2226,7 @@ extern "C" { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("cast", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("cast", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -2274,7 +2274,7 @@ extern "C" &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("rescale", { op }, { input, output }, { input->GetName() }, + tosa::TosaSerializationBasicBlock block("rescale", "main", { op }, { input, output }, { input->GetName() }, { output->GetName() }); // Setup model @@ -2305,7 +2305,7 @@ extern "C" { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("identity", { op }, { input1, output }, { input1->GetName() }, + tosa::TosaSerializationBasicBlock block("identity", "main", { op }, { input1, output }, { input1->GetName() }, { output->GetName() }); // Setup model diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 7105caf..942652d 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ #include "control_flow.h" #include "subgraph_traverser.h" - using namespace TosaReference; using namespace Eigen; using namespace tosa; @@ -37,7 +36,7 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block, DEBUG_MED(OP, "Evaluating block %s", block_name.c_str()); - SubgraphTraverser block_sgt(block, tsh); + SubgraphTraverser block_sgt(block, tsh, this->parent_sgt); ERROR_IF(block_sgt.initializeGraph(), "evalBlock(): Unable to initialize graph traverser for %s", block_name.c_str()); @@ -182,8 +181,10 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); - then_block = tsh->GetBlockByName(attribute->then_branch()); - else_block = tsh->GetBlockByName(attribute->else_branch()); + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + then_block = curr_region->GetBlockByName(attribute->then_branch()); + else_block = curr_region->GetBlockByName(attribute->else_branch()); ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); @@ -193,6 +194,7 @@ int OpCondIf::checkTensorAttributes() // Skip the first rank 0 bool tensor on input list int32_t num_input_tensor = getInputs().size() - 1; int32_t num_output_tensor = getOutputs().size(); + ERROR_IF((int32_t)then_block->GetInputs().size() != num_input_tensor, "OpCondIf: then_block has unexpected number of input"); ERROR_IF((int32_t)else_block->GetInputs().size() != num_input_tensor, @@ -307,8 +309,10 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - cond_block = tsh->GetBlockByName(attribute->cond_branch()); - body_block = tsh->GetBlockByName(attribute->body_branch()); + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + cond_block = curr_region->GetBlockByName(attribute->cond_branch()); + body_block = curr_region->GetBlockByName(attribute->body_branch()); ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); @@ -403,12 +407,7 @@ int OpWhileLoop::eval() // assigning output tensors value back to input tensors value for next iteration for (size_t i = 0; i < num_input_output; i++) { - if (getInputs()[i]->copyValueFrom(getOutputs()[i])) - { - WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(), - getInputs()[i]->getName().c_str()); - return 1; - } + getInputs()[i] = getOutputs()[i]; } } else diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 112e641..8867ada 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,13 +37,14 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh) +SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt) { - graph_status = GraphStatus::TOSA_VALID; + graph_status = GraphStatus::TOSA_VALID; block = _block; - tsh = _tsh; + tsh = _tsh; + parent_sgt = _parent_sgt; tensors.clear(); nodes.clear(); nextNodeList.clear(); @@ -120,6 +121,17 @@ int SubgraphTraverser::initializeGraph() { int idx = 0; + std::vector<TosaSerializationTensor*> ser_tensor_vec; + // Get all the serialized tensors from TosaSerializationHandler. + for (auto block: tsh->GetMainRegion()->GetBlocks()) + { + for (auto ser_tensor : block->GetTensors()) + { + ser_tensor_vec.push_back(ser_tensor); + } + } + + std::vector<GraphNode*> non_const_node_vec; for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode @@ -159,7 +171,13 @@ int SubgraphTraverser::initializeGraph() EnumNamesOp()[op->GetOp()], input_index); std::string input_name = op->GetInputTensorNames()[input_index]; - TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name); + TosaSerializationTensor* input_tensor = nullptr; + for (auto ser_tensor : ser_tensor_vec) { + if (ser_tensor->GetName() == input_name) { + input_tensor = ser_tensor; + } + } + SUBGRAPH_ERROR_IF( !input_tensor, "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler", @@ -175,7 +193,13 @@ int SubgraphTraverser::initializeGraph() "SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()], weight_index); std::string weight_name = op->GetInputTensorNames()[weight_index]; - TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name); + TosaSerializationTensor* weight_tensor = nullptr; + for (auto ser_tensor : ser_tensor_vec) { + if (ser_tensor->GetName() == weight_name) { + weight_tensor = ser_tensor; + } + } + SUBGRAPH_ERROR_IF( !weight_tensor, "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler", @@ -199,8 +223,19 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); - GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, + GraphNode* node = nullptr; + if (this->parent_sgt) { + node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, + input_rank, output_dtype, output_rank, weight_dtype, weight_rank); + node->setInMainBlock(false); + } else { + node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank, output_dtype, output_rank, weight_dtype, weight_rank); + if (node) { + node->setInMainBlock(true); + } + } + if (!node) { if (weight_index == -1) @@ -257,6 +292,8 @@ int SubgraphTraverser::initializeGraph() if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList()) { addToNextNodeList(node); + } else if (!node->getInMainBlock()) { + non_const_node_vec.push_back(node); } idx++; @@ -271,7 +308,6 @@ int SubgraphTraverser::initializeGraph() SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); - // update this->tensors addTensor(tensor); } @@ -296,7 +332,7 @@ int SubgraphTraverser::initializeGraph() for (auto& output_name : block->GetOutputs()) { TosaReference::Tensor* tensor = findTensorByName(output_name); - DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); + DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str()); if (tensor) { tensor->setIsSubgraphOutput(); @@ -314,6 +350,22 @@ int SubgraphTraverser::initializeGraph() dumpNextNodeList(g_func_debug.func_debug_file); } + // If the node is not in mainblock and not const + for (auto node : non_const_node_vec) { + bool all_inputs_from_parent = true; + for (std::string& name : node->getInputNames()) + { + TosaReference::Tensor* t = findTensorByName(name); + if (!t->getIsParentGraphOutput()) { + all_inputs_from_parent = false; + } + } + // In the children block, when a node has all its inputs from parent + // block, we have to manually add this node to the evaluation list + if (all_inputs_from_parent && !node->getOnNextNodeList()) { + addToNextNodeList(node); + } + } return 0; } @@ -510,29 +562,40 @@ int SubgraphTraverser::evaluateNextNode() } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output - for (auto tensor : currNode->getInputs()) - { - bool in_use = false; - for (auto node : tensor->getConsumers()) + if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks + for (auto tensor : currNode->getInputs()) { - if (!node->hasAllOutputsReady()) + bool in_use = false; + + auto tensor_check = findTensorByName(tensor->getName()); + if (tensor_check->getIsParentGraphOutput()) { + // if it's parent's block output tensor, we can't free it + continue; + } + + for (auto node : tensor->getConsumers()) { - in_use = true; + // If the node is inside a loop, the input tensor is still needed + if (!node->hasAllOutputsReady()) + { + in_use = true; + } + } - } - for (auto name : block->GetOutputs()) - { - if (name == tensor->getName()) + for (auto name : block->GetOutputs()) { - in_use = true; + if (name == tensor->getName()) + { + in_use = true; + } + } + + if (!in_use) + { + tensor->deallocate(); } - } - if (!in_use) - { - tensor->deallocate(); } } - // Search the output tensors of this node to see if // there are now new ready nodes available from completing this node for (TosaReference::Tensor* tensor : currNode->getOutputs()) @@ -642,17 +705,35 @@ int SubgraphTraverser::addNode(GraphNode* newNode) TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const { + TosaReference::Tensor* res_tensor = nullptr; + for (TosaReference::Tensor* currTensor : tensors) { if (currTensor->getName() == name) { - return currTensor; + res_tensor = currTensor; + return res_tensor; } } - WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str()); + if (parent_sgt) + { + for (TosaReference::Tensor* currTensor : parent_sgt->tensors) + { + if (currTensor->getName() == name) + { + res_tensor = currTensor; + res_tensor->setIsParentGraphOutput(); + } + } + } - return nullptr; + if (!res_tensor) + { + WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str()); + return nullptr; + } + return res_tensor; } int SubgraphTraverser::linkTensorsAndNodes() @@ -704,7 +785,6 @@ int SubgraphTraverser::validateGraph() for (TosaReference::Tensor* currTensor : tensors) { - // It's okay for block input tensor not being consumed by operators. // This is common in control flow op execution. if (!currTensor->getIsSubgraphInput()) diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h index 7940ee4..543b008 100644 --- a/reference_model/src/subgraph_traverser.h +++ b/reference_model/src/subgraph_traverser.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ namespace TosaReference class SubgraphTraverser { public: - SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh); + SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh, SubgraphTraverser* parent_sgt); ~SubgraphTraverser(); int initializeGraph(); @@ -59,6 +59,10 @@ public: { return block->GetName(); } + std::string getRegionName() const + { + return block->GetRegionName(); + } int getNumInputTensors() const; Tensor* getInputTensor(const unsigned int idx) const; Tensor* getInputTensorByName(const std::string name) const; @@ -77,6 +81,10 @@ private: GraphStatus graph_status; + // pointer to the parent subgraph traversal if exists + // e.g., Control Flow Ops will have nested blocks (subgraph traversals) + SubgraphTraverser* parent_sgt; + // pointer to serialization library and corresponding basic block TosaSerializationBasicBlock* block; TosaSerializationHandler* tsh; diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 0678bbd..7af2069 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -32,11 +32,18 @@ TosaReference::Tensor::Tensor(std::string tensorName_, DType tensorDtype_, std:: consumers.clear(); isSubgraphInput = false; isSubgraphOutput = false; + isParentGraphOutput = false; } TosaReference::Tensor::~Tensor() {} +int TosaReference::Tensor::setIsParentGraphOutput() +{ + isParentGraphOutput = true; + return 0; +} + int TosaReference::Tensor::setIsSubgraphInput() { isSubgraphInput = true; diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 08e865a..d5f1de8 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -40,6 +40,11 @@ public: int setIsSubgraphInput(); int setIsSubgraphOutput(); + int setIsParentGraphOutput(); + + int getIsParentGraphOutput() const { + return isParentGraphOutput; + } int getIsSubgraphInput() const { @@ -261,6 +266,8 @@ protected: int isSubgraphOutput; bool isAllocated; + bool isParentGraphOutput; + GraphNode* producer; std::vector<GraphNode*> consumers; diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index 6b0ed6e..3f2acb5 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -1,5 +1,5 @@ -// Copyright (c) 2022, ARM Limited. +// Copyright (c) 2022-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -64,11 +64,11 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { switch(mode) { case tosa_mode_nearest: return tosa::ResizeMode_NEAREST; - case tosa_mode_max: + case tosa_mode_max: case tosa_mode_bilinear: return tosa::ResizeMode_BILINEAR; default: - return tosa::ResizeMode_UNKNOWN; + return tosa::ResizeMode_UNKNOWN; } } @@ -131,7 +131,7 @@ extern "C" }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op }, + tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op }, { {%- for input in operator.inputs: -%} {{input}}, diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib -Subproject 6388a097de4350cc70472921c272074190fd7c9 +Subproject ca7ce0e94b3ee7339f31b47baa3a3fb4522243a diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 90bda34..767989e 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import numpy as np import tensorflow as tf @@ -252,3 +252,25 @@ class TGen: tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng))) return tf_placeholders, tf_consts + + @staticmethod + def tgRecurrent(op, ifm_shape, dtype, rng): + # Require rank 3 shape for recurrent networks + if len(ifm_shape) != 3: + return [], [] + pl, const = op["operands"] + + tf_placeholders = [] + tf_consts = [] + + for i in range(pl): + tf_placeholders.append( + ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng)) + ) + + for i in range(const): + tf_consts.append( + ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng)) + ) + + return tf_placeholders, tf_consts diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index c7c5cd7..8870f41 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import numpy as np import tensorflow as tf @@ -1164,3 +1164,82 @@ class TBuilder: def eval(self, a): return tf.bitwise.right_shift(a, self.shift, name=self.result_name) + + class While: + def __init__(self, name): + self.result_name = name + + def while_cond(self, x): + return tf.reduce_sum(x) < self.cap + + def while_body(self, x): + return tf.add(x, tf.math.sigmoid(x)) + + def eval(self, a): + self.cap = tf.cast( + tf.constant( + 2.0, + shape=[ + 1, + ], + ), + a.dtype, + ) + + result = tf.while_loop( + self.while_cond, self.while_body, [a], name=self.result_name + ) + + return result[0] + + class LSTM: + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.LSTM( + 2, + activation="tanh", + unroll=False, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + + class GRU: + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.GRU( + 2, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + + class RNN: + def __init__(self, name): + self.result_name = name + basic_cell = tf.keras.layers.SimpleRNNCell( + units=2, + activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + ) + self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False) + + def eval(self, a): + return self.rnn(a) + + class FullyConnected: + def __init__(self, name): + self.result_name = name + self.dense = tf.keras.layers.Dense(2) + + def eval(self, a): + return self.dense(a) diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 760def6..26af5dd 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -807,6 +807,42 @@ TF_OP_LIST = { ] }, }, + "while": { + "operands": (1, 0), + "build_fcn": (TBuilder.While, TGen.tgBasic, ArgGen.agNone), + "types": { + "tflite": list(TYPE_F), + }, + }, + "lstm": { + "operands": (1, 0), + "build_fcn": (TBuilder.LSTM, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + # tf.int32 + ] + }, + }, + "gru": { + "operands": (1, 0), + "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + # tf.int32 + ] + }, + }, + "rnn": { + "operands": (1, 0), + "build_fcn": (TBuilder.RNN, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + ] + }, + }, } # Shapes to be tested; default can be overwritten diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 515e8bb..d799eb0 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import os from copy import deepcopy @@ -1845,7 +1845,7 @@ class TosaTestGen: # Finally, build the op and the two blocks self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr) - self.ser.startBasicBlock(then_block) + self.ser.addBasicBlock(then_block) # Build the actual then/else tensors inside their blocks if error_name == ErrorIf.CondIfOutputListThenGraphMismatch: then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) @@ -1853,7 +1853,7 @@ class TosaTestGen: then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr) self.ser.addOutputTensor(then_tens) - self.ser.startBasicBlock(else_block) + self.ser.addBasicBlock(else_block) if error_name == ErrorIf.CondIfOutputListElseGraphMismatch: else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) else: @@ -1865,7 +1865,7 @@ class TosaTestGen: validator_fcns, error_name, op=op, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, cond=cond_tens, ): return None @@ -1914,7 +1914,7 @@ class TosaTestGen: assert False, f"No tests for DType: {a.dtype}" for block, op in ((then_block, then_op), (else_block, else_op)): - self.ser.startBasicBlock(block) + self.ser.addBasicBlock(block) if ( error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block @@ -1948,7 +1948,7 @@ class TosaTestGen: op=op, a=a, b=b, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, cond=cond_tens, ): return None @@ -2005,7 +2005,8 @@ class TosaTestGen: incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3]) # COND block (input: iter, output: cond_tens ) - self.ser.startBasicBlock(cond_block) + self.ser.addBasicBlock(cond_block) + if error_name == ErrorIf.InputListCondGraphMismatch: self.ser.addInputTensor(incorrect_iter) self.ser.addInputTensor(a) @@ -2034,7 +2035,8 @@ class TosaTestGen: # BODY block (input: a, acc, iter, output: a, acc, iter) # Note that local intermediate tensors need to be declared here for the outputs - self.ser.startBasicBlock(body_block) + self.ser.addBasicBlock(body_block) + if error_name == ErrorIf.InputListBodyGraphInputMismatch: self.ser.addInputTensor(incorrect_iter) self.ser.addInputTensor(a) @@ -2068,7 +2070,7 @@ class TosaTestGen: validator_fcns, error_name, op=op, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, ): return None |