From 8690a0873fac28acccbb0acb15c16e8337e14df1 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Mon, 18 Dec 2023 20:40:24 +0000 Subject: [reference model] Add shape operators - fixed up reshape conformance tests to use shape input instead of attribute - fixed up tile conformance tests to use shape input instead of attribute - fixed output and output rank of dim op - allow rank 0 and rank 1 tensors for tosa.shape values (for shape = {}) - added initialization of rank 0 const_shape tensors (for shape = {}) - Update conformance tests to use new rescale attributes Signed-off-by: Tai Ly Signed-off-by: Won Jeon Change-Id: I6cce0d2a9ab066fe20a2abf9d2cfde3eb3d8c18b --- reference_model/src/ops/data_layout.cc | 55 +++------ reference_model/src/ops/data_layout.h | 16 +-- reference_model/src/ops/op_factory.cc | 16 ++- reference_model/src/ops/shape.cc | 198 +++++++++++++++++++++++++++++++++ reference_model/src/ops/shape.h | 120 ++++++++++++++++++++ 5 files changed, 358 insertions(+), 47 deletions(-) create mode 100644 reference_model/src/ops/shape.cc create mode 100644 reference_model/src/ops/shape.h (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index fa99d21..a4b4e0a 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -258,7 +258,7 @@ int OpDim::eval() int32_t axis = attribute->axis(); int64_t out_val = in->getShape()[axis]; - this->out->getTensor().setConstant(out_val); + this->out->getTensor().setValues({ out_val }); return GraphNode::eval(); } @@ -267,17 +267,12 @@ template OpReshape::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_RESHAPE, id_) { - setRequiredOperands(1, 1); - - INIT_ATTRIBUTE(Reshape); + setRequiredOperands(2, 1); } template OpReshape::~OpReshape() -{ - if (attribute) - delete attribute; -} +{} template int OpReshape::checkTensorAttributes() @@ -297,25 +292,17 @@ int OpReshape::checkTensorAttributes() return 1; } - // Check for unsupported -1 shape inferencing - for (int32_t d = 0; d < OutRank; d++) - { - auto curr_new_dim = attribute->new_shape()[d]; - ERROR_IF(curr_new_dim == -1, "OpReshape: inferred dimensions in output shape are unsupported") - } - ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(), "Input tensor size does not match output tensor size"); - for (uint32_t d = 0; d < OutRank; d++) - { - auto curr_new_dim = attribute->new_shape()[d]; - ERROR_IF(curr_new_dim != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape"); - } - in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); + // note: do not assert mem on shape input, because it may be {} for reshape to scalar + // and also, because the shape input is not actually used in eval() + + ASSERT_MEM(in && out) + return 0; } @@ -506,18 +493,13 @@ template OpTileBase::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TILE, id_) { - setRequiredOperands(1, 1); + setRequiredOperands(2, 1); setRequiredRank(1); - - INIT_ATTRIBUTE(Tile); } template OpTileBase::~OpTileBase() -{ - if (attribute) - delete attribute; -} +{} template int OpTileBase::checkTensorAttributes() @@ -541,23 +523,18 @@ int OpTileBase::checkTensorAttributes() return 1; } - in = dynamic_cast*>(inputs[0]); - out = dynamic_cast*>(outputs[0]); + in = dynamic_cast*>(inputs[0]); + multiples = dynamic_cast*>(inputs[1]); + out = dynamic_cast*>(outputs[0]); - ASSERT_MEM(in && out); + ASSERT_MEM(in && multiples && out); - if (attribute->multiples().size() != Rank) + if (multiples->getElementCount() != Rank) { printNodeValidationError("1D list 'multiples' must have size equal to input rank"); return 1; } - for (int32_t d = 0; d < Rank; d++) - { - ERROR_IF(in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d], - "Output shape not equal to input * multiples;") - } - return 0; } diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index 024f9a2..9341709 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -79,7 +79,7 @@ public: using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; - using TOut = Eigen::Tensor; + using TOut = Eigen::Tensor; protected: TosaReference::TensorTemplate* in; @@ -107,7 +107,6 @@ protected: Eigen::array in_reverser; Eigen::array out_reverser; TosaReference::TensorTemplate* in; - TosaReshapeAttribute* attribute; TosaReference::TensorTemplate* out; }; @@ -165,14 +164,17 @@ public: virtual int checkTensorAttributes(); - using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; - using TIn = Eigen::Tensor; - using TOut = Eigen::Tensor; + using InEigenType = typename GetEigenType::type; + using InEigenShapeType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TInMultiples = Eigen::Tensor; + using TOut = Eigen::Tensor; protected: TosaTileAttribute* attribute; TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* multiples; TosaReference::TensorTemplate* out; }; diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 34db903..af8332e 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ #include "image.h" #include "reduction.h" #include "scatter_gather.h" +#include "shape.h" #include "tensor_ops.h" #include "type_conversion.h" @@ -600,6 +601,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_WHILE_LOOP: return new OpWhileLoop(sgt, tsh, attribute, id); + case Op_CONST_SHAPE: + return new OpConstShape(sgt, id); + case Op_CONCAT_SHAPE: + return new OpConcatShape(sgt, id); + case Op_ADD_SHAPE: + return new OpAddShape(sgt, id); + case Op_SUB_SHAPE: + return new OpSubShape(sgt, id); + case Op_MUL_SHAPE: + return new OpMulShape(sgt, id); + case Op_DIV_SHAPE: + return new OpDivShape(sgt, id); + // Ops not recognized default: goto done; diff --git a/reference_model/src/ops/shape.cc b/reference_model/src/ops/shape.cc new file mode 100644 index 0000000..b087dd8 --- /dev/null +++ b/reference_model/src/ops/shape.cc @@ -0,0 +1,198 @@ +// Copyright (c) 2023-2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "shape.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +OpConstShape::OpConstShape(SubgraphTraverser* sgt_, uint64_t id_) + : GraphNode(sgt_, Op_CONST, id_) +{ + setRequiredOperands(0, 1); +} + +OpConstShape::~OpConstShape() +{} + +int OpConstShape::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + return 0; +} + +int OpConstShape::eval() +{ + for (auto ct : getOutputs()) + { + if (!ct->getIsValid()) + { + std::string err = "Constant Shape tensor " + ct->getName() + " not correctly initialized"; + printNodeValidationError(err.c_str()); + return 1; + } + } + + // Evaluation is trivial for constants + return GraphNode::eval(); +} + +OpConcatShape::OpConcatShape(SubgraphTraverser* sgt_, uint64_t id_) + : GraphNode(sgt_, Op_CONCAT_SHAPE, id_) +{ + setRequiredOperands(-1, 1); + setRequiredRank(1, 1); +} + +OpConcatShape::~OpConcatShape() +{} + +int OpConcatShape::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (inputs.empty()) + { + printNodeValidationError("ConcatShape operator must have at least one input tensor"); + return 1; + } + + int32_t num_inputs = inputs.size(); + int32_t elements_count = 0; + for (int32_t i = 0; i < num_inputs; i++) + { + if (validateRequiredRank(inputs[i])) + return 1; + ins.push_back(dynamic_cast*>(inputs[i])); + elements_count += inputs[i]->getShape()[0]; + } + + ERROR_IF(elements_count != outputs[0]->getShape()[0], + "OpConcatShape: sum of input elements not equal to output number of elements"); + + num_dims = outputs[0]->getShape()[0]; + out = dynamic_cast*>(outputs[0]); + + return 0; +} + +int OpConcatShape::eval() +{ + ETensor1 out_tensor(num_dims); + int32_t out_idx = 0; + for (size_t i = 0; i < ins.size(); i++) + { + // all tosa.shape values are 1-d tensors + // interate in_idx in range of [0, rank of 1-d tensor] + for (int32_t in_idx = 0; in_idx < inputs[i]->getShape()[0]; in_idx++) + { + out_tensor(out_idx) = ins[i]->getTensor()(in_idx); + out_idx++; + } + } + out->getTensor() = out_tensor; + return GraphNode::eval(); +} + +ShapeBinaryNodeBase::ShapeBinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) + : GraphNode(sgt_, op_, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(1, 1); + + fcn = [](EigenType a, EigenType b) -> EigenType { return EigenType(); }; +} + +ShapeBinaryNodeBase::~ShapeBinaryNodeBase() +{} + +int ShapeBinaryNodeBase::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + return 1; + + num_dims = outputs[0]->getShape()[0]; + + if (inputs[0]->getShape()[0] != num_dims) + { + std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) + + " lhs input and output rank/shape must match"; + printNodeValidationError(err.c_str()); + return 1; + } + + if (inputs[1]->getShape()[0] != num_dims) + { + std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) + + " rhs input and output rank/shape must match"; + printNodeValidationError(err.c_str()); + return 1; + } + + a = dynamic_cast*>(inputs[0]); + b = dynamic_cast*>(inputs[1]); + result = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(a && b && result); + + return 0; +} + +int ShapeBinaryNodeBase::eval() +{ + auto ia = a->getTensor(); + auto ib = b->getTensor(); + ETensor1 out_tens(num_dims); + for (int32_t i = 0; i < num_dims; i++) + { + EigenType lhs = ia(i); + EigenType rhs = ib(i); + out_tens(i) = (lhs < 0 || rhs < 0) ? static_cast(-1) : fcn(lhs, rhs); + } + + result->getTensor() = out_tens; + return GraphNode::eval(); +} + +int OpAddShape::register_fcn() +{ + fcn = [](EigenType a, EigenType b) -> EigenType { return a + b; }; + return 0; +} + +int OpSubShape::register_fcn() +{ + fcn = [](EigenType a, EigenType b) -> EigenType { return a - b; }; + return 0; +} + +int OpMulShape::register_fcn() +{ + fcn = [](EigenType a, EigenType b) -> EigenType { return a * b; }; + return 0; +} + +int OpDivShape::register_fcn() +{ + fcn = [](EigenType a, EigenType b) -> EigenType { + return (b == static_cast(0)) ? static_cast(-1) : (a / b); + }; + return 0; +} \ No newline at end of file diff --git a/reference_model/src/ops/shape.h b/reference_model/src/ops/shape.h new file mode 100644 index 0000000..38ecda8 --- /dev/null +++ b/reference_model/src/ops/shape.h @@ -0,0 +1,120 @@ +// Copyright (c) 2023-2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_SHAPES_H +#define OPS_SHAPES_H + +#include "graph_node.h" + +namespace TosaReference +{ + +class OpConstShape : public GraphNode +{ +public: + OpConstShape(SubgraphTraverser* sgt_, uint64_t id_); + virtual ~OpConstShape(); + + virtual int checkTensorAttributes(); + virtual int eval(); +}; + +class OpConcatShape : public GraphNode +{ +public: + OpConcatShape(SubgraphTraverser* sgt_, uint64_t id_); + virtual ~OpConcatShape(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using EigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + int32_t num_dims; // number of dimensions in concat_shape output + std::vector*> ins; + TosaReference::TensorTemplate* out; +}; + +class ShapeBinaryNodeBase : public GraphNode +{ +public: + ShapeBinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_); + virtual ~ShapeBinaryNodeBase(); + + virtual int checkTensorAttributes() final; + virtual int eval(); + virtual int register_fcn() = 0; + + using EigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + int32_t num_dims; // number of dimensions in shape op's result + std::function fcn; + TosaReference::TensorTemplate* a; + TosaReference::TensorTemplate* b; + TosaReference::TensorTemplate* result; +}; + +class OpAddShape : public ShapeBinaryNodeBase +{ +public: + OpAddShape(SubgraphTraverser* sgt_, uint64_t id_) + : ShapeBinaryNodeBase(sgt_, Op_ADD_SHAPE, id_) + { + register_fcn(); + } + virtual int register_fcn(); +}; + +class OpSubShape : public ShapeBinaryNodeBase +{ +public: + OpSubShape(SubgraphTraverser* sgt_, uint64_t id_) + : ShapeBinaryNodeBase(sgt_, Op_SUB_SHAPE, id_) + { + register_fcn(); + } + virtual int register_fcn(); +}; + +class OpMulShape : public ShapeBinaryNodeBase +{ +public: + OpMulShape(SubgraphTraverser* sgt_, uint64_t id_) + : ShapeBinaryNodeBase(sgt_, Op_MUL_SHAPE, id_) + { + register_fcn(); + } + virtual int register_fcn(); +}; + +class OpDivShape : public ShapeBinaryNodeBase +{ +public: + OpDivShape(SubgraphTraverser* sgt_, uint64_t id_) + : ShapeBinaryNodeBase(sgt_, Op_DIV_SHAPE, id_) + { + register_fcn(); + } + virtual int register_fcn(); +}; + +}; // namespace TosaReference + +#endif -- cgit v1.2.1