From e5e2676409a936431f87d31fb74d825257b20804 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 13 Oct 2020 16:11:07 -0700 Subject: Initial checkin of TOSA reference_model and tests Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze --- reference_model/src/ops/ewise_ternary.cc | 115 +++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 reference_model/src/ops/ewise_ternary.cc (limited to 'reference_model/src/ops/ewise_ternary.cc') diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc new file mode 100644 index 0000000..eded0d7 --- /dev/null +++ b/reference_model/src/ops/ewise_ternary.cc @@ -0,0 +1,115 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ewise_ternary.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template +OpSelectBase::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_SELECT, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(0, 6); +} + +template +OpSelectBase::~OpSelectBase() +{} + +template +int OpSelectBase::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) || + validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) || + inputs[2]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank and type"); + return 1; + } + + cond = dynamic_cast*>(inputs[0]); + then_val = dynamic_cast*>(inputs[1]); + else_val = dynamic_cast*>(inputs[2]); + out = dynamic_cast*>(outputs[0]); + + return 0; +} + +template +int OpSelectBase::eval() +{ + FATAL_ERROR_NODE("shouldn't be called"); +} + +template +int OpSelect::broadcast() +{ + std::vector cond_shape = this->cond->getShape(); + std::vector then_shape = this->then_val->getShape(); + std::vector else_shape = this->else_val->getShape(); + std::vector out_shape = this->out->getShape(); + + for (int i = 0; i < Rank; i++) + { + this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1; + this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1; + this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1; + ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + } + + return 0; +} + +template +int OpSelect::eval() +{ + this->broadcast(); + this->out->getTensor() = this->cond->getTensor() + .broadcast(this->bcast_cond) + .select(this->then_val->getTensor().broadcast(this->bcast_then), + this->else_val->getTensor().broadcast(this->bcast_else)); + + return GraphNode::eval(); +} + +template +int OpSelect<0, Dtype>::eval() +{ + this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor()); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); -- cgit v1.2.1