diff options
Diffstat (limited to 'src/armnn/test/TensorTest.cpp')
-rw-r--r-- | src/armnn/test/TensorTest.cpp | 42 |
1 files changed, 40 insertions, 2 deletions
diff --git a/src/armnn/test/TensorTest.cpp b/src/armnn/test/TensorTest.cpp index 1ecad503d4..8d8751f614 100644 --- a/src/armnn/test/TensorTest.cpp +++ b/src/armnn/test/TensorTest.cpp @@ -145,18 +145,56 @@ TEST_CASE("TensorVsConstTensor") const int immutableDatum = 3; armnn::Tensor uninitializedTensor; + uninitializedTensor.GetInfo().SetConstant(true); armnn::ConstTensor uninitializedTensor2; uninitializedTensor2 = uninitializedTensor; - armnn::Tensor t(TensorInfo(), &mutableDatum); - armnn::ConstTensor ct(TensorInfo(), &immutableDatum); + armnn::TensorInfo emptyTensorInfo; + emptyTensorInfo.SetConstant(true); + armnn::Tensor t(emptyTensorInfo, &mutableDatum); + armnn::ConstTensor ct(emptyTensorInfo, &immutableDatum); // Checks that both Tensor and ConstTensor can be passed as a ConstTensor. CheckTensor(t); CheckTensor(ct); } +TEST_CASE("ConstTensor_EmptyConstructorTensorInfoSet") +{ + armnn::ConstTensor t; + CHECK(t.GetInfo().IsConstant() == true); +} + +TEST_CASE("ConstTensor_TensorInfoNotConstantError") +{ + armnn::TensorInfo tensorInfo ({ 1 }, armnn::DataType::Float32); + std::vector<float> tensorData = { 1.0f }; + try + { + armnn::ConstTensor ct(tensorInfo, tensorData); + FAIL("InvalidArgumentException should have been thrown"); + } + catch(const InvalidArgumentException& exc) + { + CHECK(strcmp(exc.what(), "Invalid attempt to construct ConstTensor from non-constant TensorInfo.") == 0); + } +} + +TEST_CASE("PassTensorToConstTensor_TensorInfoNotConstantError") +{ + try + { + armnn::ConstTensor t = ConstTensor(Tensor()); + FAIL("InvalidArgumentException should have been thrown"); + } + catch(const InvalidArgumentException& exc) + { + CHECK(strcmp(exc.what(), "Invalid attempt to construct ConstTensor from " + "Tensor due to non-constant TensorInfo") == 0); + } +} + TEST_CASE("ModifyTensorInfo") { TensorInfo info; |