From 81beae3a870004795275e9266bc43d845b9f78db Mon Sep 17 00:00:00 2001 From: Matthew Sloyan Date: Tue, 13 Jul 2021 19:46:11 +0100 Subject: IVGCVSW-6119 ConstTensorsAsInput: FullyConnected * Constant weights and biases are now stored as Constant layers. * Updated Serializer, Deserializer and unit tests to reflect this. * Updated TfLiteDelegate, TfLiteParser and OnnxParser. * Updated Schema with IsConstant and ConstantTensorsAsInputs. * Updated Ref backend to handle constant weights and bias as inputs rather than reading from member variables. * Added dynamic or constant input EndToEnd tests. !android-nn-driver:5959 Signed-off-by: Matthew Sloyan Change-Id: Ibf3cf437df1100e4b322b0d303c575c6339f9696 --- src/backends/reference/test/RefCreateWorkloadTests.cpp | 18 ++++++++++++++++++ src/backends/reference/test/RefEndToEndTests.cpp | 12 +++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) (limited to 'src/backends/reference/test') diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 4293ef54f3..fae8d0cdd4 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -486,6 +486,24 @@ TEST_CASE("RefCreateFullyConnectedWithBlobWorkloadTest") TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale)); } +TEST_CASE("CreateFullyConnectedWorkloadWeightsBiasesAsInputsFloat32") +{ + Graph graph; + RefWorkloadFactory factory = GetFactory(); + + auto workload = + CreateFullyConnectedWorkloadWeightsBiasesAsInputsTest(factory, graph); + + // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest). + float inputsQScale = 0.0f; + float outputQScale = 0.0f; + CheckInputsOutput(std::move(workload), + TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale), + TensorInfo({ 7, 20 }, armnn::DataType::Float32, inputsQScale), + TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale)); +} + template static void RefCreateFullyConnectedWorkloadTest() { diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 69a2048078..424df977c8 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -600,11 +600,21 @@ TEST_CASE("RefFillEndToEndTestInt32") FillEndToEnd(defaultBackends); } -TEST_CASE("RefFullyConnectedEndToEndTestInt32") +TEST_CASE("RefFullyConnectedEndToEndTestFloat32") { FullyConnectedWithDynamicWeightsEndToEnd(defaultBackends); } +TEST_CASE("RefFullyConnectedEndToEndTestNonConstantWeightsConstantBiasesFloat32") +{ + FullyConnectedWithDynamicOrConstantInputsEndToEnd(defaultBackends, true, true); +} + +TEST_CASE("RefFullyConnectedEndToEndTestConstantWeightsNonConstantBiasesFloat32") +{ + FullyConnectedWithDynamicOrConstantInputsEndToEnd(defaultBackends, true, false); +} + TEST_CASE("RefGatherFloatTest") { GatherEndToEnd(defaultBackends); -- cgit v1.2.1