aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2020-10-13 16:11:07 -0700
committerKevin Cheng <kevin.cheng@arm.com>2020-10-14 11:11:43 -0700
commite5e2676409a936431f87d31fb74d825257b20804 (patch)
tree304d93d993ef6417b02a515025f9030367682774 /reference_model/src/ops
parent88b7860f180f91b5b66764c61cfd97d8bc53cece (diff)
downloadreference_model-e5e2676409a936431f87d31fb74d825257b20804.tar.gz
Initial checkin of TOSA reference_model and tests
Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/activation_funcs.cc118
-rw-r--r--reference_model/src/ops/activation_funcs.h101
-rw-r--r--reference_model/src/ops/comparison.cc81
-rw-r--r--reference_model/src/ops/comparison.h71
-rw-r--r--reference_model/src/ops/control_flow.cc353
-rw-r--r--reference_model/src/ops/control_flow.h72
-rw-r--r--reference_model/src/ops/custom.cc40
-rw-r--r--reference_model/src/ops/custom.h38
-rw-r--r--reference_model/src/ops/data_layout.cc644
-rw-r--r--reference_model/src/ops/data_layout.h216
-rw-r--r--reference_model/src/ops/data_nodes.cc172
-rw-r--r--reference_model/src/ops/data_nodes.h86
-rw-r--r--reference_model/src/ops/ewise_binary.cc586
-rw-r--r--reference_model/src/ops/ewise_binary.h195
-rw-r--r--reference_model/src/ops/ewise_ternary.cc115
-rw-r--r--reference_model/src/ops/ewise_ternary.h83
-rw-r--r--reference_model/src/ops/ewise_unary.cc302
-rw-r--r--reference_model/src/ops/ewise_unary.h102
-rw-r--r--reference_model/src/ops/image.cc169
-rw-r--r--reference_model/src/ops/image.h53
-rw-r--r--reference_model/src/ops/op_factory.cc432
-rw-r--r--reference_model/src/ops/op_factory.h294
-rw-r--r--reference_model/src/ops/reduction.cc139
-rw-r--r--reference_model/src/ops/reduction.h109
-rw-r--r--reference_model/src/ops/scatter_gather.cc120
-rw-r--r--reference_model/src/ops/scatter_gather.h54
-rw-r--r--reference_model/src/ops/template_types.h277
-rw-r--r--reference_model/src/ops/tensor_ops.cc1229
-rw-r--r--reference_model/src/ops/tensor_ops.h253
-rw-r--r--reference_model/src/ops/type_conversion.cc299
-rw-r--r--reference_model/src/ops/type_conversion.h162
31 files changed, 6965 insertions, 0 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
new file mode 100644
index 0000000..bca9507
--- /dev/null
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -0,0 +1,118 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "activation_funcs.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+int OpClamp<Rank, Dtype>::register_fcn()
+{
+
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ {
+ InEigenType min = (InEigenType)attribute->min_fp();
+ InEigenType max = (InEigenType)attribute->max_fp();
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ }
+ break;
+ case DType_AINT8:
+ case DType_INT16:
+ {
+ InEigenType min = (InEigenType)attribute->min_int();
+ InEigenType max = (InEigenType)attribute->max_int();
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ }
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReluN<Rank, Dtype>::register_fcn()
+{
+
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ {
+ InEigenType N = (InEigenType)attribute->max_fp();
+ this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; };
+ }
+ break;
+ case DType_INT32:
+ {
+ InEigenType N = (InEigenType)attribute->max_int();
+ this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; };
+ }
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSigmoid<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTanh<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
new file mode 100644
index 0000000..b051b9d
--- /dev/null
+++ b/reference_model/src/ops/activation_funcs.h
@@ -0,0 +1,101 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_ACTIVATION_FUNCS_H
+#define OPS_ACTIVATION_FUNCS_H
+
+#include "ewise_unary.h"
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpClamp : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpClamp(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_CLAMP, id_)
+ {
+ INIT_ATTRIBUTE(Clamp);
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaClampAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpReluN : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpReluN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_RELUN, id_)
+ {
+ INIT_ATTRIBUTE(ReluN);
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaReluNAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpSigmoid : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpSigmoid(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_SIGMOID, id_)
+ {
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpTanh : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpTanh(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_TANH, id_)
+ {
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
new file mode 100644
index 0000000..402e152
--- /dev/null
+++ b/reference_model/src/ops/comparison.cc
@@ -0,0 +1,81 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "comparison.h"
+#include "arith_util.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+int OpEqual<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpGreater<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpGreaterEqual<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h
new file mode 100644
index 0000000..e75b1a6
--- /dev/null
+++ b/reference_model/src/ops/comparison.h
@@ -0,0 +1,71 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_COMPARISON_H
+#define OPS_COMPARISON_H
+
+#include "ewise_binary.h"
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpGreater(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_GREATER, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpGreaterEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
new file mode 100644
index 0000000..9d5db40
--- /dev/null
+++ b/reference_model/src/ops/control_flow.cc
@@ -0,0 +1,353 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "control_flow.h"
+#include "subgraph_traverser.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpControlFlow::OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ tsh = tsh_;
+}
+
+OpControlFlow::~OpControlFlow()
+{}
+
+int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
+ std::vector<TosaReference::Tensor*>& block_inputs,
+ std::vector<TosaReference::Tensor*>& block_outputs)
+{
+ std::string block_name = block->GetName();
+
+ DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
+
+ SubgraphTraverser gt(block, tsh);
+
+ if (gt.initializeGraph())
+ {
+ FATAL_ERROR("Unable to initialize graph traverser for block %s", block_name.c_str());
+ }
+
+ if (gt.linkTensorsAndNodes())
+ {
+ FATAL_ERROR("Failed to link tensors and nodes for block %s", block_name.c_str());
+ }
+
+ if (gt.validateGraph())
+ {
+ FATAL_ERROR("Failed to validate subgraph for block %s", block_name.c_str());
+ }
+
+ int num_input_tensors = gt.getNumInputTensors();
+ int num_output_tensors = gt.getNumOutputTensors();
+
+ for (size_t i = 0; i < block_inputs.size(); i++)
+ {
+ DEBUG_HIGH(OP, "Input[%ld]: %s", i, block_inputs[i]->getName().c_str());
+ }
+ for (size_t i = 0; i < block_outputs.size(); i++)
+ {
+ DEBUG_HIGH(OP, "Output[%ld]: %s", i, block_outputs[i]->getName().c_str());
+ }
+
+ ASSERT_MSG((size_t)num_input_tensors == block_inputs.size(),
+ "op block %s inputs[%lu] does not match with graph traverser's inputs[%d]", block_name.c_str(),
+ block_inputs.size(), num_input_tensors);
+ ASSERT_MSG((size_t)num_output_tensors == block_outputs.size(),
+ "op block %s outputs[%lu] does not match with graph traverser's outputs[%d]", block_name.c_str(),
+ block_outputs.size(), num_output_tensors);
+
+ // set graph traverser's input = basic block's input
+ for (int i = 0; i < num_input_tensors; i++)
+ {
+ TosaReference::Tensor* tensor = gt.getInputTensor(i);
+ ASSERT_MSG(!tensor->is_allocated(), "block %s input tensors are unexpectedly initialized before",
+ block_name.c_str());
+
+ if (tensor->allocate())
+ {
+ WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->copyValueFrom(block_inputs[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", block_inputs[i]->getName().c_str(),
+ tensor->getName().c_str());
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ {
+ gt.addToNextNodeList(gn);
+ }
+ }
+ }
+
+ if (gt.evaluateAll())
+ {
+ FATAL_ERROR("Error evaluating network. Giving up.");
+ }
+
+ // make sure output tensor is evaluated and show its value
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ const TosaReference::Tensor* ct = gt.getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
+ {
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
+ }
+ }
+ if (!all_output_valid)
+ {
+ gt.dumpGraph(g_func_debug.func_debug_file);
+ FATAL_ERROR("SubgraphTraverser \"%s\" error: Output tensors are not all valid at the end of evaluation.",
+ block_name.c_str());
+ }
+
+ // set basic block's output = subgraph_traverser's output
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ TosaReference::Tensor* tensor = gt.getOutputTensor(i);
+ ASSERT_MSG(tensor->is_allocated(), "tensor %s is not allocated", tensor->getName().c_str());
+
+ if (block_outputs[i]->copyValueFrom(tensor))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", tensor->getName().c_str(), outputs[i]->getName().c_str());
+ return 1;
+ }
+ }
+ return 0;
+}
+
+OpCondIf::OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpControlFlow(tsh_, Op_COND_IF, id_)
+{
+ INIT_ATTRIBUTE(CondIf);
+}
+
+OpCondIf::~OpCondIf()
+{
+ if (attribute)
+ delete attribute;
+}
+
+int OpCondIf::checkTensorAttributes()
+{
+ if (getInputs().size() < 1)
+ {
+ WARNING("OpCondIf: must have at least 1 operand");
+ return 1;
+ }
+
+ if (inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0)
+ {
+ WARNING("OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
+ inputs[0]->getRank());
+ return 1;
+ }
+
+ cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
+ ASSERT_MEM(cond);
+
+ then_block = tsh->GetBlockByName(attribute->then_branch());
+ else_block = tsh->GetBlockByName(attribute->else_branch());
+
+ if (!then_block)
+ {
+ WARNING("OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
+ return 1;
+ }
+
+ if (!else_block)
+ {
+ WARNING("OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str());
+ return 1;
+ }
+
+ return 0;
+}
+
+int OpCondIf::eval()
+{
+ bool cond_val = cond->getTensor()(0);
+ std::vector<TosaReference::Tensor*> block_inputs(getInputs().begin() + 1, getInputs().end());
+
+ if (cond_val)
+ {
+ if (evalBlock(then_block, block_inputs, getOutputs()))
+ {
+ WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str());
+ return 1;
+ }
+ }
+ else
+ {
+ if (evalBlock(else_block, block_inputs, getOutputs()))
+ {
+ WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str());
+ return 1;
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+OpWhileLoop::OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpControlFlow(tsh_, Op_WHILE_LOOP, id_)
+{
+ INIT_ATTRIBUTE(WhileLoop);
+}
+
+OpWhileLoop::~OpWhileLoop()
+{
+ if (attribute)
+ delete attribute;
+}
+
+int OpWhileLoop::checkTensorAttributes()
+{
+ if (getInputs().size() <= 0)
+ {
+ WARNING("OpWhileLoop: must have at least 1 operands");
+ return 1;
+ }
+
+ if (getInputs().size() != getOutputs().size())
+ {
+ WARNING("OpWhileLoop: inputs and outputs size must match");
+ return 1;
+ }
+
+ cond_block = tsh->GetBlockByName(attribute->cond_branch());
+ body_block = tsh->GetBlockByName(attribute->body_branch());
+
+ if (!cond_block)
+ {
+ WARNING("OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
+ return 1;
+ }
+
+ if (!body_block)
+ {
+ WARNING("OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
+ return 1;
+ }
+
+ if (cond_block->GetOutputs().size() != 1)
+ {
+ WARNING("OpWhileLoop: invalid cond_block output size %lu", cond_block->GetOutputs().size());
+ return 1;
+ }
+
+ TosaSerializationTensor* cond_output_tensor = cond_block->GetTensorByName(cond_block->GetOutputs()[0]);
+
+ if (!cond_output_tensor)
+ {
+ WARNING("OpWhileLoop: fail to resolve cond_block's output tensor %s", cond_block->GetOutputs()[0].c_str());
+ return 1;
+ }
+
+ if (cond_output_tensor->GetDtype() != DType_BOOL)
+ {
+ WARNING("OpWhileLoop: invalid cond_block's output tensor data type %s",
+ EnumNamesDType()[cond_output_tensor->GetDtype()]);
+ return 1;
+ }
+ if (cond_output_tensor->GetShape().size() != 0)
+ {
+ WARNING("OpWhileLoop: invalid cond_block's output rank %lu", cond_output_tensor->GetShape().size());
+ return 1;
+ }
+
+ return 0;
+}
+
+int OpWhileLoop::eval()
+{
+
+ TosaReference::Tensor0<bool> cond_output_ctensor(
+ std::string("cond_output"), DType_BOOL, std::vector<Usage>({ Usage_ACTIVATION }),
+ std::vector<Format>({ Format_UNKNOWN }), std::vector<int32_t>({}), false);
+
+ cond_output_ctensor.allocate();
+ std::vector<TosaReference::Tensor*> cond_block_outputs;
+ cond_block_outputs.push_back(&cond_output_ctensor);
+
+ size_t num_input_output = getInputs().size();
+ size_t eval_count = 0;
+
+ while (eval_count++ < MAX_WHILE_LOOP_ITERATION)
+ {
+ if (evalBlock(cond_block, getInputs(), cond_block_outputs))
+ {
+ WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str());
+ return 1;
+ }
+ bool cond_val = cond_output_ctensor.getTensor()(0);
+ DEBUG_HIGH(OP, "Conditional block value: %d", cond_val);
+
+ if (cond_val)
+ {
+ if (evalBlock(body_block, getInputs(), getOutputs()))
+ {
+ WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str());
+ return 1;
+ }
+
+ // assigning output tensors value back to input tensors value for next iteration
+ for (size_t i = 0; i < num_input_output; i++)
+ {
+ if (getInputs()[i]->copyValueFrom(getOutputs()[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(),
+ getInputs()[i]->getName().c_str());
+ return 1;
+ }
+ }
+ }
+ else
+ {
+ // in last iteration or the case it never evaluates body block
+ // assign input tensors value to output tensors
+ for (size_t i = 0; i < num_input_output; i++)
+ {
+ if (getOutputs()[i]->copyValueFrom(getInputs()[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", getInputs()[i]->getName().c_str(),
+ getOutputs()[i]->getName().c_str());
+ return 1;
+ }
+ }
+ break;
+ }
+ }
+
+ return GraphNode::eval();
+}
diff --git a/reference_model/src/ops/control_flow.h b/reference_model/src/ops/control_flow.h
new file mode 100644
index 0000000..14c11bc
--- /dev/null
+++ b/reference_model/src/ops/control_flow.h
@@ -0,0 +1,72 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_CONTROL_FLOW_H
+#define OPS_CONTROL_FLOW_H
+
+#include "graph_node.h"
+
+#define MAX_WHILE_LOOP_ITERATION 10000
+
+namespace TosaReference
+{
+class OpControlFlow : public GraphNode
+{
+public:
+ OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_);
+ ~OpControlFlow();
+
+ virtual int evalBlock(TosaSerializationBasicBlock* block,
+ std::vector<TosaReference::Tensor*>& block_inputs,
+ std::vector<TosaReference::Tensor*>& block_outputs);
+
+protected:
+ TosaSerializationHandler* tsh;
+};
+
+class OpCondIf : public OpControlFlow
+{
+public:
+ OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpCondIf();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+protected:
+ TosaCondIfAttribute* attribute;
+ TosaReference::Tensor0<bool>* cond;
+ TosaSerializationBasicBlock* then_block;
+ TosaSerializationBasicBlock* else_block;
+};
+
+class OpWhileLoop : public OpControlFlow
+{
+public:
+ OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpWhileLoop();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+protected:
+ TosaWhileLoopAttribute* attribute;
+ TosaSerializationBasicBlock* cond_block;
+ TosaSerializationBasicBlock* body_block;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc
new file mode 100644
index 0000000..5c4f29b
--- /dev/null
+++ b/reference_model/src/ops/custom.cc
@@ -0,0 +1,40 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "custom.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpCustom::OpCustom(uint64_t id_)
+ : GraphNode(Op_CUSTOM, id_)
+{}
+
+OpCustom::~OpCustom()
+{}
+
+int OpCustom::checkTensorAttributes()
+{
+ return 0;
+}
+
+int OpCustom::eval()
+{
+ FATAL_ERROR_NODE("not supported yet");
+
+ // Evaluation is trivial for constants
+ return GraphNode::eval();
+}
diff --git a/reference_model/src/ops/custom.h b/reference_model/src/ops/custom.h
new file mode 100644
index 0000000..b1085a5
--- /dev/null
+++ b/reference_model/src/ops/custom.h
@@ -0,0 +1,38 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_CUSTOM_H
+#define OPS_CUSTOM_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+class OpCustom : public GraphNode
+{
+public:
+ OpCustom(uint64_t id_);
+ virtual ~OpCustom();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
new file mode 100644
index 0000000..32029b9
--- /dev/null
+++ b/reference_model/src/ops/data_layout.cc
@@ -0,0 +1,644 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "data_layout.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CONCAT, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpConcat<Rank, Dtype>::~OpConcat()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpConcat<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types and rank
+ // inputs[0] and inputs[1] should also match type and rank
+ if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Concat operator input ranks and types must match");
+ return 1;
+ }
+
+ lhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ rhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size())
+ {
+ printNodeValidationError("Axis is beyond input tensor rank");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpConcat<Rank, Dtype>::eval()
+{
+
+ int32_t reversed_axis = Rank - 1 - attribute->axis();
+
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ reverser[d] = Rank - 1 - d;
+ }
+
+ TIn lhs_reversed = lhs->getTensor().shuffle(reverser);
+ TIn rhs_reversed = rhs->getTensor().shuffle(reverser);
+
+ TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis);
+ out->getTensor() = reversed_result.shuffle(reverser);
+ // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpPad<Rank, Dtype>::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_PAD, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+
+ INIT_QINFO(Pad);
+}
+
+template <int Rank, DType Dtype>
+OpPad<Rank, Dtype>::~OpPad()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <int Rank, DType Dtype>
+int OpPad<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type and rank");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ TosaReference::TensorTemplate<ETensor2<int32_t>>* paddings =
+ dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
+
+ for (int i = 0; i < Rank; i++)
+ {
+ paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpPad<Rank, Dtype>::eval()
+{
+ InEigenType pad_value = 0;
+ if (this->qinfo)
+ {
+ pad_value = (InEigenType)this->qinfo->input_zp();
+ }
+
+ this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
+
+ return GraphNode::eval();
+}
+
+template <int InRank, int OutRank, DType Dtype>
+OpReshape<InRank, OutRank, Dtype>::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESHAPE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Reshape);
+}
+
+template <int InRank, int OutRank, DType Dtype>
+OpReshape<InRank, OutRank, Dtype>::~OpReshape()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int InRank, int OutRank, DType Dtype>
+int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
+{
+ uint32_t minusOneCount = 0;
+
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpReshape: Input and output types must match");
+ return 1;
+ }
+
+ for (uint32_t d = 0; d < OutRank; d++)
+ {
+ if (attribute->shape()[d] == -1)
+ {
+ minusOneCount++;
+ }
+ }
+
+ if (minusOneCount > 1)
+ {
+ printNodeValidationError("OpReshape: new shape has more than one -1 dimension");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int InRank, int OutRank, DType Dtype>
+int OpReshape<InRank, OutRank, Dtype>::eval()
+{
+ uint32_t remainingSize = in->getElementCount();
+
+ // If there is a -1 dimension, find the remainder in one pass over the output shape
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ if (attribute->shape()[d] != -1)
+ {
+ remainingSize = remainingSize / attribute->shape()[d];
+ }
+ }
+
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ array_shape[d] = attribute->shape()[OutRank - 1 - d];
+ out_reverser[d] = OutRank - 1 - d;
+
+ // Jam in the remainder here
+ if (array_shape[d] == -1)
+ {
+ array_shape[d] = remainingSize;
+ }
+ }
+
+ for (int32_t d = 0; d < InRank; d++)
+ {
+ in_reverser[d] = InRank - 1 - d;
+ }
+
+ // Eigen Tensor is col-major, and we're referencing row-major result
+ // need to reverse it to row-major before reshape, and perform another reverse afterward
+
+ // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
+ TIn in_reversed;
+ if (InRank > 1)
+ {
+ in_reversed = in->getTensor().shuffle(in_reverser);
+ }
+ else
+ {
+ in_reversed = in->getTensor();
+ }
+
+ TOut in_reshaped = in_reversed.reshape(array_shape);
+
+ // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
+ if (OutRank > 1)
+ {
+ out->getTensor() = in_reshaped.shuffle(out_reverser);
+ }
+ else
+ {
+ out->getTensor() = in_reshaped;
+ }
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpReverse<Rank, Dtype>::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_REVERSE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpReverse<Rank, Dtype>::~OpReverse()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpReverse<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankTypeShape(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank/type/shape");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
+ {
+ printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
+ return 1;
+ }
+
+ // transform list of axis into true or false list
+ // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
+ for (int i = 0; i < Rank; i++)
+ {
+ reverse_array[i] = false;
+ }
+ reverse_array[attribute->axis()] = true;
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReverse<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor().reverse(reverse_array);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpSlice<Rank, Dtype>::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SLICE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Slice);
+}
+
+template <int Rank, DType Dtype>
+OpSlice<Rank, Dtype>::~OpSlice()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpSlice<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ for (size_t i = 0; i < attribute->begin().size(); i++)
+ {
+ begin_array[i] = attribute->begin()[i];
+ }
+
+ for (size_t i = 0; i < attribute->size().size(); i++)
+ {
+ if (attribute->size()[i] != 0)
+ {
+ size_array[i] = attribute->size()[i];
+ }
+ else
+ {
+ // Tensorflow assigns a zero size to dimensions that are kept
+ // Eigen expects size to be the full size of the dimension
+ size_array[i] = in->getTensor().dimension(0);
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSlice<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor().slice(begin_array, size_array);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpTileBase<Rank, Dtype>::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TILE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Tile);
+}
+
+template <int Rank, DType Dtype>
+OpTileBase<Rank, Dtype>::~OpTileBase()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpTileBase<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same ranks and types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank or type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (attribute->multiples().size() != Rank)
+ {
+ printNodeValidationError("1D list 'multiples' must have size equal to input rank");
+ return 1;
+ }
+
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ if (in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d])
+ {
+ printNodeValidationError("unexpected output shape");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTile<Rank, Dtype>::eval()
+{
+ // primary template shouldn't be called
+ FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
+}
+
+template <DType Dtype>
+int OpTile<1, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ this->out->getTensor()(od0) = this->in->getTensor()(id0);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<2, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<3, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<4, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpTranspose<Rank, Dtype>::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TRANSPOSE, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpTranspose<Rank, Dtype>::~OpTranspose()
+{}
+
+template <int Rank, DType Dtype>
+int OpTranspose<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank and type");
+ return 1;
+ }
+
+ if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
+ {
+ printNodeValidationError("Failure to match input and output total element count");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ perm_tensor = dynamic_cast<TosaReference::TensorTemplate<ETensor1<int32_t>>*>(inputs[1]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTranspose<Rank, Dtype>::eval()
+{
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ perm_array[d] = this->perm_tensor->getTensor().data()[d];
+ }
+
+ out->getTensor() = in->getTensor().shuffle(perm_array);
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+
+DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
+DEF_INSTANTIATE_RESHAPE(OpReshape, AINT8);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
+DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
new file mode 100644
index 0000000..100bd6b
--- /dev/null
+++ b/reference_model/src/ops/data_layout.h
@@ -0,0 +1,216 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_DATA_LAYOUT_H
+#define OPS_DATA_LAYOUT_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpConcat : public GraphNode
+{
+public:
+ OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpConcat();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, Rank> reverser;
+ TosaReference::TensorTemplate<TIn>* lhs;
+ TosaReference::TensorTemplate<TIn>* rhs;
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpPad : public GraphNode
+{
+public:
+ OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpPad();
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ TosaPadQuantInfo* qinfo;
+};
+
+template <int InRank, int OutRank, DType Dtype>
+class OpReshape : public GraphNode
+{
+public:
+ OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpReshape();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, InRank>;
+ using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+
+protected:
+ Eigen::array<Eigen::Index, OutRank> array_shape;
+ Eigen::array<Eigen::Index, InRank> in_reverser;
+ Eigen::array<Eigen::Index, OutRank> out_reverser;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReshapeAttribute* attribute;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpReverse : public GraphNode
+{
+public:
+ OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpReverse();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ Eigen::array<bool, Rank> reverse_array;
+};
+
+template <int Rank, DType Dtype>
+class OpSlice : public GraphNode
+{
+public:
+ OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpSlice();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaSliceAttribute* attribute;
+ Eigen::array<Eigen::Index, Rank> begin_array;
+ Eigen::array<Eigen::Index, Rank> size_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpTileBase : public GraphNode
+{
+public:
+ OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTileBase();
+
+ virtual int checkTensorAttributes();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaTileAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+// primary template for op tile
+template <int Rank, DType Dtype>
+class OpTile : public OpTileBase<Rank, Dtype>
+{
+public:
+ OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpTileBase<Rank, Dtype>(attribute_, qinfo_, id_)
+ {}
+
+protected:
+ virtual int eval();
+};
+
+// partial specialization for specific rank
+#define DEF_OP_TILE_RANK(N) \
+ template <DType Dtype> \
+ class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
+ { \
+ public: \
+ OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : OpTileBase<N, Dtype>(attribute_, qinfo_, id_) \
+ {} \
+ \
+ protected: \
+ virtual int eval(); \
+ };
+
+DEF_OP_TILE_RANK(1)
+DEF_OP_TILE_RANK(2)
+DEF_OP_TILE_RANK(3)
+DEF_OP_TILE_RANK(4)
+
+#undef DEF_OP_TILE_RANK
+
+template <int Rank, DType Dtype>
+class OpTranspose : public GraphNode
+{
+public:
+ OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTranspose();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, Rank> perm_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<ETensor1<int32_t>>* perm_tensor;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
new file mode 100644
index 0000000..2ee4935
--- /dev/null
+++ b/reference_model/src/ops/data_nodes.cc
@@ -0,0 +1,172 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "data_nodes.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpConst::OpConst(uint64_t id_)
+ : GraphNode(Op_CONST, id_)
+{
+ setRequiredOperands(0, 1);
+}
+
+OpConst::~OpConst()
+{}
+
+int OpConst::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ return 0;
+}
+
+int OpConst::eval()
+{
+ // Evaluation is trivial for constants
+ return GraphNode::eval();
+}
+
+OpPlaceholder::OpPlaceholder(uint64_t id_)
+ : GraphNode(Op_PLACEHOLDER, id_)
+{
+ setRequiredOperands(0, 1);
+}
+
+OpPlaceholder::~OpPlaceholder()
+{}
+
+int OpPlaceholder::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ return 0;
+}
+
+int OpPlaceholder::eval()
+{
+ // Evaluation is trivial for placeholders
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_IDENTITY, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpIdentity<Rank, Dtype>::~OpIdentity()
+{}
+
+template <int Rank, DType Dtype>
+int OpIdentity<Rank, Dtype>::checkTensorAttributes()
+{
+
+ if (inputs.size() != outputs.size())
+ {
+ printNodeValidationError("Input and output tensor list lengths are not equal");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (in->matchRankTypeShape(*out))
+ {
+ printNodeValidationError("Input and output tensor rank, type, or shape do not match");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpIdentity<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor();
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpIdentityN<Rank, Dtype>::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_IDENTITYN, id_)
+{
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpIdentityN<Rank, Dtype>::~OpIdentityN()
+{}
+
+template <int Rank, DType Dtype>
+int OpIdentityN<Rank, Dtype>::checkTensorAttributes()
+{
+
+ if (inputs.size() != outputs.size())
+ {
+ printNodeValidationError("Input and output tensor list lengths are not equal");
+ return 1;
+ }
+
+ for (size_t i = 0; i < inputs.size(); i++)
+ {
+ ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
+ outs.push_back(dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[i]));
+
+ if (ins[i]->matchRankTypeShape(*outs[i]))
+ {
+ printNodeValidationError("Input and output tensor rank, type, or shape do not match");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpIdentityN<Rank, Dtype>::eval()
+{
+ for (size_t i = 0; i < ins.size(); i++)
+ {
+ outs[i]->getTensor() = ins[i]->getTensor();
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+// note OpConst and OpPlaceholder are not templated
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h
new file mode 100644
index 0000000..bec4669
--- /dev/null
+++ b/reference_model/src/ops/data_nodes.h
@@ -0,0 +1,86 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_DATA_NODES_H
+#define OPS_DATA_NODES_H
+
+#include "graph_node.h"
+
+namespace TosaReference
+{
+
+class OpConst : public GraphNode
+{
+public:
+ OpConst(uint64_t id_);
+ virtual ~OpConst();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+class OpPlaceholder : public GraphNode
+{
+public:
+ OpPlaceholder(uint64_t id_);
+ virtual ~OpPlaceholder();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpIdentity : public GraphNode
+{
+public:
+ OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpIdentity();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpIdentityN : public GraphNode
+{
+public:
+ OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpIdentityN();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ std::vector<TosaReference::TensorTemplate<TIn>*> ins;
+ std::vector<TosaReference::TensorTemplate<TOut>*> outs;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
new file mode 100644
index 0000000..4d4f8b9
--- /dev/null
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -0,0 +1,586 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_binary.h"
+#include "arith_util.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType InDtype, DType OutDtype>
+BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+
+ a_rank = b_rank = max_input_rank = -1;
+ a = b = nullptr;
+ a_rank0 = b_rank0 = nullptr;
+ result = nullptr;
+
+ fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
+{}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ a_rank = inputs[0]->getRank();
+ b_rank = inputs[1]->getRank();
+ if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
+ {
+ printNodeValidationError("Binary operator input ranks must match");
+ return 1;
+ }
+
+ max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
+
+ // A & B must be the same types
+ if (inputs[0]->matchType(*inputs[1]))
+ {
+ printNodeValidationError("Binary operator input types must match");
+ return 1;
+ }
+
+ // Result's geometry must match, but the type may be wider
+ if (outputs[0]->getRank() != max_input_rank)
+ {
+ printNodeValidationError("Binary operator input and output genometry must match");
+ return 1;
+ }
+
+ if (a_rank == max_input_rank)
+ {
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ }
+ else
+ {
+ a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
+ }
+
+ if (b_rank == max_input_rank)
+ {
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ }
+ else
+ {
+ b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
+ }
+
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ // either a or b can be rank0
+ // a_rank0 and b_rank0 can't be valid at the same time.
+ // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
+ ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
+{
+ auto output_shape = result->getTensor().dimensions();
+
+ std::vector<int> a_shape, b_shape;
+
+ if (a_rank == max_input_rank)
+ {
+ a_shape = a->getShape();
+ }
+ else
+ {
+ a_shape.assign(max_input_rank, 1);
+ }
+
+ if (b_rank == max_input_rank)
+ {
+ b_shape = b->getShape();
+ }
+ else
+ {
+ b_shape.assign(max_input_rank, 1);
+ }
+
+ for (int i = 0; i < max_input_rank; i++)
+ {
+ if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
+ {
+ bcast_a[i] = output_shape[i];
+ }
+ else
+ {
+ bcast_a[i] = 1;
+ }
+ if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
+ {
+ bcast_b[i] = output_shape[i];
+ }
+ else
+ {
+ bcast_b[i] = 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNode<Rank, InDtype, OutDtype>::eval()
+{
+ this->broadcast();
+
+ Eigen::array<int, Rank> reshaper;
+ reshaper.fill(1);
+ TIn ia, ib;
+
+ if (this->a_rank == this->max_input_rank)
+ {
+ ia = this->a->getTensor().broadcast(this->bcast_a);
+ }
+ else
+ {
+ ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
+ }
+
+ if (this->b_rank == this->max_input_rank)
+ {
+ ib = this->b->getTensor().broadcast(this->bcast_b);
+ }
+ else
+ {
+ ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
+ }
+
+ this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
+
+ return GraphNode::eval();
+}
+
+// still need to partial specialize this, or Eigen will throw static assertion
+template <DType InDtype, DType OutDtype>
+int BinaryNode<0, InDtype, OutDtype>::eval()
+{
+ this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpAdd<Rank, Dtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits = 0;
+ switch (Dtype)
+ {
+ case DType_INT8:
+ num_bits = 8;
+ break;
+ case DType_INT16:
+ num_bits = 16;
+ break;
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ uint32_t sign = a & (1 << (num_bits - 1));
+ uint32_t ones_mask = ONES_MASK(b) << (num_bits - b);
+ if (sign)
+ return ones_mask | (a >> b);
+ else
+ return (~ones_mask) & (a >> b);
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseAnd<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseOr<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseXor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalAnd<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_INT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalRightShift<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits = 0;
+ switch (Dtype)
+ {
+ case DType_INT8:
+ num_bits = 8;
+ break;
+ case DType_INT16:
+ num_bits = 16;
+ break;
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ uint32_t mask = ONES_MASK(num_bits) >> b;
+ return (a >> b) & mask;
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalOr<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalXor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpMaximum<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpMinimum<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpMul<Rank, InDtype, OutDtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ break;
+ case DType_INT8:
+ case DType_INT16:
+ this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
+ OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
+
+ OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
+
+ return clamped_output;
+ };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpPow<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSub<Rank, Dtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank>
+OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TABLE, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank>
+OpTable<Rank>::~OpTable()
+{}
+
+template <int Rank>
+int OpTable<Rank>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+ {
+ FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && table && out);
+
+ return 0;
+}
+
+template <int Rank>
+int OpTable<Rank>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+ // 1. make sure input is int16 range
+ int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+
+ // 2. calculate index and interpolation fraction
+ int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
+ index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
+ int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
+
+ // 3. interpolate, generate 16.7 (23-bit) output
+ int32_t base = this->table->getTensor()(index);
+ int32_t next = this->table->getTensor()(index + 1);
+ int32_t value = (base << 7) + (next - base) * frac;
+
+ return value;
+ });
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+
+DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
new file mode 100644
index 0000000..00fb3d9
--- /dev/null
+++ b/reference_model/src/ops/ewise_binary.h
@@ -0,0 +1,195 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_EWISE_BINARY_H
+#define OPS_EWISE_BINARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// class BinaryNodeBase: hold common functions of all the binary nodes
+// when an binary op is created, the virtual OpXXX::register_fcn() will be called
+// and 'fcn' will be register with lambda function which has two inputs
+// class BinaryNode: the level of indirection to partially specialize template for rank 0
+// eval() from toplevel called should call the .binaryExpr(dims, fcn) here
+// this needs to be partially specialize or
+// compiler will statically fail when trying to broadcast rank0 tensor
+// class OpXXX: implement per-element lambda function based on different data type
+// unlike BinaryNode, this doesn't need to be partially specialized
+
+// Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.)
+// which might be faster since it could be implemented with SIMD instructions
+// the way of registering lambda + .binaryExpr() might sacrifice performance here
+// but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...}
+// needs to revisit if performance becomes a bottleneck here
+template <int Rank, DType InDtype, DType OutDtype>
+class BinaryNodeBase : public GraphNode
+{
+public:
+ BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_);
+ virtual ~BinaryNodeBase();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() = 0;
+ virtual int register_fcn() = 0;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ int broadcast();
+
+protected:
+ std::function<OutEigenType(InEigenType, InEigenType)> fcn;
+ Eigen::array<int, Rank> bcast_a;
+ Eigen::array<int, Rank> bcast_b;
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TIn>* b;
+ TosaReference::TensorTemplate<ETensor0<InEigenType>>* a_rank0;
+ TosaReference::TensorTemplate<ETensor0<InEigenType>>* b_rank0;
+ TosaReference::TensorTemplate<TOut>* result;
+ int a_rank;
+ int b_rank;
+ int max_input_rank;
+};
+
+// primary class
+template <int Rank, DType InDtype, DType OutDtype>
+class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
+{
+public:
+ BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
+ : BinaryNodeBase<Rank, InDtype, OutDtype>(op_, qinfo_, id_)
+ {}
+ virtual ~BinaryNode()
+ {}
+
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+};
+
+// partial specialization for rank 0
+template <DType InDtype, DType OutDtype>
+class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
+{
+public:
+ BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
+ : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_)
+ {}
+ virtual ~BinaryNode()
+ {}
+
+ virtual int eval();
+};
+
+#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : BinaryNode<Rank, Dtype, Dtype>(Op_##OPNAME, qinfo_, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr DType InDtype = Dtype; \
+ static constexpr DType OutDtype = Dtype; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \
+ template <int Rank, DType InDtype, DType OutDtype> \
+ class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<OutDtype>::value; \
+ static constexpr int32_t QMax = GetQMax<OutDtype>::value; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM)
+DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB)
+
+#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE
+#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE
+
+template <int Rank>
+class OpTable : public GraphNode
+{
+public:
+ OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTable();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr DType InDtype = DType_INT16;
+ static constexpr DType TableDtype = DType_INT16;
+ static constexpr DType OutDtype = DType_INT32;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using TableEigenType = typename GetEigenType<TableDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TTable = Eigen::Tensor<TableEigenType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ static constexpr int32_t IntegerBits = 9;
+ static constexpr int32_t FractionBits = 7;
+ static constexpr int32_t NumTableEntries = (1 << IntegerBits);
+ static constexpr int32_t QInMin = GetQMin<InDtype>::value;
+ static constexpr int32_t QInMax = GetQMax<InDtype>::value;
+ static constexpr int32_t QOutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t QOutMax = GetQMax<OutDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TTable>* table;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
new file mode 100644
index 0000000..eded0d7
--- /dev/null
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -0,0 +1,115 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_ternary.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpSelectBase<Rank, Dtype>::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SELECT, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpSelectBase<Rank, Dtype>::~OpSelectBase()
+{}
+
+template <int Rank, DType Dtype>
+int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
+ validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) ||
+ inputs[2]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank and type");
+ return 1;
+ }
+
+ cond = dynamic_cast<TosaReference::TensorTemplate<TCond>*>(inputs[0]);
+ then_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ else_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[2]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSelectBase<Rank, Dtype>::eval()
+{
+ FATAL_ERROR_NODE("shouldn't be called");
+}
+
+template <int Rank, DType Dtype>
+int OpSelect<Rank, Dtype>::broadcast()
+{
+ std::vector<int> cond_shape = this->cond->getShape();
+ std::vector<int> then_shape = this->then_val->getShape();
+ std::vector<int> else_shape = this->else_val->getShape();
+ std::vector<int> out_shape = this->out->getShape();
+
+ for (int i = 0; i < Rank; i++)
+ {
+ this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1;
+ this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1;
+ this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1;
+ ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSelect<Rank, Dtype>::eval()
+{
+ this->broadcast();
+ this->out->getTensor() = this->cond->getTensor()
+ .broadcast(this->bcast_cond)
+ .select(this->then_val->getTensor().broadcast(this->bcast_then),
+ this->else_val->getTensor().broadcast(this->bcast_else));
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpSelect<0, Dtype>::eval()
+{
+ this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
new file mode 100644
index 0000000..b354247
--- /dev/null
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -0,0 +1,83 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TERNARY_H
+#define OPS_TERNARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// The Ternary Select op has the following operands:
+// 1. Cond: rank N, type=bool
+// 2. Then_val: Rank N, type=<V>
+// 3. Else_val: Rank N, type=<V>
+// 4. Result: Rank N, type=<V>
+// Cond, Then_val, Else_val need to be mutually-broadcastable
+template <int Rank, DType Dtype>
+class OpSelectBase : public GraphNode
+{
+public:
+ OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpSelectBase();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using CondEigenType = typename GetEigenType<DType_BOOL>::type;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using TCond = Eigen::Tensor<CondEigenType, Rank>;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+
+protected:
+ TosaReference::TensorTemplate<TCond>* cond;
+ Eigen::array<int, Rank> bcast_cond;
+ Eigen::array<int, Rank> bcast_then;
+ Eigen::array<int, Rank> bcast_else;
+ TosaReference::TensorTemplate<TIn>* then_val;
+ TosaReference::TensorTemplate<TIn>* else_val;
+ TosaReference::TensorTemplate<TIn>* out;
+};
+
+// primary class
+template <int Rank, DType Dtype>
+class OpSelect : public OpSelectBase<Rank, Dtype>
+{
+public:
+ OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpSelectBase<Rank, Dtype>(attribute_, qinfo_, id_)
+ {}
+ virtual int eval();
+ int broadcast();
+
+ using InEigenType = typename OpSelectBase<Rank, Dtype>::InEigenType;
+};
+
+// partial specialization for rank 0
+template <DType Dtype>
+class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype>
+{
+public:
+ OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpSelectBase<0, Dtype>(attribute_, qinfo_, id_)
+ {}
+ virtual int eval();
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
new file mode 100644
index 0000000..d7bddc0
--- /dev/null
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -0,0 +1,302 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_unary.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); };
+}
+
+template <int Rank, DType Dtype>
+UnaryNode<Rank, Dtype>::~UnaryNode()
+{}
+
+template <int Rank, DType Dtype>
+int UnaryNode<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("UnaryNode: input and output rank must match");
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(a && result);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int UnaryNode<Rank, Dtype>::eval()
+{
+ this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpAbs<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseNot<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpCeil<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpClz<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits;
+ switch (Dtype)
+ {
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](int32_t a) -> int32_t {
+ int32_t leading_zeros = 0;
+ for (int bit = num_bits - 1; bit >= 0; bit--)
+ {
+ if (((a >> bit) & 0x1) == 0)
+ {
+ leading_zeros++;
+ }
+ else
+ {
+ break;
+ }
+ }
+ return leading_zeros;
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpExp<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpFloor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLog<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalNot<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpNegate<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a);
+ return result;
+ };
+ break;
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a);
+ return result;
+ };
+ break;
+ case DType_AINT8:
+ ASSERT(this->qinfo);
+ this->fcn = [this](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp();
+ return result;
+ };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReciprocal<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpRsqrt<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
new file mode 100644
index 0000000..0db3cfb
--- /dev/null
+++ b/reference_model/src/ops/ewise_unary.h
@@ -0,0 +1,102 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_EWISE_UNARY_H
+#define OPS_EWISE_UNARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+template <int Rank, DType Dtype>
+class UnaryNode : public GraphNode
+{
+public:
+ UnaryNode(const Op& nodeType, const uint64_t id_);
+ virtual ~UnaryNode();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+ virtual int register_fcn() = 0;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ std::function<OutEigenType(InEigenType)> fcn;
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TOut>* result;
+};
+
+#define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public UnaryNode<Rank, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<Dtype>::value; \
+ static constexpr int32_t QMax = GetQMax<Dtype>::value; \
+ using InEigenType = typename GetEigenType<Dtype>::type; \
+ using OutEigenType = typename GetEigenType<Dtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+#define DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public UnaryNode<Rank, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \
+ { \
+ INIT_QINFO(Unary); \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<Dtype>::value; \
+ static constexpr int32_t QMax = GetQMax<Dtype>::value; \
+ using InEigenType = typename GetEigenType<Dtype>::type; \
+ using OutEigenType = typename GetEigenType<Dtype>::type; \
+ virtual int register_fcn(); \
+ \
+ protected: \
+ TosaUnaryQuantInfo* qinfo; \
+ };
+
+DEF_TEMPLATE_UNARY_OP(Abs, ABS)
+DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT)
+DEF_TEMPLATE_UNARY_OP(Ceil, CEIL)
+DEF_TEMPLATE_UNARY_OP(Clz, CLZ)
+DEF_TEMPLATE_UNARY_OP(Exp, EXP)
+DEF_TEMPLATE_UNARY_OP(Floor, FLOOR)
+DEF_TEMPLATE_UNARY_OP(Log, LOG)
+DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT)
+DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Negate, NEGATE)
+DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL)
+DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT)
+
+#undef DEF_TEMPLATE_UNARY_OP
+#undef DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
new file mode 100644
index 0000000..d3352ce
--- /dev/null
+++ b/reference_model/src/ops/image.cc
@@ -0,0 +1,169 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "image.h"
+#include "arith_util.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <DType InDtype, DType OutDtype>
+OpResize<InDtype, OutDtype>::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESIZE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4, 4);
+
+ INIT_ATTRIBUTE(Resize);
+}
+
+template <DType InDtype, DType OutDtype>
+OpResize<InDtype, OutDtype>::~OpResize()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpResize<InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ return 1;
+
+ output_size = this->attribute->output_size();
+ stride = this->attribute->stride();
+ offset = this->attribute->offset();
+ shift = this->attribute->shift();
+ mode = this->attribute->mode();
+
+ int output_height = outputs[0]->getShape()[1];
+ int output_width = outputs[0]->getShape()[2];
+
+ if (this->mode == ResizeMode_BILINEAR)
+ {
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48)
+ {
+ printNodeValidationError("OpResize: invalid data type for BILINEAR");
+ return 1;
+ }
+ }
+ else
+ {
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16)
+ {
+ printNodeValidationError("OpResize: invalid data type for NEAREST");
+ return 1;
+ }
+ }
+
+ if (output_size[0] != output_height || output_size[1] != output_width)
+ {
+ printNodeValidationError("OpResize: attribute output_size doesn't match output [height, width]");
+ return 1;
+ }
+
+ if (shift < 1 || shift > 11)
+ {
+ printNodeValidationError("OpResize: attribute shift should be within [1, 11]");
+ return 1;
+ }
+
+ if (stride[0] <= 0 || stride[1] <= 0)
+ {
+ printNodeValidationError("OpResize: invalid attribute stride");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpResize<InDtype, OutDtype>::eval()
+{
+ int in_batch = in->getShape()[0];
+ int in_height = in->getShape()[1];
+ int in_width = in->getShape()[2];
+ int in_channels = in->getShape()[3];
+
+ int out_batch = out->getShape()[0];
+ int out_height = out->getShape()[1];
+ int out_width = out->getShape()[2];
+ int out_channels = out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
+ ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+
+ for (int b = 0; b < out_batch; b++)
+ for (int c = 0; c < out_channels; c++)
+ for (int oy = 0; oy < out_height; oy++)
+ for (int ox = 0; ox < out_width; ox++)
+ {
+ int y = oy * stride[0] + offset[0];
+ int x = ox * stride[1] + offset[1];
+
+ int iy = y >> shift;
+ int dy = y - (iy << shift);
+ int ix = x >> shift;
+ int dx = x - (ix << shift);
+
+ int iy0 = MAX(iy, 0);
+ int iy1 = MIN(iy + 1, in_height - 1);
+ int ix0 = MAX(ix, 0);
+ int ix1 = MIN(ix + 1, in_width - 1);
+
+ ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
+ iy0, iy1, ix0, ix1);
+
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
+ OutEigenType acc;
+ if (mode == ResizeMode_BILINEAR)
+ {
+ acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx);
+ acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx;
+ acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx);
+ acc = acc + (OutEigenType)v11 * dy * dx;
+ }
+ else
+ {
+ iy = (dy >> (shift - 1)) != 0 ? iy1 : iy0;
+ ix = (dx >> (shift - 1)) != 0 ? ix1 : ix0;
+ acc = in->getTensor()(b, iy, ix, c);
+ }
+
+ out->getTensor()(b, oy, ox, c) = acc;
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16);
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
new file mode 100644
index 0000000..9d15d49
--- /dev/null
+++ b/reference_model/src/ops/image.h
@@ -0,0 +1,53 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_IMAGE_H
+#define OPS_IMAGE_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <DType InDtype, DType OutDtype>
+class OpResize : public GraphNode
+{
+public:
+ OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpResize();
+ virtual int checkTensorAttributes() final;
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+protected:
+ TosaResizeAttribute* attribute;
+ std::vector<int32_t> output_size;
+ std::vector<int32_t> stride;
+ std::vector<int32_t> offset;
+ int32_t shift;
+ ResizeMode mode;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
new file mode 100644
index 0000000..bad0c40
--- /dev/null
+++ b/reference_model/src/ops/op_factory.cc
@@ -0,0 +1,432 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "op_factory.h"
+#include "activation_funcs.h"
+#include "comparison.h"
+#include "control_flow.h"
+#include "custom.h"
+#include "data_layout.h"
+#include "data_nodes.h"
+#include "ewise_binary.h"
+#include "ewise_ternary.h"
+#include "ewise_unary.h"
+#include "image.h"
+#include "reduction.h"
+#include "scatter_gather.h"
+#include "tensor_ops.h"
+#include "type_conversion.h"
+
+using namespace TosaReference;
+using namespace tosa;
+
+GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
+ Op opType,
+ TosaAttributeBase* attribute,
+ TosaQuantInfoBase* qinfo,
+ uint64_t id,
+ DType inputDType,
+ int inputRank,
+ DType outputDType,
+ int outputRank,
+ DType weightDType,
+ int weightRank)
+{
+ switch (opType)
+ {
+ // tensor_ops
+ case Op_ARGMAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+ break;
+ case Op_AVG_POOL2D:
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16);
+ break;
+ case Op_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8);
+ break;
+ case Op_DEPTHWISE_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
+ break;
+ case Op_FULLY_CONNECTED:
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8);
+ break;
+ case Op_MATMUL:
+ DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpMatMul, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpMatMul, INT16);
+ break;
+ case Op_MAX_POOL2D:
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
+ break;
+ case Op_TRANSPOSE_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
+ break;
+
+ // activation_funcs
+ case Op_CLAMP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+ break;
+ case Op_RELUN:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32);
+ break;
+ case Op_SIGMOID:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+ break;
+ case Op_TANH:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
+ break;
+
+ // ewise_binary
+ case Op_ADD:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+ break;
+ case Op_ARITHMETIC_RIGHT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
+ break;
+ case Op_BITWISE_AND:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
+ break;
+ case Op_BITWISE_OR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
+ break;
+ case Op_BITWISE_XOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
+ break;
+ case Op_LOGICAL_AND:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
+ break;
+ case Op_LOGICAL_LEFT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
+ break;
+ case Op_LOGICAL_RIGHT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
+ break;
+ case Op_LOGICAL_OR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+ break;
+ case Op_LOGICAL_XOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+ break;
+ case Op_MAXIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+ break;
+ case Op_MINIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+ break;
+ case Op_MUL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+ break;
+ case Op_POW:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+ break;
+ case Op_SUB:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+ break;
+ case Op_TABLE:
+ DEF_FACTORY_ONE_RANK_0_6(OpTable);
+ break;
+
+ // ewise_unary
+ case Op_ABS:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+ break;
+ case Op_BITWISE_NOT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
+ break;
+ case Op_CEIL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+ break;
+ case Op_CLZ:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+ break;
+ case Op_EXP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+ break;
+ case Op_FLOOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+ break;
+ case Op_LOG:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+ break;
+ case Op_LOGICAL_NOT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
+ break;
+ case Op_NEGATE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+ break;
+ case Op_RECIPROCAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
+ break;
+ case Op_RSQRT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+ break;
+
+ // ewise_ternary
+ case Op_SELECT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
+ break;
+
+ // comparison
+ case Op_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+ break;
+ case Op_GREATER:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+ break;
+ case Op_GREATER_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
+ break;
+
+ // reduction
+ case Op_REDUCE_ALL:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
+ break;
+ case Op_REDUCE_ANY:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
+ break;
+ case Op_REDUCE_MAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+ break;
+ case Op_REDUCE_MIN:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+ break;
+ case Op_REDUCE_PRODUCT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+ break;
+ case Op_REDUCE_SUM:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32);
+ break;
+
+ // data layout
+ case Op_CONCAT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
+ break;
+ case Op_PAD:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+ break;
+ case Op_RESHAPE:
+ DEF_FACTORY_RESHAPE(OpReshape, FLOAT);
+ DEF_FACTORY_RESHAPE(OpReshape, AINT8);
+ DEF_FACTORY_RESHAPE(OpReshape, INT8);
+ DEF_FACTORY_RESHAPE(OpReshape, INT16);
+ DEF_FACTORY_RESHAPE(OpReshape, INT32);
+ DEF_FACTORY_RESHAPE(OpReshape, BOOL);
+ break;
+ case Op_REVERSE:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+ break;
+ case Op_SLICE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+ break;
+ case Op_TILE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+ break;
+ case Op_TRANSPOSE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+ break;
+
+ // scatter_gather
+ case Op_GATHER:
+ {
+ // output.rank = input.rank - 1 + index.rank
+ int32_t index_rank = outputRank - inputRank + 1;
+ DEF_FACTORY_GATHER(OpGather, AINT8);
+ DEF_FACTORY_GATHER(OpGather, INT16);
+ DEF_FACTORY_GATHER(OpGather, INT32);
+ }
+ break;
+
+ // image
+ case Op_RESIZE:
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16);
+ break;
+
+ // data_nodes
+ case Op_CONST:
+ return new OpConst(id);
+ case Op_PLACEHOLDER:
+ return new OpPlaceholder(id);
+ case Op_IDENTITY:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+ break;
+ case Op_IDENTITYN:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
+ break;
+
+ // type_conversion
+ case Op_CAST:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
+ break;
+ case Op_RESCALE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8);
+ break;
+
+ // custom
+ case Op_CUSTOM:
+ return new OpCustom(id);
+
+ // control_flow
+ case Op_COND_IF:
+ return new OpCondIf(tsh, attribute, id);
+ case Op_WHILE_LOOP:
+ return new OpWhileLoop(tsh, attribute, id);
+
+ // Ops not recognized
+ default:
+ goto done;
+
+ } // End of switch(opType)
+
+done:
+ return nullptr;
+}
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
new file mode 100644
index 0000000..cde6841
--- /dev/null
+++ b/reference_model/src/ops/op_factory.h
@@ -0,0 +1,294 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_OP_FACTORY_H
+#define OPS_OP_FACTORY_H
+
+#include "attribute.h"
+#include "graph_node.h"
+#include "quant_info.h"
+#include "template_types.h"
+#include "tosa_serialization_handler.h"
+
+#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
+ case RANK: \
+ return new OP<RANK, DType_##DTYPE>(attribute, qinfo, id);
+
+#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
+ case RANK: \
+ return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id);
+
+#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
+ case RANK2: \
+ return new OP<RANK1, RANK2, DType_##DTYPE>(attribute, qinfo, id);
+
+#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
+ case RANK2: \
+ return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id);
+
+#define DEF_FACTORY_ONE_RANK_0_6(OP) \
+ switch (inputRank) \
+ { \
+ case 0: \
+ return new OP<0>(attribute, qinfo, id); \
+ case 1: \
+ return new OP<1>(attribute, qinfo, id); \
+ case 2: \
+ return new OP<2>(attribute, qinfo, id); \
+ case 3: \
+ return new OP<3>(attribute, qinfo, id); \
+ case 4: \
+ return new OP<4>(attribute, qinfo, id); \
+ case 5: \
+ return new OP<5>(attribute, qinfo, id); \
+ case 6: \
+ return new OP<6>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ return new OP<DType_##DTYPE>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \
+ } \
+ }
+
+#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \
+ } \
+ }
+
+#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2) \
+ } \
+ }
+
+#define DEF_FACTORY_RESHAPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ case 0: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
+ } \
+ } \
+ case 1: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
+ } \
+ } \
+ case 2: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
+ } \
+ } \
+ case 3: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
+ } \
+ } \
+ case 4: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
+ } \
+ } \
+ case 5: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
+ } \
+ } \
+ case 6: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE) \
+ } \
+ } \
+ } \
+ }
+
+#define DEF_FACTORY_GATHER(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ case 1: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE); \
+ } \
+ case 2: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE); \
+ } \
+ case 3: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE); \
+ } \
+ case 4: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE); \
+ } \
+ case 5: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE); \
+ } \
+ case 6: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE); \
+ } \
+ } \
+ }
+
+namespace TosaReference
+{
+
+class OpFactory
+{
+public:
+ static GraphNode* newOp(tosa::TosaSerializationHandler* tsh,
+ tosa::Op opType,
+ tosa::TosaAttributeBase* attribute,
+ tosa::TosaQuantInfoBase* qinfo,
+ uint64_t id,
+ tosa::DType inputDType,
+ int inputRank,
+ tosa::DType outputDType,
+ int outputRank,
+ tosa::DType weightDType,
+ int weightRank);
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
new file mode 100644
index 0000000..a2adfdb
--- /dev/null
+++ b/reference_model/src/ops/reduction.cc
@@ -0,0 +1,139 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reduction.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+ReduceNode<Rank, Dtype>::ReduceNode(const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 4);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+ReduceNode<Rank, Dtype>::~ReduceNode()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int ReduceNode<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
+ {
+ printNodeValidationError("Reduce axis must between [0, input_rank - 1]");
+ return 1;
+ }
+
+ if (inputs[0]->matchRank(*outputs[0]))
+ {
+ printNodeValidationError("Input and output tensor ranks must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ dims[0] = this->attribute->axis();
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReduceAll<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceAny<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceMax<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceMin<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceProduct<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceSum<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32);
diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h
new file mode 100644
index 0000000..cf75812
--- /dev/null
+++ b/reference_model/src/ops/reduction.h
@@ -0,0 +1,109 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_REDUCTION_H
+#define OPS_REDUCTION_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class ReduceNode : public GraphNode
+{
+public:
+ ReduceNode(const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_);
+ virtual ~ReduceNode();
+ virtual int checkTensorAttributes();
+ virtual int eval() = 0;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, 1> dims;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ TosaAxisAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpReduceAll : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceAll(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceAny : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceAny(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceMax : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_MAX, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceMin : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceMin(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_MIN, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceProduct : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceProduct(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_PRODUCT, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceSum : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceSum(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_SUM, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
new file mode 100644
index 0000000..c54204a
--- /dev/null
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -0,0 +1,120 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "scatter_gather.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int InRank, int IndexRank, DType Dtype>
+OpGather<InRank, IndexRank, Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_GATHER, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+OpGather<InRank, IndexRank, Dtype>::~OpGather()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ index = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && index && out);
+
+ return 0;
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+int OpGather<InRank, IndexRank, Dtype>::eval()
+{
+ int axis = attribute->axis();
+
+ // calculate size left and right to axis
+ int left_size = 1;
+ for (int i = 0; i < axis; ++i)
+ {
+ left_size *= in->getShape()[i];
+ }
+
+ int right_size = 1;
+ for (int i = axis + 1; i < in->getRank(); ++i)
+ {
+ right_size *= in->getShape()[i];
+ }
+
+ InEigenType* input_data = in->getTensor().data();
+ int32_t* index_data = index->getTensor().data();
+ OutEigenType* output_data = out->getTensor().data();
+
+ int32_t axis_size = in->getShape()[axis];
+ int32_t index_count = index->getElementCount();
+
+ // sanity check if index is valid
+ // need to check until this point since index is not known until runtime
+ for (size_t i = 0; i < index->getElementCount(); i++)
+ {
+ if (index_data[i] >= axis_size)
+ {
+ FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size);
+ }
+ }
+
+ // Eigen stores tensor in column-major
+ // so we iterate through dimension right to axis and the index array
+ // do memory copy with size of left size each time
+ for (int right = 0; right < right_size; ++right)
+ {
+ for (int i = 0; i < index_count; ++i)
+ {
+ std::memcpy(output_data + (right * index_count + i) * left_size,
+ input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size);
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_GATHER(OpGather, AINT8);
+DEF_INSTANTIATE_GATHER(OpGather, INT16);
+DEF_INSTANTIATE_GATHER(OpGather, INT32);
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
new file mode 100644
index 0000000..d9b1263
--- /dev/null
+++ b/reference_model/src/ops/scatter_gather.h
@@ -0,0 +1,54 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_SCATTER_GATHER_H
+#define OPS_SCATTER_GATHER_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// input and index can have different rank
+// and infer OutRank statically
+template <int InRank, int IndexRank, DType Dtype>
+class OpGather : public GraphNode
+{
+public:
+ OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpGather();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr int OutRank = InRank - 1 + IndexRank;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, InRank>;
+ using TIndex = Eigen::Tensor<int32_t, IndexRank>;
+ using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TIndex>* index;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
new file mode 100644
index 0000000..1859e03
--- /dev/null
+++ b/reference_model/src/ops/template_types.h
@@ -0,0 +1,277 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OP_TEMPLATE_TYPES_H
+#define OP_TEMPLATE_TYPES_H
+
+#include "tosa_generated.h"
+#include <Eigen/CXX11/Tensor>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+// Shorter aliase templates for common Eigen::Tensor types
+template <typename T>
+using ETensor0 = Eigen::Tensor<T, 0>;
+template <typename T>
+using ETensor1 = Eigen::Tensor<T, 1>;
+template <typename T>
+using ETensor2 = Eigen::Tensor<T, 2>;
+template <typename T>
+using ETensor3 = Eigen::Tensor<T, 3>;
+template <typename T>
+using ETensor4 = Eigen::Tensor<T, 4>;
+template <typename T>
+using ETensor5 = Eigen::Tensor<T, 5>;
+template <typename T>
+using ETensor6 = Eigen::Tensor<T, 6>;
+
+// Forward declaration
+template <class T>
+class TensorTemplate;
+
+// Shortcut to hide the TensorTemplate class.
+// For example, declare Tensor1<float> to get a TensorTemplate
+// with an Eigen::Tensor<float, 1>
+template <typename T>
+using Tensor0 = TensorTemplate<ETensor0<T>>;
+template <typename T>
+using Tensor1 = TensorTemplate<ETensor1<T>>;
+template <typename T>
+using Tensor2 = TensorTemplate<ETensor2<T>>;
+template <typename T>
+using Tensor3 = TensorTemplate<ETensor3<T>>;
+template <typename T>
+using Tensor4 = TensorTemplate<ETensor4<T>>;
+template <typename T>
+using Tensor5 = TensorTemplate<ETensor5<T>>;
+template <typename T>
+using Tensor6 = TensorTemplate<ETensor6<T>>;
+
+template <DType type>
+struct GetEigenType;
+template <>
+struct GetEigenType<DType_FLOAT>
+{
+ using type = float;
+};
+template <>
+struct GetEigenType<DType_INT32>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT48>
+{
+ using type = int64_t;
+};
+template <>
+struct GetEigenType<DType_BOOL>
+{
+ using type = bool;
+};
+template <>
+struct GetEigenType<DType_AINT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_UINT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT4>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT16>
+{
+ using type = int32_t;
+};
+
+// Meta function to get number of bits
+template <DType T>
+struct GetNumBits
+{
+ static constexpr int32_t value = 0;
+};
+template <>
+struct GetNumBits<DType_BOOL>
+{
+ static constexpr int32_t value = 1;
+};
+template <>
+struct GetNumBits<DType_AINT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_UINT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_INT4>
+{
+ static constexpr int32_t value = 4;
+};
+template <>
+struct GetNumBits<DType_INT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_INT16>
+{
+ static constexpr int32_t value = 16;
+};
+template <>
+struct GetNumBits<DType_INT32>
+{
+ static constexpr int32_t value = 32;
+};
+template <>
+struct GetNumBits<DType_INT48>
+{
+ static constexpr int32_t value = 48;
+};
+
+// Meta function to get quantized min/max in compile time
+template <DType T>
+struct GetQMin
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMin<DType_AINT8>
+{
+ static constexpr int64_t value = -128L;
+};
+template <>
+struct GetQMin<DType_UINT8>
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMin<DType_INT4>
+{
+ static constexpr int64_t value = -8L;
+};
+template <>
+struct GetQMin<DType_INT8>
+{
+ static constexpr int64_t value = -128L;
+};
+template <>
+struct GetQMin<DType_INT16>
+{
+ static constexpr int64_t value = -32768L;
+};
+template <>
+struct GetQMin<DType_INT32>
+{
+ static constexpr int64_t value = -(1L << 31);
+};
+template <>
+struct GetQMin<DType_INT48>
+{
+ static constexpr int64_t value = -(1L << 47);
+};
+
+template <DType T>
+struct GetQMax
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMax<DType_AINT8>
+{
+ static constexpr int64_t value = 127L;
+};
+template <>
+struct GetQMax<DType_UINT8>
+{
+ static constexpr int64_t value = 255L;
+};
+template <>
+struct GetQMax<DType_INT4>
+{
+ static constexpr int64_t value = 7L;
+};
+template <>
+struct GetQMax<DType_INT8>
+{
+ static constexpr int64_t value = 127L;
+};
+template <>
+struct GetQMax<DType_INT16>
+{
+ static constexpr int64_t value = 32767L;
+};
+template <>
+struct GetQMax<DType_INT32>
+{
+ static constexpr int64_t value = (1L << 31) - 1;
+};
+template <>
+struct GetQMax<DType_INT48>
+{
+ static constexpr int64_t value = (1L << 47) - 1;
+};
+
+template <DType TIn1, DType TIn2>
+struct GetAccDType;
+template <>
+struct GetAccDType<DType_AINT8, DType_AINT8>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_AINT8, DType_INT4>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_AINT8, DType_INT8>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_INT16, DType_INT8>
+{
+ static constexpr DType value = DType_INT48;
+};
+template <>
+struct GetAccDType<DType_INT16, DType_INT16>
+{
+ static constexpr DType value = DType_INT48;
+};
+template <>
+struct GetAccDType<DType_FLOAT, DType_FLOAT>
+{
+ static constexpr DType value = DType_FLOAT;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
new file mode 100644
index 0000000..a735334
--- /dev/null
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -0,0 +1,1229 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tensor_ops.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpArgMax<Rank, Dtype>::OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_ARGMAX, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpArgMax<Rank, Dtype>::~OpArgMax()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpArgMax<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpArgMax<Rank, Dtype>::eval()
+{
+ Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
+
+ this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpAvgPool2d<Dtype>::OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_AVG_POOL2D, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Pool2d);
+ INIT_QINFO(Unary);
+}
+
+template <DType Dtype>
+OpAvgPool2d<Dtype>::~OpAvgPool2d()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType Dtype>
+int OpAvgPool2d<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (!in->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpAvgPool2d: unsupported tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->kernel().size() != 2)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute kernel");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType Dtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride)
+{
+ ETensor1<int32_t> result(out_size);
+
+ int32_t total_pad = (out_size - 1) * stride + kernel_size - in_size;
+ total_pad = total_pad < 0 ? 0 : total_pad;
+
+ int32_t pad_left = total_pad >> 1;
+ int32_t pad_right = total_pad - pad_left;
+
+ result.setConstant(kernel_size);
+
+ // the index left to 'left_index' and index right to 'right_index' indicates
+ // the input window of this output covers a pad bit
+ int32_t left_index = pad_left / stride;
+ int32_t right_index = pad_right / stride;
+
+ // not handle ultra small activation yet
+ ASSERT_MSG_NODE((out_size - 1 - right_index) >= left_index, "AvgPool2d: Small activations not supported yet");
+
+ // minus the number of pad bit this index cover
+ while (left_index >= 0)
+ {
+ result(left_index) -= (pad_left - left_index * stride);
+ left_index--;
+ }
+
+ while (right_index >= 0)
+ {
+ result(out_size - 1 - right_index) -= (pad_right - right_index * stride);
+ right_index--;
+ }
+
+ return result;
+}
+
+// assuming input and output tensor have same scales like tflite reference
+// so no need to scale input and output
+template <DType Dtype>
+int OpAvgPool2d<Dtype>::eval()
+{
+ int in_batch = this->in->getShape()[0];
+ int in_height = this->in->getShape()[1];
+ int in_width = this->in->getShape()[2];
+ int in_channels = this->in->getShape()[3];
+
+ int out_batch = this->out->getShape()[0];
+ int out_height = this->out->getShape()[1];
+ int out_width = this->out->getShape()[2];
+ int out_channels = this->out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int kernel_h = this->attribute->kernel()[0];
+ int kernel_w = this->attribute->kernel()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+
+ DEBUG_INFO(OP,
+ "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
+ "stride=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
+ kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right);
+
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = kernel_h * kernel_w;
+ im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ ETensor4<InEigenType> input_val = this->in->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // assuming input and output have same scales
+ // so input and output scaling is not required
+ // TODO: check if this assumption TOSA made
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // transpose to [KH, KW, N, H * W, C]
+ // reshape to [KH * KW, N * H * W * C]
+ ETensor2<InEigenType> input_extract_patches =
+ input_padded.extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID)
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
+ .reshape(im2col_input_dims);
+
+ // 1D result with [N * H * W * C]
+ ETensor1<AccEigenType> out_1d(this->out->getElementCount());
+ out_1d.setZero();
+
+ // sum pool
+ for (size_t i = 0; i < this->out->getElementCount(); i++)
+ {
+ for (int32_t j = 0; j < kernel_h * kernel_w; j++)
+ {
+ out_1d(i) += (AccEigenType)input_extract_patches(j, i);
+ }
+ }
+
+ // reshape result to [N, H, W, C] and divide with div_map
+ ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
+
+ // calculate 1d height/width div_map (number of elements this pooling window covers)
+ // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
+ ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h);
+ ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w);
+ Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
+ Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
+
+ ETensor4<int32_t> div_map =
+ div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 })
+ .contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
+ .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
+ .broadcast(bcast);
+
+ if (Dtype != DType_FLOAT)
+ {
+ this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
+ int32_t multiplier, shift;
+ TosaReference::QuantUtil<AccDtype>::reciprocal_scale(div, multiplier, shift);
+
+ return (OutEigenType)TosaReference::QuantUtil<AccDtype>::apply_scale(value, multiplier, shift, false);
+ });
+ this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp());
+ this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
+ this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
+ }
+ else
+ {
+ this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv2d<InDtype, WeightDtype>::OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Conv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv2d<InDtype, WeightDtype>::~OpConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+ if (inputs[2]->getRank() != 1)
+ {
+ printNodeValidationError("OpConv2d: bias tensor must be rank 1");
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_OHWI))
+ {
+ printNodeValidationError("OpConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv2d<InDtype, WeightDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_out_channels = this->weight->getShape()[0];
+ int f_height = this->weight->getShape()[1];
+ int f_width = this->weight->getShape()[2];
+ int f_in_channels = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
+ in_channels);
+ ASSERT_MSG_NODE(f_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
+ out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", b_out_channels,
+ out_channels);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ DEBUG_INFO(OP,
+ "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
+ "stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_bottom, padding_left, padding_right);
+
+ // GEMM-conv2d, left matrix is input, right matrix is weight
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = out_batch * out_height * out_width;
+ im2col_input_dims[1] = f_height * f_width * f_in_channels;
+
+ Eigen::array<Eigen::Index, 2> im2col_weight_dims;
+ im2col_weight_dims[0] = f_height * f_width * f_in_channels;
+ im2col_weight_dims[1] = f_out_channels;
+
+ Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
+ bias_reshaped_dims[0] = 1;
+ bias_reshaped_dims[1] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
+ weight_zp_bcast_dims[0] = f_height;
+ weight_zp_bcast_dims[1] = f_width;
+ weight_zp_bcast_dims[2] = f_in_channels;
+
+ Eigen::array<Eigen::Index, 2> bias_bcast_dims;
+ bias_bcast_dims[0] = out_batch * out_height * out_width;
+ bias_bcast_dims[1] = 1;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // need to transpose to [N, H * W, KH, KW, C]
+ ETensor5<InEigenType> input_extract_patches =
+ input_padded
+ .extract_image_patches(f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID)
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
+
+ // reshape input to [N * H * W, KH * KW * C]
+ ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
+
+ // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
+ ETensor2<WeightEigenType> im2col_weight =
+ weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
+
+ // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
+ // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
+ ETensor2<AccEigenType> bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims);
+
+ // output matrix is [N * H * W, C]
+ ETensor2<AccEigenType> contracted_result =
+ im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims);
+
+ // adding bias
+ ETensor2<AccEigenType> biased_output = contracted_result + bias_2d.template cast<AccEigenType>();
+
+ // reshape back to [N, H, W, C]
+ this->output->getTensor() = biased_output.reshape(col2im_output_dims);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_DEPTHWISE_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Conv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+ if (inputs[2]->getRank() != 1)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpDepthwiseConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_HWIM))
+ {
+ printNodeValidationError("OpDepthwiseConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_height = this->weight->getShape()[0];
+ int f_width = this->weight->getShape()[1];
+ int f_in_channels = this->weight->getShape()[2];
+ int f_multiplier = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d",
+ f_in_channels, in_channels);
+ ASSERT_MSG_NODE(in_channels * f_multiplier == out_channels,
+ "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier,
+ out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d",
+ b_out_channels, out_channels);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ DEBUG_INFO(OP,
+ "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_bottom, padding_left, padding_right);
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // GEMM doesn't fit well with DepthwiseConv2d
+ // 1. use extract_image_patches() to handle stride/dilation/padding
+ // 2. perform direct convolution
+
+ // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
+ ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
+ f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID);
+
+ Eigen::array<Eigen::Index, 4> reshape_dim;
+ reshape_dim.fill(1);
+ reshape_dim[3] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> bcast;
+ bcast[0] = out_batch;
+ bcast[1] = out_height;
+ bcast[2] = out_width;
+ bcast[3] = 1;
+
+ // initialize with bias
+ this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+ // 2. direct depthwise convolution
+ for (int ob = 0; ob < out_batch; ob++)
+ {
+ for (int oh = 0; oh < out_height; oh++)
+ {
+ for (int ow = 0; ow < out_width; ow++)
+ {
+ for (int ic = 0; ic < in_channels; ic++)
+ {
+ for (int cm = 0; cm < f_multiplier; cm++)
+ {
+ for (int fh = 0; fh < f_height; fh++)
+ {
+ for (int fw = 0; fw < f_width; fw++)
+ {
+ this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
+ ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
+ (AccEigenType)weight_val(fh, fw, ic, cm));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_FULLY_CONNECTED, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(2);
+
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+
+ if (input->getShape()[1] != weight->getShape()[1])
+ {
+ printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
+ return 1;
+ }
+
+ if (weight->getShape()[0] != bias->getShape()[0])
+ {
+ printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
+ return 1;
+ }
+
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpFullyConnected<InDtype, WeightDtype>::eval()
+{
+ typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
+ Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
+
+ Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
+
+ Eigen::array<Eigen::Index, 2> bias_reshape;
+ bias_reshape[0] = 1;
+ bias_reshape[1] = this->bias->getShape()[0];
+
+ Eigen::array<Eigen::Index, 2> bias_bcast;
+ bias_bcast[0] = this->input->getShape()[0];
+ bias_bcast[1] = 1;
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ this->output->getTensor() =
+ input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims) +
+ this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpMatMul<Dtype>::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_MATMUL, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(2);
+
+ INIT_QINFO(MatMul);
+}
+
+template <DType Dtype>
+OpMatMul<Dtype>::~OpMatMul()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType Dtype>
+int OpMatMul<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+
+ if (a->getShape()[1] != b->getShape()[0])
+ {
+ printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]");
+ return 1;
+ }
+
+ c = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpMatMul<Dtype>::eval()
+{
+ typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
+ Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
+
+ TIn a_val = this->a->getTensor();
+ TIn b_val = this->b->getTensor();
+ if (this->qinfo)
+ {
+ a_val = a_val - (InEigenType)this->qinfo->a_zp();
+ b_val = b_val - (InEigenType)this->qinfo->b_zp();
+ }
+
+ this->c->getTensor() = a_val.template cast<AccEigenType>().contract(b_val.template cast<AccEigenType>(), dims);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpMaxPool2d<Dtype>::OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_MAX_POOL2D, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Pool2d);
+}
+
+template <DType Dtype>
+OpMaxPool2d<Dtype>::~OpMaxPool2d()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType Dtype>
+int OpMaxPool2d<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (!in->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpMaxPool2d: unsupported tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->kernel().size() != 2)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute kernel");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpMaxPool2d<Dtype>::eval()
+{
+ int in_batch = this->in->getShape()[0];
+ int in_height = this->in->getShape()[1];
+ int in_width = this->in->getShape()[2];
+ int in_channels = this->in->getShape()[3];
+
+ int out_batch = this->out->getShape()[0];
+ int out_height = this->out->getShape()[1];
+ int out_width = this->out->getShape()[2];
+ int out_channels = this->out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int kernel_h = this->attribute->kernel()[0];
+ int kernel_w = this->attribute->kernel()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+
+ DEBUG_INFO(OP,
+ "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
+ "stride=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
+ kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right);
+
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = kernel_h * kernel_w;
+ im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ ETensor4<InEigenType> input_padded = this->in->getTensor().pad(padding, std::numeric_limits<InEigenType>::lowest());
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // transpose to [KH, KW, N, H * W, C]
+ // reshape to [KH * KW, N * H * W * C]
+ //
+ // Set the padding value to be the most negative value that can be
+ // represented by the datatype to ensure that any padding values will be equal
+ // to or smaller than the actual maximum in the KH x KW patch.
+ ETensor2<InEigenType> input_extract_patches =
+ input_padded
+ .extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID,
+ std::numeric_limits<InEigenType>::lowest())
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
+ .reshape(im2col_input_dims);
+
+ // Get the maximum of the KHxHW patches along axis 0
+ Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
+
+ // 1D result with [N * H * W * C]
+ ETensor1<OutEigenType> out_1d(this->out->getElementCount());
+
+ // index input_patches with argmax array should give the result
+ for (size_t i = 0; i < this->out->getElementCount(); i++)
+ {
+ out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
+ }
+
+ // reshape result to [N, H, W, C]
+ this->out->getTensor() = out_1d.reshape(col2im_output_dims);
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_TRANSPOSE_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(TransposeConv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpTransposeConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpTransposeConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_OHWI))
+ {
+ printNodeValidationError("OpTransposeConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->outpad().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute outpad");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ if (attribute->output_shape().size() != 4)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
+ return 1;
+ }
+
+ for (int d = 0; d < 4; d++)
+ {
+ if (attribute->output_shape()[d] != this->output->getShape()[d])
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, OutDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_out_channels = this->weight->getShape()[0];
+ int f_height = this->weight->getShape()[1];
+ int f_width = this->weight->getShape()[2];
+ int f_in_channels = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ int padding_top = this->attribute->outpad()[0];
+ int padding_left = this->attribute->outpad()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d",
+ f_in_channels, in_channels);
+ ASSERT_MSG_NODE(f_out_channels == out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
+ f_out_channels, out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d",
+ b_out_channels, out_channels);
+
+ DEBUG_INFO(OP,
+ "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_left);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ Eigen::array<Eigen::Index, 4> reshape_dim;
+ reshape_dim.fill(1);
+ reshape_dim[3] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> bcast;
+ bcast[0] = out_batch;
+ bcast[1] = out_height;
+ bcast[2] = out_width;
+ bcast[3] = 1;
+
+ // initialize with bias
+ this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+ int out_x_origin, out_y_origin;
+ int out_x, out_y;
+
+ // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
+ for (int ob = 0; ob < out_batch; ob++)
+ {
+ for (int ih = 0; ih < in_height; ih++)
+ {
+ for (int iw = 0; iw < in_width; iw++)
+ {
+ out_x_origin = iw * stride_w - padding_left;
+ out_y_origin = ih * stride_h - padding_top;
+ for (int ic = 0; ic < in_channels; ic++)
+ {
+ for (int fh = 0; fh < f_height; fh++)
+ {
+ for (int fw = 0; fw < f_width; fw++)
+ {
+ out_x = out_x_origin + fw * dilation_w;
+ out_y = out_y_origin + fh * dilation_h;
+ for (int oc = 0; oc < out_channels; oc++)
+ {
+ if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
+ {
+ this->output->getTensor()(ob, out_y, out_x, oc) +=
+ ((AccEigenType)input_val(ob, ih, iw, ic) *
+ (AccEigenType)weight_val(oc, fh, fw, ic));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT)
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, AINT8)
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16)
+
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
+
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
+
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8);
+
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT);
+
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
+
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
new file mode 100644
index 0000000..26ce84b
--- /dev/null
+++ b/reference_model/src/ops/tensor_ops.h
@@ -0,0 +1,253 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TENSOR_OPS_H
+#define OPS_TENSOR_OPS_H
+
+#include "graph_node.h"
+#include "quant_util.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpArgMax : public GraphNode
+{
+public:
+ OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpArgMax();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_INT32>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TOut>* output;
+};
+
+template <DType Dtype>
+class OpAvgPool2d : public GraphNode
+{
+public:
+ OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpAvgPool2d();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+ static constexpr int64_t QMin = GetQMin<Dtype>::value;
+ static constexpr int64_t QMax = GetQMax<Dtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ tosa::TosaPool2dAttribute* attribute;
+ tosa::TosaUnaryQuantInfo* qinfo;
+
+protected:
+ // return a 1D [N] tensor that describes a how many valid elements covered in the input space
+ ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride);
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpConv2d : public GraphNode
+{
+public:
+ OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConv2dAttribute* attribute;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpDepthwiseConv2d : public GraphNode
+{
+public:
+ OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpDepthwiseConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConv2dAttribute* attribute;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpFullyConnected : public GraphNode
+{
+public:
+ OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpFullyConnected();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 2>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 2>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 2>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType Dtype>
+class OpMatMul : public GraphNode
+{
+public:
+ OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpMatMul();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 2>;
+ using TAcc = Eigen::Tensor<AccEigenType, 2>;
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TIn>* b;
+ TosaReference::TensorTemplate<TAcc>* c;
+ tosa::TosaMatMulQuantInfo* qinfo;
+};
+
+template <DType Dtype>
+class OpMaxPool2d : public GraphNode
+{
+public:
+ OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpMaxPool2d();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ tosa::TosaPool2dAttribute* attribute;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpTransposeConv2d : public GraphNode
+{
+public:
+ OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTransposeConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ TosaTransposeConv2dAttribute* attribute;
+ TosaConvQuantInfo* qinfo;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
new file mode 100644
index 0000000..61a19f4
--- /dev/null
+++ b/reference_model/src/ops/type_conversion.cc
@@ -0,0 +1,299 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "type_conversion.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpRescale<Rank, InDtype, OutDtype>::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESCALE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+ INIT_ATTRIBUTE(Rescale);
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same rank and size
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("OpRescale: input and output rank/size must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpRescale<Rank, InDtype, OutDtype>::eval()
+{
+ int32_t input_zp = attribute->input_zp();
+ int32_t output_zp = attribute->output_zp();
+ std::vector<int32_t> multiplier = attribute->multiplier();
+ std::vector<int32_t> shift = attribute->shift();
+ //bool scale32 = attribute->scale32();
+ bool double_round = attribute->double_round();
+ bool per_channel = attribute->per_channel();
+
+ if (TosaReference::TypeChecker::is_symmetric(InDtype))
+ {
+ if (input_zp != 0)
+ {
+ FATAL_ERROR_NODE("input tensor is symmetric type %s but zeropoint is %d instead of 0",
+ EnumNamesDType()[InDtype], input_zp);
+ }
+ }
+
+ if (TosaReference::TypeChecker::is_symmetric(OutDtype))
+ {
+ if (output_zp != 0)
+ {
+ FATAL_ERROR_NODE("output tensor is symmetric type %s but zeropoint is %d instead of 0",
+ EnumNamesDType()[OutDtype], output_zp);
+ }
+ }
+
+ // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
+ Eigen::array<Eigen::Index, 2> shape_2d;
+ shape_2d[0] = 1;
+ if (Rank > 0)
+ {
+ for (int i = 0; i < Rank - 1; i++)
+ {
+ shape_2d[0] *= this->in->getShape()[i];
+ }
+ shape_2d[1] = this->in->getShape()[Rank - 1];
+ }
+ else
+ {
+ shape_2d[1] = 1;
+ }
+ ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
+
+ ETensor2<OutEigenType> output_2d(shape_2d);
+
+ // TODO: pass scale32 in when 16-bit mode implemented
+ if (per_channel)
+ {
+ ETensor2<InEigenType> curr_channel_slice_prescaled;
+ ETensor2<OutEigenType> curr_channel_slice_postscaled;
+ int32_t channel_multiplier, channel_shift;
+ Eigen::array<Eigen::Index, 2> begin, size;
+ size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
+ for (int32_t i = 0; i < shape_2d[1]; i++)
+ {
+ begin = Eigen::array<Eigen::Index, 2>({ 0, i });
+ curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
+ channel_multiplier = multiplier[i];
+ channel_shift = shift[i];
+ curr_channel_slice_postscaled =
+ curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
+ double_round](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale(
+ input_zp_shifted, channel_multiplier, channel_shift, double_round);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
+
+ for (int32_t j = 0; j < shape_2d[0]; j++)
+ {
+ output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
+ }
+ }
+ }
+ else
+ {
+ int32_t tensor_multiplier = multiplier[0];
+ int32_t tensor_shift = shift[0];
+ output_2d = input_reshaped.unaryExpr(
+ [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale(input_zp_shifted, tensor_multiplier,
+ tensor_shift, double_round);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
+ }
+
+ // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
+ Eigen::array<Eigen::Index, Rank> output_shape;
+ for (int i = 0; i < Rank; i++)
+ {
+ output_shape[i] = this->out->getShape()[i];
+ }
+ this->out->getTensor() = output_2d.reshape(output_shape);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpCast<Rank, InDtype, OutDtype>::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CAST, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpCast<Rank, InDtype, OutDtype>::~OpCast()
+{}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same rank and size
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("OpCast: input and output rank/size must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpCast<Rank, InDtype, OutDtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType OutDtype>
+CastHelper<InDtype, OutDtype>::CastHelper()
+{
+ fcn = [](InEigenType in) -> OutEigenType {
+ OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
+ int64_t mask = (1L << OutBits) - 1;
+ out = out & mask;
+ return out;
+ };
+}
+
+template <DType InDtype>
+CastHelper<InDtype, DType_BOOL>::CastHelper()
+{
+ fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
+}
+
+template <DType OutDtype>
+CastHelper<DType_BOOL, OutDtype>::CastHelper()
+{
+ fcn = [](bool in) -> OutEigenType {
+ OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
+ return out;
+ };
+}
+
+template <DType InDtype>
+CastHelper<InDtype, DType_FLOAT>::CastHelper()
+{
+ fcn = [](InEigenType in) -> float {
+ float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
+ return out;
+ };
+}
+
+template <DType OutDtype>
+CastHelper<DType_FLOAT, OutDtype>::CastHelper()
+{
+ fcn = [](float in) -> OutEigenType {
+ OutEigenType out = std::round(in);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
new file mode 100644
index 0000000..6ec4d6d
--- /dev/null
+++ b/reference_model/src/ops/type_conversion.h
@@ -0,0 +1,162 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TYPE_CONVERSION_H
+#define OPS_TYPE_CONVERSION_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+template <int Rank, DType InDtype, DType OutDtype>
+class OpRescale : public GraphNode
+{
+public:
+ OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpRescale();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+ static constexpr int32_t QMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t QMax = GetQMax<OutDtype>::value;
+
+protected:
+ TosaRescaleAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <DType InDtype, DType OutDtype>
+class CastHelper
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutBits = GetNumBits<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType InDtype>
+class CastHelper<InDtype, DType_BOOL>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType OutDtype>
+class CastHelper<DType_BOOL, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType InDtype>
+class CastHelper<InDtype, DType_FLOAT>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<DType_FLOAT>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType OutDtype>
+class CastHelper<DType_FLOAT, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_FLOAT>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <int Rank, DType InDtype, DType OutDtype>
+class OpCast : public GraphNode
+{
+public:
+ OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpCast();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ CastHelper<InDtype, OutDtype> cast_helper;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif