aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-04-23 16:16:21 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:17 +0000
commitad0c7388f6261989a268ffb2d042f2bd80736e3f (patch)
tree84a0f1accc9a7c4b820f150e4265525c08a67ccf /src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp
parent1ed442a9b4024741860106cd96f5f7535a38fd04 (diff)
downloadComputeLibrary-ad0c7388f6261989a268ffb2d042f2bd80736e3f.tar.gz
COMPMID-1068 Create validate method to CLDepthWiseConvolution
Change-Id: I3301b66a8a072c6ecd0d7f2dabef350017b55ac4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128677 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp47
1 files changed, 37 insertions, 10 deletions
diff --git a/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp b/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp
index 83fc168f45..26336ebf79 100644
--- a/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp
@@ -34,6 +34,34 @@
using namespace arm_compute;
+namespace
+{
+TensorShape compute_output_shape(const TensorShape &input, size_t conv_w, size_t conv_h)
+{
+ TensorShape output_shape(input);
+ output_shape.set(0, conv_w);
+ output_shape.set(1, conv_h);
+ output_shape.set(2, input.x() / (conv_w * conv_h));
+
+ return output_shape;
+}
+
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
+
+ if(output->total_size() != 0)
+ {
+ TensorShape output_shape = compute_output_shape(input->tensor_shape(), conv_w, conv_h);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
+ }
+
+ return Status{};
+}
+} // namespace
+
CLDepthwiseVectorToTensorKernel::CLDepthwiseVectorToTensorKernel()
: _input(nullptr), _output(nullptr)
{
@@ -41,20 +69,13 @@ CLDepthwiseVectorToTensorKernel::CLDepthwiseVectorToTensorKernel()
void CLDepthwiseVectorToTensorKernel::configure(const ICLTensor *input, ICLTensor *output, size_t conv_w, size_t conv_h)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_NULLPTR(output);
-
- TensorShape output_shape = input->info()->tensor_shape();
- output_shape.set(0, conv_w);
- output_shape.set(1, conv_h);
- output_shape.set(2, input->info()->tensor_shape()[0] / (conv_w * conv_h));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output auto inizialitation if not yet initialized
+ TensorShape output_shape = compute_output_shape(input->info()->tensor_shape(), conv_w, conv_h);
auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), conv_w, conv_h));
_input = input;
_output = output;
@@ -75,6 +96,12 @@ void CLDepthwiseVectorToTensorKernel::configure(const ICLTensor *input, ICLTenso
ICLKernel::configure(win);
}
+Status CLDepthwiseVectorToTensorKernel::validate(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, conv_w, conv_h));
+ return Status{};
+}
+
void CLDepthwiseVectorToTensorKernel::run(const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);