aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2022-10-27 09:57:00 -0700
committerEric Kunze <eric.kunze@arm.com>2023-01-13 19:09:21 +0000
commit9e94af8f10f0a21a117b3bc7ea42004844fdc3bb (patch)
tree868ab73bb67d4827963a4b43f28d8a8a49f50307
parentdd8d9c251db0fece6453d86116052ad7f3e2d697 (diff)
downloadreference_model-9e94af8f10f0a21a117b3bc7ea42004844fdc3bb.tar.gz
Reference model update for control flow operators support
Rationale for making this change: - In the original design, for control flow operators like WhileOp, child blocks couldn't read the tensor variables (global consts) in the root level block, this patch added the machanism for child blocks to access their parent level block's tensors. - This change also relies on another serialization change on adding another layer of abtraction called Region: - Serialization patch: [region] Add TosaSerializationRegion to serialization_lib - Updated the corresponding python version of the serialization code: TosaSerializerRegion to python version of serialization_lib - This change also relies on the TOSA MLIR Translator change: Add RegionBuilder to TOSA MLIR Translator - Added the WhileOp related test cases: While, LSTM, GRU, RNN - Other related fixes Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I13ae33628ad07e41d248e88652ce1328654694ab
-rw-r--r--examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosabin492 -> 524 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosabin492 -> 524 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/model.pb2
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosabin1584 -> 1616 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosabin1296 -> 1328 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb2
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosabin1232 -> 1264 bytes
-rw-r--r--reference_model/src/graph_node.cc3
-rw-r--r--reference_model/src/graph_node.h21
-rw-r--r--reference_model/src/main.cpp4
-rw-r--r--reference_model/src/model_runner_impl.cc6
-rw-r--r--reference_model/src/operators.cc138
-rw-r--r--reference_model/src/ops/control_flow.cc25
-rw-r--r--reference_model/src/subgraph_traverser.cc138
-rw-r--r--reference_model/src/subgraph_traverser.h12
-rw-r--r--reference_model/src/tensor.cc9
-rw-r--r--reference_model/src/tensor.h9
-rw-r--r--scripts/operator_api/templates/operators_cc.j28
m---------thirdparty/serialization_lib0
-rw-r--r--verif/frameworks/tensor_gen.py24
-rw-r--r--verif/frameworks/test_builder.py81
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py36
-rw-r--r--verif/generator/tosa_test_gen.py20
23 files changed, 398 insertions, 140 deletions
diff --git a/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa b/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa
index c240a91..03050de 100644
--- a/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa
+++ b/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa b/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa
index c240a91..03050de 100644
--- a/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa
+++ b/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/model.pb b/examples/test_add_1x4x4x4_f32/model.pb
index 479979c..f76ec12 100644
--- a/examples/test_add_1x4x4x4_f32/model.pb
+++ b/examples/test_add_1x4x4x4_f32/model.pb
@@ -92,5 +92,5 @@ node {
}
}
versions {
- producer: 1247
+ producer: 1286
}
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
index c4224d3..70d269f 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
index 98480be..3ddbae8 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
index e95ba45..3cddba7 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
@@ -137,5 +137,5 @@ node {
}
}
versions {
- producer: 1247
+ producer: 1286
}
diff --git a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
index 5570543..ab8bb78 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc
index 5f85932..1781e40 100644
--- a/reference_model/src/graph_node.cc
+++ b/reference_model/src/graph_node.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const
clearOnNextNodeList();
setRequiredOperands(-1, -1);
setRequiredRank(-1);
+ inMainBlock = false;
}
GraphNode::~GraphNode()
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index 874f1d8..b227d17 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -163,7 +163,6 @@ public:
int addInputTensor(Tensor* tens);
int addOutputTensor(Tensor* tens);
-
// Validate that the input tensors match properly
// in their types, attributes, rank, etc well enough to be
// processed.
@@ -252,6 +251,22 @@ public:
return nodeType;
}
+ SubgraphTraverser* getParentSGT()
+ {
+ return parent_sgt;
+ }
+
+ int setInMainBlock(bool isInMainBlock)
+ {
+ inMainBlock = isInMainBlock;
+ return 0;
+ }
+
+ bool getInMainBlock()
+ {
+ return inMainBlock;
+ }
+
// Helper functions.
int idiv_check(int input1, int input2, int& result);
@@ -328,6 +343,8 @@ protected:
// -1 means n/a
int requiredRankMin;
int requiredRankMax;
+
+ bool inMainBlock;
};
}; // namespace TosaReference
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 0941f2b..0375a48 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -83,7 +83,7 @@ int main(int argc, char** argv)
FATAL_ERROR("Unable to load graph");
}
- SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh);
+ SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr);
if (main_gt.initializeGraph())
{
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index 1109dd6..fa39c75 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2022, ARM Limited.
+// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -45,7 +45,7 @@ void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug)
GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler)
{
validateTosaVersion(serialization_handler);
- return initialize(serialization_handler.GetMainBlock(), &serialization_handler);
+ return initialize(serialization_handler.GetMainRegion()->GetBlocks()[0], &serialization_handler);
}
GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock& bb)
@@ -284,7 +284,7 @@ GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb,
// Make nullptr in case ModelRunnerImpl is being initialized again with a different graph.
_main_gt = nullptr;
- _main_gt = new SubgraphTraverser(bb, serialization_handler);
+ _main_gt = new SubgraphTraverser(bb, serialization_handler, nullptr);
if (_main_gt == nullptr)
{
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index dfad9b8..af348ca 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2022, ARM Limited.
+// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -94,7 +94,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("argmax", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("argmax", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -137,7 +137,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("avg_pool2d", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("avg_pool2d", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -185,7 +185,7 @@ extern "C"
{ output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("conv2d", { op }, { input, weight, bias, output },
+ tosa::TosaSerializationBasicBlock block("conv2d", "main", { op }, { input, weight, bias, output },
{ input->GetName(), weight->GetName(), bias->GetName() },
{ output->GetName() });
@@ -236,7 +236,7 @@ extern "C"
{ output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("conv3d", { op }, { input, weight, bias, output },
+ tosa::TosaSerializationBasicBlock block("conv3d", "main", { op }, { input, weight, bias, output },
{ input->GetName(), weight->GetName(), bias->GetName() },
{ output->GetName() });
@@ -287,7 +287,7 @@ extern "C"
{ input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("depthwise_conv2d", { op }, { input, weight, bias, output },
+ tosa::TosaSerializationBasicBlock block("depthwise_conv2d", "main", { op }, { input, weight, bias, output },
{ input->GetName(), weight->GetName(), bias->GetName() },
{ output->GetName() });
@@ -328,8 +328,8 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("fully_connected", { op }, { input, output }, { input->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, output },
+ { input->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -367,8 +367,8 @@ extern "C"
&attr, { a->GetName(), b->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("matmul", { op }, { a, b, output }, { a->GetName(), b->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("matmul", "main", { op }, { a, b, output },
+ { a->GetName(), b->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -411,7 +411,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("max_pool2d", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("max_pool2d", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -463,7 +463,7 @@ extern "C"
{ input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("transpose_conv2d", { op }, { input, weight, bias, output },
+ tosa::TosaSerializationBasicBlock block("transpose_conv2d", "main", { op }, { input, weight, bias, output },
{ input->GetName(), weight->GetName(), bias->GetName() },
{ output->GetName() });
@@ -506,7 +506,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("clamp", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("clamp", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -537,7 +537,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("sigmoid", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("sigmoid", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -568,7 +568,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("tanh", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("tanh", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -600,7 +600,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("add", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("add", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -638,7 +638,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("arithmetic_right_shift", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("arithmetic_right_shift", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -672,7 +672,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("bitwise_and", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("bitwise_and", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -706,7 +706,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("bitwise_or", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("bitwise_or", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -740,7 +740,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("bitwise_xor", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("bitwise_xor", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -773,7 +773,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("intdiv", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("intdiv", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -807,7 +807,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_and", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_and", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -843,7 +843,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_left_shift", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_left_shift", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -879,7 +879,7 @@ extern "C"
&attr, { input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_right_shift", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_right_shift", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -913,7 +913,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_or", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_or", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -947,7 +947,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_xor", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_xor", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -981,7 +981,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("maximum", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("maximum", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1015,7 +1015,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("minimum", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("minimum", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1052,7 +1052,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("mul", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("mul", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1085,7 +1085,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("pow", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("pow", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1118,7 +1118,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("sub", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("sub", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1154,7 +1154,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("table", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("table", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1185,7 +1185,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("abs", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("abs", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1216,8 +1216,8 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("bitwise_not", { op }, { input1, output }, { input1->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("bitwise_not", "main", { op }, { input1, output },
+ { input1->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -1247,7 +1247,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("ceil", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("ceil", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1278,7 +1278,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("clz", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("clz", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1309,7 +1309,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("exp", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("exp", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1340,7 +1340,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("floor", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("floor", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1371,7 +1371,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("log", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("log", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1402,8 +1402,8 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_not", { op }, { input1, output }, { input1->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("logical_not", "main", { op }, { input1, output },
+ { input1->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -1438,7 +1438,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("negate", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("negate", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1469,7 +1469,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reciprocal", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("reciprocal", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1500,7 +1500,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("rsqrt", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("rsqrt", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1537,7 +1537,7 @@ extern "C"
{ output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("select", { op }, { input1, input2, input3, output },
+ tosa::TosaSerializationBasicBlock block("select", "main", { op }, { input1, input2, input3, output },
{ input1->GetName(), input2->GetName(), input3->GetName() },
{ output->GetName() });
@@ -1572,7 +1572,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("equal", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("equal", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1606,7 +1606,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("greater", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("greater", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1641,7 +1641,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("greater_equal", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("greater_equal", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1674,7 +1674,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_all", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_all", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1706,7 +1706,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_any", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_any", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1738,7 +1738,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_max", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_max", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1770,7 +1770,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_min", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_min", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1802,8 +1802,8 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_product", { op }, { input, output }, { input->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("reduce_product", "main", { op }, { input, output },
+ { input->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -1834,7 +1834,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_sum", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_sum", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1865,7 +1865,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("concat", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("concat", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1904,7 +1904,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("pad", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("pad", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1939,7 +1939,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reshape", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1970,7 +1970,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reverse", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reverse", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -2008,7 +2008,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("slice", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("slice", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -2045,7 +2045,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("tile", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("tile", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -2081,7 +2081,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("transpose", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("transpose", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -2114,7 +2114,7 @@ extern "C"
{ values->GetName(), indices->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("gather", { op }, { values, indices, output },
+ tosa::TosaSerializationBasicBlock block("gather", "main", { op }, { values, indices, output },
{ values->GetName(), indices->GetName() }, { output->GetName() });
// Setup model
@@ -2152,7 +2152,7 @@ extern "C"
{ values_out->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("scatter", { op }, { values_in, indices, input, values_out },
+ tosa::TosaSerializationBasicBlock block("scatter", "main", { op }, { values_in, indices, input, values_out },
{ values_in->GetName(), indices->GetName(), input->GetName() },
{ values_out->GetName() });
@@ -2195,7 +2195,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("resize", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("resize", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -2226,7 +2226,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("cast", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("cast", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -2274,7 +2274,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("rescale", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("rescale", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -2305,7 +2305,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("identity", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("identity", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 7105caf..942652d 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
#include "control_flow.h"
#include "subgraph_traverser.h"
-
using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
@@ -37,7 +36,7 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
- SubgraphTraverser block_sgt(block, tsh);
+ SubgraphTraverser block_sgt(block, tsh, this->parent_sgt);
ERROR_IF(block_sgt.initializeGraph(), "evalBlock(): Unable to initialize graph traverser for %s",
block_name.c_str());
@@ -182,8 +181,10 @@ int OpCondIf::checkTensorAttributes()
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
ASSERT_MEM(cond);
- then_block = tsh->GetBlockByName(attribute->then_branch());
- else_block = tsh->GetBlockByName(attribute->else_branch());
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ then_block = curr_region->GetBlockByName(attribute->then_branch());
+ else_block = curr_region->GetBlockByName(attribute->else_branch());
ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
@@ -193,6 +194,7 @@ int OpCondIf::checkTensorAttributes()
// Skip the first rank 0 bool tensor on input list
int32_t num_input_tensor = getInputs().size() - 1;
int32_t num_output_tensor = getOutputs().size();
+
ERROR_IF((int32_t)then_block->GetInputs().size() != num_input_tensor,
"OpCondIf: then_block has unexpected number of input");
ERROR_IF((int32_t)else_block->GetInputs().size() != num_input_tensor,
@@ -307,8 +309,10 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
- cond_block = tsh->GetBlockByName(attribute->cond_branch());
- body_block = tsh->GetBlockByName(attribute->body_branch());
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ cond_block = curr_region->GetBlockByName(attribute->cond_branch());
+ body_block = curr_region->GetBlockByName(attribute->body_branch());
ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
@@ -403,12 +407,7 @@ int OpWhileLoop::eval()
// assigning output tensors value back to input tensors value for next iteration
for (size_t i = 0; i < num_input_output; i++)
{
- if (getInputs()[i]->copyValueFrom(getOutputs()[i]))
- {
- WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(),
- getInputs()[i]->getName().c_str());
- return 1;
- }
+ getInputs()[i] = getOutputs()[i];
}
}
else
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 112e641..8867ada 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -37,13 +37,14 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh)
+SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt)
{
- graph_status = GraphStatus::TOSA_VALID;
+ graph_status = GraphStatus::TOSA_VALID;
block = _block;
- tsh = _tsh;
+ tsh = _tsh;
+ parent_sgt = _parent_sgt;
tensors.clear();
nodes.clear();
nextNodeList.clear();
@@ -120,6 +121,17 @@ int SubgraphTraverser::initializeGraph()
{
int idx = 0;
+ std::vector<TosaSerializationTensor*> ser_tensor_vec;
+ // Get all the serialized tensors from TosaSerializationHandler.
+ for (auto block: tsh->GetMainRegion()->GetBlocks())
+ {
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
+ }
+
+ std::vector<GraphNode*> non_const_node_vec;
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
@@ -159,7 +171,13 @@ int SubgraphTraverser::initializeGraph()
EnumNamesOp()[op->GetOp()], input_index);
std::string input_name = op->GetInputTensorNames()[input_index];
- TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name);
+ TosaSerializationTensor* input_tensor = nullptr;
+ for (auto ser_tensor : ser_tensor_vec) {
+ if (ser_tensor->GetName() == input_name) {
+ input_tensor = ser_tensor;
+ }
+ }
+
SUBGRAPH_ERROR_IF(
!input_tensor,
"SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
@@ -175,7 +193,13 @@ int SubgraphTraverser::initializeGraph()
"SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
EnumNamesOp()[op->GetOp()], weight_index);
std::string weight_name = op->GetInputTensorNames()[weight_index];
- TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name);
+ TosaSerializationTensor* weight_tensor = nullptr;
+ for (auto ser_tensor : ser_tensor_vec) {
+ if (ser_tensor->GetName() == weight_name) {
+ weight_tensor = ser_tensor;
+ }
+ }
+
SUBGRAPH_ERROR_IF(
!weight_tensor,
"SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
@@ -199,8 +223,19 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
- GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
+ GraphNode* node = nullptr;
+ if (this->parent_sgt) {
+ node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
+ input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ node->setInMainBlock(false);
+ } else {
+ node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ if (node) {
+ node->setInMainBlock(true);
+ }
+ }
+
if (!node)
{
if (weight_index == -1)
@@ -257,6 +292,8 @@ int SubgraphTraverser::initializeGraph()
if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
{
addToNextNodeList(node);
+ } else if (!node->getInMainBlock()) {
+ non_const_node_vec.push_back(node);
}
idx++;
@@ -271,7 +308,6 @@ int SubgraphTraverser::initializeGraph()
SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
- // update this->tensors
addTensor(tensor);
}
@@ -296,7 +332,7 @@ int SubgraphTraverser::initializeGraph()
for (auto& output_name : block->GetOutputs())
{
TosaReference::Tensor* tensor = findTensorByName(output_name);
- DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
+ DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str());
if (tensor)
{
tensor->setIsSubgraphOutput();
@@ -314,6 +350,22 @@ int SubgraphTraverser::initializeGraph()
dumpNextNodeList(g_func_debug.func_debug_file);
}
+ // If the node is not in mainblock and not const
+ for (auto node : non_const_node_vec) {
+ bool all_inputs_from_parent = true;
+ for (std::string& name : node->getInputNames())
+ {
+ TosaReference::Tensor* t = findTensorByName(name);
+ if (!t->getIsParentGraphOutput()) {
+ all_inputs_from_parent = false;
+ }
+ }
+ // In the children block, when a node has all its inputs from parent
+ // block, we have to manually add this node to the evaluation list
+ if (all_inputs_from_parent && !node->getOnNextNodeList()) {
+ addToNextNodeList(node);
+ }
+ }
return 0;
}
@@ -510,29 +562,40 @@ int SubgraphTraverser::evaluateNextNode()
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
- for (auto tensor : currNode->getInputs())
- {
- bool in_use = false;
- for (auto node : tensor->getConsumers())
+ if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks
+ for (auto tensor : currNode->getInputs())
{
- if (!node->hasAllOutputsReady())
+ bool in_use = false;
+
+ auto tensor_check = findTensorByName(tensor->getName());
+ if (tensor_check->getIsParentGraphOutput()) {
+ // if it's parent's block output tensor, we can't free it
+ continue;
+ }
+
+ for (auto node : tensor->getConsumers())
{
- in_use = true;
+ // If the node is inside a loop, the input tensor is still needed
+ if (!node->hasAllOutputsReady())
+ {
+ in_use = true;
+ }
+
}
- }
- for (auto name : block->GetOutputs())
- {
- if (name == tensor->getName())
+ for (auto name : block->GetOutputs())
{
- in_use = true;
+ if (name == tensor->getName())
+ {
+ in_use = true;
+ }
+ }
+
+ if (!in_use)
+ {
+ tensor->deallocate();
}
- }
- if (!in_use)
- {
- tensor->deallocate();
}
}
-
// Search the output tensors of this node to see if
// there are now new ready nodes available from completing this node
for (TosaReference::Tensor* tensor : currNode->getOutputs())
@@ -642,17 +705,35 @@ int SubgraphTraverser::addNode(GraphNode* newNode)
TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
{
+ TosaReference::Tensor* res_tensor = nullptr;
+
for (TosaReference::Tensor* currTensor : tensors)
{
if (currTensor->getName() == name)
{
- return currTensor;
+ res_tensor = currTensor;
+ return res_tensor;
}
}
- WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
+ if (parent_sgt)
+ {
+ for (TosaReference::Tensor* currTensor : parent_sgt->tensors)
+ {
+ if (currTensor->getName() == name)
+ {
+ res_tensor = currTensor;
+ res_tensor->setIsParentGraphOutput();
+ }
+ }
+ }
- return nullptr;
+ if (!res_tensor)
+ {
+ WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
+ return nullptr;
+ }
+ return res_tensor;
}
int SubgraphTraverser::linkTensorsAndNodes()
@@ -704,7 +785,6 @@ int SubgraphTraverser::validateGraph()
for (TosaReference::Tensor* currTensor : tensors)
{
-
// It's okay for block input tensor not being consumed by operators.
// This is common in control flow op execution.
if (!currTensor->getIsSubgraphInput())
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 7940ee4..543b008 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -30,7 +30,7 @@ namespace TosaReference
class SubgraphTraverser
{
public:
- SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh);
+ SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh, SubgraphTraverser* parent_sgt);
~SubgraphTraverser();
int initializeGraph();
@@ -59,6 +59,10 @@ public:
{
return block->GetName();
}
+ std::string getRegionName() const
+ {
+ return block->GetRegionName();
+ }
int getNumInputTensors() const;
Tensor* getInputTensor(const unsigned int idx) const;
Tensor* getInputTensorByName(const std::string name) const;
@@ -77,6 +81,10 @@ private:
GraphStatus graph_status;
+ // pointer to the parent subgraph traversal if exists
+ // e.g., Control Flow Ops will have nested blocks (subgraph traversals)
+ SubgraphTraverser* parent_sgt;
+
// pointer to serialization library and corresponding basic block
TosaSerializationBasicBlock* block;
TosaSerializationHandler* tsh;
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 0678bbd..7af2069 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -32,11 +32,18 @@ TosaReference::Tensor::Tensor(std::string tensorName_, DType tensorDtype_, std::
consumers.clear();
isSubgraphInput = false;
isSubgraphOutput = false;
+ isParentGraphOutput = false;
}
TosaReference::Tensor::~Tensor()
{}
+int TosaReference::Tensor::setIsParentGraphOutput()
+{
+ isParentGraphOutput = true;
+ return 0;
+}
+
int TosaReference::Tensor::setIsSubgraphInput()
{
isSubgraphInput = true;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 08e865a..d5f1de8 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -40,6 +40,11 @@ public:
int setIsSubgraphInput();
int setIsSubgraphOutput();
+ int setIsParentGraphOutput();
+
+ int getIsParentGraphOutput() const {
+ return isParentGraphOutput;
+ }
int getIsSubgraphInput() const
{
@@ -261,6 +266,8 @@ protected:
int isSubgraphOutput;
bool isAllocated;
+ bool isParentGraphOutput;
+
GraphNode* producer;
std::vector<GraphNode*> consumers;
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index 6b0ed6e..3f2acb5 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -1,5 +1,5 @@
-// Copyright (c) 2022, ARM Limited.
+// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -64,11 +64,11 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
switch(mode) {
case tosa_mode_nearest:
return tosa::ResizeMode_NEAREST;
- case tosa_mode_max:
+ case tosa_mode_max:
case tosa_mode_bilinear:
return tosa::ResizeMode_BILINEAR;
default:
- return tosa::ResizeMode_UNKNOWN;
+ return tosa::ResizeMode_UNKNOWN;
}
}
@@ -131,7 +131,7 @@ extern "C"
});
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op },
+ tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
{
{%- for input in operator.inputs: -%}
{{input}},
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 6388a097de4350cc70472921c272074190fd7c9
+Subproject ca7ce0e94b3ee7339f31b47baa3a3fb4522243a
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 90bda34..767989e 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -252,3 +252,25 @@ class TGen:
tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng)))
return tf_placeholders, tf_consts
+
+ @staticmethod
+ def tgRecurrent(op, ifm_shape, dtype, rng):
+ # Require rank 3 shape for recurrent networks
+ if len(ifm_shape) != 3:
+ return [], []
+ pl, const = op["operands"]
+
+ tf_placeholders = []
+ tf_consts = []
+
+ for i in range(pl):
+ tf_placeholders.append(
+ ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ for i in range(const):
+ tf_consts.append(
+ ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ return tf_placeholders, tf_consts
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index c7c5cd7..8870f41 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -1164,3 +1164,82 @@ class TBuilder:
def eval(self, a):
return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
+
+ class While:
+ def __init__(self, name):
+ self.result_name = name
+
+ def while_cond(self, x):
+ return tf.reduce_sum(x) < self.cap
+
+ def while_body(self, x):
+ return tf.add(x, tf.math.sigmoid(x))
+
+ def eval(self, a):
+ self.cap = tf.cast(
+ tf.constant(
+ 2.0,
+ shape=[
+ 1,
+ ],
+ ),
+ a.dtype,
+ )
+
+ result = tf.while_loop(
+ self.while_cond, self.while_body, [a], name=self.result_name
+ )
+
+ return result[0]
+
+ class LSTM:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class GRU:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.GRU(
+ 2,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class RNN:
+ def __init__(self, name):
+ self.result_name = name
+ basic_cell = tf.keras.layers.SimpleRNNCell(
+ units=2,
+ activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ )
+ self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False)
+
+ def eval(self, a):
+ return self.rnn(a)
+
+ class FullyConnected:
+ def __init__(self, name):
+ self.result_name = name
+ self.dense = tf.keras.layers.Dense(2)
+
+ def eval(self, a):
+ return self.dense(a)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 760def6..26af5dd 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -807,6 +807,42 @@ TF_OP_LIST = {
]
},
},
+ "while": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.While, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": list(TYPE_F),
+ },
+ },
+ "lstm": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.LSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "gru": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "rnn": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.RNN, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
}
# Shapes to be tested; default can be overwritten
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 515e8bb..d799eb0 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import os
from copy import deepcopy
@@ -1845,7 +1845,7 @@ class TosaTestGen:
# Finally, build the op and the two blocks
self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
- self.ser.startBasicBlock(then_block)
+ self.ser.addBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
@@ -1853,7 +1853,7 @@ class TosaTestGen:
then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
self.ser.addOutputTensor(then_tens)
- self.ser.startBasicBlock(else_block)
+ self.ser.addBasicBlock(else_block)
if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
else:
@@ -1865,7 +1865,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -1914,7 +1914,7 @@ class TosaTestGen:
assert False, f"No tests for DType: {a.dtype}"
for block, op in ((then_block, then_op), (else_block, else_op)):
- self.ser.startBasicBlock(block)
+ self.ser.addBasicBlock(block)
if (
error_name == ErrorIf.CondIfInputListThenGraphMismatch
and block == then_block
@@ -1948,7 +1948,7 @@ class TosaTestGen:
op=op,
a=a,
b=b,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -2005,7 +2005,8 @@ class TosaTestGen:
incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
# COND block (input: iter, output: cond_tens )
- self.ser.startBasicBlock(cond_block)
+ self.ser.addBasicBlock(cond_block)
+
if error_name == ErrorIf.InputListCondGraphMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2034,7 +2035,8 @@ class TosaTestGen:
# BODY block (input: a, acc, iter, output: a, acc, iter)
# Note that local intermediate tensors need to be declared here for the outputs
- self.ser.startBasicBlock(body_block)
+ self.ser.addBasicBlock(body_block)
+
if error_name == ErrorIf.InputListBodyGraphInputMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2068,7 +2070,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
):
return None