diff options
Diffstat (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp')
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 4293ef54f3..fae8d0cdd4 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -486,6 +486,24 @@ TEST_CASE("RefCreateFullyConnectedWithBlobWorkloadTest") TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale)); } +TEST_CASE("CreateFullyConnectedWorkloadWeightsBiasesAsInputsFloat32") +{ + Graph graph; + RefWorkloadFactory factory = GetFactory(); + + auto workload = + CreateFullyConnectedWorkloadWeightsBiasesAsInputsTest<RefFullyConnectedWorkload, + armnn::DataType::Float32>(factory, graph); + + // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest). + float inputsQScale = 0.0f; + float outputQScale = 0.0f; + CheckInputsOutput(std::move(workload), + TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale), + TensorInfo({ 7, 20 }, armnn::DataType::Float32, inputsQScale), + TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale)); +} + template <typename FullyConnectedWorkloadType, armnn::DataType DataType> static void RefCreateFullyConnectedWorkloadTest() { |