aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2019-12-19 15:45:35 +0000
committerDerek Lamberti <derek.lamberti@arm.com>2019-12-19 16:05:25 +0000
commit57ea6d15d3d32d263192e94bc67302e96cf1178f (patch)
treeeb6a4f526149affb9d306f5d07143b6f3fd5b508
parent5dd88ab93e644c525a21ba820456e87cf3a6fb22 (diff)
downloadandroid-nn-driver-57ea6d15d3d32d263192e94bc67302e96cf1178f.tar.gz
IVGCVSW-4301 Correctly validate reshape for broadcastable inputs
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com> Change-Id: I4db6ea4ed0a192c85f124c4a9ced60b1666a3870
-rw-r--r--1.2/HalPolicy.cpp6
-rw-r--r--ConversionUtils.hpp12
2 files changed, 9 insertions, 9 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 2665ea9a..97449c0b 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -555,7 +555,7 @@ bool HalPolicy::ConvertMaximum(const Operation& operation, const Model& model, C
armnn::IConnectableLayer* layer = data.m_Network->AddMaximumLayer();
assert(layer != nullptr);
- bool isReshapeSupported = BroadcastTensor(input0, input1, outInfo, layer, data);
+ bool isReshapeSupported = BroadcastTensor(input0, input1, layer, data);
if (!isReshapeSupported)
{
return false;
@@ -610,7 +610,7 @@ bool HalPolicy::ConvertMinimum(const Operation& operation, const Model& model, C
armnn::IConnectableLayer* const layer = data.m_Network->AddMinimumLayer();
assert(layer != nullptr);
- bool isReshapeSupported = BroadcastTensor(input0, input1, outputInfo, layer, data);
+ bool isReshapeSupported = BroadcastTensor(input0, input1, layer, data);
if (!isReshapeSupported)
{
return false;
@@ -773,7 +773,7 @@ bool HalPolicy::ConvertPrelu(const Operation& operation, const Model& model, Con
return Fail("%s: AddPreluLayer failed", __func__);
}
- bool isReshapeSupported = BroadcastTensor(input, alpha, outputInfo, layer, data);
+ bool isReshapeSupported = BroadcastTensor(input, alpha, layer, data);
if (!isReshapeSupported)
{
return false;
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index bbd2f07a..afaf1af7 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -232,7 +232,7 @@ armnn::IConnectableLayer& AddReshapeLayer(armnn::INetwork& network, LayerHandleT
return *reshapeLayer;
}
-bool BroadcastTensor(LayerInputHandle& input0, LayerInputHandle& input1, const armnn::TensorInfo& outputInfo,
+bool BroadcastTensor(LayerInputHandle& input0, LayerInputHandle& input1,
armnn::IConnectableLayer* startLayer, ConversionData& data)
{
BOOST_ASSERT(startLayer != nullptr);
@@ -282,8 +282,8 @@ bool BroadcastTensor(LayerInputHandle& input0, LayerInputHandle& input1, const a
IsReshapeSupported,
data.m_Backends,
isSupported,
+ smallInfo,
reshapedInfo,
- outputInfo,
reshapeDescriptor);
if (!isSupported)
{
@@ -1555,7 +1555,7 @@ bool ConvertAdd(const Operation& operation, const Model& model, ConversionData&
if (endLayer != nullptr)
{
- bool isReshapeSupported = BroadcastTensor(input0, input1, outputInfo, startLayer, data);
+ bool isReshapeSupported = BroadcastTensor(input0, input1, startLayer, data);
if (!isReshapeSupported)
{
return false;
@@ -2219,7 +2219,7 @@ bool ConvertDiv(const Operation& operation, const Model& model, ConversionData&
if (endLayer)
{
- bool isReshapeSupported = BroadcastTensor(input0, input1, outputInfo, startLayer, data);
+ bool isReshapeSupported = BroadcastTensor(input0, input1, startLayer, data);
if (!isReshapeSupported)
{
return false;
@@ -2665,7 +2665,7 @@ bool ConvertMul(const Operation& operation, const Model& model, ConversionData&
if (endLayer != nullptr)
{
- bool isReshapeSupported = BroadcastTensor(input0, input1, outputInfo, startLayer, data);
+ bool isReshapeSupported = BroadcastTensor(input0, input1, startLayer, data);
if (!isReshapeSupported)
{
return false;
@@ -2874,7 +2874,7 @@ bool ConvertSub(const Operation& operation, const Model& model, ConversionData&
if (endLayer)
{
- bool isReshapeSupported = BroadcastTensor(input0, input1, outputInfo, startLayer, data);
+ bool isReshapeSupported = BroadcastTensor(input0, input1, startLayer, data);
if (!isReshapeSupported)
{
return false;