aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r--reference_model/src/ops/tensor_ops.h24
1 files changed, 12 insertions, 12 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 9ef4a58..df53f2b 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -24,7 +24,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpArgMax : public GraphNode
{
public:
@@ -35,7 +35,7 @@ public:
virtual int eval();
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_INT32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_INT32>::type;
using TIn = Eigen::Tensor<InEigenType, Rank>;
using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
@@ -45,7 +45,7 @@ protected:
TosaReference::TensorTemplate<TOut>* output;
};
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
class OpAvgPool2d : public GraphNode
{
public:
@@ -74,7 +74,7 @@ protected:
ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpConv2d : public GraphNode
{
public:
@@ -104,7 +104,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpConv3d : public GraphNode
{
public:
@@ -134,7 +134,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
@@ -164,7 +164,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpFullyConnected : public GraphNode
{
public:
@@ -195,7 +195,7 @@ protected:
tosa::TosaFullyConnectedAttribute* attribute;
};
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
class OpMatMul : public GraphNode
{
public:
@@ -227,7 +227,7 @@ protected:
tosa::TosaMatMulAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpMaxPool2d : public GraphNode
{
public:
@@ -248,7 +248,7 @@ protected:
tosa::TosaPoolAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpFFT2d : public GraphNode
{
public:
@@ -271,7 +271,7 @@ protected:
tosa::TosaFFTAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpRFFT2d : public GraphNode
{
public:
@@ -292,7 +292,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out_imag;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpTransposeConv2d : public GraphNode
{
public: