aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2019-11-28 15:21:41 +0000
committerPablo Tello <pablo.tello@arm.com>2019-12-13 17:34:31 +0000
commit972603fb38e2f5a75b2389bd8a265b6e3173a02f (patch)
treec006c89dd47e54ed8f41aad597e6a00fb4e178de
parent624fe9f02d4cf229cc7f33d9c970b39fcf560d1f (diff)
downloadandroid-nn-driver-972603fb38e2f5a75b2389bd8a265b6e3173a02f.tar.gz
MLCE-133: Driver infers scratchBuffer shape.
Change-Id: I7977d697772349b8ea7eb300937409ce0a3a4dee Signed-off-by: Pablo Tello <pablo.tello@arm.com>
-rw-r--r--1.2/HalPolicy.cpp64
1 files changed, 60 insertions, 4 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 60bbf1d5..ac78e96b 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -2101,6 +2101,36 @@ bool HalPolicy::ConvertTanH(const Operation& operation, const Model& model, Conv
return ::ConvertTanH<hal_1_2::HalPolicy>(operation, model, data);
}
+template<typename HalPolicy,
+ typename HalOperation = typename HalPolicy::Operation,
+ typename HalModel = typename HalPolicy::Model>
+bool SetupAndTrackLayerOutputSlotAndOverrideTensorInfo(const HalOperation& operation,
+ uint32_t operationOutputIndex,
+ armnn::IConnectableLayer& layer,
+ uint32_t layerOutputIndex,
+ const HalModel& model,
+ ConversionData& data,
+ const armnn::TensorInfo tensor_info)
+{
+ using HalOperand = typename HalPolicy::Operand;
+
+ const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, operationOutputIndex, model);
+ if ((outputOperand == nullptr) || (operationOutputIndex >= layer.GetNumOutputSlots()))
+ {
+ return false;
+ }
+
+ armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex);
+
+ const uint32_t operandIndex = operation.outputs[operationOutputIndex];
+ data.m_OutputSlotForOperand[operandIndex] = &outputSlot;
+
+ outputSlot.SetTensorInfo(tensor_info);
+
+ return true;
+}
+
+
bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_2::HalPolicy::ConvertLstm()");
@@ -2399,8 +2429,28 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv
const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut);
const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
- if (IsDynamicTensor(scratchBufferInfo) ||
- IsDynamicTensor(outputStateOutInfo) ||
+ // Check if the scratch buffer shape was initialized,
+ // In some cases the shape could be (0,0) which requires the driver
+ // to infer the shape and set it up accordingly.
+ // The code below does that.
+ TensorInfo fixSbInfo = scratchBufferInfo;
+ if (IsDynamicTensor(scratchBufferInfo))
+ {
+ auto & s = fixSbInfo.GetShape();
+ s[0] = outputStateInInfo.GetShape()[0];
+ if (desc.m_CifgEnabled)
+ {
+ // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG
+ s[1] = cellStateOutInfo.GetShape()[1]*3;
+ }
+ else
+ {
+ // scratch_buffer [num_units * 4, batch_size] without CIFG
+ s[1] = cellStateOutInfo.GetShape()[1]*4;
+ }
+ }
+
+ if (IsDynamicTensor(outputStateOutInfo) ||
IsDynamicTensor(cellStateOutInfo) ||
IsDynamicTensor(outputInfo))
{
@@ -2467,7 +2517,7 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv
inputInfo,
outputStateInInfo,
cellStateInInfo,
- scratchBufferInfo,
+ fixSbInfo,
outputStateOutInfo,
cellStateOutInfo,
outputInfo,
@@ -2485,7 +2535,13 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv
outputStateIn.Connect(layer->GetInputSlot(1));
cellStateIn.Connect(layer->GetInputSlot(2));
- return (SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, 0, model, data) &&
+
+ return (
+ (IsDynamicTensor(scratchBufferInfo)?
+ SetupAndTrackLayerOutputSlotAndOverrideTensorInfo<hal_1_2::HalPolicy>(
+ operation, 0, *layer, 0, model, data,fixSbInfo):
+ SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(
+ operation, 0, *layer, 0, model, data)) &&
SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 1, *layer, 1, model, data) &&
SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 2, *layer, 2, model, data) &&
SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 3, *layer, 3, model, data));