aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2020-07-13 13:40:24 +0100
committerJan Eilers <jan.eilers@arm.com>2020-07-14 15:28:51 +0100
commitbac9b35df7c59f6b5b61e1d233a49bdb88a973ba (patch)
tree798a14e6ffaabe18476b066bdd73e932924caad3
parent171ca7c6063dea09607a4bbe866ea9e94bccd831 (diff)
downloadarmnn-bac9b35df7c59f6b5b61e1d233a49bdb88a973ba.tar.gz
IVGCVSW-4847, Github #393 Fix TfLite reshape operator
* Change order of reading target shape. Checks built-in option first then input. Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: Iddc39188ebfb7f71e33c35847de7506a02d807af
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp65
1 files changed, 37 insertions, 28 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index b1ec0e54c2..69430134df 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -2125,48 +2125,57 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]);
CheckMatchingQuantization(inputTensorInfo, actualOutputTensorInfo, layerName, "Input 0", "Output 0");
+ // Extracting new shape for the output
+ // There are two ways it can be passed
+ // * First is to define the target shape in the operator built-in options
+ // * Second is to pass it as a second input tensor
std::vector<int32_t> targetShape;
- if (inputs.size() > 1 && inputs[1] != nullptr)
+ bool targetShapeFound = false;
+ // Check if built-in options were given
+ if (options != nullptr)
{
- if (inputs[1]->is_variable)
+ // make sure the parameter is given
+ if (options->new_shape.empty() == false)
{
- ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported");
+ targetShape = options->new_shape;
+ targetShapeFound = true;
}
+ }
- if (inputs[1]->shape.size() != 1)
+ // If there is no built-in option given or if the built-in new_shape parameter was empty
+ if (!targetShapeFound)
+ {
+ // Check for a second input tensor
+ if (inputs.size() > 1 && inputs[1] != nullptr)
{
- ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor");
- }
+ if (inputs[1]->is_variable)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported");
+ }
- if (inputs[1]->type != tflite::TensorType_INT32)
- {
- ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type");
- }
+ if (inputs[1]->shape.size() != 1)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor");
+ }
- auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
- auto vals = reinterpret_cast<const int32_t*>(bufferPtr->data.data());
- for (int i=0; i < inputs[1]->shape[0]; i++)
- {
- targetShape.push_back(vals[i]);
- }
+ if (inputs[1]->type != tflite::TensorType_INT32)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type");
+ }
- if (options != nullptr &&
- options->new_shape.empty() == false &&
- options->new_shape != targetShape)
- {
- ARMNN_THROW_PARSE_EXCEPTION("Target shape defined in reshape parameters and as input tensor but "
- "the values do not match");
+ // Extract target shape from input
+ auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+ auto values = reinterpret_cast<const int32_t*>(bufferPtr->data.data());
+ for (int i=0; i < inputs[1]->shape[0]; ++i)
+ {
+ targetShape.push_back(values[i]);
+ }
}
- }
- else
- {
- if (options == nullptr)
+ else
{
ARMNN_THROW_PARSE_EXCEPTION("Target shape not defined in reshape parameters or input tensor. "
"At least one method required");
}
-
- targetShape = options->new_shape;
}
armnn::TensorInfo reshapeOutputTensorInfo =