diff options
author | James Conroy <james.conroy@arm.com> | 2019-07-17 11:27:46 +0100 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-07-24 10:40:13 +0100 |
commit | ee18dc8d1725f472850ab0c398fd7cbc4b850891 (patch) | |
tree | b57738b18781d512f5438ca5154652571393e4e8 /include/armnn/QuantizedLstmParams.hpp | |
parent | 7b1845206d723a91aec811edaf7cb0cf832dfd25 (diff) | |
download | armnn-ee18dc8d1725f472850ab0c398fd7cbc4b850891.tar.gz |
IVGCVSW-3469 Add front end for Quantized LSTM layer
* Added new layer QuantizedLstm (Android Q)
* Made necessary changes to APIs
* Added unit tests
Change-Id: I3b9f16b0e7e49f51932cf204c87cb7118798123a
Signed-off-by: James Conroy <james.conroy@arm.com>
Diffstat (limited to 'include/armnn/QuantizedLstmParams.hpp')
-rw-r--r-- | include/armnn/QuantizedLstmParams.hpp | 218 |
1 files changed, 218 insertions, 0 deletions
diff --git a/include/armnn/QuantizedLstmParams.hpp b/include/armnn/QuantizedLstmParams.hpp new file mode 100644 index 0000000000..b3033acc9a --- /dev/null +++ b/include/armnn/QuantizedLstmParams.hpp @@ -0,0 +1,218 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "TensorFwd.hpp" +#include "Exceptions.hpp" + +namespace armnn +{ + +struct QuantizedLstmInputParams +{ + QuantizedLstmInputParams() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + { + } + + const ConstTensor* m_InputToInputWeights; + const ConstTensor* m_InputToForgetWeights; + const ConstTensor* m_InputToCellWeights; + const ConstTensor* m_InputToOutputWeights; + + const ConstTensor* m_RecurrentToInputWeights; + const ConstTensor* m_RecurrentToForgetWeights; + const ConstTensor* m_RecurrentToCellWeights; + const ConstTensor* m_RecurrentToOutputWeights; + + const ConstTensor* m_InputGateBias; + const ConstTensor* m_ForgetGateBias; + const ConstTensor* m_CellBias; + const ConstTensor* m_OutputGateBias; + + const ConstTensor& deref(const ConstTensor* tensorPtr) const + { + if (tensorPtr != nullptr) + { + const ConstTensor &temp = *tensorPtr; + return temp; + } + throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer"); + } + + const ConstTensor& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + + const ConstTensor& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + + const ConstTensor& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + + const ConstTensor& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + + const ConstTensor& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + + const ConstTensor& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + + const ConstTensor& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + + const ConstTensor& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + + const ConstTensor& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + + const ConstTensor& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + + const ConstTensor& get_CellBias() const + { + return deref(m_CellBias); + } + + const ConstTensor& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } +}; + +struct QuantizedLstmInputParamsInfo +{ + QuantizedLstmInputParamsInfo() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + { + } + + const TensorInfo* m_InputToInputWeights; + const TensorInfo* m_InputToForgetWeights; + const TensorInfo* m_InputToCellWeights; + const TensorInfo* m_InputToOutputWeights; + + const TensorInfo* m_RecurrentToInputWeights; + const TensorInfo* m_RecurrentToForgetWeights; + const TensorInfo* m_RecurrentToCellWeights; + const TensorInfo* m_RecurrentToOutputWeights; + + const TensorInfo* m_InputGateBias; + const TensorInfo* m_ForgetGateBias; + const TensorInfo* m_CellBias; + const TensorInfo* m_OutputGateBias; + + + const TensorInfo& deref(const TensorInfo* tensorInfo) const + { + if (tensorInfo != nullptr) + { + const TensorInfo &temp = *tensorInfo; + return temp; + } + throw InvalidArgumentException("Can't dereference a null pointer"); + } + + const TensorInfo& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + const TensorInfo& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + const TensorInfo& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + const TensorInfo& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + + const TensorInfo& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + const TensorInfo& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + const TensorInfo& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + const TensorInfo& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + + const TensorInfo& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + const TensorInfo& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + const TensorInfo& get_CellBias() const + { + return deref(m_CellBias); + } + const TensorInfo& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } +}; + +} // namespace armnn + |