aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-09-03 11:33:07 +0100
committerSadik Armagan <sadik.armagan@arm.com>2020-09-03 11:33:07 +0100
commitdbda4b7eb52ba3c271b06cc5172a40bf84aafde8 (patch)
treee195783e782222d0a62ccc468cd1b4b6181ccb1e
parent346e8119e51cd9a3fbb56eb6fa85ad48a256c8eb (diff)
downloadandroid-nn-driver-dbda4b7eb52ba3c271b06cc5172a40bf84aafde8.tar.gz
IVGCVSW-5272 'Update ConvertLstm function to use ShapeInferenceMethod'
* Enabled Dynamic Tensors on LSTM operator Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I9cae539559570a44088a986870d3d3e41aee9468
-rw-r--r--ConversionUtils.hpp40
-rw-r--r--ConversionUtils_1_2.hpp99
2 files changed, 66 insertions, 73 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index cdb57d1f..fe8e026e 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1399,7 +1399,8 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
const HalModel& model,
ConversionData& data,
const armnn::TensorInfo* overrideOutputInfo = nullptr,
- const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr)
+ const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr,
+ bool inferOutputShapes = false)
{
using HalOperand = typename HalPolicy::Operand;
@@ -1410,7 +1411,6 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
}
armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex);
-
if (overrideOutputInfo == nullptr)
{
outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
@@ -1420,32 +1420,30 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
outputSlot.SetTensorInfo(*overrideOutputInfo);
}
- // Type one dynamic tensors require the previous layer's output shape for inference
- if (!layer.GetInputSlot(0).GetConnection() &&
- IsDynamicTensor(outputSlot.GetTensorInfo()))
- {
- return false;
- }
-
bool isSupported = false;
- if (validateFunc &&
- layer.GetInputSlot(0).GetConnection() &&
- IsDynamicTensor(outputSlot.GetTensorInfo()))
+ if (validateFunc && (IsDynamicTensor(outputSlot.GetTensorInfo()) || inferOutputShapes))
{
+ // Type one dynamic tensors require the previous layer's output shape for inference
+ for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
+ {
+ if(!layer.GetInputSlot(inputSlotIndex).GetConnection())
+ {
+ return false;
+ }
+ }
// IsTensorInfoSet will infer the dynamic output shape
outputSlot.IsTensorInfoSet();
// Once the shape is inferred we can validate it
validateFunc(outputSlot.GetTensorInfo(), isSupported);
- if(!isSupported)
- {
- for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
- {
- layer.GetInputSlot(inputSlotIndex).GetConnection()->Disconnect(layer.GetInputSlot(inputSlotIndex));
- }
-
- return false;
- }
+ if(!isSupported)
+ {
+ for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
+ {
+ layer.GetInputSlot(inputSlotIndex).GetConnection()->Disconnect(layer.GetInputSlot(inputSlotIndex));
+ }
+ return false;
+ }
}
const uint32_t operandIndex = operation.outputs[operationOutputIndex];
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp
index 210caa1b..4e0ccb6e 100644
--- a/ConversionUtils_1_2.hpp
+++ b/ConversionUtils_1_2.hpp
@@ -2522,36 +2522,6 @@ bool ConvertLstm(const HalOperation& operation, const HalModel& model, Conversio
const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut);
const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
- // Check if the scratch buffer shape was initialized,
- // In some cases the shape could be (0,0) which requires the driver
- // to infer the shape and set it up accordingly.
- // The code below does that.
- TensorInfo fixSbInfo = scratchBufferInfo;
- if (IsDynamicTensor(scratchBufferInfo))
- {
- auto & s = fixSbInfo.GetShape();
- s[0] = outputStateInInfo.GetShape()[0];
- if (desc.m_CifgEnabled)
- {
- // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG
- s[1] = cellStateOutInfo.GetShape()[1]*3;
- }
- else
- {
- // scratch_buffer [num_units * 4, batch_size] without CIFG
- s[1] = cellStateOutInfo.GetShape()[1]*4;
- }
- }
-
- if (IsDynamicTensor(outputStateOutInfo) ||
- IsDynamicTensor(cellStateOutInfo) ||
- IsDynamicTensor(outputInfo))
- {
- return Fail("%s: Dynamic output tensors are not supported %d %d %d %d", __func__,
- IsDynamicTensor(scratchBufferInfo), IsDynamicTensor(outputStateOutInfo),
- IsDynamicTensor(cellStateOutInfo), IsDynamicTensor(outputInfo));
- }
-
// Basic parameters
LstmInputParamsInfo paramsInfo;
paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
@@ -2603,20 +2573,36 @@ bool ConvertLstm(const HalOperation& operation, const HalModel& model, Conversio
}
bool isSupported = false;
+ auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
+ {
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsLstmSupported,
+ data.m_Backends,
+ isSupported,
+ inputInfo,
+ outputStateInInfo,
+ cellStateInInfo,
+ scratchBufferInfo,
+ outputStateOutInfo,
+ cellStateOutInfo,
+ outputInfo,
+ desc,
+ paramsInfo);
+ };
- FORWARD_LAYER_SUPPORT_FUNC(__func__,
- IsLstmSupported,
- data.m_Backends,
- isSupported,
- inputInfo,
- outputStateInInfo,
- cellStateInInfo,
- fixSbInfo,
- outputStateOutInfo,
- cellStateOutInfo,
- outputInfo,
- desc,
- paramsInfo);
+ bool isDynamic = false;
+ if (!IsDynamicTensor(outputStateOutInfo) &&
+ !IsDynamicTensor(scratchBufferInfo) &&
+ !IsDynamicTensor(cellStateOutInfo) &&
+ !IsDynamicTensor(outputInfo))
+ {
+ validateFunc(outputInfo, isSupported);
+ }
+ else
+ {
+ isDynamic = true;
+ isSupported = AreDynamicTensorsSupported();
+ }
if (!isSupported)
{
@@ -2630,15 +2616,24 @@ bool ConvertLstm(const HalOperation& operation, const HalModel& model, Conversio
outputStateIn.Connect(layer->GetInputSlot(1));
cellStateIn.Connect(layer->GetInputSlot(2));
- return (
- (IsDynamicTensor(scratchBufferInfo)?
- SetupAndTrackLayerOutputSlotAndOverrideTensorInfo<HalPolicy>(
- operation, 0, *layer, 0, model, data,fixSbInfo):
- SetupAndTrackLayerOutputSlot<HalPolicy>(
- operation, 0, *layer, 0, model, data)) &&
- SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
- SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data) &&
- SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 3, *layer, 3, model, data));
+ if (!isDynamic)
+ {
+ return (
+ SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
+ SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
+ SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data) &&
+ SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 3, *layer, 3, model, data));
+ }
+ else
+ {
+ return (
+ SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
+ SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
+ SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data) &&
+ SetupAndTrackLayerOutputSlot<HalPolicy>(
+ operation, 3, *layer, 3, model, data, nullptr, validateFunc, true));
+ }
+
}
template<typename HalPolicy,