// Copyright (c) 2020-2021, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef OP_TEMPLATE_TYPES_H #define OP_TEMPLATE_TYPES_H #include "tosa_generated.h" #include using namespace tosa; namespace TosaReference { // Shorter aliase templates for common Eigen::Tensor types template using ETensor0 = Eigen::Tensor; template using ETensor1 = Eigen::Tensor; template using ETensor2 = Eigen::Tensor; template using ETensor3 = Eigen::Tensor; template using ETensor4 = Eigen::Tensor; template using ETensor5 = Eigen::Tensor; template using ETensor6 = Eigen::Tensor; // Forward declaration template class TensorTemplate; // Shortcut to hide the TensorTemplate class. // For example, declare Tensor1 to get a TensorTemplate // with an Eigen::Tensor template using Tensor0 = TensorTemplate>; template using Tensor1 = TensorTemplate>; template using Tensor2 = TensorTemplate>; template using Tensor3 = TensorTemplate>; template using Tensor4 = TensorTemplate>; template using Tensor5 = TensorTemplate>; template using Tensor6 = TensorTemplate>; template struct GetEigenType; template <> struct GetEigenType { using type = float; }; template <> struct GetEigenType { using type = int32_t; }; template <> struct GetEigenType { using type = int64_t; }; template <> struct GetEigenType { using type = bool; }; template <> struct GetEigenType { using type = int32_t; }; template <> struct GetEigenType { using type = int32_t; }; template <> struct GetEigenType { using type = int32_t; }; template <> struct GetEigenType { using type = int32_t; }; // Meta function to get number of bits template struct GetNumBits { static constexpr int32_t value = 0; }; template <> struct GetNumBits { static constexpr int32_t value = 1; }; template <> struct GetNumBits { static constexpr int32_t value = 8; }; template <> struct GetNumBits { static constexpr int32_t value = 4; }; template <> struct GetNumBits { static constexpr int32_t value = 8; }; template <> struct GetNumBits { static constexpr int32_t value = 16; }; template <> struct GetNumBits { static constexpr int32_t value = 32; }; template <> struct GetNumBits { static constexpr int32_t value = 48; }; // Meta function to get quantized min/max in compile time template struct GetQMin { static constexpr int64_t value = 0L; }; template <> struct GetQMin { static constexpr int64_t value = 0L; }; template <> struct GetQMin { static constexpr int64_t value = -8L; }; template <> struct GetQMin { static constexpr int64_t value = -128L; }; template <> struct GetQMin { static constexpr int64_t value = -32768L; }; template <> struct GetQMin { static constexpr int64_t value = -(1L << 31); }; template <> struct GetQMin { static constexpr int64_t value = -(1L << 47); }; template struct GetQMax { static constexpr int64_t value = 0L; }; template <> struct GetQMax { static constexpr int64_t value = 255L; }; template <> struct GetQMax { static constexpr int64_t value = 7L; }; template <> struct GetQMax { static constexpr int64_t value = 127L; }; template <> struct GetQMax { static constexpr int64_t value = 32767L; }; template <> struct GetQMax { static constexpr int64_t value = (1L << 31) - 1; }; template <> struct GetQMax { static constexpr int64_t value = (1L << 47) - 1; }; template struct GetAccDType; template <> struct GetAccDType { static constexpr DType value = DType_INT32; }; template <> struct GetAccDType { static constexpr DType value = DType_INT32; }; template <> struct GetAccDType { static constexpr DType value = DType_INT48; }; template <> struct GetAccDType { static constexpr DType value = DType_INT48; }; template <> struct GetAccDType { static constexpr DType value = DType_FLOAT; }; }; // namespace TosaReference #endif