diff options
Diffstat (limited to 'include/armnn/QuantizedLstmParams.hpp')
-rw-r--r-- | include/armnn/QuantizedLstmParams.hpp | 218 |
1 files changed, 218 insertions, 0 deletions
diff --git a/include/armnn/QuantizedLstmParams.hpp b/include/armnn/QuantizedLstmParams.hpp new file mode 100644 index 0000000000..b3033acc9a --- /dev/null +++ b/include/armnn/QuantizedLstmParams.hpp @@ -0,0 +1,218 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "TensorFwd.hpp" +#include "Exceptions.hpp" + +namespace armnn +{ + +struct QuantizedLstmInputParams +{ + QuantizedLstmInputParams() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + { + } + + const ConstTensor* m_InputToInputWeights; + const ConstTensor* m_InputToForgetWeights; + const ConstTensor* m_InputToCellWeights; + const ConstTensor* m_InputToOutputWeights; + + const ConstTensor* m_RecurrentToInputWeights; + const ConstTensor* m_RecurrentToForgetWeights; + const ConstTensor* m_RecurrentToCellWeights; + const ConstTensor* m_RecurrentToOutputWeights; + + const ConstTensor* m_InputGateBias; + const ConstTensor* m_ForgetGateBias; + const ConstTensor* m_CellBias; + const ConstTensor* m_OutputGateBias; + + const ConstTensor& deref(const ConstTensor* tensorPtr) const + { + if (tensorPtr != nullptr) + { + const ConstTensor &temp = *tensorPtr; + return temp; + } + throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer"); + } + + const ConstTensor& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + + const ConstTensor& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + + const ConstTensor& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + + const ConstTensor& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + + const ConstTensor& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + + const ConstTensor& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + + const ConstTensor& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + + const ConstTensor& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + + const ConstTensor& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + + const ConstTensor& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + + const ConstTensor& get_CellBias() const + { + return deref(m_CellBias); + } + + const ConstTensor& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } +}; + +struct QuantizedLstmInputParamsInfo +{ + QuantizedLstmInputParamsInfo() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + { + } + + const TensorInfo* m_InputToInputWeights; + const TensorInfo* m_InputToForgetWeights; + const TensorInfo* m_InputToCellWeights; + const TensorInfo* m_InputToOutputWeights; + + const TensorInfo* m_RecurrentToInputWeights; + const TensorInfo* m_RecurrentToForgetWeights; + const TensorInfo* m_RecurrentToCellWeights; + const TensorInfo* m_RecurrentToOutputWeights; + + const TensorInfo* m_InputGateBias; + const TensorInfo* m_ForgetGateBias; + const TensorInfo* m_CellBias; + const TensorInfo* m_OutputGateBias; + + + const TensorInfo& deref(const TensorInfo* tensorInfo) const + { + if (tensorInfo != nullptr) + { + const TensorInfo &temp = *tensorInfo; + return temp; + } + throw InvalidArgumentException("Can't dereference a null pointer"); + } + + const TensorInfo& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + const TensorInfo& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + const TensorInfo& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + const TensorInfo& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + + const TensorInfo& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + const TensorInfo& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + const TensorInfo& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + const TensorInfo& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + + const TensorInfo& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + const TensorInfo& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + const TensorInfo& get_CellBias() const + { + return deref(m_CellBias); + } + const TensorInfo& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } +}; + +} // namespace armnn + |