-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subclass PyTorchYolo
and PyTorchDetectionTransformer
off PyTorchObjectDetector
#2321
Subclass PyTorchYolo
and PyTorchDetectionTransformer
off PyTorchObjectDetector
#2321
Conversation
baeec58
to
008679e
Compare
Codecov ReportAttention:
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## dev_1.18.0 #2321 +/- ##
==============================================
+ Coverage 85.39% 85.54% +0.14%
==============================================
Files 327 327
Lines 30205 29888 -317
Branches 5589 5528 -61
==============================================
- Hits 25793 25567 -226
+ Misses 2964 2900 -64
+ Partials 1448 1421 -27
|
@@ -247,8 +249,41 @@ def _preprocess_and_convert_inputs( | |||
|
|||
return x_preprocessed, y_preprocessed | |||
|
|||
def _translate_labels(self, labels: List[Dict[str, "torch.Tensor"]]) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define the return type more accurately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since _translate_labels()
and _translate_predictions()
need to be overridden for each object detector model, this can take many different types (e.g., dictionary for FRCNN, tensor for Yolo, list of tensor for Detr). Replacing the Any
would require a very long Union with every possible type for all subtypes which would need to updated if any new object detector is added. Therefore, it makes the most sense to keep this an Any
type.
labels_translated = [{k: v.to(self.device) for k, v in y_i.items()} for y_i in labels] | ||
return labels_translated | ||
|
||
def _translate_predictions(self, predictions: Any) -> List[Dict[str, np.ndarray]]: # pylint: disable=R0201 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the pylint comment and decorate this function with @staticmethod
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since _translate_predictions
needs to be overridden for all subclasses, it cannot be a static method. For the specific cases of Faster RCNN, it does not use any of the properties, but this is not the case for YOLO and Detr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @f4str Thank you very much for your pull request! I have added a few questions to get a better understanding of all the changes. Please let me know what you think?
|
||
if isinstance(x, torch.Tensor): | ||
return loss | ||
return loss # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there a type ignore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be resolved now
|
||
return loss.detach().cpu().numpy() | ||
return loss.detach().cpu().numpy() # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there a type ignore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be resolved now
|
||
|
||
@pytest.fixture() | ||
def get_pytorch_detr(get_default_cifar10_subset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -247,8 +249,41 @@ def _preprocess_and_convert_inputs( | |||
|
|||
return x_preprocessed, y_preprocessed | |||
|
|||
def _translate_labels(self, labels: List[Dict[str, "torch.Tensor"]]) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to ensure that inheriting classes remember to update the label and precision translations?
self.set_dropout(False) | ||
self.set_multihead_attention(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this new?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these are required for PyTorchDetectionTransformer
since it is a subclass of PyTorchObjectDetector
. Since PyTorchObjectDetector
and PyTorchYolo
do not have attention layers, this is just a no-op and does not affect functionality.
from art.utils import load_dataset | ||
from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer | ||
@pytest.mark.only_with_platform("pytorch") | ||
def test_predict(art_warning, get_pytorch_detr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did the expected values for predictions change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorchFasterRCNN
previously did not support the preprocessing
parameter. This is probably legacy before the original refactor from before as it is fully supported now. Therefore, this test case is no longer applicable.
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
750bd30
to
400f839
Compare
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
Signed-off-by: Farhan Ahmed <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @f4str Thank you very much! The changes look good to me.
Description
Modify both the
PyTorchYolo
andPyTorchDetectionTransformer
estimators to be subclassed off thePyTorchObjectDetector
estimator. This reduces a lot of redundant code and allows most functionality to be shared among the estimators. Additionally, this adds the option to train thePyTorchDetectionTransformer
since the code is inherited accordingly.The
PyTorchObjectSeeker
class was also cleaned up accordingly since there is now one unified superclass for all PyTorch object detectors.Additionally updated the unit tests for all PyTorch object detection estimators to use a fixed input. This reduces randomness and ensures that outputs can remain consistent throughout tests.
Fixes #2267
Type of change
Please check all relevant options.
Testing
Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.
PyTorchObjectDetector
PyTorchFasterRCNN
PyTorchYolo
PyTorchDetectionTransformer
Test Configuration:
Checklist