15 struct TensorInfoFixture
19 unsigned int sizes[] = {6,7,8,9};
20 m_TensorInfo =
TensorInfo(4, sizes, DataType::Float32);
22 ~TensorInfoFixture() {};
30 CHECK(listInitializedShape == m_TensorInfo.
GetShape());
36 CHECK(m_TensorInfo.
GetShape()[0] == 6);
37 CHECK(m_TensorInfo.
GetShape()[1] == 7);
38 CHECK(m_TensorInfo.
GetShape()[2] == 8);
39 CHECK(m_TensorInfo.
GetShape()[3] == 9);
45 CHECK(copyConstructed.GetNumDimensions() == 4);
46 CHECK(copyConstructed.GetShape()[0] == 6);
47 CHECK(copyConstructed.GetShape()[1] == 7);
48 CHECK(copyConstructed.GetShape()[2] == 8);
49 CHECK(copyConstructed.GetShape()[3] == 9);
55 CHECK(copyConstructed == m_TensorInfo);
61 unsigned int sizes[] = {2,3,4,5};
62 other =
TensorInfo(4, sizes, DataType::Float32);
64 CHECK(other != m_TensorInfo);
71 CHECK(copy == m_TensorInfo);
74 TEST_CASE(
"CopyNoQuantizationTensorInfo")
93 CHECK(infoA != infoB);
95 CHECK(infoA == infoB);
104 TEST_CASE(
"CopyDifferentQuantizationTensorInfo")
121 CHECK((infoA.
GetDataType() == DataType::QAsymmU8));
126 CHECK(infoA != infoB);
128 CHECK(infoA == infoB);
131 CHECK((infoA.
GetDataType() == DataType::QAsymmU8));
142 TEST_CASE(
"TensorVsConstTensor")
144 int mutableDatum = 2;
145 const int immutableDatum = 3;
151 uninitializedTensor2 = uninitializedTensor;
163 TEST_CASE(
"ConstTensor_EmptyConstructorTensorInfoSet")
169 TEST_CASE(
"ConstTensor_TensorInfoNotConstantError")
172 std::vector<float> tensorData = { 1.0f };
176 FAIL(
"InvalidArgumentException should have been thrown");
180 CHECK(strcmp(exc.
what(),
"Invalid attempt to construct ConstTensor from non-constant TensorInfo.") == 0);
184 TEST_CASE(
"PassTensorToConstTensor_TensorInfoNotConstantError")
189 FAIL(
"InvalidArgumentException should have been thrown");
193 CHECK(strcmp(exc.
what(),
"Invalid attempt to construct ConstTensor from " 194 "Tensor due to non-constant TensorInfo") == 0);
198 TEST_CASE(
"ModifyTensorInfo")
211 TEST_CASE(
"TensorShapeOperatorBrackets")
217 CHECK(shape[2] == 2);
219 CHECK(shape[2] == 20);
222 CHECK(constShape[2] == 2);
225 TEST_CASE(
"TensorInfoPerAxisQuantization")
228 TensorInfo tensorInfo0({ 1, 1 }, DataType::Float32, 2.0f, 1);
229 CHECK(!tensorInfo0.HasMultipleQuantizationScales());
230 CHECK(tensorInfo0.GetQuantizationScale() == 2.0f);
231 CHECK(tensorInfo0.GetQuantizationOffset() == 1);
232 CHECK(tensorInfo0.GetQuantizationScales()[0] == 2.0f);
233 CHECK(!tensorInfo0.GetQuantizationDim().has_value());
236 std::vector<float> perAxisScales{ 3.0f, 4.0f };
238 CHECK(tensorInfo0.HasMultipleQuantizationScales());
239 CHECK(tensorInfo0.GetQuantizationScales() == perAxisScales);
242 tensorInfo0.SetQuantizationScale(5.0f);
243 CHECK(!tensorInfo0.HasMultipleQuantizationScales());
244 CHECK(tensorInfo0.GetQuantizationScales()[0] == 5.0f);
248 CHECK(tensorInfo0.GetQuantizationDim().value() == 1);
251 perAxisScales = { 6.0f, 7.0f };
252 TensorInfo tensorInfo1({ 1, 1 }, DataType::Float32, perAxisScales, 1);
253 CHECK(tensorInfo1.HasMultipleQuantizationScales());
254 CHECK(tensorInfo1.GetQuantizationOffset() == 0);
255 CHECK(tensorInfo1.GetQuantizationScales() == perAxisScales);
256 CHECK(tensorInfo1.GetQuantizationDim().value() == 1);
259 TEST_CASE(
"TensorShape_scalar")
261 float mutableDatum = 3.1416f;
268 float scalarValue = *
reinterpret_cast<float*
>(tensor.GetMemoryArea());
269 CHECK_MESSAGE(mutableDatum == scalarValue,
"Scalar value is " << scalarValue);
274 CHECK(shape_equal == shape);
275 CHECK(shape_different != shape);
276 CHECK_MESSAGE(1 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
277 CHECK_MESSAGE(1 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
278 CHECK(
true == shape.GetDimensionSpecificity(0));
279 CHECK(shape.AreAllDimensionsSpecified());
280 CHECK(shape.IsAtLeastOneDimensionSpecified());
282 CHECK(1 == shape[0]);
283 CHECK(1 == tensor.GetShape()[0]);
284 CHECK(1 == tensor.GetInfo().GetShape()[0]);
287 float newMutableDatum = 42.f;
288 std::memcpy(tensor.GetMemoryArea(), &newMutableDatum,
sizeof(float));
289 scalarValue = *
reinterpret_cast<float*
>(tensor.GetMemoryArea());
290 CHECK_MESSAGE(newMutableDatum == scalarValue,
"Scalar value is " << scalarValue);
293 TEST_CASE(
"TensorShape_DynamicTensorType1_unknownNumberDimensions")
295 float mutableDatum = 3.1416f;
309 CHECK(shape_equal == shape);
310 CHECK(shape_different != shape);
313 TEST_CASE(
"TensorShape_DynamicTensorType1_unknownAllDimensionsSizes")
315 float mutableDatum = 3.1416f;
322 CHECK_MESSAGE(0 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
323 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
324 CHECK(
false == shape.GetDimensionSpecificity(0));
325 CHECK(
false == shape.GetDimensionSpecificity(1));
326 CHECK(
false == shape.GetDimensionSpecificity(2));
327 CHECK(!shape.AreAllDimensionsSpecified());
328 CHECK(!shape.IsAtLeastOneDimensionSpecified());
333 CHECK(shape_equal == shape);
334 CHECK(shape_different != shape);
337 TEST_CASE(
"TensorShape_DynamicTensorType1_unknownSomeDimensionsSizes")
339 std::vector<float> mutableDatum { 42.f, 42.f, 42.f,
347 CHECK_MESSAGE(6 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
348 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
349 CHECK(
true == shape.GetDimensionSpecificity(0));
350 CHECK(
false == shape.GetDimensionSpecificity(1));
351 CHECK(
true == shape.GetDimensionSpecificity(2));
352 CHECK(!shape.AreAllDimensionsSpecified());
353 CHECK(shape.IsAtLeastOneDimensionSpecified());
359 CHECK(2 == shape[0]);
360 CHECK(2 == tensor.GetShape()[0]);
361 CHECK(2 == tensor.GetInfo().GetShape()[0]);
364 CHECK(3 == shape[2]);
365 CHECK(3 == tensor.GetShape()[2]);
366 CHECK(3 == tensor.GetInfo().GetShape()[2]);
371 CHECK(shape_equal == shape);
372 CHECK(shape_different != shape);
375 TEST_CASE(
"TensorShape_DynamicTensorType1_transitionFromUnknownToKnownDimensionsSizes")
377 std::vector<float> mutableDatum { 42.f, 42.f, 42.f,
385 shape.SetNumDimensions(3);
387 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
388 CHECK(
false == shape.GetDimensionSpecificity(0));
389 CHECK(
false == shape.GetDimensionSpecificity(1));
390 CHECK(
false == shape.GetDimensionSpecificity(2));
391 CHECK(!shape.AreAllDimensionsSpecified());
392 CHECK(!shape.IsAtLeastOneDimensionSpecified());
395 shape.SetDimensionSize(0, 2);
396 shape.SetDimensionSize(2, 3);
397 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
398 CHECK_MESSAGE(6 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
399 CHECK(
true == shape.GetDimensionSpecificity(0));
400 CHECK(
false == shape.GetDimensionSpecificity(1));
401 CHECK(
true == shape.GetDimensionSpecificity(2));
402 CHECK(!shape.AreAllDimensionsSpecified());
403 CHECK(shape.IsAtLeastOneDimensionSpecified());
407 CHECK(2 == shape[0]);
408 CHECK(2 == tensor2.GetShape()[0]);
409 CHECK(2 == tensor2.GetInfo().GetShape()[0]);
415 CHECK(3 == shape[2]);
416 CHECK(3 == tensor2.GetShape()[2]);
417 CHECK(3 == tensor2.GetInfo().GetShape()[2]);
422 CHECK(shape_equal == shape);
423 CHECK(shape_different != shape);
426 shape.SetDimensionSize(1, 5);
427 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
428 CHECK_MESSAGE(30 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
429 CHECK(
true == shape.GetDimensionSpecificity(0));
430 CHECK(
true == shape.GetDimensionSpecificity(1));
431 CHECK(
true == shape.GetDimensionSpecificity(2));
432 CHECK(shape.AreAllDimensionsSpecified());
433 CHECK(shape.IsAtLeastOneDimensionSpecified());
436 TEST_CASE(
"Tensor_emptyConstructors")
439 CHECK_MESSAGE( 0 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
440 CHECK_MESSAGE( 0 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
442 CHECK( shape.AreAllDimensionsSpecified());
446 CHECK_MESSAGE( 0 == tensor.GetNumDimensions(),
"Number of dimensions is " << tensor.GetNumDimensions());
447 CHECK_MESSAGE( 0 == tensor.GetNumElements(),
"Number of elements is " << tensor.GetNumElements());
448 CHECK_MESSAGE( 0 == tensor.GetShape().GetNumDimensions(),
"Number of dimensions is " <<
449 tensor.GetShape().GetNumDimensions());
450 CHECK_MESSAGE( 0 == tensor.GetShape().GetNumElements(),
"Number of dimensions is " <<
451 tensor.GetShape().GetNumElements());
453 CHECK( tensor.GetShape().AreAllDimensionsSpecified());
const TensorShape & GetShape() const
Optional< unsigned int > GetQuantizationDim() const
virtual const char * what() const noexcept override
void SetShape(const TensorShape &newShape)
A tensor defined by a TensorInfo (shape and data type) and a mutable backing store.
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")
int32_t GetQuantizationOffset() const
float GetQuantizationScale() const
DataType GetDataType() const
bool has_value() const noexcept
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
void SetQuantizationScale(float scale)
const TensorInfo & GetInfo() const
void SetDataType(DataType type)
void SetQuantizationDim(const Optional< unsigned int > &quantizationDim)
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
void SetQuantizationOffset(int32_t offset)
void SetQuantizationScales(const std::vector< float > &scales)
unsigned int GetNumDimensions() const