diff options
Diffstat (limited to 'src/backends/tosaReference')
3 files changed, 88 insertions, 0 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp index 5cda85af20..daa27f63dc 100644 --- a/src/backends/tosaReference/TosaRefLayerSupport.cpp +++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp @@ -303,6 +303,24 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, return RunTosaLayerChecksSingleDataType( op, inputs, outputs, supportedAttributes, supportedTypes, reasonIfUnsupported); } + case tosa::Op_SLICE: + { + std::vector<Attribute> supportedAttributes = { Attribute_SliceAttribute }; + + std::vector<DType> supportedTypes = + { + DType_FP16, + DType_FP32, + DType_INT8, + DType_INT16, + DType_INT32, + DType_BOOL + }; + + // Check the attribute, data types and bounds for inputs and outputs. + return RunTosaLayerChecksSingleDataType( + op, inputs, outputs, supportedAttributes, supportedTypes, reasonIfUnsupported); + } default: SetValueChecked(reasonIfUnsupported, "Operation is currently unsupported by the TOSA Reference Backend."); return false; @@ -351,6 +369,7 @@ bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type, } case LayerType::Pooling2d: case LayerType::Reshape: + case LayerType::Slice: // Setup inputs and outputs inputInfos.push_back(&infos[0]); outputInfos.push_back(&infos[1]); diff --git a/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp b/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp index aaf8a678e3..2f1231013a 100644 --- a/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp +++ b/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp @@ -9,6 +9,7 @@ #include "backendsCommon/test/Convolution2dEndToEndTestImpl.hpp" #include "backendsCommon/test/Pooling2dEndToEndTestImpl.hpp" #include "backendsCommon/test/ReshapeEndToEndTestImpl.hpp" +#include "backendsCommon/test/SliceEndToEndTestImpl.hpp" #include <doctest/doctest.h> @@ -91,4 +92,20 @@ TEST_CASE("TosaRefReshapeEndtoEndTestFloat16") ReshapeEndToEndFloat16<DataType::Float16>(tosaDefaultBackends); } +// Slice +TEST_CASE("TosaRefSliceEndtoEndTestFloat32") +{ + SliceEndToEnd<DataType::Float32>(tosaDefaultBackends); +} + +TEST_CASE("TosaRefSliceEndtoEndTestInt32") +{ + SliceEndToEnd<DataType::Signed32>(tosaDefaultBackends); +} + +TEST_CASE("TosaRefSliceEndtoEndTestFloat16") +{ + SliceEndToEndFloat16<DataType::Float16>(tosaDefaultBackends); +} + }
\ No newline at end of file diff --git a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp index 86b01d8d0c..a1bab83e72 100644 --- a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp +++ b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp @@ -329,4 +329,56 @@ TEST_CASE("IsLayerSupportedTosaReferenceReshapeUnsupported") "has an unsupported data type: DType_UNKNOWN") != std::string::npos); } +TEST_CASE("IsLayerSupportedTosaReferenceSlice") +{ + TensorShape inShape = {3,2,3}; + TensorShape outShape = {2,1,3}; + TensorInfo in(inShape, DataType::Float32); + TensorInfo out(outShape, DataType::Float32); + + SliceDescriptor descriptor; + descriptor.m_Begin = {1,0,0 }; + descriptor.m_Size = {2,1,3 }; + + TosaRefLayerSupport supportChecker; + std::string reasonIfNotSupported; + auto supported = supportChecker.IsLayerSupported(LayerType::Slice, + {in, out}, + descriptor, + EmptyOptional(), + EmptyOptional(), + reasonIfNotSupported); + + CHECK(supported); +} + +TEST_CASE("IsLayerSupportedTosaReferenceSliceUnsupported") +{ + TensorShape inShape = {3,2,3}; + TensorShape outShape = {2,1,3}; + TensorInfo in(inShape, DataType::Signed64); + TensorInfo out(outShape, DataType::Signed64); + + SliceDescriptor descriptor; + descriptor.m_Begin = {1,0,0}; + descriptor.m_Size = {2,1,3}; + + TosaRefLayerSupport supportChecker; + std::string reasonIfNotSupported; + auto supported = supportChecker.IsLayerSupported(LayerType::Slice, + {in, out}, + descriptor, + EmptyOptional(), + EmptyOptional(), + reasonIfNotSupported); + + CHECK(!supported); + REQUIRE(reasonIfNotSupported.find( + "TOSA Reference Operator: Op_SLICE for input: input0_") != std::string::npos); + REQUIRE(reasonIfNotSupported.find( + "TOSA Reference Operator: Op_SLICE for output: output0_") != std::string::npos); + REQUIRE(reasonIfNotSupported.find( + "has an unsupported data type: DType_UNKNOWN") != std::string::npos); +} + } |