From bac9b35df7c59f6b5b61e1d233a49bdb88a973ba Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Mon, 13 Jul 2020 13:40:24 +0100 Subject: IVGCVSW-4847, Github #393 Fix TfLite reshape operator * Change order of reading target shape. Checks built-in option first then input. Signed-off-by: Jan Eilers Change-Id: Iddc39188ebfb7f71e33c35847de7506a02d807af --- src/armnnTfLiteParser/TfLiteParser.cpp | 65 +++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index b1ec0e54c2..69430134df 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -2125,48 +2125,57 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]); CheckMatchingQuantization(inputTensorInfo, actualOutputTensorInfo, layerName, "Input 0", "Output 0"); + // Extracting new shape for the output + // There are two ways it can be passed + // * First is to define the target shape in the operator built-in options + // * Second is to pass it as a second input tensor std::vector targetShape; - if (inputs.size() > 1 && inputs[1] != nullptr) + bool targetShapeFound = false; + // Check if built-in options were given + if (options != nullptr) { - if (inputs[1]->is_variable) + // make sure the parameter is given + if (options->new_shape.empty() == false) { - ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported"); + targetShape = options->new_shape; + targetShapeFound = true; } + } - if (inputs[1]->shape.size() != 1) + // If there is no built-in option given or if the built-in new_shape parameter was empty + if (!targetShapeFound) + { + // Check for a second input tensor + if (inputs.size() > 1 && inputs[1] != nullptr) { - ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor"); - } + if (inputs[1]->is_variable) + { + ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported"); + } - if (inputs[1]->type != tflite::TensorType_INT32) - { - ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type"); - } + if (inputs[1]->shape.size() != 1) + { + ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor"); + } - auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer); - auto vals = reinterpret_cast(bufferPtr->data.data()); - for (int i=0; i < inputs[1]->shape[0]; i++) - { - targetShape.push_back(vals[i]); - } + if (inputs[1]->type != tflite::TensorType_INT32) + { + ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type"); + } - if (options != nullptr && - options->new_shape.empty() == false && - options->new_shape != targetShape) - { - ARMNN_THROW_PARSE_EXCEPTION("Target shape defined in reshape parameters and as input tensor but " - "the values do not match"); + // Extract target shape from input + auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + auto values = reinterpret_cast(bufferPtr->data.data()); + for (int i=0; i < inputs[1]->shape[0]; ++i) + { + targetShape.push_back(values[i]); + } } - } - else - { - if (options == nullptr) + else { ARMNN_THROW_PARSE_EXCEPTION("Target shape not defined in reshape parameters or input tensor. " "At least one method required"); } - - targetShape = options->new_shape; } armnn::TensorInfo reshapeOutputTensorInfo = -- cgit v1.2.1