aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosabin492 -> 524 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosabin492 -> 524 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/model.pb2
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosabin1584 -> 1616 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosabin1296 -> 1328 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb2
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosabin1232 -> 1264 bytes
-rw-r--r--reference_model/src/graph_node.cc3
-rw-r--r--reference_model/src/graph_node.h21
-rw-r--r--reference_model/src/main.cpp4
-rw-r--r--reference_model/src/model_runner_impl.cc6
-rw-r--r--reference_model/src/operators.cc138
-rw-r--r--reference_model/src/ops/control_flow.cc25
-rw-r--r--reference_model/src/subgraph_traverser.cc138
-rw-r--r--reference_model/src/subgraph_traverser.h12
-rw-r--r--reference_model/src/tensor.cc9
-rw-r--r--reference_model/src/tensor.h9
-rw-r--r--scripts/operator_api/templates/operators_cc.j28
m---------thirdparty/serialization_lib0
-rw-r--r--verif/frameworks/tensor_gen.py24
-rw-r--r--verif/frameworks/test_builder.py81
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py36
-rw-r--r--verif/generator/tosa_test_gen.py20
23 files changed, 398 insertions, 140 deletions
diff --git a/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa b/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa
index c240a91..03050de 100644
--- a/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa
+++ b/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa b/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa
index c240a91..03050de 100644
--- a/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa
+++ b/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/model.pb b/examples/test_add_1x4x4x4_f32/model.pb
index 479979c..f76ec12 100644
--- a/examples/test_add_1x4x4x4_f32/model.pb
+++ b/examples/test_add_1x4x4x4_f32/model.pb
@@ -92,5 +92,5 @@ node {
}
}
versions {
- producer: 1247
+ producer: 1286
}
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
index c4224d3..70d269f 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
index 98480be..3ddbae8 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
index e95ba45..3cddba7 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
@@ -137,5 +137,5 @@ node {
}
}
versions {
- producer: 1247
+ producer: 1286
}
diff --git a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
index 5570543..ab8bb78 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc
index 5f85932..1781e40 100644
--- a/reference_model/src/graph_node.cc
+++ b/reference_model/src/graph_node.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const
clearOnNextNodeList();
setRequiredOperands(-1, -1);
setRequiredRank(-1);
+ inMainBlock = false;
}
GraphNode::~GraphNode()
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index 874f1d8..b227d17 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -163,7 +163,6 @@ public:
int addInputTensor(Tensor* tens);
int addOutputTensor(Tensor* tens);
-
// Validate that the input tensors match properly
// in their types, attributes, rank, etc well enough to be
// processed.
@@ -252,6 +251,22 @@ public:
return nodeType;
}
+ SubgraphTraverser* getParentSGT()
+ {
+ return parent_sgt;
+ }
+
+ int setInMainBlock(bool isInMainBlock)
+ {
+ inMainBlock = isInMainBlock;
+ return 0;
+ }
+
+ bool getInMainBlock()
+ {
+ return inMainBlock;
+ }
+
// Helper functions.
int idiv_check(int input1, int input2, int& result);
@@ -328,6 +343,8 @@ protected:
// -1 means n/a
int requiredRankMin;
int requiredRankMax;
+
+ bool inMainBlock;
};
}; // namespace TosaReference
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 0941f2b..0375a48 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -83,7 +83,7 @@ int main(int argc, char** argv)
FATAL_ERROR("Unable to load graph");
}
- SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh);
+ SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr);
if (main_gt.initializeGraph())
{
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index 1109dd6..fa39c75 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2022, ARM Limited.
+// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -45,7 +45,7 @@ void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug)
GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler)
{
validateTosaVersion(serialization_handler);
- return initialize(serialization_handler.GetMainBlock(), &serialization_handler);
+ return initialize(serialization_handler.GetMainRegion()->GetBlocks()[0], &serialization_handler);
}
GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock& bb)
@@ -284,7 +284,7 @@ GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb,
// Make nullptr in case ModelRunnerImpl is being initialized again with a different graph.
_main_gt = nullptr;
- _main_gt = new SubgraphTraverser(bb, serialization_handler);
+ _main_gt = new SubgraphTraverser(bb, serialization_handler, nullptr);
if (_main_gt == nullptr)
{
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index dfad9b8..af348ca 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2022, ARM Limited.
+// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -94,7 +94,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("argmax", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("argmax", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -137,7 +137,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("avg_pool2d", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("avg_pool2d", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -185,7 +185,7 @@ extern "C"
{ output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("conv2d", { op }, { input, weight, bias, output },
+ tosa::TosaSerializationBasicBlock block("conv2d", "main", { op }, { input, weight, bias, output },
{ input->GetName(), weight->GetName(), bias->GetName() },
{ output->GetName() });
@@ -236,7 +236,7 @@ extern "C"
{ output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("conv3d", { op }, { input, weight, bias, output },
+ tosa::TosaSerializationBasicBlock block("conv3d", "main", { op }, { input, weight, bias, output },
{ input->GetName(), weight->GetName(), bias->GetName() },
{ output->GetName() });
@@ -287,7 +287,7 @@ extern "C"
{ input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("depthwise_conv2d", { op }, { input, weight, bias, output },
+ tosa::TosaSerializationBasicBlock block("depthwise_conv2d", "main", { op }, { input, weight, bias, output },
{ input->GetName(), weight->GetName(), bias->GetName() },
{ output->GetName() });
@@ -328,8 +328,8 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("fully_connected", { op }, { input, output }, { input->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, output },
+ { input->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -367,8 +367,8 @@ extern "C"
&attr, { a->GetName(), b->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("matmul", { op }, { a, b, output }, { a->GetName(), b->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("matmul", "main", { op }, { a, b, output },
+ { a->GetName(), b->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -411,7 +411,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("max_pool2d", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("max_pool2d", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -463,7 +463,7 @@ extern "C"
{ input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("transpose_conv2d", { op }, { input, weight, bias, output },
+ tosa::TosaSerializationBasicBlock block("transpose_conv2d", "main", { op }, { input, weight, bias, output },
{ input->GetName(), weight->GetName(), bias->GetName() },
{ output->GetName() });
@@ -506,7 +506,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("clamp", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("clamp", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -537,7 +537,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("sigmoid", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("sigmoid", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -568,7 +568,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("tanh", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("tanh", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -600,7 +600,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("add", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("add", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -638,7 +638,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("arithmetic_right_shift", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("arithmetic_right_shift", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -672,7 +672,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("bitwise_and", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("bitwise_and", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -706,7 +706,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("bitwise_or", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("bitwise_or", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -740,7 +740,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("bitwise_xor", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("bitwise_xor", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -773,7 +773,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("intdiv", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("intdiv", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -807,7 +807,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_and", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_and", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -843,7 +843,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_left_shift", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_left_shift", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -879,7 +879,7 @@ extern "C"
&attr, { input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_right_shift", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_right_shift", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -913,7 +913,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_or", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_or", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -947,7 +947,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_xor", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("logical_xor", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -981,7 +981,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("maximum", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("maximum", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1015,7 +1015,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("minimum", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("minimum", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1052,7 +1052,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("mul", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("mul", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1085,7 +1085,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("pow", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("pow", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1118,7 +1118,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("sub", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("sub", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1154,7 +1154,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("table", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("table", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1185,7 +1185,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("abs", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("abs", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1216,8 +1216,8 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("bitwise_not", { op }, { input1, output }, { input1->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("bitwise_not", "main", { op }, { input1, output },
+ { input1->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -1247,7 +1247,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("ceil", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("ceil", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1278,7 +1278,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("clz", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("clz", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1309,7 +1309,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("exp", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("exp", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1340,7 +1340,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("floor", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("floor", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1371,7 +1371,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("log", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("log", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1402,8 +1402,8 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("logical_not", { op }, { input1, output }, { input1->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("logical_not", "main", { op }, { input1, output },
+ { input1->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -1438,7 +1438,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("negate", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("negate", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1469,7 +1469,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reciprocal", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("reciprocal", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1500,7 +1500,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("rsqrt", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("rsqrt", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1537,7 +1537,7 @@ extern "C"
{ output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("select", { op }, { input1, input2, input3, output },
+ tosa::TosaSerializationBasicBlock block("select", "main", { op }, { input1, input2, input3, output },
{ input1->GetName(), input2->GetName(), input3->GetName() },
{ output->GetName() });
@@ -1572,7 +1572,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("equal", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("equal", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1606,7 +1606,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("greater", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("greater", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1641,7 +1641,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("greater_equal", { op }, { input1, input2, output },
+ tosa::TosaSerializationBasicBlock block("greater_equal", "main", { op }, { input1, input2, output },
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
@@ -1674,7 +1674,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_all", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_all", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1706,7 +1706,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_any", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_any", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1738,7 +1738,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_max", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_max", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1770,7 +1770,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_min", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_min", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1802,8 +1802,8 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_product", { op }, { input, output }, { input->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("reduce_product", "main", { op }, { input, output },
+ { input->GetName() }, { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
@@ -1834,7 +1834,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reduce_sum", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reduce_sum", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -1865,7 +1865,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("concat", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("concat", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1904,7 +1904,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("pad", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("pad", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1939,7 +1939,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reshape", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -1970,7 +1970,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reverse", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("reverse", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -2008,7 +2008,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("slice", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("slice", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -2045,7 +2045,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("tile", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("tile", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -2081,7 +2081,7 @@ extern "C"
&attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("transpose", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("transpose", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
@@ -2114,7 +2114,7 @@ extern "C"
{ values->GetName(), indices->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("gather", { op }, { values, indices, output },
+ tosa::TosaSerializationBasicBlock block("gather", "main", { op }, { values, indices, output },
{ values->GetName(), indices->GetName() }, { output->GetName() });
// Setup model
@@ -2152,7 +2152,7 @@ extern "C"
{ values_out->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("scatter", { op }, { values_in, indices, input, values_out },
+ tosa::TosaSerializationBasicBlock block("scatter", "main", { op }, { values_in, indices, input, values_out },
{ values_in->GetName(), indices->GetName(), input->GetName() },
{ values_out->GetName() });
@@ -2195,7 +2195,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("resize", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("resize", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -2226,7 +2226,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("cast", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("cast", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -2274,7 +2274,7 @@ extern "C"
&attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("rescale", { op }, { input, output }, { input->GetName() },
+ tosa::TosaSerializationBasicBlock block("rescale", "main", { op }, { input, output }, { input->GetName() },
{ output->GetName() });
// Setup model
@@ -2305,7 +2305,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("identity", { op }, { input1, output }, { input1->GetName() },
+ tosa::TosaSerializationBasicBlock block("identity", "main", { op }, { input1, output }, { input1->GetName() },
{ output->GetName() });
// Setup model
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 7105caf..942652d 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
#include "control_flow.h"
#include "subgraph_traverser.h"
-
using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
@@ -37,7 +36,7 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
- SubgraphTraverser block_sgt(block, tsh);
+ SubgraphTraverser block_sgt(block, tsh, this->parent_sgt);
ERROR_IF(block_sgt.initializeGraph(), "evalBlock(): Unable to initialize graph traverser for %s",
block_name.c_str());
@@ -182,8 +181,10 @@ int OpCondIf::checkTensorAttributes()
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
ASSERT_MEM(cond);
- then_block = tsh->GetBlockByName(attribute->then_branch());
- else_block = tsh->GetBlockByName(attribute->else_branch());
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ then_block = curr_region->GetBlockByName(attribute->then_branch());
+ else_block = curr_region->GetBlockByName(attribute->else_branch());
ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
@@ -193,6 +194,7 @@ int OpCondIf::checkTensorAttributes()
// Skip the first rank 0 bool tensor on input list
int32_t num_input_tensor = getInputs().size() - 1;
int32_t num_output_tensor = getOutputs().size();
+
ERROR_IF((int32_t)then_block->GetInputs().size() != num_input_tensor,
"OpCondIf: then_block has unexpected number of input");
ERROR_IF((int32_t)else_block->GetInputs().size() != num_input_tensor,
@@ -307,8 +309,10 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
- cond_block = tsh->GetBlockByName(attribute->cond_branch());
- body_block = tsh->GetBlockByName(attribute->body_branch());
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ cond_block = curr_region->GetBlockByName(attribute->cond_branch());
+ body_block = curr_region->GetBlockByName(attribute->body_branch());
ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
@@ -403,12 +407,7 @@ int OpWhileLoop::eval()
// assigning output tensors value back to input tensors value for next iteration
for (size_t i = 0; i < num_input_output; i++)
{
- if (getInputs()[i]->copyValueFrom(getOutputs()[i]))
- {
- WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(),
- getInputs()[i]->getName().c_str());
- return 1;
- }
+ getInputs()[i] = getOutputs()[i];
}
}
else
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 112e641..8867ada 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -37,13 +37,14 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh)
+SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt)
{
- graph_status = GraphStatus::TOSA_VALID;
+ graph_status = GraphStatus::TOSA_VALID;
block = _block;
- tsh = _tsh;
+ tsh = _tsh;
+ parent_sgt = _parent_sgt;
tensors.clear();
nodes.clear();
nextNodeList.clear();
@@ -120,6 +121,17 @@ int SubgraphTraverser::initializeGraph()
{
int idx = 0;
+ std::vector<TosaSerializationTensor*> ser_tensor_vec;
+ // Get all the serialized tensors from TosaSerializationHandler.
+ for (auto block: tsh->GetMainRegion()->GetBlocks())
+ {
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
+ }
+
+ std::vector<GraphNode*> non_const_node_vec;
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
@@ -159,7 +171,13 @@ int SubgraphTraverser::initializeGraph()
EnumNamesOp()[op->GetOp()], input_index);
std::string input_name = op->GetInputTensorNames()[input_index];
- TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name);
+ TosaSerializationTensor* input_tensor = nullptr;
+ for (auto ser_tensor : ser_tensor_vec) {
+ if (ser_tensor->GetName() == input_name) {
+ input_tensor = ser_tensor;
+ }
+ }
+
SUBGRAPH_ERROR_IF(
!input_tensor,
"SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
@@ -175,7 +193,13 @@ int SubgraphTraverser::initializeGraph()
"SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
EnumNamesOp()[op->GetOp()], weight_index);
std::string weight_name = op->GetInputTensorNames()[weight_index];
- TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name);
+ TosaSerializationTensor* weight_tensor = nullptr;
+ for (auto ser_tensor : ser_tensor_vec) {
+ if (ser_tensor->GetName() == weight_name) {
+ weight_tensor = ser_tensor;
+ }
+ }
+
SUBGRAPH_ERROR_IF(
!weight_tensor,
"SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
@@ -199,8 +223,19 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
- GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
+ GraphNode* node = nullptr;
+ if (this->parent_sgt) {
+ node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
+ input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ node->setInMainBlock(false);
+ } else {
+ node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ if (node) {
+ node->setInMainBlock(true);
+ }
+ }
+
if (!node)
{
if (weight_index == -1)
@@ -257,6 +292,8 @@ int SubgraphTraverser::initializeGraph()
if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
{
addToNextNodeList(node);
+ } else if (!node->getInMainBlock()) {
+ non_const_node_vec.push_back(node);
}
idx++;
@@ -271,7 +308,6 @@ int SubgraphTraverser::initializeGraph()
SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
- // update this->tensors
addTensor(tensor);
}
@@ -296,7 +332,7 @@ int SubgraphTraverser::initializeGraph()
for (auto& output_name : block->GetOutputs())
{
TosaReference::Tensor* tensor = findTensorByName(output_name);
- DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
+ DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str());
if (tensor)
{
tensor->setIsSubgraphOutput();
@@ -314,6 +350,22 @@ int SubgraphTraverser::initializeGraph()
dumpNextNodeList(g_func_debug.func_debug_file);
}
+ // If the node is not in mainblock and not const
+ for (auto node : non_const_node_vec) {
+ bool all_inputs_from_parent = true;
+ for (std::string& name : node->getInputNames())
+ {
+ TosaReference::Tensor* t = findTensorByName(name);
+ if (!t->getIsParentGraphOutput()) {
+ all_inputs_from_parent = false;
+ }
+ }
+ // In the children block, when a node has all its inputs from parent
+ // block, we have to manually add this node to the evaluation list
+ if (all_inputs_from_parent && !node->getOnNextNodeList()) {
+ addToNextNodeList(node);
+ }
+ }
return 0;
}
@@ -510,29 +562,40 @@ int SubgraphTraverser::evaluateNextNode()
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
- for (auto tensor : currNode->getInputs())
- {
- bool in_use = false;
- for (auto node : tensor->getConsumers())
+ if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks
+ for (auto tensor : currNode->getInputs())
{
- if (!node->hasAllOutputsReady())
+ bool in_use = false;
+
+ auto tensor_check = findTensorByName(tensor->getName());
+ if (tensor_check->getIsParentGraphOutput()) {
+ // if it's parent's block output tensor, we can't free it
+ continue;
+ }
+
+ for (auto node : tensor->getConsumers())
{
- in_use = true;
+ // If the node is inside a loop, the input tensor is still needed
+ if (!node->hasAllOutputsReady())
+ {
+ in_use = true;
+ }
+
}
- }
- for (auto name : block->GetOutputs())
- {
- if (name == tensor->getName())
+ for (auto name : block->GetOutputs())
{
- in_use = true;
+ if (name == tensor->getName())
+ {
+ in_use = true;
+ }
+ }
+
+ if (!in_use)
+ {
+ tensor->deallocate();
}
- }
- if (!in_use)
- {
- tensor->deallocate();
}
}
-
// Search the output tensors of this node to see if
// there are now new ready nodes available from completing this node
for (TosaReference::Tensor* tensor : currNode->getOutputs())
@@ -642,17 +705,35 @@ int SubgraphTraverser::addNode(GraphNode* newNode)
TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
{
+ TosaReference::Tensor* res_tensor = nullptr;
+
for (TosaReference::Tensor* currTensor : tensors)
{
if (currTensor->getName() == name)
{
- return currTensor;
+ res_tensor = currTensor;
+ return res_tensor;
}
}
- WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
+ if (parent_sgt)
+ {
+ for (TosaReference::Tensor* currTensor : parent_sgt->tensors)
+ {
+ if (currTensor->getName() == name)
+ {
+ res_tensor = currTensor;
+ res_tensor->setIsParentGraphOutput();
+ }
+ }
+ }
- return nullptr;
+ if (!res_tensor)
+ {
+ WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
+ return nullptr;
+ }
+ return res_tensor;
}
int SubgraphTraverser::linkTensorsAndNodes()
@@ -704,7 +785,6 @@ int SubgraphTraverser::validateGraph()
for (TosaReference::Tensor* currTensor : tensors)
{
-
// It's okay for block input tensor not being consumed by operators.
// This is common in control flow op execution.
if (!currTensor->getIsSubgraphInput())
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 7940ee4..543b008 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -30,7 +30,7 @@ namespace TosaReference
class SubgraphTraverser
{
public:
- SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh);
+ SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh, SubgraphTraverser* parent_sgt);
~SubgraphTraverser();
int initializeGraph();
@@ -59,6 +59,10 @@ public:
{
return block->GetName();
}
+ std::string getRegionName() const
+ {
+ return block->GetRegionName();
+ }
int getNumInputTensors() const;
Tensor* getInputTensor(const unsigned int idx) const;
Tensor* getInputTensorByName(const std::string name) const;
@@ -77,6 +81,10 @@ private:
GraphStatus graph_status;
+ // pointer to the parent subgraph traversal if exists
+ // e.g., Control Flow Ops will have nested blocks (subgraph traversals)
+ SubgraphTraverser* parent_sgt;
+
// pointer to serialization library and corresponding basic block
TosaSerializationBasicBlock* block;
TosaSerializationHandler* tsh;
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 0678bbd..7af2069 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -32,11 +32,18 @@ TosaReference::Tensor::Tensor(std::string tensorName_, DType tensorDtype_, std::
consumers.clear();
isSubgraphInput = false;
isSubgraphOutput = false;
+ isParentGraphOutput = false;
}
TosaReference::Tensor::~Tensor()
{}
+int TosaReference::Tensor::setIsParentGraphOutput()
+{
+ isParentGraphOutput = true;
+ return 0;
+}
+
int TosaReference::Tensor::setIsSubgraphInput()
{
isSubgraphInput = true;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 08e865a..d5f1de8 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -40,6 +40,11 @@ public:
int setIsSubgraphInput();
int setIsSubgraphOutput();
+ int setIsParentGraphOutput();
+
+ int getIsParentGraphOutput() const {
+ return isParentGraphOutput;
+ }
int getIsSubgraphInput() const
{
@@ -261,6 +266,8 @@ protected:
int isSubgraphOutput;
bool isAllocated;
+ bool isParentGraphOutput;
+
GraphNode* producer;
std::vector<GraphNode*> consumers;
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index 6b0ed6e..3f2acb5 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -1,5 +1,5 @@
-// Copyright (c) 2022, ARM Limited.
+// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -64,11 +64,11 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
switch(mode) {
case tosa_mode_nearest:
return tosa::ResizeMode_NEAREST;
- case tosa_mode_max:
+ case tosa_mode_max:
case tosa_mode_bilinear:
return tosa::ResizeMode_BILINEAR;
default:
- return tosa::ResizeMode_UNKNOWN;
+ return tosa::ResizeMode_UNKNOWN;
}
}
@@ -131,7 +131,7 @@ extern "C"
});
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op },
+ tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
{
{%- for input in operator.inputs: -%}
{{input}},
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 6388a097de4350cc70472921c272074190fd7c9
+Subproject ca7ce0e94b3ee7339f31b47baa3a3fb4522243a
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 90bda34..767989e 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -252,3 +252,25 @@ class TGen:
tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng)))
return tf_placeholders, tf_consts
+
+ @staticmethod
+ def tgRecurrent(op, ifm_shape, dtype, rng):
+ # Require rank 3 shape for recurrent networks
+ if len(ifm_shape) != 3:
+ return [], []
+ pl, const = op["operands"]
+
+ tf_placeholders = []
+ tf_consts = []
+
+ for i in range(pl):
+ tf_placeholders.append(
+ ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ for i in range(const):
+ tf_consts.append(
+ ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ return tf_placeholders, tf_consts
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index c7c5cd7..8870f41 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -1164,3 +1164,82 @@ class TBuilder:
def eval(self, a):
return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
+
+ class While:
+ def __init__(self, name):
+ self.result_name = name
+
+ def while_cond(self, x):
+ return tf.reduce_sum(x) < self.cap
+
+ def while_body(self, x):
+ return tf.add(x, tf.math.sigmoid(x))
+
+ def eval(self, a):
+ self.cap = tf.cast(
+ tf.constant(
+ 2.0,
+ shape=[
+ 1,
+ ],
+ ),
+ a.dtype,
+ )
+
+ result = tf.while_loop(
+ self.while_cond, self.while_body, [a], name=self.result_name
+ )
+
+ return result[0]
+
+ class LSTM:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class GRU:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.GRU(
+ 2,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class RNN:
+ def __init__(self, name):
+ self.result_name = name
+ basic_cell = tf.keras.layers.SimpleRNNCell(
+ units=2,
+ activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ )
+ self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False)
+
+ def eval(self, a):
+ return self.rnn(a)
+
+ class FullyConnected:
+ def __init__(self, name):
+ self.result_name = name
+ self.dense = tf.keras.layers.Dense(2)
+
+ def eval(self, a):
+ return self.dense(a)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 760def6..26af5dd 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -807,6 +807,42 @@ TF_OP_LIST = {
]
},
},
+ "while": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.While, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": list(TYPE_F),
+ },
+ },
+ "lstm": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.LSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "gru": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "rnn": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.RNN, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
}
# Shapes to be tested; default can be overwritten
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 515e8bb..d799eb0 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import os
from copy import deepcopy
@@ -1845,7 +1845,7 @@ class TosaTestGen:
# Finally, build the op and the two blocks
self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
- self.ser.startBasicBlock(then_block)
+ self.ser.addBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
@@ -1853,7 +1853,7 @@ class TosaTestGen:
then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
self.ser.addOutputTensor(then_tens)
- self.ser.startBasicBlock(else_block)
+ self.ser.addBasicBlock(else_block)
if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
else:
@@ -1865,7 +1865,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -1914,7 +1914,7 @@ class TosaTestGen:
assert False, f"No tests for DType: {a.dtype}"
for block, op in ((then_block, then_op), (else_block, else_op)):
- self.ser.startBasicBlock(block)
+ self.ser.addBasicBlock(block)
if (
error_name == ErrorIf.CondIfInputListThenGraphMismatch
and block == then_block
@@ -1948,7 +1948,7 @@ class TosaTestGen:
op=op,
a=a,
b=b,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -2005,7 +2005,8 @@ class TosaTestGen:
incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
# COND block (input: iter, output: cond_tens )
- self.ser.startBasicBlock(cond_block)
+ self.ser.addBasicBlock(cond_block)
+
if error_name == ErrorIf.InputListCondGraphMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2034,7 +2035,8 @@ class TosaTestGen:
# BODY block (input: a, acc, iter, output: a, acc, iter)
# Note that local intermediate tensors need to be declared here for the outputs
- self.ser.startBasicBlock(body_block)
+ self.ser.addBasicBlock(body_block)
+
if error_name == ErrorIf.InputListBodyGraphInputMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2068,7 +2070,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
):
return None