aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-05-19 14:10:30 +0100
committerSadik Armagan <sadik.armagan@arm.com>2020-05-19 14:10:30 +0100
commit813f23049d73177edfc1f1cff71147c39f4b695e (patch)
treecef15224e1e7a4bd8117bdc9f57a71be84e4325d /ConversionUtils.hpp
parentabc95d04dfb2462ffb42bc1facde4f45ecc65319 (diff)
downloadandroid-nn-driver-813f23049d73177edfc1f1cff71147c39f4b695e.tar.gz
IVGCVSW-4453 Add Support for ANEURALNETWORKS_QLSTM to HAL 1.3 Driver
* Add QLSTM support for Android NN Driver * Add overrideOutputInfo parameter to SetupAndTrackLayerOutputSlot * Add optional condition to GetInputScalar * Refactor Quantized 16 Bit LSTM impl Change-Id: Ie8fa98ad5ee4a62174ef91ca80f1df62b7fde937 Signed-off-by: Keith Davis <keith.davis@arm.com> Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp29
1 files changed, 21 insertions, 8 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 8313d045..5a111317 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -877,35 +877,40 @@ bool GetInputScalar(const HalOperation& operation,
HalOperandType type,
OutputType& outValue,
const HalModel& model,
- const ConversionData& data)
+ const ConversionData& data,
+ bool optional = false)
{
using HalOperand = typename HalPolicy::Operand;
const HalOperand* operand = GetInputOperand<HalPolicy>(operation, inputIndex, model);
- if (!operand)
+ if (!optional && !operand)
{
return Fail("%s: invalid input operand at index %i", __func__, inputIndex);
}
- if (operand->type != type)
+ if (!optional && operand->type != type)
{
return Fail("%s: unexpected operand type: %s (should be %s)",
__func__, toString(operand->type).c_str(), toString(type).c_str());
}
- if (operand->location.length != sizeof(OutputType))
+ if (!optional && operand->location.length != sizeof(OutputType))
{
return Fail("%s: incorrect operand location length: %i (should be %i)",
__func__, operand->location.length, sizeof(OutputType));
}
const void* valueAddress = GetOperandValueReadOnlyAddress<HalPolicy>(*operand, model, data);
- if (!valueAddress)
+ if (!optional && !valueAddress)
{
return Fail("%s: failed to get address for operand", __func__);
}
- outValue = *(static_cast<const OutputType*>(valueAddress));
+ if(!optional)
+ {
+ outValue = *(static_cast<const OutputType*>(valueAddress));
+ }
+
return true;
}
@@ -1374,7 +1379,8 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
armnn::IConnectableLayer& layer,
uint32_t layerOutputIndex,
const HalModel& model,
- ConversionData& data)
+ ConversionData& data,
+ const armnn::TensorInfo* overrideOutputInfo = nullptr)
{
using HalOperand = typename HalPolicy::Operand;
@@ -1389,7 +1395,14 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
const uint32_t operandIndex = operation.outputs[operationOutputIndex];
data.m_OutputSlotForOperand[operandIndex] = &outputSlot;
- outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
+ if (overrideOutputInfo == nullptr)
+ {
+ outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
+ }
+ else
+ {
+ outputSlot.SetTensorInfo(*overrideOutputInfo);
+ }
return true;
}