aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathad01 <matthew.haddon@arm.com>2021-04-20 16:12:45 +0100
committerJan Eilers <jan.eilers@arm.com>2021-04-22 08:21:43 +0000
commitbf7edb619cd5cf0ae84342299abe7c27f3ba6e7d (patch)
treeaace117c1b6de99b257552be74b8d8dc2b39718d
parente11e63d749b0909f13f9a39c8d34ef5523255170 (diff)
downloadarmnn-bf7edb619cd5cf0ae84342299abe7c27f3ba6e7d.tar.gz
IVGCVSW-5418 ExecuteNetwork test for MobileBERT
* Removed check in TfLiteParser and Delegate that requires both weights and biases to be constant or non-constant simultaneously * Updated TfLiteParser FullyConnected layer test to properly use non-constant weights * MobileBERT Float32 model now runs on TfLiteParser Signed-off-by: mathad01 <matthew.haddon@arm.com> Change-Id: I1d75eea466caa90cd695ad353160362df2f69483
-rw-r--r--delegate/src/FullyConnected.hpp9
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp19
-rw-r--r--src/armnnTfLiteParser/test/FullyConnected.cpp7
3 files changed, 11 insertions, 24 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index 2b45c48a89..e94304fb21 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -77,15 +77,6 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
{
return kTfLiteError;
}
- if ((isConstantWeights && !tflite::IsConstantTensor(&tfLiteBiasTensor))
- || (!isConstantWeights && tflite::IsConstantTensor(&tfLiteBiasTensor)))
- {
- TF_LITE_MAYBE_KERNEL_LOG(
- tfLiteContext,
- "TfLiteArmnnDelegate: Weights and bias are not compatible"
- " in operator #%d node #%d: ", operatorCode, nodeIndex);
- return kTfLiteError;
- }
biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
}
else
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 9b1fa9075c..5070d5b22f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -2374,15 +2374,6 @@ void TfLiteParserImpl::ParseFullyConnected(size_t subgraphIndex, size_t operator
desc.m_ConstantWeights = IsConstTensor(inputs[1]);
- // Either both weights and biases need to be inputs or both weights and biases need to be constant
- if (inputs.size() == 3 && desc.m_ConstantWeights != IsConstTensor(inputs[2]))
- {
- throw ParseException(
- fmt::format("Weights and bias are not compatible."
- "Node {}",
- CHECK_LOCATION().AsString()));
- }
-
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
std::vector<unsigned int> tensorIndexesToRegister = {inputTensorIndexes[0]};
if (desc.m_ConstantWeights)
@@ -3600,7 +3591,15 @@ TfLiteParserImpl::CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr b
bool TfLiteParserImpl::IsConstTensor(TensorRawPtr tensorPtr)
{
CHECK_TENSOR_PTR(tensorPtr);
- return !tensorPtr->is_variable;
+ bool isConst = true;
+
+ auto buffer = GetBuffer(m_Model, tensorPtr->buffer);
+ if (buffer->data.size() == 0)
+ {
+ isConst = false;
+ }
+
+ return isConst;
}
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp
index 333e17fafd..1ce1b2f74f 100644
--- a/src/armnnTfLiteParser/test/FullyConnected.cpp
+++ b/src/armnnTfLiteParser/test/FullyConnected.cpp
@@ -224,7 +224,7 @@ struct FullyConnectedNonConstWeightsFixture : public ParserFlatbuffersFixture
"is_variable": true
}, )";
- biasBuffer = R"(,{ "data": [ 10, 0, 0, 0 ] } )";
+ biasBuffer = R"(,{ "data": [] } )";
outputs = "3";
}
m_JsonString = R"(
@@ -250,7 +250,6 @@ struct FullyConnectedNonConstWeightsFixture : public ParserFlatbuffersFixture
"details_type": 0,
"quantized_dimension": 0
},
- "is_variable": false
},
{
"shape": )" + filterShape + R"(,
@@ -263,7 +262,6 @@ struct FullyConnectedNonConstWeightsFixture : public ParserFlatbuffersFixture
"details_type": 0,
"quantized_dimension": 0
},
- "is_variable": true
},
)" + biasTensor + R"(
{
@@ -281,7 +279,6 @@ struct FullyConnectedNonConstWeightsFixture : public ParserFlatbuffersFixture
"details_type": 0,
"quantized_dimension": 0
},
- "is_variable": false
}
],
"inputs": )" + inputTensors + R"(,
@@ -309,7 +306,7 @@ struct FullyConnectedNonConstWeightsFixture : public ParserFlatbuffersFixture
"data": []
},
{
- "data": [ 2, 3, 4, 5 ]
+ "data": []
}
)" + biasBuffer + R"(
]