diff options
Diffstat (limited to 'src/backends/gpuFsa/GpuFsaLayerSupport.cpp')
-rw-r--r-- | src/backends/gpuFsa/GpuFsaLayerSupport.cpp | 28 |
1 files changed, 7 insertions, 21 deletions
diff --git a/src/backends/gpuFsa/GpuFsaLayerSupport.cpp b/src/backends/gpuFsa/GpuFsaLayerSupport.cpp index 56af9c4d68..1ee80c964f 100644 --- a/src/backends/gpuFsa/GpuFsaLayerSupport.cpp +++ b/src/backends/gpuFsa/GpuFsaLayerSupport.cpp @@ -13,8 +13,7 @@ #include "layers/GpuFsaCast.hpp" #include "layers/GpuFsaConvolution2d.hpp" #include "layers/GpuFsaDepthwiseConvolution2d.hpp" -#include "layers/GpuFsaElementwiseBinaryAdd.hpp" -#include "layers/GpuFsaElementwiseBinarySub.hpp" +#include "layers/GpuFsaElementwiseBinary.hpp" #include "layers/GpuFsaPooling2d.hpp" #include "layers/GpuFsaResize.hpp" #endif @@ -150,28 +149,15 @@ bool GpuFsaLayerSupport::IsLayerSupported(const LayerType& type, if (infos.size() != 3) { throw InvalidArgumentException("Invalid number of ElementwiseBinary TensorInfos. " - "TensorInfos should be of format: {input0, input1, output0}."); + "TensorInfos should be of format: {input0, input1, output}."); } auto desc = PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&descriptor); - if (desc->m_Operation == BinaryOperation::Add) - { - FORWARD_LAYER_VALIDATE_FUNC(GpuFsaElementwiseBinaryAddValidate, - reasonIfUnsupported, - infos[0], - infos[1]); - } - else if (desc->m_Operation == BinaryOperation::Sub) - { - FORWARD_LAYER_VALIDATE_FUNC(GpuFsaElementwiseBinarySubValidate, - reasonIfUnsupported, - infos[0], - infos[1]); - } - else - { - throw InvalidArgumentException("Invalid ElementwiseBinary BinaryOperation operation."); - } + FORWARD_LAYER_VALIDATE_FUNC(GpuFsaElementwiseBinaryValidate, + reasonIfUnsupported, + infos[0], + infos[1], + *desc); } case LayerType::Pooling2d: { |