From 18a3ae8d2f9999d05b22a38b85b02284cfcdc4ae Mon Sep 17 00:00:00 2001 From: Anirudh Date: Thu, 2 Jan 2025 12:09:50 -0500 Subject: [PATCH] fix: allow model.predict to handle numpy array inputs --- roboflow/models/inference.py | 14 +++++++++++++- tests/models/test_instance_segmentation.py | 20 ++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/roboflow/models/inference.py b/roboflow/models/inference.py index 7d2f5653..ad8bf391 100644 --- a/roboflow/models/inference.py +++ b/roboflow/models/inference.py @@ -62,7 +62,7 @@ def __get_image_params(self, image_path): Get parameters about an image (i.e. dimensions) for use in an inference request. Args: - image_path (str): path to the image you'd like to perform prediction on + image_path (Union[str, np.ndarray]): path to image or numpy array Returns: Tuple containing a dict of querystring params and a dict of requests kwargs @@ -70,6 +70,18 @@ def __get_image_params(self, image_path): Raises: Exception: Image path is not valid """ + import numpy as np + + if isinstance(image_path, np.ndarray): + # Convert numpy array to PIL Image + image = Image.fromarray(image_path) + dimensions = image.size + image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])} + buffered = io.BytesIO() + image.save(buffered, quality=90, format="JPEG") + data = MultipartEncoder(fields={"file": ("imageToUpload", buffered.getvalue(), "image/jpeg")}) + return {}, {"data": data, "headers": {"Content-Type": data.content_type}}, image_dims + validate_image_path(image_path) hosted_image = urllib.parse.urlparse(image_path).scheme in ("http", "https") diff --git a/tests/models/test_instance_segmentation.py b/tests/models/test_instance_segmentation.py index 6c1dfc60..a98d9b3a 100644 --- a/tests/models/test_instance_segmentation.py +++ b/tests/models/test_instance_segmentation.py @@ -142,3 +142,23 @@ def test_predict_with_non_200_response_raises_http_error(self): with self.assertRaises(HTTPError): instance.predict(image_path) + + @responses.activate + def test_predict_with_numpy_array(self): + # Create a simple numpy array image + import numpy as np + + image_array = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image + image_array[30:70, 30:70] = 255 # Add a white square + + instance = InstanceSegmentationModel(self.api_key, self.version_id) + + responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE) + group = instance.predict(image_array) + self.assertIsInstance(group, PredictionGroup) + + request = responses.calls[0].request + self.assertEqual(request.method, "POST") + self.assertRegex(request.url, rf"^{self.api_url}") + self.assertDictEqual(request.params, self._default_params) + self.assertIsNotNone(request.body)