diff options
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp')
-rw-r--r-- | src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp index ed68b1f6db..d7c54d9aad 100644 --- a/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp +++ b/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp @@ -17,12 +17,15 @@ void RefMultiplicationFloat32Workload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefMultiplicationFloat32Workload_Execute"); - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorShape& inShape0 = GetTensorInfo(m_Data.m_Inputs[0]).GetShape(); + const TensorShape& inShape1 = GetTensorInfo(m_Data.m_Inputs[1]).GetShape(); + const TensorShape& outShape = GetTensorInfo(m_Data.m_Outputs[0]).GetShape(); float* outputData = GetOutputTensorDataFloat(0, m_Data); const float* inputData0 = GetInputTensorDataFloat(0, m_Data); const float* inputData1 = GetInputTensorDataFloat(1, m_Data); - Multiplication(inputData0, inputData1, inputInfo0.GetNumElements(), outputData); + + Multiplication(inShape0, inShape1, outShape, inputData0, inputData1, outputData); } } //namespace armnn |