aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r--reference_model/src/ops/ewise_binary.h36
1 files changed, 19 insertions, 17 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 020ddb5..5f6e531 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -38,7 +38,7 @@ namespace TosaReference
// the way of registering lambda + .binaryExpr() might sacrifice performance here
// but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...}
// needs to revisit if performance becomes a bottleneck here
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNodeBase : public GraphNode
{
public:
@@ -67,7 +67,7 @@ protected:
};
// primary class
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
{
public:
@@ -86,7 +86,7 @@ public:
};
// partial specialization for rank 0
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
{
public:
@@ -100,19 +100,19 @@ public:
};
#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \
- template <int Rank, DType Dtype> \
+ template <int Rank, TOSA_REF_TYPE Dtype> \
class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
{ \
public: \
- Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
- : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
register_fcn(); \
} \
- static constexpr DType InDtype = Dtype; \
- static constexpr DType OutDtype = Dtype; \
- using InEigenType = typename GetEigenType<InDtype>::type; \
- using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ static constexpr TOSA_REF_TYPE InDtype = Dtype; \
+ static constexpr TOSA_REF_TYPE OutDtype = Dtype; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
virtual int register_fcn(); \
};
@@ -133,7 +133,7 @@ DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB)
#undef DEF_TEMPLATE_BINARY_OP_DEFAULT
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
{
public:
@@ -154,7 +154,7 @@ protected:
TosaArithmeticRightShiftAttribute* attribute;
};
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
{
public:
@@ -175,7 +175,7 @@ protected:
TosaMulAttribute* attribute;
};
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
class OpTable : public GraphNode
{
public:
@@ -185,9 +185,11 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16;
- static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32;
- static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513;
+ static constexpr TOSA_REF_TYPE TableDtype =
+ (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT16;
+ static constexpr TOSA_REF_TYPE OutDtype =
+ (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT32;
+ static constexpr uint32_t TableNumEntries = (InDtype == TOSA_REF_TYPE_INT8) ? 256 : 513;
using InEigenType = typename GetEigenType<InDtype>::type;
using TableEigenType = typename GetEigenType<TableDtype>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;