diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 020ddb5..5f6e531 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ namespace TosaReference // the way of registering lambda + .binaryExpr() might sacrifice performance here // but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...} // needs to revisit if performance becomes a bottleneck here -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class BinaryNodeBase : public GraphNode { public: @@ -67,7 +67,7 @@ protected: }; // primary class -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype> { public: @@ -86,7 +86,7 @@ public: }; // partial specialization for rank 0 -template <DType InDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: @@ -100,19 +100,19 @@ public: }; #define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \ - template <int Rank, DType Dtype> \ + template <int Rank, TOSA_REF_TYPE Dtype> \ class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ - : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ } \ - static constexpr DType InDtype = Dtype; \ - static constexpr DType OutDtype = Dtype; \ - using InEigenType = typename GetEigenType<InDtype>::type; \ - using OutEigenType = typename GetEigenType<OutDtype>::type; \ + static constexpr TOSA_REF_TYPE InDtype = Dtype; \ + static constexpr TOSA_REF_TYPE OutDtype = Dtype; \ + using InEigenType = typename GetEigenType<InDtype>::type; \ + using OutEigenType = typename GetEigenType<OutDtype>::type; \ virtual int register_fcn(); \ }; @@ -133,7 +133,7 @@ DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB) #undef DEF_TEMPLATE_BINARY_OP_DEFAULT -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype> { public: @@ -154,7 +154,7 @@ protected: TosaArithmeticRightShiftAttribute* attribute; }; -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class OpMul : public BinaryNode<Rank, InDtype, OutDtype> { public: @@ -175,7 +175,7 @@ protected: TosaMulAttribute* attribute; }; -template <int Rank, DType InDtype> +template <int Rank, TOSA_REF_TYPE InDtype> class OpTable : public GraphNode { public: @@ -185,9 +185,11 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16; - static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32; - static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513; + static constexpr TOSA_REF_TYPE TableDtype = + (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT16; + static constexpr TOSA_REF_TYPE OutDtype = + (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT32; + static constexpr uint32_t TableNumEntries = (InDtype == TOSA_REF_TYPE_INT8) ? 256 : 513; using InEigenType = typename GetEigenType<InDtype>::type; using TableEigenType = typename GetEigenType<TableDtype>::type; using OutEigenType = typename GetEigenType<OutDtype>::type; |