aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/validation/CPP/TopKV.cpp10
1 files changed, 5 insertions, 5 deletions
diff --git a/tests/validation/CPP/TopKV.cpp b/tests/validation/CPP/TopKV.cpp
index ee11cbc54c..02178192aa 100644
--- a/tests/validation/CPP/TopKV.cpp
+++ b/tests/validation/CPP/TopKV.cpp
@@ -129,8 +129,8 @@ TEST_CASE(Float, framework::DatasetMode::ALL)
topkv.run();
// Validate against the expected values
- SimpleTensor<float> expected_output(TensorShape(20), DataType::U8);
- fill_tensor(expected_output, std::vector<uint8_t> { 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1 });
+ SimpleTensor<uint8_t> expected_output(TensorShape(20), DataType::U8);
+ fill_tensor(expected_output, std::vector<uint8_t> { 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0 });
validate(Accessor(output), expected_output);
}
@@ -138,7 +138,7 @@ TEST_CASE(Quantized, framework::DatasetMode::ALL)
{
const unsigned int k = 5;
- Tensor predictions = create_tensor<Tensor>(TensorShape(10, 20), DataType::F32);
+ Tensor predictions = create_tensor<Tensor>(TensorShape(10, 20), DataType::QASYMM8, 1, QuantizationInfo());
Tensor targets = create_tensor<Tensor>(TensorShape(20), DataType::U32);
predictions.allocator()->allocate();
@@ -182,8 +182,8 @@ TEST_CASE(Quantized, framework::DatasetMode::ALL)
topkv.run();
// Validate against the expected values
- SimpleTensor<float> expected_output(TensorShape(20), DataType::U8);
- fill_tensor(expected_output, std::vector<uint8_t> { 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0 });
+ SimpleTensor<uint8_t> expected_output(TensorShape(20), DataType::U8);
+ fill_tensor(expected_output, std::vector<uint8_t> { 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0 });
validate(Accessor(output), expected_output);
}