-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TAXPoseD #18
base: main
Are you sure you want to change the base?
Add TAXPoseD #18
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. try running the unit tests and formatter
pytest
and
black
(forget how this works...)
third_party/dcp/model.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file looks kind of weird, no? why is the diff so scattered.... any way to make the diff clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, will move these changes into taxpose/nets which should also have a cleaner diff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved these changes out of third_party/... and into a new file in taxpose/nets/taxposed_dgcnn.py, maybe a better file name exists, could also probably merge the other DGCNN implementations (DGCNN_GC, VN_DGCNN) into a single custom dgcnns file in the future.
third_party/dcp/model.py
Outdated
@@ -280,18 +281,100 @@ def forward(self, x): | |||
|
|||
|
|||
class DGCNN(nn.Module): | |||
def __init__(self, emb_dims=512): | |||
def __init__(self, emb_dims=512, input_dims=3, num_heads=1, conditioning_size=0, last_relu=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two thoughts:
-
In general, I'm trying to keep this file as close as possible to the original code we took from the DCP paper. In its current state it's a true third-party import, and changing it kind of contaminates that. If you want to add additional models, put them in taxpose/nets, and mention at the top that it's derived from this file. I know it's a bit clunky, but I really want to make sure we keep track of various low-level modifications to networks or code we borrow.
-
What is the motivation behind teh changes to DGCNN here? is it to just make it conditional? in my mind that's a pretty distinct (custom) model that should live in taxpose/nets. Are there other logical changes? i.e. last_relu...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Makes complete sense, will move these changes into a new class in taxpose/nets
- Motivation is primarily to handle conditioning, either through an extra channels in the input or added to the later conv layers. Additionally, supports having multiple heads when TAXPoseD vae encoder needs to predict a normal distribution.
taxpose/utils/se3.py
Outdated
) | ||
constrained_axix_angle = rot_ratio * axis_angle_random # max angle is rot_var | ||
R = axis_angle_to_matrix(constrained_axix_angle) | ||
def random_se3(N, rot_var=np.pi/180 * 5, trans_var=0.1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this won't pass black formatting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverted this change as i realized it was unnecessary
taxpose/nets/transformer_flow.py
Outdated
if emb_nn == "dgcnn": | ||
self.emb_nn_action = DGCNN(emb_dims=self.emb_dims) | ||
self.emb_nn_anchor = DGCNN(emb_dims=self.emb_dims) | ||
self.emb_nn_action = DGCNN(emb_dims=self.emb_dims, input_dims=self.input_dims, conditioning_size=self.conditioning_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if i'd prefer a different case here - like c-dgcnn or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, will add a new case for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a case for "cond_dgcnn"
action_points = input[0].permute(0, 2, 1)[:, :self.input_dims] # B,3,num_points | ||
anchor_points = input[1].permute(0, 2, 1)[:, :self.input_dims] | ||
|
||
if action_center is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is action_center / anchor_center here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taxposed vae encoder will output distributions over the action/anchor point clouds which we sample a point from each to encode the demo. Jenny's found some benefit to centering the point clouds about the sampled points instead of mean, so on the taxposed side we pass in either the mean or sampled points for action/anchor point cloud centering as action_center/anchor_center
@@ -606,8 +629,8 @@ def forward(self, *input): | |||
head_action_output = self.head_action( | |||
action_embedding_tf, | |||
anchor_embedding_tf, | |||
action_points, | |||
anchor_points, | |||
action_points[:, :3, :], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what gets reshaped to make this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taxposed action_points/anchor_points will be [Batch size, XYZ + Conditioning size, Num. points], we want the full size inputs for the dgcnn embeddings but for the heads it seems we only need XYZ, additionally seems necessary for dimensions to align.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is fine, totally isolated from other code in the repo.
softmax_temperature=cfg.task.phase.softmax_temperature, | ||
flow_supervision=cfg.training.flow_supervision, | ||
) | ||
if cfg.model.name in ["taxposed", "taxposed_mlat_s100", "taxposed_mlat_s256"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might want to extract this conditional into a function above. just so that it's clear what's actually being set up in the main function. (i.e. a network, a model, a dataset, and training).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
created functions for network creation, model creation, and maybe loading weights so main function is clear
Seems pretty cool - can you post links to some training runs? Curious to see if the mlat stuff worked as well! also in general, the more you can isolate changes for these codepaths the better (which is why i commented only on files that touch normal taxpose training). I'm a bit paranoid about logic/structural changes to existing code which is known to work w/o running full training + eval on a few tasks. So the less that happens, the easier it is to say that nothing meaningful has changed about the code. But in general I'm really impressed with how quickly you got this going. The config complexity alone is annoying enough as it is, so good work figuring that out so fast! |
Also, can you put some example commands for how to train this thing in scripts/README.md? You might even try to get this running on the cluster o.O |
Updated code with black formatting and ran pytests, only cases that fail are due to missing checkpoint configs, mainly for evaluations and loading pretrained weights. Ran the first training stage for TAXPoseD on the RLBench tasks with and without mlat: https://wandb.ai/r-pad/taxposed |
Adding back NDF dataset support with some feature additions
…s, add taxposed launch commands to readme
f4e7f48
to
062ffb0
Compare
Add TAXPoseD with support for base TAXPose decoder and RLBench training.
Only changes that should possibly affect non TAXPoseD code are the changes to scripts/train_residual_flow.py, taxpose/nets/transformer_flow.py, and third_party/dcp/model.
Currently using a separate base config file (configs/train_ndf_multimodal.yaml), let me know if you would prefer a solution that doesn't need this.