From 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 Mon Sep 17 00:00:00 2001 From: James Ward Date: Fri, 12 Aug 2022 20:48:56 +0100 Subject: Reference model changes for fp16 support Change-Id: I72f21fcfa153046274969d327313e3349981dbe6 Signed-off-by: James Ward --- reference_model/src/ops/tensor_ops.h | 81 +++++++++++++++++------------------- 1 file changed, 38 insertions(+), 43 deletions(-) (limited to 'reference_model/src/ops/tensor_ops.h') diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 24eadeb..fd6dd25 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate* output; }; -template +template class OpAvgPool2d : public GraphNode { public: @@ -55,9 +55,8 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; @@ -75,7 +74,7 @@ protected: ETensor1 calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template +template class OpConv2d : public GraphNode { public: @@ -85,15 +84,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -102,11 +100,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpConv3d : public GraphNode { public: @@ -116,15 +114,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -133,11 +130,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpDepthwiseConv2d : public GraphNode { public: @@ -147,15 +144,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -164,11 +160,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpFullyConnected : public GraphNode { public: @@ -178,14 +174,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -194,12 +190,12 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaFullyConnectedAttribute* attribute; }; -template +template class OpMatMul : public GraphNode { public: @@ -209,11 +205,11 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TOut = Eigen::Tensor; using TInRank2 = Eigen::Tensor; using TAccRank2 = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; @@ -222,7 +218,7 @@ public: protected: TosaReference::TensorTemplate* a; TosaReference::TensorTemplate* b; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; int64_t N; int64_t H; int64_t W; @@ -252,7 +248,7 @@ protected: tosa::TosaPoolAttribute* attribute; }; -template +template class OpTransposeConv2d : public GraphNode { public: @@ -262,15 +258,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -279,7 +274,7 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; TosaTransposeConvAttribute* attribute; }; -- cgit v1.2.1