From e5e2676409a936431f87d31fb74d825257b20804 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 13 Oct 2020 16:11:07 -0700 Subject: Initial checkin of TOSA reference_model and tests Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze --- reference_model/src/ops/data_layout.h | 216 ++++++++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 reference_model/src/ops/data_layout.h (limited to 'reference_model/src/ops/data_layout.h') diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h new file mode 100644 index 0000000..100bd6b --- /dev/null +++ b/reference_model/src/ops/data_layout.h @@ -0,0 +1,216 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_DATA_LAYOUT_H +#define OPS_DATA_LAYOUT_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template +class OpConcat : public GraphNode +{ +public: + OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpConcat(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array reverser; + TosaReference::TensorTemplate* lhs; + TosaReference::TensorTemplate* rhs; + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate* out; +}; + +template +class OpPad : public GraphNode +{ +public: + OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpPad(); + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array, Rank> paddings_array; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; + TosaPadQuantInfo* qinfo; +}; + +template +class OpReshape : public GraphNode +{ +public: + OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpReshape(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array array_shape; + Eigen::array in_reverser; + Eigen::array out_reverser; + TosaReference::TensorTemplate* in; + TosaReshapeAttribute* attribute; + TosaReference::TensorTemplate* out; +}; + +template +class OpReverse : public GraphNode +{ +public: + OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpReverse(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; + Eigen::array reverse_array; +}; + +template +class OpSlice : public GraphNode +{ +public: + OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpSlice(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaSliceAttribute* attribute; + Eigen::array begin_array; + Eigen::array size_array; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; +}; + +template +class OpTileBase : public GraphNode +{ +public: + OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTileBase(); + + virtual int checkTensorAttributes(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaTileAttribute* attribute; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out; +}; + +// primary template for op tile +template +class OpTile : public OpTileBase +{ +public: + OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpTileBase(attribute_, qinfo_, id_) + {} + +protected: + virtual int eval(); +}; + +// partial specialization for specific rank +#define DEF_OP_TILE_RANK(N) \ + template \ + class OpTile : public OpTileBase \ + { \ + public: \ + OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : OpTileBase(attribute_, qinfo_, id_) \ + {} \ + \ + protected: \ + virtual int eval(); \ + }; + +DEF_OP_TILE_RANK(1) +DEF_OP_TILE_RANK(2) +DEF_OP_TILE_RANK(3) +DEF_OP_TILE_RANK(4) + +#undef DEF_OP_TILE_RANK + +template +class OpTranspose : public GraphNode +{ +public: + OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTranspose(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + Eigen::array perm_array; + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate>* perm_tensor; + TosaReference::TensorTemplate* out; +}; +}; // namespace TosaReference + +#endif -- cgit v1.2.1