aboutsummaryrefslogtreecommitdiff
path: root/src/backends/gpuFsa/GpuFsaLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/gpuFsa/GpuFsaLayerSupport.cpp')
-rw-r--r--src/backends/gpuFsa/GpuFsaLayerSupport.cpp28
1 files changed, 7 insertions, 21 deletions
diff --git a/src/backends/gpuFsa/GpuFsaLayerSupport.cpp b/src/backends/gpuFsa/GpuFsaLayerSupport.cpp
index 56af9c4d68..1ee80c964f 100644
--- a/src/backends/gpuFsa/GpuFsaLayerSupport.cpp
+++ b/src/backends/gpuFsa/GpuFsaLayerSupport.cpp
@@ -13,8 +13,7 @@
#include "layers/GpuFsaCast.hpp"
#include "layers/GpuFsaConvolution2d.hpp"
#include "layers/GpuFsaDepthwiseConvolution2d.hpp"
-#include "layers/GpuFsaElementwiseBinaryAdd.hpp"
-#include "layers/GpuFsaElementwiseBinarySub.hpp"
+#include "layers/GpuFsaElementwiseBinary.hpp"
#include "layers/GpuFsaPooling2d.hpp"
#include "layers/GpuFsaResize.hpp"
#endif
@@ -150,28 +149,15 @@ bool GpuFsaLayerSupport::IsLayerSupported(const LayerType& type,
if (infos.size() != 3)
{
throw InvalidArgumentException("Invalid number of ElementwiseBinary TensorInfos. "
- "TensorInfos should be of format: {input0, input1, output0}.");
+ "TensorInfos should be of format: {input0, input1, output}.");
}
auto desc = PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&descriptor);
- if (desc->m_Operation == BinaryOperation::Add)
- {
- FORWARD_LAYER_VALIDATE_FUNC(GpuFsaElementwiseBinaryAddValidate,
- reasonIfUnsupported,
- infos[0],
- infos[1]);
- }
- else if (desc->m_Operation == BinaryOperation::Sub)
- {
- FORWARD_LAYER_VALIDATE_FUNC(GpuFsaElementwiseBinarySubValidate,
- reasonIfUnsupported,
- infos[0],
- infos[1]);
- }
- else
- {
- throw InvalidArgumentException("Invalid ElementwiseBinary BinaryOperation operation.");
- }
+ FORWARD_LAYER_VALIDATE_FUNC(GpuFsaElementwiseBinaryValidate,
+ reasonIfUnsupported,
+ infos[0],
+ infos[1],
+ *desc);
}
case LayerType::Pooling2d:
{