diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 34 |
1 files changed, 18 insertions, 16 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index fd4d408..373dfb8 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -184,26 +184,28 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16; - static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32; - using InEigenType = typename GetEigenType<InDtype>::type; - using TableEigenType = typename GetEigenType<TableDtype>::type; - using OutEigenType = typename GetEigenType<OutDtype>::type; - using TIn = Eigen::Tensor<InEigenType, Rank>; - using TTable = Eigen::Tensor<TableEigenType, 1>; - using TOut = Eigen::Tensor<OutEigenType, Rank>; - static constexpr int32_t IntegerBits = 9; - static constexpr int32_t FractionBits = 7; - static constexpr int32_t NumTableEntries = (1 << IntegerBits); - static constexpr int32_t QInMin = GetQMin<InDtype>::value; - static constexpr int32_t QInMax = GetQMax<InDtype>::value; - static constexpr int32_t QOutMin = GetQMin<OutDtype>::value; - static constexpr int32_t QOutMax = GetQMax<OutDtype>::value; + static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16; + static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32; + static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513; + using InEigenType = typename GetEigenType<InDtype>::type; + using TableEigenType = typename GetEigenType<TableDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TTable = Eigen::Tensor<TableEigenType, 1>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + static constexpr int32_t IntegerBits = 9; + static constexpr int32_t FractionBits = 7; + static constexpr int32_t NumTableEntries = (1 << IntegerBits); + static constexpr int32_t QInMin = GetQMin<InDtype>::value; + static constexpr int32_t QInMax = GetQMax<InDtype>::value; + static constexpr int32_t QOutMin = GetQMin<OutDtype>::value; + static constexpr int32_t QOutMax = GetQMax<OutDtype>::value; protected: TosaReference::TensorTemplate<TIn>* in; - TosaReference::TensorTemplate<TTable>* table; TosaReference::TensorTemplate<TOut>* out; + TosaTableAttribute* attribute; + std::array<TableEigenType, TableNumEntries> table; }; }; // namespace TosaReference |