diff options
author | James Ward <james.ward@arm.com> | 2022-08-12 20:48:56 +0100 |
---|---|---|
committer | James Ward <james.ward@arm.com> | 2022-10-11 11:56:02 +0100 |
commit | 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch) | |
tree | fea519246b698eb944b9d58537fc90bc30481d11 /reference_model/src/ops/tensor_ops.h | |
parent | ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff) | |
download | reference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz |
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 81 |
1 files changed, 38 insertions, 43 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 24eadeb..fd6dd25 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate<TOut>* output; }; -template <DType Dtype> +template <DType Dtype, DType AccDtype> class OpAvgPool2d : public GraphNode { public: @@ -55,9 +55,8 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value; using InEigenType = typename GetEigenType<Dtype>::type; - using AccEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<Dtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TOut = Eigen::Tensor<OutEigenType, 4>; @@ -75,7 +74,7 @@ protected: ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template <DType InDtype, DType WeightDtype> +template <DType InDtype, DType WeightDtype, DType AccDtype> class OpConv2d : public GraphNode { public: @@ -85,15 +84,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; - using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<AccDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; - using TBias = Eigen::Tensor<AccEigenType, 1>; - using TAcc = Eigen::Tensor<AccEigenType, 4>; + using TBias = Eigen::Tensor<OutEigenType, 1>; + using TOut = Eigen::Tensor<OutEigenType, 4>; static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; @@ -102,11 +100,11 @@ protected: TosaReference::TensorTemplate<TIn>* input; TosaReference::TensorTemplate<TWeight>* weight; TosaReference::TensorTemplate<TBias>* bias; - TosaReference::TensorTemplate<TAcc>* output; + TosaReference::TensorTemplate<TOut>* output; tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype> +template <DType InDtype, DType WeightDtype, DType AccDtype> class OpConv3d : public GraphNode { public: @@ -116,15 +114,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; - using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<AccDtype>::type; using TIn = Eigen::Tensor<InEigenType, 5>; using TWeight = Eigen::Tensor<WeightEigenType, 5>; - using TBias = Eigen::Tensor<AccEigenType, 1>; - using TAcc = Eigen::Tensor<AccEigenType, 5>; + using TBias = Eigen::Tensor<OutEigenType, 1>; + using TOut = Eigen::Tensor<OutEigenType, 5>; static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; @@ -133,11 +130,11 @@ protected: TosaReference::TensorTemplate<TIn>* input; TosaReference::TensorTemplate<TWeight>* weight; TosaReference::TensorTemplate<TBias>* bias; - TosaReference::TensorTemplate<TAcc>* output; + TosaReference::TensorTemplate<TOut>* output; tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype> +template <DType InDtype, DType WeightDtype, DType AccDtype> class OpDepthwiseConv2d : public GraphNode { public: @@ -147,15 +144,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; - using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<AccDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; - using TBias = Eigen::Tensor<AccEigenType, 1>; - using TAcc = Eigen::Tensor<AccEigenType, 4>; + using TBias = Eigen::Tensor<OutEigenType, 1>; + using TOut = Eigen::Tensor<OutEigenType, 4>; static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; @@ -164,11 +160,11 @@ protected: TosaReference::TensorTemplate<TIn>* input; TosaReference::TensorTemplate<TWeight>* weight; TosaReference::TensorTemplate<TBias>* bias; - TosaReference::TensorTemplate<TAcc>* output; + TosaReference::TensorTemplate<TOut>* output; tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype> +template <DType InDtype, DType WeightDtype, DType AccDtype> class OpFullyConnected : public GraphNode { public: @@ -178,14 +174,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<AccDtype>::type; using TIn = Eigen::Tensor<InEigenType, 2>; using TWeight = Eigen::Tensor<WeightEigenType, 2>; - using TBias = Eigen::Tensor<AccEigenType, 1>; - using TAcc = Eigen::Tensor<AccEigenType, 2>; + using TBias = Eigen::Tensor<OutEigenType, 1>; + using TOut = Eigen::Tensor<OutEigenType, 2>; static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; @@ -194,12 +190,12 @@ protected: TosaReference::TensorTemplate<TIn>* input; TosaReference::TensorTemplate<TWeight>* weight; TosaReference::TensorTemplate<TBias>* bias; - TosaReference::TensorTemplate<TAcc>* output; + TosaReference::TensorTemplate<TOut>* output; tosa::TosaFullyConnectedAttribute* attribute; }; -template <DType Dtype> +template <DType Dtype, DType AccDtype> class OpMatMul : public GraphNode { public: @@ -209,11 +205,11 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value; using InEigenType = typename GetEigenType<Dtype>::type; - using AccEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<AccDtype>::type; using TIn = Eigen::Tensor<InEigenType, 3>; - using TAcc = Eigen::Tensor<AccEigenType, 3>; + using TOut = Eigen::Tensor<OutEigenType, 3>; using TInRank2 = Eigen::Tensor<InEigenType, 2>; using TAccRank2 = Eigen::Tensor<AccEigenType, 2>; static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; @@ -222,7 +218,7 @@ public: protected: TosaReference::TensorTemplate<TIn>* a; TosaReference::TensorTemplate<TIn>* b; - TosaReference::TensorTemplate<TAcc>* output; + TosaReference::TensorTemplate<TOut>* output; int64_t N; int64_t H; int64_t W; @@ -252,7 +248,7 @@ protected: tosa::TosaPoolAttribute* attribute; }; -template <DType InDtype, DType WeightDtype> +template <DType InDtype, DType WeightDtype, DType AccDtype> class OpTransposeConv2d : public GraphNode { public: @@ -262,15 +258,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; - using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<AccDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; - using TBias = Eigen::Tensor<AccEigenType, 1>; - using TAcc = Eigen::Tensor<AccEigenType, 4>; + using TBias = Eigen::Tensor<OutEigenType, 1>; + using TOut = Eigen::Tensor<OutEigenType, 4>; static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; @@ -279,7 +274,7 @@ protected: TosaReference::TensorTemplate<TIn>* input; TosaReference::TensorTemplate<TWeight>* weight; TosaReference::TensorTemplate<TBias>* bias; - TosaReference::TensorTemplate<TAcc>* output; + TosaReference::TensorTemplate<TOut>* output; TosaTransposeConvAttribute* attribute; }; |