aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/shape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/shape.cc')
-rw-r--r--reference_model/src/ops/shape.cc198
1 files changed, 198 insertions, 0 deletions
diff --git a/reference_model/src/ops/shape.cc b/reference_model/src/ops/shape.cc
new file mode 100644
index 0000000..b087dd8
--- /dev/null
+++ b/reference_model/src/ops/shape.cc
@@ -0,0 +1,198 @@
+// Copyright (c) 2023-2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "shape.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpConstShape::OpConstShape(SubgraphTraverser* sgt_, uint64_t id_)
+ : GraphNode(sgt_, Op_CONST, id_)
+{
+ setRequiredOperands(0, 1);
+}
+
+OpConstShape::~OpConstShape()
+{}
+
+int OpConstShape::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ return 0;
+}
+
+int OpConstShape::eval()
+{
+ for (auto ct : getOutputs())
+ {
+ if (!ct->getIsValid())
+ {
+ std::string err = "Constant Shape tensor " + ct->getName() + " not correctly initialized";
+ printNodeValidationError(err.c_str());
+ return 1;
+ }
+ }
+
+ // Evaluation is trivial for constants
+ return GraphNode::eval();
+}
+
+OpConcatShape::OpConcatShape(SubgraphTraverser* sgt_, uint64_t id_)
+ : GraphNode(sgt_, Op_CONCAT_SHAPE, id_)
+{
+ setRequiredOperands(-1, 1);
+ setRequiredRank(1, 1);
+}
+
+OpConcatShape::~OpConcatShape()
+{}
+
+int OpConcatShape::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (inputs.empty())
+ {
+ printNodeValidationError("ConcatShape operator must have at least one input tensor");
+ return 1;
+ }
+
+ int32_t num_inputs = inputs.size();
+ int32_t elements_count = 0;
+ for (int32_t i = 0; i < num_inputs; i++)
+ {
+ if (validateRequiredRank(inputs[i]))
+ return 1;
+ ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
+ elements_count += inputs[i]->getShape()[0];
+ }
+
+ ERROR_IF(elements_count != outputs[0]->getShape()[0],
+ "OpConcatShape: sum of input elements not equal to output number of elements");
+
+ num_dims = outputs[0]->getShape()[0];
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ return 0;
+}
+
+int OpConcatShape::eval()
+{
+ ETensor1<EigenType> out_tensor(num_dims);
+ int32_t out_idx = 0;
+ for (size_t i = 0; i < ins.size(); i++)
+ {
+ // all tosa.shape values are 1-d tensors
+ // interate in_idx in range of [0, rank of 1-d tensor]
+ for (int32_t in_idx = 0; in_idx < inputs[i]->getShape()[0]; in_idx++)
+ {
+ out_tensor(out_idx) = ins[i]->getTensor()(in_idx);
+ out_idx++;
+ }
+ }
+ out->getTensor() = out_tensor;
+ return GraphNode::eval();
+}
+
+ShapeBinaryNodeBase::ShapeBinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
+ : GraphNode(sgt_, op_, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(1, 1);
+
+ fcn = [](EigenType a, EigenType b) -> EigenType { return EigenType(); };
+}
+
+ShapeBinaryNodeBase::~ShapeBinaryNodeBase()
+{}
+
+int ShapeBinaryNodeBase::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ return 1;
+
+ num_dims = outputs[0]->getShape()[0];
+
+ if (inputs[0]->getShape()[0] != num_dims)
+ {
+ std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) +
+ " lhs input and output rank/shape must match";
+ printNodeValidationError(err.c_str());
+ return 1;
+ }
+
+ if (inputs[1]->getShape()[0] != num_dims)
+ {
+ std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) +
+ " rhs input and output rank/shape must match";
+ printNodeValidationError(err.c_str());
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(a && b && result);
+
+ return 0;
+}
+
+int ShapeBinaryNodeBase::eval()
+{
+ auto ia = a->getTensor();
+ auto ib = b->getTensor();
+ ETensor1<EigenType> out_tens(num_dims);
+ for (int32_t i = 0; i < num_dims; i++)
+ {
+ EigenType lhs = ia(i);
+ EigenType rhs = ib(i);
+ out_tens(i) = (lhs < 0 || rhs < 0) ? static_cast<EigenType>(-1) : fcn(lhs, rhs);
+ }
+
+ result->getTensor() = out_tens;
+ return GraphNode::eval();
+}
+
+int OpAddShape::register_fcn()
+{
+ fcn = [](EigenType a, EigenType b) -> EigenType { return a + b; };
+ return 0;
+}
+
+int OpSubShape::register_fcn()
+{
+ fcn = [](EigenType a, EigenType b) -> EigenType { return a - b; };
+ return 0;
+}
+
+int OpMulShape::register_fcn()
+{
+ fcn = [](EigenType a, EigenType b) -> EigenType { return a * b; };
+ return 0;
+}
+
+int OpDivShape::register_fcn()
+{
+ fcn = [](EigenType a, EigenType b) -> EigenType {
+ return (b == static_cast<EigenType>(0)) ? static_cast<EigenType>(-1) : (a / b);
+ };
+ return 0;
+} \ No newline at end of file