aboutsummaryrefslogtreecommitdiff
path: root/1.1
diff options
context:
space:
mode:
authorsaoste01 <saoirse.stewart@arm.com>2018-10-10 09:44:51 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 17:25:39 +0100
commitb847148f5d97f0514098414c010f764da3f75af6 (patch)
treeb04840f72539eae882aefdb2bec9285237f42210 /1.1
parent07dedda93954646d9ed8bf3d468a89634f88f112 (diff)
downloadandroid-nn-driver-b847148f5d97f0514098414c010f764da3f75af6.tar.gz
IVGCVSW-1961: Add converter method for SQUEEZE to V1.1 section of HalPolicy
Change-Id: I15dffef32d394b13e57df134000b7dca4b8788af
Diffstat (limited to '1.1')
-rw-r--r--1.1/HalPolicy.cpp79
-rw-r--r--1.1/HalPolicy.hpp1
2 files changed, 79 insertions, 1 deletions
diff --git a/1.1/HalPolicy.cpp b/1.1/HalPolicy.cpp
index a94f3058..1b1c06ea 100644
--- a/1.1/HalPolicy.cpp
+++ b/1.1/HalPolicy.cpp
@@ -33,6 +33,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertMean(operation, model, data);
case V1_1::OperationType::PAD:
return ConvertPad(operation, model, data);
+ case V1_1::OperationType::SQUEEZE:
+ return ConvertSqueeze(operation, model, data);
default:
return Fail("%s: Operation type %s not supported in ArmnnDriver",
__func__, toString(operation.type).c_str());
@@ -272,5 +274,80 @@ bool HalPolicy::ConvertPad(const Operation& operation, const Model& model, Conve
return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data);
}
+bool HalPolicy::ConvertSqueeze(const Operation& operation, const Model& model, ConversionData& data)
+{
+ static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 };
+ LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
+
+ if (!input.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
+
+ unsigned int rank = inputInfo.GetNumDimensions();
+ if( rank > 4 )
+ {
+ Fail("%s: Inputs with rank greater than: %i are not supported", __func__, rank);
+ }
+
+ // NOTE: Axis is an optional parameter to SQUEEZE, therefore we do not want to generate a failure
+ // if the operand index is out of bounds.
+ const Operand* axisOperand = GetInputOperand(operation, 1, model, false);
+
+ std::vector<int32_t> axis;
+ if(!axisOperand)
+ {
+ axis.assign(dimensionSequence,
+ dimensionSequence+inputInfo.GetNumDimensions());
+ }
+ else
+ {
+ GetTensorInt32Values(*axisOperand, axis, model, data);
+ }
+
+ std::vector<uint32_t> outputDims;
+ for (auto& i : axis)
+ {
+ auto currentDimension = inputInfo.GetShape()[i];
+ bool skipSqueeze = (std::find(axis.begin(), axis.end(), i) == axis.end());
+
+ if (skipSqueeze || currentDimension != 1)
+ {
+ outputDims.push_back(currentDimension);
+ }
+ }
+
+ armnn::TensorShape outShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
+
+ armnn::TensorInfo outputInfo = inputInfo;
+ outputInfo.SetShape(outShape);
+
+ armnn::ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = outputInfo.GetShape();
+
+ const Operand* output = GetOutputOperand(operation, 0, model);
+ if (!output)
+ {
+ return Fail("%s: Could not read output 0", __func__);
+ }
+
+ if (!IsLayerSupported(__func__,
+ armnn::IsReshapeSupported,
+ data.m_Compute,
+ inputInfo))
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* const layer = data.m_Network->AddReshapeLayer(reshapeDesc);
+ assert(layer != nullptr);
+ input.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data);
+}
+
} // namespace hal_1_1
-} // namespace armnn_driver \ No newline at end of file
+} // namespace armnn_driver
diff --git a/1.1/HalPolicy.hpp b/1.1/HalPolicy.hpp
index a9189106..06cc5743 100644
--- a/1.1/HalPolicy.hpp
+++ b/1.1/HalPolicy.hpp
@@ -29,6 +29,7 @@ private:
static bool ConvertSub(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertMean(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertPad(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertSqueeze(const Operation& operation, const Model& model, ConversionData& data);
};
} // namespace hal_1_1