-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add missing PyTorch/JAX export for logical_or
, logical_and
, and relu
#433
base: master
Are you sure you want to change the base?
Conversation
logical_or
, logical_and
, and relu
logical_or
, logical_and
, and relu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to avoid the following error:
KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'
This function should be added to the mappings:
def if_then_else(*conds):
a, b, c = conds
return torch.where(a, torch.where(b, True, False), torch.where(c, True, False))
extra_torch_mappings = {sympy.logic.boolalg.ITE: if_then_else}
Additionally, lines 87-89 can be replaced with:
output += torch.where(
cond.bool() & ~already_used, expr, torch.zeros_like(expr)
)
already_used = already_used | cond.bool()
as when cond is a float (1.) the code fails without cond.bool()
If you go to the files tab, you can click the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added if_then_else
as ITE
mapping.
Changed cond
to cond.bool()
to avoid error when cond it Float (1.).
Co-authored-by: tbuckworth <[email protected]>
Co-authored-by: tbuckworth <[email protected]>
for more information, see https://pre-commit.ci
Thanks to @j-thib for pointing this out, I didn't realize the SymPy maps were not built-in.
However this gets fairly complicated as we need to map
sympy.Piecewise
into torch/jax code.TODO:
Piecewise
Piecewise