Skip to content

Commit

Permalink
Change image default dtype from float32 to uint8 (#1175)
Browse files Browse the repository at this point in the history
- Currently, there is discrepancy between the return image data types:
`ImageFromBytes.data` (`np.float32`), `ImageFromNumpy.data`
(`np.float32`), and `ImageFromFile.data` (`np.uint8`).
- This makes the data loader based on the Arrow data format (using
`ImageFromBytes.data`) slower since the image preprocessing will be
conducted on the `np.float32` data (4x larger than `np.uint8`).
- This PR forces `np.uint8` data to be returned for all `Image` classes.

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim authored Oct 19, 2023
1 parent fd53680 commit 46a97d7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Enhancements
- Enhance Datumaro data format stream importer performance
(<https://github.com/openvinotoolkit/datumaro/pull/1153>)
- Change image default dtype from float32 to uint8
(<https://github.com/openvinotoolkit/datumaro/pull/1175>)

### Bug fixes
- Fix errata in the voc document. Color values in the labelmap.txt should be separated by commas, not colons.
Expand Down
14 changes: 6 additions & 8 deletions src/datumaro/components/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(

@property
def data(self) -> Optional[np.ndarray]:
"""Image data in BGRA HWC [0; 255] (float) format"""
"""Image data in BGRA HWC [0; 255] (uint8) format"""

if not self.has_data:
return None
Expand Down Expand Up @@ -375,12 +375,12 @@ def __init__(

@property
def data(self) -> Optional[np.ndarray]:
"""Image data in BGRA HWC [0; 255] (float) format"""
"""Image data in BGRA HWC [0; 255] (uint8) format"""

data = super().data

if isinstance(data, np.ndarray):
data = data.astype(np.float32)
if isinstance(data, np.ndarray) and data.dtype != np.uint8:
data = np.clip(data, 0.0, 255.0).astype(np.uint8)
if self._size is None and data is not None:
if not 2 <= data.ndim <= 3:
raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.")
Expand Down Expand Up @@ -420,14 +420,12 @@ def _guess_ext(cls, data: bytes) -> Optional[str]:

@property
def data(self) -> Optional[np.ndarray]:
"""Image data in BGRA HWC [0; 255] (float) format"""
"""Image data in BGRA HWC [0; 255] (uint8) format"""

data = super().data

if isinstance(data, bytes):
data = decode_image(data)
if isinstance(data, np.ndarray):
data = data.astype(np.float32)
data = decode_image(data, dtype=np.uint8)
if self._size is None and data is not None:
if not 2 <= data.ndim <= 3:
raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.")
Expand Down
3 changes: 0 additions & 3 deletions src/datumaro/plugins/framework_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def __init__(
def __getitem__(self, idx):
image, label = self._gen_item(idx)

if image.dtype == np.uint8 or image.max() > 1:
image = image.astype(np.float32) / 255

if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/operations/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

@pytest.fixture
def fxt_image_dataset_expected_mean_std():
np.random.seed(3003)
expected_mean = [100, 50, 150]
expected_std = [20, 50, 10]
expected_std = [2, 1, 3]

return expected_mean, expected_std

Expand Down Expand Up @@ -90,9 +91,9 @@ def test_image_stats(
actual_std = actual["subsets"]["default"]["image std"][::-1]

for em, am in zip(expected_mean, actual_mean):
assert am == pytest.approx(em, 1e-2)
assert am == pytest.approx(em, 5e-1)
for estd, astd in zip(expected_std, actual_std):
assert astd == pytest.approx(estd, 1e-2)
assert astd == pytest.approx(estd, 1e-1)

@mark_requirement(Requirements.DATUM_BUG_873)
def test_invalid_media_type(
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_framework_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def test_can_convert_torch_framework(
label = np.sum(masks, axis=0, dtype=np.uint8)

if fxt_convert_kwargs.get("transform", None):
assert np.array_equal(image, dm_torch_item[0].reshape(5, 5, 3).numpy())
actual = dm_torch_item[0].permute(1, 2, 0).mul(255.0).to(torch.uint8).numpy()
assert np.array_equal(image, actual)
else:
assert np.array_equal(image, dm_torch_item[0])

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@
class TestOperations(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_mean_std(self):
np.random.seed(3000)
expected_mean = [100, 50, 150]
expected_std = [20, 50, 10]
expected_std = [2, 1, 3]

dataset = Dataset.from_iterable(
[
Expand All @@ -62,9 +63,9 @@ def test_mean_std(self):
actual_mean, actual_std = mean_std(dataset)

for em, am in zip(expected_mean, actual_mean):
self.assertAlmostEqual(em, am, places=0)
assert np.allclose(em, am, atol=0.6)
for estd, astd in zip(expected_std, actual_std):
self.assertAlmostEqual(estd, astd, places=0)
assert np.allclose(estd, astd, atol=0.1)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_stats(self):
Expand Down

0 comments on commit 46a97d7

Please sign in to comment.