From e5e2676409a936431f87d31fb74d825257b20804 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 13 Oct 2020 16:11:07 -0700 Subject: Initial checkin of TOSA reference_model and tests Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze --- reference_model/src/ops/activation_funcs.cc | 118 +++ reference_model/src/ops/activation_funcs.h | 101 +++ reference_model/src/ops/comparison.cc | 81 ++ reference_model/src/ops/comparison.h | 71 ++ reference_model/src/ops/control_flow.cc | 353 ++++++++ reference_model/src/ops/control_flow.h | 72 ++ reference_model/src/ops/custom.cc | 40 + reference_model/src/ops/custom.h | 38 + reference_model/src/ops/data_layout.cc | 644 ++++++++++++++ reference_model/src/ops/data_layout.h | 216 +++++ reference_model/src/ops/data_nodes.cc | 172 ++++ reference_model/src/ops/data_nodes.h | 86 ++ reference_model/src/ops/ewise_binary.cc | 586 +++++++++++++ reference_model/src/ops/ewise_binary.h | 195 +++++ reference_model/src/ops/ewise_ternary.cc | 115 +++ reference_model/src/ops/ewise_ternary.h | 83 ++ reference_model/src/ops/ewise_unary.cc | 302 +++++++ reference_model/src/ops/ewise_unary.h | 102 +++ reference_model/src/ops/image.cc | 169 ++++ reference_model/src/ops/image.h | 53 ++ reference_model/src/ops/op_factory.cc | 432 ++++++++++ reference_model/src/ops/op_factory.h | 294 +++++++ reference_model/src/ops/reduction.cc | 139 +++ reference_model/src/ops/reduction.h | 109 +++ reference_model/src/ops/scatter_gather.cc | 120 +++ reference_model/src/ops/scatter_gather.h | 54 ++ reference_model/src/ops/template_types.h | 277 ++++++ reference_model/src/ops/tensor_ops.cc | 1229 +++++++++++++++++++++++++++ reference_model/src/ops/tensor_ops.h | 253 ++++++ reference_model/src/ops/type_conversion.cc | 299 +++++++ reference_model/src/ops/type_conversion.h | 162 ++++ 31 files changed, 6965 insertions(+) create mode 100644 reference_model/src/ops/activation_funcs.cc create mode 100644 reference_model/src/ops/activation_funcs.h create mode 100644 reference_model/src/ops/comparison.cc create mode 100644 reference_model/src/ops/comparison.h create mode 100644 reference_model/src/ops/control_flow.cc create mode 100644 reference_model/src/ops/control_flow.h create mode 100644 reference_model/src/ops/custom.cc create mode 100644 reference_model/src/ops/custom.h create mode 100644 reference_model/src/ops/data_layout.cc create mode 100644 reference_model/src/ops/data_layout.h create mode 100644 reference_model/src/ops/data_nodes.cc create mode 100644 reference_model/src/ops/data_nodes.h create mode 100644 reference_model/src/ops/ewise_binary.cc create mode 100644 reference_model/src/ops/ewise_binary.h create mode 100644 reference_model/src/ops/ewise_ternary.cc create mode 100644 reference_model/src/ops/ewise_ternary.h create mode 100644 reference_model/src/ops/ewise_unary.cc create mode 100644 reference_model/src/ops/ewise_unary.h create mode 100644 reference_model/src/ops/image.cc create mode 100644 reference_model/src/ops/image.h create mode 100644 reference_model/src/ops/op_factory.cc create mode 100644 reference_model/src/ops/op_factory.h create mode 100644 reference_model/src/ops/reduction.cc create mode 100644 reference_model/src/ops/reduction.h create mode 100644 reference_model/src/ops/scatter_gather.cc create mode 100644 reference_model/src/ops/scatter_gather.h create mode 100644 reference_model/src/ops/template_types.h create mode 100644 reference_model/src/ops/tensor_ops.cc create mode 100644 reference_model/src/ops/tensor_ops.h create mode 100644 reference_model/src/ops/type_conversion.cc create mode 100644 reference_model/src/ops/type_conversion.h (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc new file mode 100644 index 0000000..bca9507 --- /dev/null +++ b/reference_model/src/ops/activation_funcs.cc @@ -0,0 +1,118 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "activation_funcs.h" +#include "quant_util.h" +#include "template_types.h" +#include + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +int OpClamp::register_fcn() +{ + + switch (Dtype) + { + case DType_FLOAT: + { + InEigenType min = (InEigenType)attribute->min_fp(); + InEigenType max = (InEigenType)attribute->max_fp(); + this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; + } + break; + case DType_AINT8: + case DType_INT16: + { + InEigenType min = (InEigenType)attribute->min_int(); + InEigenType max = (InEigenType)attribute->max_int(); + this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; + } + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpReluN::register_fcn() +{ + + switch (Dtype) + { + case DType_FLOAT: + { + InEigenType N = (InEigenType)attribute->max_fp(); + this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; }; + } + break; + case DType_INT32: + { + InEigenType N = (InEigenType)attribute->max_int(); + this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; }; + } + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpSigmoid::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpTanh::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h new file mode 100644 index 0000000..b051b9d --- /dev/null +++ b/reference_model/src/ops/activation_funcs.h @@ -0,0 +1,101 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_ACTIVATION_FUNCS_H +#define OPS_ACTIVATION_FUNCS_H + +#include "ewise_unary.h" +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template +class OpClamp : public UnaryNode +{ +public: + OpClamp(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode(Op_CLAMP, id_) + { + INIT_ATTRIBUTE(Clamp); + register_fcn(); + } + static constexpr int32_t QMin = GetQMin::value; + static constexpr int32_t QMax = GetQMax::value; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); + +protected: + TosaClampAttribute* attribute; +}; + +template +class OpReluN : public UnaryNode +{ +public: + OpReluN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode(Op_RELUN, id_) + { + INIT_ATTRIBUTE(ReluN); + register_fcn(); + } + static constexpr int32_t QMin = GetQMin::value; + static constexpr int32_t QMax = GetQMax::value; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); + +protected: + TosaReluNAttribute* attribute; +}; + +template +class OpSigmoid : public UnaryNode +{ +public: + OpSigmoid(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode(Op_SIGMOID, id_) + { + register_fcn(); + } + static constexpr int32_t QMin = GetQMin::value; + static constexpr int32_t QMax = GetQMax::value; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); +}; + +template +class OpTanh : public UnaryNode +{ +public: + OpTanh(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode(Op_TANH, id_) + { + register_fcn(); + } + static constexpr int32_t QMin = GetQMin::value; + static constexpr int32_t QMax = GetQMax::value; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc new file mode 100644 index 0000000..402e152 --- /dev/null +++ b/reference_model/src/ops/comparison.cc @@ -0,0 +1,81 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "comparison.h" +#include "arith_util.h" +#include "quant_util.h" +#include "template_types.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +int OpEqual::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpGreater::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpGreaterEqual::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h new file mode 100644 index 0000000..e75b1a6 --- /dev/null +++ b/reference_model/src/ops/comparison.h @@ -0,0 +1,71 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_COMPARISON_H +#define OPS_COMPARISON_H + +#include "ewise_binary.h" +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template +class OpEqual : public BinaryNode +{ +public: + OpEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(Op_EQUAL, qinfo_, id_) + { + register_fcn(); + } + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); +}; + +template +class OpGreater : public BinaryNode +{ +public: + OpGreater(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(Op_GREATER, qinfo_, id_) + { + register_fcn(); + } + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); +}; + +template +class OpGreaterEqual : public BinaryNode +{ +public: + OpGreaterEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(Op_EQUAL, qinfo_, id_) + { + register_fcn(); + } + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc new file mode 100644 index 0000000..9d5db40 --- /dev/null +++ b/reference_model/src/ops/control_flow.cc @@ -0,0 +1,353 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "control_flow.h" +#include "subgraph_traverser.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +OpControlFlow::OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_) + : GraphNode(op_, id_) +{ + tsh = tsh_; +} + +OpControlFlow::~OpControlFlow() +{} + +int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block, + std::vector& block_inputs, + std::vector& block_outputs) +{ + std::string block_name = block->GetName(); + + DEBUG_MED(OP, "Evaluating block %s", block_name.c_str()); + + SubgraphTraverser gt(block, tsh); + + if (gt.initializeGraph()) + { + FATAL_ERROR("Unable to initialize graph traverser for block %s", block_name.c_str()); + } + + if (gt.linkTensorsAndNodes()) + { + FATAL_ERROR("Failed to link tensors and nodes for block %s", block_name.c_str()); + } + + if (gt.validateGraph()) + { + FATAL_ERROR("Failed to validate subgraph for block %s", block_name.c_str()); + } + + int num_input_tensors = gt.getNumInputTensors(); + int num_output_tensors = gt.getNumOutputTensors(); + + for (size_t i = 0; i < block_inputs.size(); i++) + { + DEBUG_HIGH(OP, "Input[%ld]: %s", i, block_inputs[i]->getName().c_str()); + } + for (size_t i = 0; i < block_outputs.size(); i++) + { + DEBUG_HIGH(OP, "Output[%ld]: %s", i, block_outputs[i]->getName().c_str()); + } + + ASSERT_MSG((size_t)num_input_tensors == block_inputs.size(), + "op block %s inputs[%lu] does not match with graph traverser's inputs[%d]", block_name.c_str(), + block_inputs.size(), num_input_tensors); + ASSERT_MSG((size_t)num_output_tensors == block_outputs.size(), + "op block %s outputs[%lu] does not match with graph traverser's outputs[%d]", block_name.c_str(), + block_outputs.size(), num_output_tensors); + + // set graph traverser's input = basic block's input + for (int i = 0; i < num_input_tensors; i++) + { + TosaReference::Tensor* tensor = gt.getInputTensor(i); + ASSERT_MSG(!tensor->is_allocated(), "block %s input tensors are unexpectedly initialized before", + block_name.c_str()); + + if (tensor->allocate()) + { + WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); + return 1; + } + + if (tensor->copyValueFrom(block_inputs[i])) + { + WARNING("Fail to copy tensor value %s -> %s", block_inputs[i]->getName().c_str(), + tensor->getName().c_str()); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + gt.addToNextNodeList(gn); + } + } + } + + if (gt.evaluateAll()) + { + FATAL_ERROR("Error evaluating network. Giving up."); + } + + // make sure output tensor is evaluated and show its value + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) + { + const TosaReference::Tensor* ct = gt.getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) + { + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; + } + } + if (!all_output_valid) + { + gt.dumpGraph(g_func_debug.func_debug_file); + FATAL_ERROR("SubgraphTraverser \"%s\" error: Output tensors are not all valid at the end of evaluation.", + block_name.c_str()); + } + + // set basic block's output = subgraph_traverser's output + for (int i = 0; i < num_output_tensors; i++) + { + TosaReference::Tensor* tensor = gt.getOutputTensor(i); + ASSERT_MSG(tensor->is_allocated(), "tensor %s is not allocated", tensor->getName().c_str()); + + if (block_outputs[i]->copyValueFrom(tensor)) + { + WARNING("Fail to copy tensor value %s -> %s", tensor->getName().c_str(), outputs[i]->getName().c_str()); + return 1; + } + } + return 0; +} + +OpCondIf::OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) + : OpControlFlow(tsh_, Op_COND_IF, id_) +{ + INIT_ATTRIBUTE(CondIf); +} + +OpCondIf::~OpCondIf() +{ + if (attribute) + delete attribute; +} + +int OpCondIf::checkTensorAttributes() +{ + if (getInputs().size() < 1) + { + WARNING("OpCondIf: must have at least 1 operand"); + return 1; + } + + if (inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0) + { + WARNING("OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()], + inputs[0]->getRank()); + return 1; + } + + cond = dynamic_cast*>(inputs[0]); + ASSERT_MEM(cond); + + then_block = tsh->GetBlockByName(attribute->then_branch()); + else_block = tsh->GetBlockByName(attribute->else_branch()); + + if (!then_block) + { + WARNING("OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); + return 1; + } + + if (!else_block) + { + WARNING("OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str()); + return 1; + } + + return 0; +} + +int OpCondIf::eval() +{ + bool cond_val = cond->getTensor()(0); + std::vector block_inputs(getInputs().begin() + 1, getInputs().end()); + + if (cond_val) + { + if (evalBlock(then_block, block_inputs, getOutputs())) + { + WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str()); + return 1; + } + } + else + { + if (evalBlock(else_block, block_inputs, getOutputs())) + { + WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str()); + return 1; + } + } + + return GraphNode::eval(); +} + +OpWhileLoop::OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) + : OpControlFlow(tsh_, Op_WHILE_LOOP, id_) +{ + INIT_ATTRIBUTE(WhileLoop); +} + +OpWhileLoop::~OpWhileLoop() +{ + if (attribute) + delete attribute; +} + +int OpWhileLoop::checkTensorAttributes() +{ + if (getInputs().size() <= 0) + { + WARNING("OpWhileLoop: must have at least 1 operands"); + return 1; + } + + if (getInputs().size() != getOutputs().size()) + { + WARNING("OpWhileLoop: inputs and outputs size must match"); + return 1; + } + + cond_block = tsh->GetBlockByName(attribute->cond_branch()); + body_block = tsh->GetBlockByName(attribute->body_branch()); + + if (!cond_block) + { + WARNING("OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); + return 1; + } + + if (!body_block) + { + WARNING("OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); + return 1; + } + + if (cond_block->GetOutputs().size() != 1) + { + WARNING("OpWhileLoop: invalid cond_block output size %lu", cond_block->GetOutputs().size()); + return 1; + } + + TosaSerializationTensor* cond_output_tensor = cond_block->GetTensorByName(cond_block->GetOutputs()[0]); + + if (!cond_output_tensor) + { + WARNING("OpWhileLoop: fail to resolve cond_block's output tensor %s", cond_block->GetOutputs()[0].c_str()); + return 1; + } + + if (cond_output_tensor->GetDtype() != DType_BOOL) + { + WARNING("OpWhileLoop: invalid cond_block's output tensor data type %s", + EnumNamesDType()[cond_output_tensor->GetDtype()]); + return 1; + } + if (cond_output_tensor->GetShape().size() != 0) + { + WARNING("OpWhileLoop: invalid cond_block's output rank %lu", cond_output_tensor->GetShape().size()); + return 1; + } + + return 0; +} + +int OpWhileLoop::eval() +{ + + TosaReference::Tensor0 cond_output_ctensor( + std::string("cond_output"), DType_BOOL, std::vector({ Usage_ACTIVATION }), + std::vector({ Format_UNKNOWN }), std::vector({}), false); + + cond_output_ctensor.allocate(); + std::vector cond_block_outputs; + cond_block_outputs.push_back(&cond_output_ctensor); + + size_t num_input_output = getInputs().size(); + size_t eval_count = 0; + + while (eval_count++ < MAX_WHILE_LOOP_ITERATION) + { + if (evalBlock(cond_block, getInputs(), cond_block_outputs)) + { + WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str()); + return 1; + } + bool cond_val = cond_output_ctensor.getTensor()(0); + DEBUG_HIGH(OP, "Conditional block value: %d", cond_val); + + if (cond_val) + { + if (evalBlock(body_block, getInputs(), getOutputs())) + { + WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str()); + return 1; + } + + // assigning output tensors value back to input tensors value for next iteration + for (size_t i = 0; i < num_input_output; i++) + { + if (getInputs()[i]->copyValueFrom(getOutputs()[i])) + { + WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(), + getInputs()[i]->getName().c_str()); + return 1; + } + } + } + else + { + // in last iteration or the case it never evaluates body block + // assign input tensors value to output tensors + for (size_t i = 0; i < num_input_output; i++) + { + if (getOutputs()[i]->copyValueFrom(getInputs()[i])) + { + WARNING("Fail to copy tensor value %s -> %s", getInputs()[i]->getName().c_str(), + getOutputs()[i]->getName().c_str()); + return 1; + } + } + break; + } + } + + return GraphNode::eval(); +} diff --git a/reference_model/src/ops/control_flow.h b/reference_model/src/ops/control_flow.h new file mode 100644 index 0000000..14c11bc --- /dev/null +++ b/reference_model/src/ops/control_flow.h @@ -0,0 +1,72 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_CONTROL_FLOW_H +#define OPS_CONTROL_FLOW_H + +#include "graph_node.h" + +#define MAX_WHILE_LOOP_ITERATION 10000 + +namespace TosaReference +{ +class OpControlFlow : public GraphNode +{ +public: + OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_); + ~OpControlFlow(); + + virtual int evalBlock(TosaSerializationBasicBlock* block, + std::vector& block_inputs, + std::vector& block_outputs); + +protected: + TosaSerializationHandler* tsh; +}; + +class OpCondIf : public OpControlFlow +{ +public: + OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpCondIf(); + + virtual int checkTensorAttributes(); + virtual int eval(); + +protected: + TosaCondIfAttribute* attribute; + TosaReference::Tensor0* cond; + TosaSerializationBasicBlock* then_block; + TosaSerializationBasicBlock* else_block; +}; + +class OpWhileLoop : public OpControlFlow +{ +public: + OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpWhileLoop(); + + virtual int checkTensorAttributes(); + virtual int eval(); + +protected: + TosaWhileLoopAttribute* attribute; + TosaSerializationBasicBlock* cond_block; + TosaSerializationBasicBlock* body_block; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc new file mode 100644 index 0000000..5c4f29b --- /dev/null +++ b/reference_model/src/ops/custom.cc @@ -0,0 +1,40 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "custom.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +OpCustom::OpCustom(uint64_t id_) + : GraphNode(Op_CUSTOM, id_) +{} + +OpCustom::~OpCustom() +{} + +int OpCustom::checkTensorAttributes() +{ + return 0; +} + +int OpCustom::eval() +{ + FATAL_ERROR_NODE("not supported yet"); + + // Evaluation is trivial for constants + return GraphNode::eval(); +} diff --git a/reference_model/src/ops/custom.h b/reference_model/src/ops/custom.h new file mode 100644 index 0000000..b1085a5 --- /dev/null +++ b/reference_model/src/ops/custom.h @@ -0,0 +1,38 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_CUSTOM_H +#define OPS_CUSTOM_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +class OpCustom : public GraphNode +{ +public: + OpCustom(uint64_t id_); + virtual ~OpCustom(); + + virtual int checkTensorAttributes(); + virtual int eval(); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc new file mode 100644 index 0000000..32029b9 --- /dev/null +++ b/reference_model/src/ops/data_layout.cc @@ -0,0 +1,644 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "data_layout.h" +#include "quant_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +OpConcat::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_CONCAT, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(1, 6); + + INIT_ATTRIBUTE(Axis); +} + +template +OpConcat::~OpConcat() +{ + if (attribute) + delete attribute; +} + +template +int OpConcat::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types and rank + // inputs[0] and inputs[1] should also match type and rank + if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0])) + { + printNodeValidationError("Concat operator input ranks and types must match"); + return 1; + } + + lhs = dynamic_cast*>(inputs[0]); + rhs = dynamic_cast*>(inputs[1]); + out = dynamic_cast*>(outputs[0]); + + if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size()) + { + printNodeValidationError("Axis is beyond input tensor rank"); + return 1; + } + + return 0; +} + +template +int OpConcat::eval() +{ + + int32_t reversed_axis = Rank - 1 - attribute->axis(); + + for (int32_t d = 0; d < Rank; d++) + { + reverser[d] = Rank - 1 - d; + } + + TIn lhs_reversed = lhs->getTensor().shuffle(reverser); + TIn rhs_reversed = rhs->getTensor().shuffle(reverser); + + TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis); + out->getTensor() = reversed_result.shuffle(reverser); + // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis); + + return GraphNode::eval(); +} + +template +OpPad::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_PAD, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(0, 6); + + INIT_QINFO(Pad); +} + +template +OpPad::~OpPad() +{ + if (qinfo) + delete qinfo; +} + +template +int OpPad::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type and rank"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + TosaReference::TensorTemplate>* paddings = + dynamic_cast>*>(inputs[1]); + + for (int i = 0; i < Rank; i++) + { + paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1)); + } + + return 0; +} + +template +int OpPad::eval() +{ + InEigenType pad_value = 0; + if (this->qinfo) + { + pad_value = (InEigenType)this->qinfo->input_zp(); + } + + this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value); + + return GraphNode::eval(); +} + +template +OpReshape::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_RESHAPE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Reshape); +} + +template +OpReshape::~OpReshape() +{ + if (attribute) + delete attribute; +} + +template +int OpReshape::checkTensorAttributes() +{ + uint32_t minusOneCount = 0; + + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("OpReshape: Input and output types must match"); + return 1; + } + + for (uint32_t d = 0; d < OutRank; d++) + { + if (attribute->shape()[d] == -1) + { + minusOneCount++; + } + } + + if (minusOneCount > 1) + { + printNodeValidationError("OpReshape: new shape has more than one -1 dimension"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + return 0; +} + +template +int OpReshape::eval() +{ + uint32_t remainingSize = in->getElementCount(); + + // If there is a -1 dimension, find the remainder in one pass over the output shape + for (int32_t d = 0; d < OutRank; d++) + { + if (attribute->shape()[d] != -1) + { + remainingSize = remainingSize / attribute->shape()[d]; + } + } + + for (int32_t d = 0; d < OutRank; d++) + { + array_shape[d] = attribute->shape()[OutRank - 1 - d]; + out_reverser[d] = OutRank - 1 - d; + + // Jam in the remainder here + if (array_shape[d] == -1) + { + array_shape[d] = remainingSize; + } + } + + for (int32_t d = 0; d < InRank; d++) + { + in_reverser[d] = InRank - 1 - d; + } + + // Eigen Tensor is col-major, and we're referencing row-major result + // need to reverse it to row-major before reshape, and perform another reverse afterward + + // input tensor rank 0 can't do .shuffle(), need to be handled otherwise + TIn in_reversed; + if (InRank > 1) + { + in_reversed = in->getTensor().shuffle(in_reverser); + } + else + { + in_reversed = in->getTensor(); + } + + TOut in_reshaped = in_reversed.reshape(array_shape); + + // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise + if (OutRank > 1) + { + out->getTensor() = in_reshaped.shuffle(out_reverser); + } + else + { + out->getTensor() = in_reshaped; + } + + return GraphNode::eval(); +} + +template +OpReverse::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_REVERSE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(1, 6); + + INIT_ATTRIBUTE(Axis); +} + +template +OpReverse::~OpReverse() +{ + if (attribute) + delete attribute; +} + +template +int OpReverse::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankTypeShape(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank/type/shape"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(in && out); + + if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank()) + { + printNodeValidationError("Reverse axis must between [0, input_rank - 1]"); + return 1; + } + + // transform list of axis into true or false list + // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false] + for (int i = 0; i < Rank; i++) + { + reverse_array[i] = false; + } + reverse_array[attribute->axis()] = true; + + return 0; +} + +template +int OpReverse::eval() +{ + out->getTensor() = in->getTensor().reverse(reverse_array); + + return GraphNode::eval(); +} + +template +OpSlice::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_SLICE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Slice); +} + +template +OpSlice::~OpSlice() +{ + if (attribute) + delete attribute; +} + +template +int OpSlice::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + for (size_t i = 0; i < attribute->begin().size(); i++) + { + begin_array[i] = attribute->begin()[i]; + } + + for (size_t i = 0; i < attribute->size().size(); i++) + { + if (attribute->size()[i] != 0) + { + size_array[i] = attribute->size()[i]; + } + else + { + // Tensorflow assigns a zero size to dimensions that are kept + // Eigen expects size to be the full size of the dimension + size_array[i] = in->getTensor().dimension(0); + } + } + + return 0; +} + +template +int OpSlice::eval() +{ + out->getTensor() = in->getTensor().slice(begin_array, size_array); + + return GraphNode::eval(); +} + +template +OpTileBase::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_TILE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Tile); +} + +template +OpTileBase::~OpTileBase() +{ + if (attribute) + delete attribute; +} + +template +int OpTileBase::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same ranks and types + if (inputs[0]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank or type"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + if (attribute->multiples().size() != Rank) + { + printNodeValidationError("1D list 'multiples' must have size equal to input rank"); + return 1; + } + + for (int32_t d = 0; d < Rank; d++) + { + if (in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d]) + { + printNodeValidationError("unexpected output shape"); + return 1; + } + } + + return 0; +} + +template +int OpTile::eval() +{ + // primary template shouldn't be called + FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); +} + +template +int OpTile<1, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + this->out->getTensor()(od0) = this->in->getTensor()(id0); + } + + return GraphNode::eval(); +} + +template +int OpTile<2, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1); + } + } + + return GraphNode::eval(); +} + +template +int OpTile<3, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++) + { + int32_t id2 = od2 % this->in->getShape()[2]; + this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2); + } + } + } + + return GraphNode::eval(); +} + +template +int OpTile<4, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++) + { + int32_t id2 = od2 % this->in->getShape()[2]; + for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++) + { + int32_t id3 = od3 % this->in->getShape()[3]; + this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3); + } + } + } + } + + return GraphNode::eval(); +} + +template +OpTranspose::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_TRANSPOSE, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(0, 6); +} + +template +OpTranspose::~OpTranspose() +{} + +template +int OpTranspose::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank and type"); + return 1; + } + + if (inputs[0]->getElementCount() != outputs[0]->getElementCount()) + { + printNodeValidationError("Failure to match input and output total element count"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + perm_tensor = dynamic_cast>*>(inputs[1]); + + return 0; +} + +template +int OpTranspose::eval() +{ + for (int32_t d = 0; d < Rank; d++) + { + perm_array[d] = this->perm_tensor->getTensor().data()[d]; + } + + out->getTensor() = in->getTensor().shuffle(perm_array); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); + +DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT); +DEF_INSTANTIATE_RESHAPE(OpReshape, AINT8); +DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); +DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); +DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); +DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h new file mode 100644 index 0000000..100bd6b --- /dev/null +++ b/reference_model/src/ops/data_layout.h @@ -0,0 +1,216 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_DATA_LAYOUT_H +#define OPS_DATA_LAYOUT_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template +class OpConcat : public GraphNode +{ +public: + OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpConcat(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array reverser; + TosaReference::TensorTemplate* lhs; + TosaReference::TensorTemplate* rhs; + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate* out; +}; + +template +class OpPad : public GraphNode +{ +public: + OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpPad(); + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array, Rank> paddings_array; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; + TosaPadQuantInfo* qinfo; +}; + +template +class OpReshape : public GraphNode +{ +public: + OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpReshape(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array array_shape; + Eigen::array in_reverser; + Eigen::array out_reverser; + TosaReference::TensorTemplate* in; + TosaReshapeAttribute* attribute; + TosaReference::TensorTemplate* out; +}; + +template +class OpReverse : public GraphNode +{ +public: + OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpReverse(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; + Eigen::array reverse_array; +}; + +template +class OpSlice : public GraphNode +{ +public: + OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpSlice(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaSliceAttribute* attribute; + Eigen::array begin_array; + Eigen::array size_array; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; +}; + +template +class OpTileBase : public GraphNode +{ +public: + OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTileBase(); + + virtual int checkTensorAttributes(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaTileAttribute* attribute; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; +}; + +// primary template for op tile +template +class OpTile : public OpTileBase +{ +public: + OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpTileBase(attribute_, qinfo_, id_) + {} + +protected: + virtual int eval(); +}; + +// partial specialization for specific rank +#define DEF_OP_TILE_RANK(N) \ + template \ + class OpTile : public OpTileBase \ + { \ + public: \ + OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : OpTileBase(attribute_, qinfo_, id_) \ + {} \ + \ + protected: \ + virtual int eval(); \ + }; + +DEF_OP_TILE_RANK(1) +DEF_OP_TILE_RANK(2) +DEF_OP_TILE_RANK(3) +DEF_OP_TILE_RANK(4) + +#undef DEF_OP_TILE_RANK + +template +class OpTranspose : public GraphNode +{ +public: + OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTranspose(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array perm_array; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate>* perm_tensor; + TosaReference::TensorTemplate* out; +}; +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc new file mode 100644 index 0000000..2ee4935 --- /dev/null +++ b/reference_model/src/ops/data_nodes.cc @@ -0,0 +1,172 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "data_nodes.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +OpConst::OpConst(uint64_t id_) + : GraphNode(Op_CONST, id_) +{ + setRequiredOperands(0, 1); +} + +OpConst::~OpConst() +{} + +int OpConst::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + return 0; +} + +int OpConst::eval() +{ + // Evaluation is trivial for constants + return GraphNode::eval(); +} + +OpPlaceholder::OpPlaceholder(uint64_t id_) + : GraphNode(Op_PLACEHOLDER, id_) +{ + setRequiredOperands(0, 1); +} + +OpPlaceholder::~OpPlaceholder() +{} + +int OpPlaceholder::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + return 0; +} + +int OpPlaceholder::eval() +{ + // Evaluation is trivial for placeholders + return GraphNode::eval(); +} + +template +OpIdentity::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_IDENTITY, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); +} + +template +OpIdentity::~OpIdentity() +{} + +template +int OpIdentity::checkTensorAttributes() +{ + + if (inputs.size() != outputs.size()) + { + printNodeValidationError("Input and output tensor list lengths are not equal"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + if (in->matchRankTypeShape(*out)) + { + printNodeValidationError("Input and output tensor rank, type, or shape do not match"); + return 1; + } + + return 0; +} + +template +int OpIdentity::eval() +{ + out->getTensor() = in->getTensor(); + + return GraphNode::eval(); +} + +template +OpIdentityN::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_IDENTITYN, id_) +{ + setRequiredRank(0, 6); +} + +template +OpIdentityN::~OpIdentityN() +{} + +template +int OpIdentityN::checkTensorAttributes() +{ + + if (inputs.size() != outputs.size()) + { + printNodeValidationError("Input and output tensor list lengths are not equal"); + return 1; + } + + for (size_t i = 0; i < inputs.size(); i++) + { + ins.push_back(dynamic_cast*>(inputs[i])); + outs.push_back(dynamic_cast*>(outputs[i])); + + if (ins[i]->matchRankTypeShape(*outs[i])) + { + printNodeValidationError("Input and output tensor rank, type, or shape do not match"); + return 1; + } + } + + return 0; +} + +template +int OpIdentityN::eval() +{ + for (size_t i = 0; i < ins.size(); i++) + { + outs[i]->getTensor() = ins[i]->getTensor(); + } + + return GraphNode::eval(); +} + +// template explicit instantiation +// note OpConst and OpPlaceholder are not templated + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL); diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h new file mode 100644 index 0000000..bec4669 --- /dev/null +++ b/reference_model/src/ops/data_nodes.h @@ -0,0 +1,86 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_DATA_NODES_H +#define OPS_DATA_NODES_H + +#include "graph_node.h" + +namespace TosaReference +{ + +class OpConst : public GraphNode +{ +public: + OpConst(uint64_t id_); + virtual ~OpConst(); + + virtual int checkTensorAttributes(); + virtual int eval(); +}; + +class OpPlaceholder : public GraphNode +{ +public: + OpPlaceholder(uint64_t id_); + virtual ~OpPlaceholder(); + + virtual int checkTensorAttributes(); + virtual int eval(); +}; + +template +class OpIdentity : public GraphNode +{ +public: + OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpIdentity(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; +}; + +template +class OpIdentityN : public GraphNode +{ +public: + OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpIdentityN(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + std::vector*> ins; + std::vector*> outs; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc new file mode 100644 index 0000000..4d4f8b9 --- /dev/null +++ b/reference_model/src/ops/ewise_binary.cc @@ -0,0 +1,586 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ewise_binary.h" +#include "arith_util.h" +#include "quant_util.h" +#include "template_types.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +BinaryNodeBase::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(op_, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(0, 6); + + a_rank = b_rank = max_input_rank = -1; + a = b = nullptr; + a_rank0 = b_rank0 = nullptr; + result = nullptr; + + fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; +} + +template +BinaryNodeBase::~BinaryNodeBase() +{} + +template +int BinaryNodeBase::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + a_rank = inputs[0]->getRank(); + b_rank = inputs[1]->getRank(); + if (a_rank != 0 && b_rank != 0 && a_rank != b_rank) + { + printNodeValidationError("Binary operator input ranks must match"); + return 1; + } + + max_input_rank = a_rank >= b_rank ? a_rank : b_rank; + + // A & B must be the same types + if (inputs[0]->matchType(*inputs[1])) + { + printNodeValidationError("Binary operator input types must match"); + return 1; + } + + // Result's geometry must match, but the type may be wider + if (outputs[0]->getRank() != max_input_rank) + { + printNodeValidationError("Binary operator input and output genometry must match"); + return 1; + } + + if (a_rank == max_input_rank) + { + a = dynamic_cast*>(inputs[0]); + } + else + { + a_rank0 = dynamic_cast>*>(inputs[0]); + } + + if (b_rank == max_input_rank) + { + b = dynamic_cast*>(inputs[1]); + } + else + { + b_rank0 = dynamic_cast>*>(inputs[1]); + } + + result = dynamic_cast*>(outputs[0]); + + // either a or b can be rank0 + // a_rank0 and b_rank0 can't be valid at the same time. + // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0' + ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result); + + return 0; +} + +template +int BinaryNodeBase::broadcast() +{ + auto output_shape = result->getTensor().dimensions(); + + std::vector a_shape, b_shape; + + if (a_rank == max_input_rank) + { + a_shape = a->getShape(); + } + else + { + a_shape.assign(max_input_rank, 1); + } + + if (b_rank == max_input_rank) + { + b_shape = b->getShape(); + } + else + { + b_shape.assign(max_input_rank, 1); + } + + for (int i = 0; i < max_input_rank; i++) + { + if (a_shape[i] != output_shape[i] && a_shape[i] == 1) + { + bcast_a[i] = output_shape[i]; + } + else + { + bcast_a[i] = 1; + } + if (b_shape[i] != output_shape[i] && b_shape[i] == 1) + { + bcast_b[i] = output_shape[i]; + } + else + { + bcast_b[i] = 1; + } + } + + return 0; +} + +template +int BinaryNode::eval() +{ + this->broadcast(); + + Eigen::array reshaper; + reshaper.fill(1); + TIn ia, ib; + + if (this->a_rank == this->max_input_rank) + { + ia = this->a->getTensor().broadcast(this->bcast_a); + } + else + { + ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a); + } + + if (this->b_rank == this->max_input_rank) + { + ib = this->b->getTensor().broadcast(this->bcast_b); + } + else + { + ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b); + } + + this->result->getTensor() = ia.binaryExpr(ib, this->fcn); + + return GraphNode::eval(); +} + +// still need to partial specialize this, or Eigen will throw static assertion +template +int BinaryNode<0, InDtype, OutDtype>::eval() +{ + this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn); + + return GraphNode::eval(); +} + +template +int OpAdd::register_fcn() +{ + switch (InDtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } + + return 0; +} + +template +int OpArithmeticRightShift::register_fcn() +{ + int32_t num_bits = 0; + switch (Dtype) + { + case DType_INT8: + num_bits = 8; + break; + case DType_INT16: + num_bits = 16; + break; + case DType_INT32: + num_bits = 32; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { + uint32_t sign = a & (1 << (num_bits - 1)); + uint32_t ones_mask = ONES_MASK(b) << (num_bits - b); + if (sign) + return ones_mask | (a >> b); + else + return (~ones_mask) & (a >> b); + }; + + return 0; +} + +template +int OpBitwiseAnd::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpBitwiseOr::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpBitwiseXor::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpLogicalAnd::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpLogicalLeftShift::register_fcn() +{ + switch (Dtype) + { + case DType_INT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpLogicalRightShift::register_fcn() +{ + int32_t num_bits = 0; + switch (Dtype) + { + case DType_INT8: + num_bits = 8; + break; + case DType_INT16: + num_bits = 16; + break; + case DType_INT32: + num_bits = 32; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { + uint32_t mask = ONES_MASK(num_bits) >> b; + return (a >> b) & mask; + }; + + return 0; +} + +template +int OpLogicalOr::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpLogicalXor::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpMaximum::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpMinimum::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpMul::register_fcn() +{ + switch (InDtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + break; + case DType_INT8: + case DType_INT16: + this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType { + OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs; + + OutEigenType clamped_output = std::min(QMax, std::max(raw_output, QMin)); + + return clamped_output; + }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } + + return 0; +} + +template +int OpPow::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpSub::register_fcn() +{ + switch (InDtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } + + return 0; +} + +template +OpTable::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_TABLE, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(0, 6); +} + +template +OpTable::~OpTable() +{} + +template +int OpTable::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) + { + FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + table = dynamic_cast*>(inputs[1]); + out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(in && table && out); + + return 0; +} + +template +int OpTable::eval() +{ + this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { + // 1. make sure input is int16 range + int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); + + // 2. calculate index and interpolation fraction + int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1)); + index = std::min(std::max(index, 0), NumTableEntries - 1); // 9-bit index + int32_t frac = (input_truncated)&0x7F; // 7-bit fraction + + // 3. interpolate, generate 16.7 (23-bit) output + int32_t base = this->table->getTensor()(index); + int32_t next = this->table->getTensor()(index + 1); + int32_t value = (base << 7) + (next - base) * frac; + + return value; + }); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); + +DEF_INSTANTIATE_ONE_RANK_0_6(OpTable); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h new file mode 100644 index 0000000..00fb3d9 --- /dev/null +++ b/reference_model/src/ops/ewise_binary.h @@ -0,0 +1,195 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_EWISE_BINARY_H +#define OPS_EWISE_BINARY_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +// class BinaryNodeBase: hold common functions of all the binary nodes +// when an binary op is created, the virtual OpXXX::register_fcn() will be called +// and 'fcn' will be register with lambda function which has two inputs +// class BinaryNode: the level of indirection to partially specialize template for rank 0 +// eval() from toplevel called should call the .binaryExpr(dims, fcn) here +// this needs to be partially specialize or +// compiler will statically fail when trying to broadcast rank0 tensor +// class OpXXX: implement per-element lambda function based on different data type +// unlike BinaryNode, this doesn't need to be partially specialized + +// Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.) +// which might be faster since it could be implemented with SIMD instructions +// the way of registering lambda + .binaryExpr() might sacrifice performance here +// but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...} +// needs to revisit if performance becomes a bottleneck here +template +class BinaryNodeBase : public GraphNode +{ +public: + BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); + virtual ~BinaryNodeBase(); + + virtual int checkTensorAttributes() final; + virtual int eval() = 0; + virtual int register_fcn() = 0; + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + int broadcast(); + +protected: + std::function fcn; + Eigen::array bcast_a; + Eigen::array bcast_b; + TosaReference::TensorTemplate* a; + TosaReference::TensorTemplate* b; + TosaReference::TensorTemplate>* a_rank0; + TosaReference::TensorTemplate>* b_rank0; + TosaReference::TensorTemplate* result; + int a_rank; + int b_rank; + int max_input_rank; +}; + +// primary class +template +class BinaryNode : public BinaryNodeBase +{ +public: + BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) + : BinaryNodeBase(op_, qinfo_, id_) + {} + virtual ~BinaryNode() + {} + + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; +}; + +// partial specialization for rank 0 +template +class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> +{ +public: + BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) + : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_) + {} + virtual ~BinaryNode() + {} + + virtual int eval(); +}; + +#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \ + template \ + class Op##Opname : public BinaryNode \ + { \ + public: \ + Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : BinaryNode(Op_##OPNAME, qinfo_, id_) \ + { \ + register_fcn(); \ + } \ + static constexpr DType InDtype = Dtype; \ + static constexpr DType OutDtype = Dtype; \ + using InEigenType = typename GetEigenType::type; \ + using OutEigenType = typename GetEigenType::type; \ + virtual int register_fcn(); \ + }; + +#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \ + template \ + class Op##Opname : public BinaryNode \ + { \ + public: \ + Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : BinaryNode(Op_##OPNAME, qinfo_, id_) \ + { \ + register_fcn(); \ + } \ + static constexpr int32_t QMin = GetQMin::value; \ + static constexpr int32_t QMax = GetQMax::value; \ + using InEigenType = typename GetEigenType::type; \ + using OutEigenType = typename GetEigenType::type; \ + virtual int register_fcn(); \ + }; + +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM) +DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB) + +#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE +#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE + +template +class OpTable : public GraphNode +{ +public: + OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTable(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + static constexpr DType InDtype = DType_INT16; + static constexpr DType TableDtype = DType_INT16; + static constexpr DType OutDtype = DType_INT32; + using InEigenType = typename GetEigenType::type; + using TableEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TTable = Eigen::Tensor; + using TOut = Eigen::Tensor; + static constexpr int32_t IntegerBits = 9; + static constexpr int32_t FractionBits = 7; + static constexpr int32_t NumTableEntries = (1 << IntegerBits); + static constexpr int32_t QInMin = GetQMin::value; + static constexpr int32_t QInMax = GetQMax::value; + static constexpr int32_t QOutMin = GetQMin::value; + static constexpr int32_t QOutMax = GetQMax::value; + +protected: + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* table; + TosaReference::TensorTemplate* out; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc new file mode 100644 index 0000000..eded0d7 --- /dev/null +++ b/reference_model/src/ops/ewise_ternary.cc @@ -0,0 +1,115 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ewise_ternary.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +OpSelectBase::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_SELECT, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(0, 6); +} + +template +OpSelectBase::~OpSelectBase() +{} + +template +int OpSelectBase::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) || + validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) || + inputs[2]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank and type"); + return 1; + } + + cond = dynamic_cast*>(inputs[0]); + then_val = dynamic_cast*>(inputs[1]); + else_val = dynamic_cast*>(inputs[2]); + out = dynamic_cast*>(outputs[0]); + + return 0; +} + +template +int OpSelectBase::eval() +{ + FATAL_ERROR_NODE("shouldn't be called"); +} + +template +int OpSelect::broadcast() +{ + std::vector cond_shape = this->cond->getShape(); + std::vector then_shape = this->then_val->getShape(); + std::vector else_shape = this->else_val->getShape(); + std::vector out_shape = this->out->getShape(); + + for (int i = 0; i < Rank; i++) + { + this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1; + this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1; + this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1; + ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + } + + return 0; +} + +template +int OpSelect::eval() +{ + this->broadcast(); + this->out->getTensor() = this->cond->getTensor() + .broadcast(this->bcast_cond) + .select(this->then_val->getTensor().broadcast(this->bcast_then), + this->else_val->getTensor().broadcast(this->bcast_else)); + + return GraphNode::eval(); +} + +template +int OpSelect<0, Dtype>::eval() +{ + this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor()); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h new file mode 100644 index 0000000..b354247 --- /dev/null +++ b/reference_model/src/ops/ewise_ternary.h @@ -0,0 +1,83 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_TERNARY_H +#define OPS_TERNARY_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +// The Ternary Select op has the following operands: +// 1. Cond: rank N, type=bool +// 2. Then_val: Rank N, type= +// 3. Else_val: Rank N, type= +// 4. Result: Rank N, type= +// Cond, Then_val, Else_val need to be mutually-broadcastable +template +class OpSelectBase : public GraphNode +{ +public: + OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpSelectBase(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using CondEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; + using TCond = Eigen::Tensor; + using TIn = Eigen::Tensor; + +protected: + TosaReference::TensorTemplate* cond; + Eigen::array bcast_cond; + Eigen::array bcast_then; + Eigen::array bcast_else; + TosaReference::TensorTemplate* then_val; + TosaReference::TensorTemplate* else_val; + TosaReference::TensorTemplate* out; +}; + +// primary class +template +class OpSelect : public OpSelectBase +{ +public: + OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpSelectBase(attribute_, qinfo_, id_) + {} + virtual int eval(); + int broadcast(); + + using InEigenType = typename OpSelectBase::InEigenType; +}; + +// partial specialization for rank 0 +template +class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype> +{ +public: + OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpSelectBase<0, Dtype>(attribute_, qinfo_, id_) + {} + virtual int eval(); +}; +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc new file mode 100644 index 0000000..d7bddc0 --- /dev/null +++ b/reference_model/src/ops/ewise_unary.cc @@ -0,0 +1,302 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ewise_unary.h" +#include "quant_util.h" +#include "template_types.h" +#include + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +UnaryNode::UnaryNode(const Op& op_, uint64_t id_) + : GraphNode(op_, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); }; +} + +template +UnaryNode::~UnaryNode() +{} + +template +int UnaryNode::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankSize(*outputs[0])) + { + printNodeValidationError("UnaryNode: input and output rank must match"); + return 1; + } + + a = dynamic_cast*>(inputs[0]); + result = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(a && result); + + return 0; +} + +template +int UnaryNode::eval() +{ + this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); + + return GraphNode::eval(); +} + +template +int OpAbs::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpBitwiseNot::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpCeil::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpClz::register_fcn() +{ + int32_t num_bits; + switch (Dtype) + { + case DType_INT32: + num_bits = 32; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + this->fcn = [num_bits](int32_t a) -> int32_t { + int32_t leading_zeros = 0; + for (int bit = num_bits - 1; bit >= 0; bit--) + { + if (((a >> bit) & 0x1) == 0) + { + leading_zeros++; + } + else + { + break; + } + } + return leading_zeros; + }; + + return 0; +} + +template +int OpExp::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpFloor::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpLog::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpLogicalNot::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpNegate::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { + InEigenType result = -(a); + return result; + }; + break; + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { + InEigenType result = -(a); + return result; + }; + break; + case DType_AINT8: + ASSERT(this->qinfo); + this->fcn = [this](InEigenType a) -> OutEigenType { + InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp(); + return result; + }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpReciprocal::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template +int OpRsqrt::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h new file mode 100644 index 0000000..0db3cfb --- /dev/null +++ b/reference_model/src/ops/ewise_unary.h @@ -0,0 +1,102 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_EWISE_UNARY_H +#define OPS_EWISE_UNARY_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ +template +class UnaryNode : public GraphNode +{ +public: + UnaryNode(const Op& nodeType, const uint64_t id_); + virtual ~UnaryNode(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + virtual int register_fcn() = 0; + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + std::function fcn; + TosaReference::TensorTemplate* a; + TosaReference::TensorTemplate* result; +}; + +#define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \ + template \ + class Op##Opname : public UnaryNode \ + { \ + public: \ + Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : UnaryNode(Op_##OPNAME, id_) \ + { \ + register_fcn(); \ + } \ + static constexpr int32_t QMin = GetQMin::value; \ + static constexpr int32_t QMax = GetQMax::value; \ + using InEigenType = typename GetEigenType::type; \ + using OutEigenType = typename GetEigenType::type; \ + virtual int register_fcn(); \ + }; + +#define DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Opname, OPNAME) \ + template \ + class Op##Opname : public UnaryNode \ + { \ + public: \ + Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : UnaryNode(Op_##OPNAME, id_) \ + { \ + INIT_QINFO(Unary); \ + register_fcn(); \ + } \ + static constexpr int32_t QMin = GetQMin::value; \ + static constexpr int32_t QMax = GetQMax::value; \ + using InEigenType = typename GetEigenType::type; \ + using OutEigenType = typename GetEigenType::type; \ + virtual int register_fcn(); \ + \ + protected: \ + TosaUnaryQuantInfo* qinfo; \ + }; + +DEF_TEMPLATE_UNARY_OP(Abs, ABS) +DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT) +DEF_TEMPLATE_UNARY_OP(Ceil, CEIL) +DEF_TEMPLATE_UNARY_OP(Clz, CLZ) +DEF_TEMPLATE_UNARY_OP(Exp, EXP) +DEF_TEMPLATE_UNARY_OP(Floor, FLOOR) +DEF_TEMPLATE_UNARY_OP(Log, LOG) +DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT) +DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Negate, NEGATE) +DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL) +DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT) + +#undef DEF_TEMPLATE_UNARY_OP +#undef DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc new file mode 100644 index 0000000..d3352ce --- /dev/null +++ b/reference_model/src/ops/image.cc @@ -0,0 +1,169 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "image.h" +#include "arith_util.h" +#include "quant_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +OpResize::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_RESIZE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(4, 4); + + INIT_ATTRIBUTE(Resize); +} + +template +OpResize::~OpResize() +{ + if (attribute) + delete attribute; +} + +template +int OpResize::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + return 1; + + output_size = this->attribute->output_size(); + stride = this->attribute->stride(); + offset = this->attribute->offset(); + shift = this->attribute->shift(); + mode = this->attribute->mode(); + + int output_height = outputs[0]->getShape()[1]; + int output_width = outputs[0]->getShape()[2]; + + if (this->mode == ResizeMode_BILINEAR) + { + if (OutDtype != DType_INT32 && OutDtype != DType_INT48) + { + printNodeValidationError("OpResize: invalid data type for BILINEAR"); + return 1; + } + } + else + { + if (OutDtype != DType_INT8 && OutDtype != DType_INT16) + { + printNodeValidationError("OpResize: invalid data type for NEAREST"); + return 1; + } + } + + if (output_size[0] != output_height || output_size[1] != output_width) + { + printNodeValidationError("OpResize: attribute output_size doesn't match output [height, width]"); + return 1; + } + + if (shift < 1 || shift > 11) + { + printNodeValidationError("OpResize: attribute shift should be within [1, 11]"); + return 1; + } + + if (stride[0] <= 0 || stride[1] <= 0) + { + printNodeValidationError("OpResize: invalid attribute stride"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(in && out); + + return 0; +} + +template +int OpResize::eval() +{ + int in_batch = in->getShape()[0]; + int in_height = in->getShape()[1]; + int in_width = in->getShape()[2]; + int in_channels = in->getShape()[3]; + + int out_batch = out->getShape()[0]; + int out_height = out->getShape()[1]; + int out_width = out->getShape()[2]; + int out_channels = out->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); + ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + + for (int b = 0; b < out_batch; b++) + for (int c = 0; c < out_channels; c++) + for (int oy = 0; oy < out_height; oy++) + for (int ox = 0; ox < out_width; ox++) + { + int y = oy * stride[0] + offset[0]; + int x = ox * stride[1] + offset[1]; + + int iy = y >> shift; + int dy = y - (iy << shift); + int ix = x >> shift; + int dx = x - (ix << shift); + + int iy0 = MAX(iy, 0); + int iy1 = MIN(iy + 1, in_height - 1); + int ix0 = MAX(ix, 0); + int ix1 = MIN(ix + 1, in_width - 1); + + ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", + iy0, iy1, ix0, ix1); + + InEigenType v00 = in->getTensor()(b, iy0, ix0, c); + InEigenType v01 = in->getTensor()(b, iy0, ix1, c); + InEigenType v10 = in->getTensor()(b, iy1, ix0, c); + InEigenType v11 = in->getTensor()(b, iy1, ix1, c); + + OutEigenType acc; + if (mode == ResizeMode_BILINEAR) + { + acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx); + acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx; + acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx); + acc = acc + (OutEigenType)v11 * dy * dx; + } + else + { + iy = (dy >> (shift - 1)) != 0 ? iy1 : iy0; + ix = (dx >> (shift - 1)) != 0 ? ix1 : ix0; + acc = in->getTensor()(b, iy, ix, c); + } + + out->getTensor()(b, oy, ox, c) = acc; + } + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48); +DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16); diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h new file mode 100644 index 0000000..9d15d49 --- /dev/null +++ b/reference_model/src/ops/image.h @@ -0,0 +1,53 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_IMAGE_H +#define OPS_IMAGE_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template +class OpResize : public GraphNode +{ +public: + OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpResize(); + virtual int checkTensorAttributes() final; + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaResizeAttribute* attribute; + std::vector output_size; + std::vector stride; + std::vector offset; + int32_t shift; + ResizeMode mode; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc new file mode 100644 index 0000000..bad0c40 --- /dev/null +++ b/reference_model/src/ops/op_factory.cc @@ -0,0 +1,432 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "op_factory.h" +#include "activation_funcs.h" +#include "comparison.h" +#include "control_flow.h" +#include "custom.h" +#include "data_layout.h" +#include "data_nodes.h" +#include "ewise_binary.h" +#include "ewise_ternary.h" +#include "ewise_unary.h" +#include "image.h" +#include "reduction.h" +#include "scatter_gather.h" +#include "tensor_ops.h" +#include "type_conversion.h" + +using namespace TosaReference; +using namespace tosa; + +GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, + Op opType, + TosaAttributeBase* attribute, + TosaQuantInfoBase* qinfo, + uint64_t id, + DType inputDType, + int inputRank, + DType outputDType, + int outputRank, + DType weightDType, + int weightRank) +{ + switch (opType) + { + // tensor_ops + case Op_ARGMAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); + break; + case Op_AVG_POOL2D: + DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT); + DEF_FACTORY_ONE_TYPE(OpAvgPool2d, AINT8); + DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16); + break; + case Op_CONV2D: + DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT4); + DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT8); + DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, AINT8); + DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8); + break; + case Op_DEPTHWISE_CONV2D: + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4); + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8); + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8); + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); + break; + case Op_FULLY_CONNECTED: + DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT4); + DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT8); + DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, AINT8); + DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8); + break; + case Op_MATMUL: + DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT); + DEF_FACTORY_ONE_TYPE(OpMatMul, AINT8); + DEF_FACTORY_ONE_TYPE(OpMatMul, INT16); + break; + case Op_MAX_POOL2D: + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, AINT8); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); + break; + case Op_TRANSPOSE_CONV2D: + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT4); + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT8); + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8); + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8); + break; + + // activation_funcs + case Op_CLAMP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); + break; + case Op_RELUN: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32); + break; + case Op_SIGMOID: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); + break; + case Op_TANH: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); + break; + + // ewise_binary + case Op_ADD: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); + break; + case Op_ARITHMETIC_RIGHT_SHIFT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32); + break; + case Op_BITWISE_AND: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32); + break; + case Op_BITWISE_OR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32); + break; + case Op_BITWISE_XOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32); + break; + case Op_LOGICAL_AND: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL); + break; + case Op_LOGICAL_LEFT_SHIFT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32); + break; + case Op_LOGICAL_RIGHT_SHIFT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32); + break; + case Op_LOGICAL_OR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); + break; + case Op_LOGICAL_XOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); + break; + case Op_MAXIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); + break; + case Op_MINIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); + break; + case Op_MUL: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); + break; + case Op_POW: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); + break; + case Op_SUB: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); + break; + case Op_TABLE: + DEF_FACTORY_ONE_RANK_0_6(OpTable); + break; + + // ewise_unary + case Op_ABS: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); + break; + case Op_BITWISE_NOT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); + break; + case Op_CEIL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); + break; + case Op_CLZ: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); + break; + case Op_EXP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); + break; + case Op_FLOOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); + break; + case Op_LOG: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); + break; + case Op_LOGICAL_NOT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); + break; + case Op_NEGATE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); + break; + case Op_RECIPROCAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); + break; + case Op_RSQRT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); + break; + + // ewise_ternary + case Op_SELECT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); + break; + + // comparison + case Op_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); + break; + case Op_GREATER: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); + break; + case Op_GREATER_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); + break; + + // reduction + case Op_REDUCE_ALL: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); + break; + case Op_REDUCE_ANY: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); + break; + case Op_REDUCE_MAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); + break; + case Op_REDUCE_MIN: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); + break; + case Op_REDUCE_PRODUCT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); + break; + case Op_REDUCE_SUM: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32); + break; + + // data layout + case Op_CONCAT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); + break; + case Op_PAD: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); + break; + case Op_RESHAPE: + DEF_FACTORY_RESHAPE(OpReshape, FLOAT); + DEF_FACTORY_RESHAPE(OpReshape, AINT8); + DEF_FACTORY_RESHAPE(OpReshape, INT8); + DEF_FACTORY_RESHAPE(OpReshape, INT16); + DEF_FACTORY_RESHAPE(OpReshape, INT32); + DEF_FACTORY_RESHAPE(OpReshape, BOOL); + break; + case Op_REVERSE: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); + break; + case Op_SLICE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + break; + case Op_TILE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + break; + case Op_TRANSPOSE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); + break; + + // scatter_gather + case Op_GATHER: + { + // output.rank = input.rank - 1 + index.rank + int32_t index_rank = outputRank - inputRank + 1; + DEF_FACTORY_GATHER(OpGather, AINT8); + DEF_FACTORY_GATHER(OpGather, INT16); + DEF_FACTORY_GATHER(OpGather, INT32); + } + break; + + // image + case Op_RESIZE: + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT32); + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8); + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48); + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16); + break; + + // data_nodes + case Op_CONST: + return new OpConst(id); + case Op_PLACEHOLDER: + return new OpPlaceholder(id); + case Op_IDENTITY: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); + break; + case Op_IDENTITYN: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL); + break; + + // type_conversion + case Op_CAST: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); + break; + case Op_RESCALE: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8); + break; + + // custom + case Op_CUSTOM: + return new OpCustom(id); + + // control_flow + case Op_COND_IF: + return new OpCondIf(tsh, attribute, id); + case Op_WHILE_LOOP: + return new OpWhileLoop(tsh, attribute, id); + + // Ops not recognized + default: + goto done; + + } // End of switch(opType) + +done: + return nullptr; +} diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h new file mode 100644 index 0000000..cde6841 --- /dev/null +++ b/reference_model/src/ops/op_factory.h @@ -0,0 +1,294 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_OP_FACTORY_H +#define OPS_OP_FACTORY_H + +#include "attribute.h" +#include "graph_node.h" +#include "quant_info.h" +#include "template_types.h" +#include "tosa_serialization_handler.h" + +#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ + case RANK: \ + return new OP(attribute, qinfo, id); + +#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ + case RANK: \ + return new OP(attribute, qinfo, id); + +#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ + case RANK2: \ + return new OP(attribute, qinfo, id); + +#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ + case RANK2: \ + return new OP(attribute, qinfo, id); + +#define DEF_FACTORY_ONE_RANK_0_6(OP) \ + switch (inputRank) \ + { \ + case 0: \ + return new OP<0>(attribute, qinfo, id); \ + case 1: \ + return new OP<1>(attribute, qinfo, id); \ + case 2: \ + return new OP<2>(attribute, qinfo, id); \ + case 3: \ + return new OP<3>(attribute, qinfo, id); \ + case 4: \ + return new OP<4>(attribute, qinfo, id); \ + case 5: \ + return new OP<5>(attribute, qinfo, id); \ + case 6: \ + return new OP<6>(attribute, qinfo, id); \ + } + +#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \ + if (inputDType == DType_##DTYPE) \ + { \ + return new OP(attribute, qinfo, id); \ + } + +#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ + { \ + return new OP(attribute, qinfo, id); \ + } + +#define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + return new OP(attribute, qinfo, id); \ + } + +#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ + if (inputDType == DType_##DTYPE) \ + { \ + switch (inputRank) \ + { \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \ + } \ + } + +#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ + if (inputDType == DType_##DTYPE) \ + { \ + switch (inputRank) \ + { \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \ + } \ + } + +#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + switch (inputRank) \ + { \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2) \ + } \ + } + +#define DEF_FACTORY_RESHAPE(OP, DTYPE) \ + if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \ + { \ + switch (inputRank) \ + { \ + case 0: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \ + } \ + } \ + case 1: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \ + } \ + } \ + case 2: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \ + } \ + } \ + case 3: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \ + } \ + } \ + case 4: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \ + } \ + } \ + case 5: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \ + } \ + } \ + case 6: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE) \ + } \ + } \ + } \ + } + +#define DEF_FACTORY_GATHER(OP, DTYPE) \ + if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \ + { \ + switch (inputRank) \ + { \ + case 1: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE); \ + } \ + case 2: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE); \ + } \ + case 3: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE); \ + } \ + case 4: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE); \ + } \ + case 5: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE); \ + } \ + case 6: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE); \ + } \ + } \ + } + +namespace TosaReference +{ + +class OpFactory +{ +public: + static GraphNode* newOp(tosa::TosaSerializationHandler* tsh, + tosa::Op opType, + tosa::TosaAttributeBase* attribute, + tosa::TosaQuantInfoBase* qinfo, + uint64_t id, + tosa::DType inputDType, + int inputRank, + tosa::DType outputDType, + int outputRank, + tosa::DType weightDType, + int weightRank); +}; +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc new file mode 100644 index 0000000..a2adfdb --- /dev/null +++ b/reference_model/src/ops/reduction.cc @@ -0,0 +1,139 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduction.h" +#include "quant_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +ReduceNode::ReduceNode(const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) + : GraphNode(op_, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 4); + + INIT_ATTRIBUTE(Axis); +} + +template +ReduceNode::~ReduceNode() +{ + if (attribute) + delete attribute; +} + +template +int ReduceNode::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank()) + { + printNodeValidationError("Reduce axis must between [0, input_rank - 1]"); + return 1; + } + + if (inputs[0]->matchRank(*outputs[0])) + { + printNodeValidationError("Input and output tensor ranks must match"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(in && out); + + dims[0] = this->attribute->axis(); + + return 0; +} + +template +int OpReduceAll::eval() +{ + this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template +int OpReduceAny::eval() +{ + this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template +int OpReduceMax::eval() +{ + this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template +int OpReduceMin::eval() +{ + this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template +int OpReduceProduct::eval() +{ + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template +int OpReduceSum::eval() +{ + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32); diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h new file mode 100644 index 0000000..cf75812 --- /dev/null +++ b/reference_model/src/ops/reduction.h @@ -0,0 +1,109 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_REDUCTION_H +#define OPS_REDUCTION_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template +class ReduceNode : public GraphNode +{ +public: + ReduceNode(const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_); + virtual ~ReduceNode(); + virtual int checkTensorAttributes(); + virtual int eval() = 0; + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array dims; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; + TosaAxisAttribute* attribute; +}; + +template +class OpReduceAll : public ReduceNode +{ +public: + OpReduceAll(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(Op_REDUCE_ALL, attribute_, id_) + {} + virtual int eval(); +}; + +template +class OpReduceAny : public ReduceNode +{ +public: + OpReduceAny(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(Op_REDUCE_ALL, attribute_, id_) + {} + virtual int eval(); +}; + +template +class OpReduceMax : public ReduceNode +{ +public: + OpReduceMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(Op_REDUCE_MAX, attribute_, id_) + {} + virtual int eval(); +}; + +template +class OpReduceMin : public ReduceNode +{ +public: + OpReduceMin(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(Op_REDUCE_MIN, attribute_, id_) + {} + virtual int eval(); +}; + +template +class OpReduceProduct : public ReduceNode +{ +public: + OpReduceProduct(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(Op_REDUCE_PRODUCT, attribute_, id_) + {} + virtual int eval(); +}; + +template +class OpReduceSum : public ReduceNode +{ +public: + OpReduceSum(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(Op_REDUCE_SUM, attribute_, id_) + {} + virtual int eval(); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc new file mode 100644 index 0000000..c54204a --- /dev/null +++ b/reference_model/src/ops/scatter_gather.cc @@ -0,0 +1,120 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "scatter_gather.h" +#include "quant_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +OpGather::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_GATHER, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(1, 6); + + INIT_ATTRIBUTE(Axis); +} + +template +OpGather::~OpGather() +{ + if (attribute) + delete attribute; +} + +template +int OpGather::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + index = dynamic_cast*>(inputs[1]); + out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(in && index && out); + + return 0; +} + +template +int OpGather::eval() +{ + int axis = attribute->axis(); + + // calculate size left and right to axis + int left_size = 1; + for (int i = 0; i < axis; ++i) + { + left_size *= in->getShape()[i]; + } + + int right_size = 1; + for (int i = axis + 1; i < in->getRank(); ++i) + { + right_size *= in->getShape()[i]; + } + + InEigenType* input_data = in->getTensor().data(); + int32_t* index_data = index->getTensor().data(); + OutEigenType* output_data = out->getTensor().data(); + + int32_t axis_size = in->getShape()[axis]; + int32_t index_count = index->getElementCount(); + + // sanity check if index is valid + // need to check until this point since index is not known until runtime + for (size_t i = 0; i < index->getElementCount(); i++) + { + if (index_data[i] >= axis_size) + { + FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size); + } + } + + // Eigen stores tensor in column-major + // so we iterate through dimension right to axis and the index array + // do memory copy with size of left size each time + for (int right = 0; right < right_size; ++right) + { + for (int i = 0; i < index_count; ++i) + { + std::memcpy(output_data + (right * index_count + i) * left_size, + input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size); + } + } + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_GATHER(OpGather, AINT8); +DEF_INSTANTIATE_GATHER(OpGather, INT16); +DEF_INSTANTIATE_GATHER(OpGather, INT32); diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h new file mode 100644 index 0000000..d9b1263 --- /dev/null +++ b/reference_model/src/ops/scatter_gather.h @@ -0,0 +1,54 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_SCATTER_GATHER_H +#define OPS_SCATTER_GATHER_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +// input and index can have different rank +// and infer OutRank statically +template +class OpGather : public GraphNode +{ +public: + OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpGather(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + static constexpr int OutRank = InRank - 1 + IndexRank; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TIndex = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* index; + TosaReference::TensorTemplate* out; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h new file mode 100644 index 0000000..1859e03 --- /dev/null +++ b/reference_model/src/ops/template_types.h @@ -0,0 +1,277 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OP_TEMPLATE_TYPES_H +#define OP_TEMPLATE_TYPES_H + +#include "tosa_generated.h" +#include + +using namespace tosa; + +namespace TosaReference +{ +// Shorter aliase templates for common Eigen::Tensor types +template +using ETensor0 = Eigen::Tensor; +template +using ETensor1 = Eigen::Tensor; +template +using ETensor2 = Eigen::Tensor; +template +using ETensor3 = Eigen::Tensor; +template +using ETensor4 = Eigen::Tensor; +template +using ETensor5 = Eigen::Tensor; +template +using ETensor6 = Eigen::Tensor; + +// Forward declaration +template +class TensorTemplate; + +// Shortcut to hide the TensorTemplate class. +// For example, declare Tensor1 to get a TensorTemplate +// with an Eigen::Tensor +template +using Tensor0 = TensorTemplate>; +template +using Tensor1 = TensorTemplate>; +template +using Tensor2 = TensorTemplate>; +template +using Tensor3 = TensorTemplate>; +template +using Tensor4 = TensorTemplate>; +template +using Tensor5 = TensorTemplate>; +template +using Tensor6 = TensorTemplate>; + +template +struct GetEigenType; +template <> +struct GetEigenType +{ + using type = float; +}; +template <> +struct GetEigenType +{ + using type = int32_t; +}; +template <> +struct GetEigenType +{ + using type = int64_t; +}; +template <> +struct GetEigenType +{ + using type = bool; +}; +template <> +struct GetEigenType +{ + using type = int32_t; +}; +template <> +struct GetEigenType +{ + using type = int32_t; +}; +template <> +struct GetEigenType +{ + using type = int32_t; +}; +template <> +struct GetEigenType +{ + using type = int32_t; +}; +template <> +struct GetEigenType +{ + using type = int32_t; +}; + +// Meta function to get number of bits +template +struct GetNumBits +{ + static constexpr int32_t value = 0; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 1; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 4; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 16; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 32; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 48; +}; + +// Meta function to get quantized min/max in compile time +template +struct GetQMin +{ + static constexpr int64_t value = 0L; +}; +template <> +struct GetQMin +{ + static constexpr int64_t value = -128L; +}; +template <> +struct GetQMin +{ + static constexpr int64_t value = 0L; +}; +template <> +struct GetQMin +{ + static constexpr int64_t value = -8L; +}; +template <> +struct GetQMin +{ + static constexpr int64_t value = -128L; +}; +template <> +struct GetQMin +{ + static constexpr int64_t value = -32768L; +}; +template <> +struct GetQMin +{ + static constexpr int64_t value = -(1L << 31); +}; +template <> +struct GetQMin +{ + static constexpr int64_t value = -(1L << 47); +}; + +template +struct GetQMax +{ + static constexpr int64_t value = 0L; +}; +template <> +struct GetQMax +{ + static constexpr int64_t value = 127L; +}; +template <> +struct GetQMax +{ + static constexpr int64_t value = 255L; +}; +template <> +struct GetQMax +{ + static constexpr int64_t value = 7L; +}; +template <> +struct GetQMax +{ + static constexpr int64_t value = 127L; +}; +template <> +struct GetQMax +{ + static constexpr int64_t value = 32767L; +}; +template <> +struct GetQMax +{ + static constexpr int64_t value = (1L << 31) - 1; +}; +template <> +struct GetQMax +{ + static constexpr int64_t value = (1L << 47) - 1; +}; + +template +struct GetAccDType; +template <> +struct GetAccDType +{ + static constexpr DType value = DType_INT32; +}; +template <> +struct GetAccDType +{ + static constexpr DType value = DType_INT32; +}; +template <> +struct GetAccDType +{ + static constexpr DType value = DType_INT32; +}; +template <> +struct GetAccDType +{ + static constexpr DType value = DType_INT48; +}; +template <> +struct GetAccDType +{ + static constexpr DType value = DType_INT48; +}; +template <> +struct GetAccDType +{ + static constexpr DType value = DType_FLOAT; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc new file mode 100644 index 0000000..a735334 --- /dev/null +++ b/reference_model/src/ops/tensor_ops.cc @@ -0,0 +1,1229 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensor_ops.h" +#include "quant_util.h" +#include "template_types.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +OpArgMax::OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_ARGMAX, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Axis); +} + +template +OpArgMax::~OpArgMax() +{ + if (attribute) + delete attribute; +} + +template +int OpArgMax::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + input = dynamic_cast*>(inputs[0]); + output = dynamic_cast*>(outputs[0]); + + return 0; +} + +template +int OpArgMax::eval() +{ + Eigen::Tensor index = this->input->getTensor().argmax(attribute->axis()); + + this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; }); + + return GraphNode::eval(); +} + +template +OpAvgPool2d::OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_AVG_POOL2D, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(Pool2d); + INIT_QINFO(Unary); +} + +template +OpAvgPool2d::~OpAvgPool2d() +{ + if (attribute) + delete attribute; +} + +template +int OpAvgPool2d::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + if (!in->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpAvgPool2d: unsupported tensor format"); + return 1; + } + + if (attribute->padding().size() != 4) + { + printNodeValidationError("OpAvgPool2d: illegal size for attribute padding"); + return 1; + } + + if (attribute->kernel().size() != 2) + { + printNodeValidationError("OpAvgPool2d: illegal size for attribute kernel"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpAvgPool2d: illegal size for attribute stride"); + return 1; + } + + return 0; +} + +template +ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride) +{ + ETensor1 result(out_size); + + int32_t total_pad = (out_size - 1) * stride + kernel_size - in_size; + total_pad = total_pad < 0 ? 0 : total_pad; + + int32_t pad_left = total_pad >> 1; + int32_t pad_right = total_pad - pad_left; + + result.setConstant(kernel_size); + + // the index left to 'left_index' and index right to 'right_index' indicates + // the input window of this output covers a pad bit + int32_t left_index = pad_left / stride; + int32_t right_index = pad_right / stride; + + // not handle ultra small activation yet + ASSERT_MSG_NODE((out_size - 1 - right_index) >= left_index, "AvgPool2d: Small activations not supported yet"); + + // minus the number of pad bit this index cover + while (left_index >= 0) + { + result(left_index) -= (pad_left - left_index * stride); + left_index--; + } + + while (right_index >= 0) + { + result(out_size - 1 - right_index) -= (pad_right - right_index * stride); + right_index--; + } + + return result; +} + +// assuming input and output tensor have same scales like tflite reference +// so no need to scale input and output +template +int OpAvgPool2d::eval() +{ + int in_batch = this->in->getShape()[0]; + int in_height = this->in->getShape()[1]; + int in_width = this->in->getShape()[2]; + int in_channels = this->in->getShape()[3]; + + int out_batch = this->out->getShape()[0]; + int out_height = this->out->getShape()[1]; + int out_width = this->out->getShape()[2]; + int out_channels = this->out->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + + int padding_top = this->attribute->padding()[0]; + int padding_bottom = this->attribute->padding()[1]; + int padding_left = this->attribute->padding()[2]; + int padding_right = this->attribute->padding()[3]; + int kernel_h = this->attribute->kernel()[0]; + int kernel_w = this->attribute->kernel()[1]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + + DEBUG_INFO(OP, + "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " + "stride=[%d,%d], padding=[%d,%d,%d,%d]", + in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h, + kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right); + + Eigen::array im2col_input_dims; + im2col_input_dims[0] = kernel_h * kernel_w; + im2col_input_dims[1] = out_batch * out_height * out_width * out_channels; + + Eigen::array col2im_output_dims; + col2im_output_dims[0] = out_batch; + col2im_output_dims[1] = out_height; + col2im_output_dims[2] = out_width; + col2im_output_dims[3] = out_channels; + + Eigen::array, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_top, padding_bottom); + padding[2] = std::make_pair(padding_left, padding_right); + padding[3] = std::make_pair(0, 0); + + ETensor4 input_val = this->in->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + } + + ETensor4 input_padded = input_val.pad(padding); + + // assuming input and output have same scales + // so input and output scaling is not required + // TODO: check if this assumption TOSA made + + // extract_image_patches() output [N, KH, KW, H * W, C] + // transpose to [KH, KW, N, H * W, C] + // reshape to [KH * KW, N * H * W * C] + ETensor2 input_extract_patches = + input_padded.extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID) + .shuffle(Eigen::array{ 1, 2, 0, 3, 4 }) + .reshape(im2col_input_dims); + + // 1D result with [N * H * W * C] + ETensor1 out_1d(this->out->getElementCount()); + out_1d.setZero(); + + // sum pool + for (size_t i = 0; i < this->out->getElementCount(); i++) + { + for (int32_t j = 0; j < kernel_h * kernel_w; j++) + { + out_1d(i) += (AccEigenType)input_extract_patches(j, i); + } + } + + // reshape result to [N, H, W, C] and divide with div_map + ETensor4 sum = out_1d.reshape(col2im_output_dims); + + // calculate 1d height/width div_map (number of elements this pooling window covers) + // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C] + ETensor1 div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h); + ETensor1 div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w); + Eigen::array, 1> contract_dims = { Eigen::IndexPair(1, 0) }; + Eigen::array bcast{ out_batch, 1, 1, out_channels }; + + ETensor4 div_map = + div_map_h.reshape(Eigen::array{ out_height, 1 }) + .contract(div_map_w.reshape(Eigen::array{ 1, out_width }), contract_dims) + .reshape(Eigen::array{ 1, out_height, out_width, 1 }) + .broadcast(bcast); + + if (Dtype != DType_FLOAT) + { + this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType { + int32_t multiplier, shift; + TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift); + + return (OutEigenType)TosaReference::QuantUtil::apply_scale(value, multiplier, shift, false); + }); + this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp()); + this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin); + this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax); + } + else + { + this->out->getTensor() = (sum / div_map.template cast()).template cast(); + } + + return GraphNode::eval(); +} + +template +OpConv2d::OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_CONV2D, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(Conv2d); + INIT_QINFO(Conv); +} + +template +OpConv2d::~OpConv2d() +{ + if (attribute) + delete attribute; + if (qinfo) + delete qinfo; +} + +template +int OpConv2d::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4 + if (inputs[2]->getRank() != 1) + { + printNodeValidationError("OpConv2d: bias tensor must be rank 1"); + } + + if (inputs[1]->getIsConst() == 0) + { + printNodeValidationError("OpConv2d: weight tensor is not const typed"); + } + + input = dynamic_cast*>(inputs[0]); + weight = dynamic_cast*>(inputs[1]); + bias = dynamic_cast*>(inputs[2]); + output = dynamic_cast*>(outputs[0]); + + if (!input->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpConv2d: unsupported input tensor format"); + return 1; + } + + if (!weight->hasFormat(Format_OHWI)) + { + printNodeValidationError("OpConv2d: unsupported weight tensor format"); + return 1; + } + + if (attribute->padding().size() != 4) + { + printNodeValidationError("OpConv2d: illegal size for attribute padding"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpConv2d: illegal size for attribute stride"); + return 1; + } + + if (attribute->dilation().size() != 2) + { + printNodeValidationError("OpConv2d: illegal size for attribute dilation"); + return 1; + } + + return 0; +} + +template +int OpConv2d::eval() +{ + int in_batch = this->input->getShape()[0]; + int in_height = this->input->getShape()[1]; + int in_width = this->input->getShape()[2]; + int in_channels = this->input->getShape()[3]; + + int f_out_channels = this->weight->getShape()[0]; + int f_height = this->weight->getShape()[1]; + int f_width = this->weight->getShape()[2]; + int f_in_channels = this->weight->getShape()[3]; + + int b_out_channels = this->bias->getShape()[0]; + + int out_batch = this->output->getShape()[0]; + int out_height = this->output->getShape()[1]; + int out_width = this->output->getShape()[2]; + int out_channels = this->output->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ASSERT_MSG_NODE(f_in_channels == in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ASSERT_MSG_NODE(f_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels, + out_channels); + ASSERT_MSG_NODE(b_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", b_out_channels, + out_channels); + + int padding_top = this->attribute->padding()[0]; + int padding_bottom = this->attribute->padding()[1]; + int padding_left = this->attribute->padding()[2]; + int padding_right = this->attribute->padding()[3]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + int dilation_h = this->attribute->dilation()[0]; + int dilation_w = this->attribute->dilation()[1]; + + DEBUG_INFO(OP, + "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], " + "stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]", + in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch, + out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top, + padding_bottom, padding_left, padding_right); + + // GEMM-conv2d, left matrix is input, right matrix is weight + Eigen::array im2col_input_dims; + im2col_input_dims[0] = out_batch * out_height * out_width; + im2col_input_dims[1] = f_height * f_width * f_in_channels; + + Eigen::array im2col_weight_dims; + im2col_weight_dims[0] = f_height * f_width * f_in_channels; + im2col_weight_dims[1] = f_out_channels; + + Eigen::array bias_reshaped_dims; + bias_reshaped_dims[0] = 1; + bias_reshaped_dims[1] = b_out_channels; + + Eigen::array weight_zp_bcast_dims; + weight_zp_bcast_dims[0] = f_height; + weight_zp_bcast_dims[1] = f_width; + weight_zp_bcast_dims[2] = f_in_channels; + + Eigen::array bias_bcast_dims; + bias_bcast_dims[0] = out_batch * out_height * out_width; + bias_bcast_dims[1] = 1; + + Eigen::array col2im_output_dims; + col2im_output_dims[0] = out_batch; + col2im_output_dims[1] = out_height; + col2im_output_dims[2] = out_width; + col2im_output_dims[3] = out_channels; + + Eigen::array, 1> contract_dims = { Eigen::IndexPair(1, 0) }; + + Eigen::array, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_top, padding_bottom); + padding[2] = std::make_pair(padding_left, padding_right); + padding[3] = std::make_pair(0, 0); + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + ETensor4 input_padded = input_val.pad(padding); + + // extract_image_patches() output [N, KH, KW, H * W, C] + // need to transpose to [N, H * W, KH, KW, C] + ETensor5 input_extract_patches = + input_padded + .extract_image_patches(f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID) + .shuffle(Eigen::array{ 0, 3, 1, 2, 4 }); + + // reshape input to [N * H * W, KH * KW * C] + ETensor2 im2col_input = input_extract_patches.reshape(im2col_input_dims); + + // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC] + ETensor2 im2col_weight = + weight_val.shuffle(Eigen::array({ 1, 2, 3, 0 })).reshape(im2col_weight_dims); + + // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it + // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C] + ETensor2 bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims); + + // output matrix is [N * H * W, C] + ETensor2 contracted_result = + im2col_input.template cast().contract(im2col_weight.template cast(), contract_dims); + + // adding bias + ETensor2 biased_output = contracted_result + bias_2d.template cast(); + + // reshape back to [N, H, W, C] + this->output->getTensor() = biased_output.reshape(col2im_output_dims); + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + +template +OpDepthwiseConv2d::OpDepthwiseConv2d(TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(Op_DEPTHWISE_CONV2D, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(Conv2d); + INIT_QINFO(Conv); +} + +template +OpDepthwiseConv2d::~OpDepthwiseConv2d() +{ + if (attribute) + delete attribute; + if (qinfo) + delete qinfo; +} + +template +int OpDepthwiseConv2d::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4 + if (inputs[2]->getRank() != 1) + { + printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1"); + } + + if (inputs[1]->getIsConst() == 0) + { + printNodeValidationError("OpDepthwiseConv2d: weight tensor is not const typed"); + } + + input = dynamic_cast*>(inputs[0]); + weight = dynamic_cast*>(inputs[1]); + bias = dynamic_cast*>(inputs[2]); + output = dynamic_cast*>(outputs[0]); + + if (!input->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpDepthwiseConv2d: unsupported input tensor format"); + return 1; + } + + if (!weight->hasFormat(Format_HWIM)) + { + printNodeValidationError("OpDepthwiseConv2d: unsupported weight tensor format"); + return 1; + } + + if (attribute->padding().size() != 4) + { + printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride"); + return 1; + } + + if (attribute->dilation().size() != 2) + { + printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation"); + return 1; + } + + return 0; +} + +template +int OpDepthwiseConv2d::eval() +{ + int in_batch = this->input->getShape()[0]; + int in_height = this->input->getShape()[1]; + int in_width = this->input->getShape()[2]; + int in_channels = this->input->getShape()[3]; + + int f_height = this->weight->getShape()[0]; + int f_width = this->weight->getShape()[1]; + int f_in_channels = this->weight->getShape()[2]; + int f_multiplier = this->weight->getShape()[3]; + + int b_out_channels = this->bias->getShape()[0]; + + int out_batch = this->output->getShape()[0]; + int out_height = this->output->getShape()[1]; + int out_width = this->output->getShape()[2]; + int out_channels = this->output->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ASSERT_MSG_NODE(f_in_channels == in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", + f_in_channels, in_channels); + ASSERT_MSG_NODE(in_channels * f_multiplier == out_channels, + "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier, + out_channels); + ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d", + b_out_channels, out_channels); + + int padding_top = this->attribute->padding()[0]; + int padding_bottom = this->attribute->padding()[1]; + int padding_left = this->attribute->padding()[2]; + int padding_right = this->attribute->padding()[3]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + int dilation_h = this->attribute->dilation()[0]; + int dilation_w = this->attribute->dilation()[1]; + + DEBUG_INFO(OP, + "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]", + in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch, + out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top, + padding_bottom, padding_left, padding_right); + + Eigen::array, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_top, padding_bottom); + padding[2] = std::make_pair(padding_left, padding_right); + padding[3] = std::make_pair(0, 0); + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + ETensor4 input_padded = input_val.pad(padding); + + // GEMM doesn't fit well with DepthwiseConv2d + // 1. use extract_image_patches() to handle stride/dilation/padding + // 2. perform direct convolution + + // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC] + ETensor5 input_extract_patches = input_padded.extract_image_patches( + f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID); + + Eigen::array reshape_dim; + reshape_dim.fill(1); + reshape_dim[3] = b_out_channels; + + Eigen::array bcast; + bcast[0] = out_batch; + bcast[1] = out_height; + bcast[2] = out_width; + bcast[3] = 1; + + // initialize with bias + this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast); + + // 2. direct depthwise convolution + for (int ob = 0; ob < out_batch; ob++) + { + for (int oh = 0; oh < out_height; oh++) + { + for (int ow = 0; ow < out_width; ow++) + { + for (int ic = 0; ic < in_channels; ic++) + { + for (int cm = 0; cm < f_multiplier; cm++) + { + for (int fh = 0; fh < f_height; fh++) + { + for (int fw = 0; fw < f_width; fw++) + { + this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) += + ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) * + (AccEigenType)weight_val(fh, fw, ic, cm)); + } + } + } + } + } + } + } + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + +template +OpFullyConnected::OpFullyConnected(TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(Op_FULLY_CONNECTED, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(2); + + INIT_QINFO(Conv); +} + +template +OpFullyConnected::~OpFullyConnected() +{ + if (qinfo) + delete qinfo; +} + +template +int OpFullyConnected::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + input = dynamic_cast*>(inputs[0]); + weight = dynamic_cast*>(inputs[1]); + bias = dynamic_cast*>(inputs[2]); + + if (input->getShape()[1] != weight->getShape()[1]) + { + printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]"); + return 1; + } + + if (weight->getShape()[0] != bias->getShape()[0]) + { + printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]"); + return 1; + } + + output = dynamic_cast*>(outputs[0]); + + return 0; +} + +template +int OpFullyConnected::eval() +{ + typedef Eigen::Tensor::DimensionPair DimPair; + Eigen::array dims{ { DimPair(1, 0) } }; + + Eigen::array weight_shuffle{ 1, 0 }; + + Eigen::array bias_reshape; + bias_reshape[0] = 1; + bias_reshape[1] = this->bias->getShape()[0]; + + Eigen::array bias_bcast; + bias_bcast[0] = this->input->getShape()[0]; + bias_bcast[1] = 1; + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + this->output->getTensor() = + input_val.template cast().contract(weight_val.template cast(), dims) + + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + return GraphNode::eval(); +} + +template +OpMatMul::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_MATMUL, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(2); + + INIT_QINFO(MatMul); +} + +template +OpMatMul::~OpMatMul() +{ + if (qinfo) + delete qinfo; +} + +template +int OpMatMul::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + a = dynamic_cast*>(inputs[0]); + b = dynamic_cast*>(inputs[1]); + + if (a->getShape()[1] != b->getShape()[0]) + { + printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]"); + return 1; + } + + c = dynamic_cast*>(outputs[0]); + + return 0; +} + +template +int OpMatMul::eval() +{ + typedef Eigen::Tensor::DimensionPair DimPair; + Eigen::array dims{ { DimPair(1, 0) } }; + + TIn a_val = this->a->getTensor(); + TIn b_val = this->b->getTensor(); + if (this->qinfo) + { + a_val = a_val - (InEigenType)this->qinfo->a_zp(); + b_val = b_val - (InEigenType)this->qinfo->b_zp(); + } + + this->c->getTensor() = a_val.template cast().contract(b_val.template cast(), dims); + + if (AccDtype == DType_INT48) + { + this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin); + this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + +template +OpMaxPool2d::OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_MAX_POOL2D, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(Pool2d); +} + +template +OpMaxPool2d::~OpMaxPool2d() +{ + if (attribute) + delete attribute; +} + +template +int OpMaxPool2d::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + if (!in->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpMaxPool2d: unsupported tensor format"); + return 1; + } + + if (attribute->padding().size() != 4) + { + printNodeValidationError("OpMaxPool2d: illegal size for attribute padding"); + return 1; + } + + if (attribute->kernel().size() != 2) + { + printNodeValidationError("OpMaxPool2d: illegal size for attribute kernel"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpMaxPool2d: illegal size for attribute stride"); + return 1; + } + + return 0; +} + +template +int OpMaxPool2d::eval() +{ + int in_batch = this->in->getShape()[0]; + int in_height = this->in->getShape()[1]; + int in_width = this->in->getShape()[2]; + int in_channels = this->in->getShape()[3]; + + int out_batch = this->out->getShape()[0]; + int out_height = this->out->getShape()[1]; + int out_width = this->out->getShape()[2]; + int out_channels = this->out->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + + int padding_top = this->attribute->padding()[0]; + int padding_bottom = this->attribute->padding()[1]; + int padding_left = this->attribute->padding()[2]; + int padding_right = this->attribute->padding()[3]; + int kernel_h = this->attribute->kernel()[0]; + int kernel_w = this->attribute->kernel()[1]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + + DEBUG_INFO(OP, + "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " + "stride=[%d,%d], padding=[%d,%d,%d,%d]", + in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h, + kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right); + + Eigen::array im2col_input_dims; + im2col_input_dims[0] = kernel_h * kernel_w; + im2col_input_dims[1] = out_batch * out_height * out_width * out_channels; + + Eigen::array col2im_output_dims; + col2im_output_dims[0] = out_batch; + col2im_output_dims[1] = out_height; + col2im_output_dims[2] = out_width; + col2im_output_dims[3] = out_channels; + + Eigen::array, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_top, padding_bottom); + padding[2] = std::make_pair(padding_left, padding_right); + padding[3] = std::make_pair(0, 0); + + ETensor4 input_padded = this->in->getTensor().pad(padding, std::numeric_limits::lowest()); + + // extract_image_patches() output [N, KH, KW, H * W, C] + // transpose to [KH, KW, N, H * W, C] + // reshape to [KH * KW, N * H * W * C] + // + // Set the padding value to be the most negative value that can be + // represented by the datatype to ensure that any padding values will be equal + // to or smaller than the actual maximum in the KH x KW patch. + ETensor2 input_extract_patches = + input_padded + .extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID, + std::numeric_limits::lowest()) + .shuffle(Eigen::array{ 1, 2, 0, 3, 4 }) + .reshape(im2col_input_dims); + + // Get the maximum of the KHxHW patches along axis 0 + Eigen::Tensor tensor_argmax = input_extract_patches.argmax(0); + + // 1D result with [N * H * W * C] + ETensor1 out_1d(this->out->getElementCount()); + + // index input_patches with argmax array should give the result + for (size_t i = 0; i < this->out->getElementCount(); i++) + { + out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i); + } + + // reshape result to [N, H, W, C] + this->out->getTensor() = out_1d.reshape(col2im_output_dims); + + return GraphNode::eval(); +} + +template +OpTransposeConv2d::OpTransposeConv2d(TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(Op_TRANSPOSE_CONV2D, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(TransposeConv2d); + INIT_QINFO(Conv); +} + +template +OpTransposeConv2d::~OpTransposeConv2d() +{ + if (attribute) + delete attribute; + if (qinfo) + delete qinfo; +} + +template +int OpTransposeConv2d::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (inputs[1]->getIsConst() == 0) + { + printNodeValidationError("OpTransposeConv2d: weight tensor is not const typed"); + } + + input = dynamic_cast*>(inputs[0]); + weight = dynamic_cast*>(inputs[1]); + bias = dynamic_cast*>(inputs[2]); + output = dynamic_cast*>(outputs[0]); + + if (!input->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpTransposeConv2d: unsupported input tensor format"); + return 1; + } + + if (!weight->hasFormat(Format_OHWI)) + { + printNodeValidationError("OpTransposeConv2d: unsupported weight tensor format"); + return 1; + } + + if (attribute->outpad().size() != 2) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute outpad"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride"); + return 1; + } + + if (attribute->dilation().size() != 2) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute dilation"); + return 1; + } + + if (attribute->output_shape().size() != 4) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape"); + return 1; + } + + for (int d = 0; d < 4; d++) + { + if (attribute->output_shape()[d] != this->output->getShape()[d]) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape"); + return 1; + } + } + + return 0; +} + +template +int OpTransposeConv2d::eval() +{ + int in_batch = this->input->getShape()[0]; + int in_height = this->input->getShape()[1]; + int in_width = this->input->getShape()[2]; + int in_channels = this->input->getShape()[3]; + + int f_out_channels = this->weight->getShape()[0]; + int f_height = this->weight->getShape()[1]; + int f_width = this->weight->getShape()[2]; + int f_in_channels = this->weight->getShape()[3]; + + int b_out_channels = this->bias->getShape()[0]; + + int out_batch = this->output->getShape()[0]; + int out_height = this->output->getShape()[1]; + int out_width = this->output->getShape()[2]; + int out_channels = this->output->getShape()[3]; + + int padding_top = this->attribute->outpad()[0]; + int padding_left = this->attribute->outpad()[1]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + int dilation_h = this->attribute->dilation()[0]; + int dilation_w = this->attribute->dilation()[1]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ASSERT_MSG_NODE(f_in_channels == in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", + f_in_channels, in_channels); + ASSERT_MSG_NODE(f_out_channels == out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d", + f_out_channels, out_channels); + ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d", + b_out_channels, out_channels); + + DEBUG_INFO(OP, + "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d]", + in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch, + out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top, + padding_left); + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + Eigen::array reshape_dim; + reshape_dim.fill(1); + reshape_dim[3] = b_out_channels; + + Eigen::array bcast; + bcast[0] = out_batch; + bcast[1] = out_height; + bcast[2] = out_width; + bcast[3] = 1; + + // initialize with bias + this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast); + + int out_x_origin, out_y_origin; + int out_x, out_y; + + // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h + for (int ob = 0; ob < out_batch; ob++) + { + for (int ih = 0; ih < in_height; ih++) + { + for (int iw = 0; iw < in_width; iw++) + { + out_x_origin = iw * stride_w - padding_left; + out_y_origin = ih * stride_h - padding_top; + for (int ic = 0; ic < in_channels; ic++) + { + for (int fh = 0; fh < f_height; fh++) + { + for (int fw = 0; fw < f_width; fw++) + { + out_x = out_x_origin + fw * dilation_w; + out_y = out_y_origin + fh * dilation_h; + for (int oc = 0; oc < out_channels; oc++) + { + if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height)) + { + this->output->getTensor()(ob, out_y, out_x, oc) += + ((AccEigenType)input_val(ob, ih, iw, ic) * + (AccEigenType)weight_val(oc, fh, fw, ic)); + } + } + } + } + } + } + } + } + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); + +DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT) +DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, AINT8) +DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16) + +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, AINT8); +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8); + +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8); +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); + +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, AINT8); +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8); + +DEF_INSTANTIATE_ONE_TYPE(OpMatMul, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT); + +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); + +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8); +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h new file mode 100644 index 0000000..26ce84b --- /dev/null +++ b/reference_model/src/ops/tensor_ops.h @@ -0,0 +1,253 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_TENSOR_OPS_H +#define OPS_TENSOR_OPS_H + +#include "graph_node.h" +#include "quant_util.h" + +using namespace tosa; + +namespace TosaReference +{ + +template +class OpArgMax : public GraphNode +{ +public: + OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpArgMax(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate* input; + TosaReference::TensorTemplate* output; +}; + +template +class OpAvgPool2d : public GraphNode +{ +public: + OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpAvgPool2d(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + static constexpr DType AccDtype = GetAccDType::value; + using InEigenType = typename GetEigenType::type; + using AccEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + + static constexpr int64_t QMin = GetQMin::value; + static constexpr int64_t QMax = GetQMax::value; + +protected: + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; + tosa::TosaPool2dAttribute* attribute; + tosa::TosaUnaryQuantInfo* qinfo; + +protected: + // return a 1D [N] tensor that describes a how many valid elements covered in the input space + ETensor1 calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride); +}; + +template +class OpConv2d : public GraphNode +{ +public: + OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpConv2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType::value; + + using InEigenType = typename GetEigenType::type; + using WeightEigenType = typename GetEigenType::type; + using AccEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TWeight = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TAcc = Eigen::Tensor; + + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; + +protected: + TosaReference::TensorTemplate* input; + TosaReference::TensorTemplate* weight; + TosaReference::TensorTemplate* bias; + TosaReference::TensorTemplate* output; + tosa::TosaConv2dAttribute* attribute; + tosa::TosaConvQuantInfo* qinfo; +}; + +template +class OpDepthwiseConv2d : public GraphNode +{ +public: + OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpDepthwiseConv2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType::value; + + using InEigenType = typename GetEigenType::type; + using WeightEigenType = typename GetEigenType::type; + using AccEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TWeight = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TAcc = Eigen::Tensor; + + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; + +protected: + TosaReference::TensorTemplate* input; + TosaReference::TensorTemplate* weight; + TosaReference::TensorTemplate* bias; + TosaReference::TensorTemplate* output; + tosa::TosaConv2dAttribute* attribute; + tosa::TosaConvQuantInfo* qinfo; +}; + +template +class OpFullyConnected : public GraphNode +{ +public: + OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpFullyConnected(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType::value; + using InEigenType = typename GetEigenType::type; + using WeightEigenType = typename GetEigenType::type; + using AccEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TWeight = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TAcc = Eigen::Tensor; + + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; + +protected: + TosaReference::TensorTemplate* input; + TosaReference::TensorTemplate* weight; + TosaReference::TensorTemplate* bias; + TosaReference::TensorTemplate* output; + tosa::TosaConvQuantInfo* qinfo; +}; + +template +class OpMatMul : public GraphNode +{ +public: + OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpMatMul(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType::value; + using InEigenType = typename GetEigenType::type; + using AccEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TAcc = Eigen::Tensor; + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; + +protected: + TosaReference::TensorTemplate* a; + TosaReference::TensorTemplate* b; + TosaReference::TensorTemplate* c; + tosa::TosaMatMulQuantInfo* qinfo; +}; + +template +class OpMaxPool2d : public GraphNode +{ +public: + OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpMaxPool2d(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; + tosa::TosaPool2dAttribute* attribute; +}; + +template +class OpTransposeConv2d : public GraphNode +{ +public: + OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTransposeConv2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType::value; + + using InEigenType = typename GetEigenType::type; + using WeightEigenType = typename GetEigenType::type; + using AccEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TWeight = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TAcc = Eigen::Tensor; + + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; + +protected: + TosaReference::TensorTemplate* input; + TosaReference::TensorTemplate* weight; + TosaReference::TensorTemplate* bias; + TosaReference::TensorTemplate* output; + TosaTransposeConv2dAttribute* attribute; + TosaConvQuantInfo* qinfo; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc new file mode 100644 index 0000000..61a19f4 --- /dev/null +++ b/reference_model/src/ops/type_conversion.cc @@ -0,0 +1,299 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "type_conversion.h" +#include "quant_util.h" +#include "template_types.h" +#include + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +OpRescale::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_RESCALE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + INIT_ATTRIBUTE(Rescale); +} + +template +OpRescale::~OpRescale() +{ + if (attribute) + delete attribute; +} + +template +int OpRescale::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same rank and size + if (inputs[0]->matchRankSize(*outputs[0])) + { + printNodeValidationError("OpRescale: input and output rank/size must match"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(in && out); + + return 0; +} + +template +int OpRescale::eval() +{ + int32_t input_zp = attribute->input_zp(); + int32_t output_zp = attribute->output_zp(); + std::vector multiplier = attribute->multiplier(); + std::vector shift = attribute->shift(); + //bool scale32 = attribute->scale32(); + bool double_round = attribute->double_round(); + bool per_channel = attribute->per_channel(); + + if (TosaReference::TypeChecker::is_symmetric(InDtype)) + { + if (input_zp != 0) + { + FATAL_ERROR_NODE("input tensor is symmetric type %s but zeropoint is %d instead of 0", + EnumNamesDType()[InDtype], input_zp); + } + } + + if (TosaReference::TypeChecker::is_symmetric(OutDtype)) + { + if (output_zp != 0) + { + FATAL_ERROR_NODE("output tensor is symmetric type %s but zeropoint is %d instead of 0", + EnumNamesDType()[OutDtype], output_zp); + } + } + + // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn] + Eigen::array shape_2d; + shape_2d[0] = 1; + if (Rank > 0) + { + for (int i = 0; i < Rank - 1; i++) + { + shape_2d[0] *= this->in->getShape()[i]; + } + shape_2d[1] = this->in->getShape()[Rank - 1]; + } + else + { + shape_2d[1] = 1; + } + ETensor2 input_reshaped = this->in->getTensor().reshape(shape_2d); + + ETensor2 output_2d(shape_2d); + + // TODO: pass scale32 in when 16-bit mode implemented + if (per_channel) + { + ETensor2 curr_channel_slice_prescaled; + ETensor2 curr_channel_slice_postscaled; + int32_t channel_multiplier, channel_shift; + Eigen::array begin, size; + size = Eigen::array({ shape_2d[0], 1 }); + for (int32_t i = 0; i < shape_2d[1]; i++) + { + begin = Eigen::array({ 0, i }); + curr_channel_slice_prescaled = input_reshaped.slice(begin, size); + channel_multiplier = multiplier[i]; + channel_shift = shift[i]; + curr_channel_slice_postscaled = + curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, + double_round](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled = TosaReference::QuantUtil::apply_scale( + input_zp_shifted, channel_multiplier, channel_shift, double_round); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max(out_val, QMin); + out_val = std::min(out_val, QMax); + return out_val; + }); + + for (int32_t j = 0; j < shape_2d[0]; j++) + { + output_2d(j, i) = curr_channel_slice_postscaled(j, 0); + } + } + } + else + { + int32_t tensor_multiplier = multiplier[0]; + int32_t tensor_shift = shift[0]; + output_2d = input_reshaped.unaryExpr( + [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled = TosaReference::QuantUtil::apply_scale(input_zp_shifted, tensor_multiplier, + tensor_shift, double_round); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max(out_val, QMin); + out_val = std::min(out_val, QMax); + return out_val; + }); + } + + // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn] + Eigen::array output_shape; + for (int i = 0; i < Rank; i++) + { + output_shape[i] = this->out->getShape()[i]; + } + this->out->getTensor() = output_2d.reshape(output_shape); + + return GraphNode::eval(); +} + +template +OpCast::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_CAST, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); +} + +template +OpCast::~OpCast() +{} + +template +int OpCast::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same rank and size + if (inputs[0]->matchRankSize(*outputs[0])) + { + printNodeValidationError("OpCast: input and output rank/size must match"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(in && out); + + return 0; +} + +template +int OpCast::eval() +{ + this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn()); + + return GraphNode::eval(); +} + +template +CastHelper::CastHelper() +{ + fcn = [](InEigenType in) -> OutEigenType { + OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t) + int64_t mask = (1L << OutBits) - 1; + out = out & mask; + return out; + }; +} + +template +CastHelper::CastHelper() +{ + fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; }; +} + +template +CastHelper::CastHelper() +{ + fcn = [](bool in) -> OutEigenType { + OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0; + return out; + }; +} + +template +CastHelper::CastHelper() +{ + fcn = [](InEigenType in) -> float { + float out = (OutEigenType)in; // default cast to float is round_to_nearest_float() + return out; + }; +} + +template +CastHelper::CastHelper() +{ + fcn = [](float in) -> OutEigenType { + OutEigenType out = std::round(in); + out = std::max(out, OutMin); + out = std::min(out, OutMax); + return out; + }; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h new file mode 100644 index 0000000..6ec4d6d --- /dev/null +++ b/reference_model/src/ops/type_conversion.h @@ -0,0 +1,162 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_TYPE_CONVERSION_H +#define OPS_TYPE_CONVERSION_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ +template +class OpRescale : public GraphNode +{ +public: + OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpRescale(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + + static constexpr int32_t QMin = GetQMin::value; + static constexpr int32_t QMax = GetQMax::value; + +protected: + TosaRescaleAttribute* attribute; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutBits = GetNumBits::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class OpCast : public GraphNode +{ +public: + OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpCast(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + CastHelper cast_helper; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; +}; + +}; // namespace TosaReference + +#endif -- cgit v1.2.1