aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/NeonWorkloads/NeonFullyConnectedFloat32Workload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/NeonWorkloads/NeonFullyConnectedFloat32Workload.hpp')
-rw-r--r--src/armnn/backends/NeonWorkloads/NeonFullyConnectedFloat32Workload.hpp15
1 files changed, 12 insertions, 3 deletions
diff --git a/src/armnn/backends/NeonWorkloads/NeonFullyConnectedFloat32Workload.hpp b/src/armnn/backends/NeonWorkloads/NeonFullyConnectedFloat32Workload.hpp
index 9c722dc573..684b5e0753 100644
--- a/src/armnn/backends/NeonWorkloads/NeonFullyConnectedFloat32Workload.hpp
+++ b/src/armnn/backends/NeonWorkloads/NeonFullyConnectedFloat32Workload.hpp
@@ -14,7 +14,13 @@
namespace armnn
{
-class NeonFullyConnectedFloat32Workload : public Float32Workload<FullyConnectedQueueDescriptor>
+arm_compute::Status NeonFullyConnectedWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const TensorInfo& weights,
+ const TensorInfo& biases,
+ const FullyConnectedDescriptor& descriptor);
+
+class NeonFullyConnectedFloat32Workload : public FloatWorkload<FullyConnectedQueueDescriptor>
{
public:
NeonFullyConnectedFloat32Workload(const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info,
@@ -23,8 +29,11 @@ public:
private:
mutable arm_compute::NEFullyConnectedLayer m_FullyConnectedLayer;
- arm_compute::Tensor m_WeightsTensor;
- arm_compute::Tensor m_BiasesTensor;
+
+ std::unique_ptr<arm_compute::Tensor> m_WeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_BiasesTensor;
+
+ void FreeUnusedTensors();
};
} //namespace armnn