Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom gradient #19279

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def unstack(x, num=None, axis=0):
]


def custom_gradient(fun):
return jax.custom_gradient(fun=fun)


def device_scope(device_name):
if isinstance(device_name, str):
# We support string value like "cpu:0", "gpu:1", etc.
Expand Down
6 changes: 6 additions & 0 deletions keras/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,9 @@ def stop_gradient(x):
def unstack(x, num=None, axis=0):
x = np.moveaxis(x, axis, 0)
return [x[i] for i in range(x.shape[0])]


def custom_gradient(fun):
raise NotImplementedError(
"`custom_gradient` is not supported with numpy backend"
)
4 changes: 4 additions & 0 deletions keras/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def unstack(x, num=None, axis=0):
return tf.unstack(x, num=num, axis=axis)


def custom_gradient(fun):
return tf.custom_gradient(f=fun)


class name_scope(base_name_scope):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand Down
7 changes: 7 additions & 0 deletions keras/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,10 @@ def stop_gradient(variable):

def unstack(x, num=None, axis=0):
return x.unbind(axis)


def custom_gradient(fun):
# TODO: Support this function
raise NotImplementedError(
"`custom_gradient` is not supported with torch backend"
)
45 changes: 45 additions & 0 deletions keras/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
convert_to_numpy
cond
is_tensor
custom_gradient
"""

import numpy as np
Expand Down Expand Up @@ -623,3 +624,47 @@ def is_tensor(x):
`True` if `x` is a tensor, otherwise `False`.
"""
return backend.core.is_tensor(x)


@keras_export("keras.ops.custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.

This decorator allows fine grained control over the gradients of a sequence
for operations. This may be useful for multiple reasons, including providing
a more efficient or numerically stable gradient for a sequence of
operations.

Note that `custom_gradient` only supports TensorFlow and JAX backends.

Args:
f: Function `f(*x)` that returns a tuple `(y, grad_fn)` where:
- `x` is a sequence of (nested structures of) tensor inputs to the
function.
- `y` is a (nested structure of) tensor outputs of applying
operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which
returns a list of tensors the same size as (flattened) `x`: the
derivatives of tensors in `y` with respect to the tensors in
`x`. `grad_ys` is a sequence of tensors the same size as
(flattened) `y` holding the initial value gradients for each
tensor in `y`.

Returns:
A function `h(x)` which returns the same value as `f(x)[0]` and whose
gradient is determined by `f(x)[1]`.

Example:

```python
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)

def grad(upstream):
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))

return ops.log(1 + e), grad
```
"""
return backend.core.custom_gradient(f)
38 changes: 38 additions & 0 deletions keras/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,44 @@ def test_is_tensor(self):
self.assertTrue(ops.is_tensor(x))
self.assertFalse(ops.is_tensor([1, 2, 3]))

@pytest.mark.skipif(
backend.backend() not in ("tensorflow", "jax"),
reason=f"{backend.backend()} doesn't support `custom_gradient`.",
)
def test_custom_gradient(self):
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)

def grad(upstream):
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))

return ops.log(1 + e), grad

def log1pexp_nan(x):
return ops.log(1 + ops.exp(x))

x = ops.convert_to_tensor(100.0)
if backend.backend() == "tensorflow":
import tensorflow as tf

with tf.GradientTape() as tape1:
tape1.watch(x)
y = log1pexp(x)
with tf.GradientTape() as tape2:
tape2.watch(x)
z = log1pexp_nan(x)
dy_dx = tape1.gradient(y, x)
dz_dx = tape2.gradient(z, x)
elif backend.backend() == "jax":
import jax

dy_dx = jax.grad(log1pexp)(x)
dz_dx = jax.grad(log1pexp_nan)(x)

self.assertEqual(ops.convert_to_numpy(dy_dx), 1.0)
self.assertTrue(ops.isnan(dz_dx))


class CoreOpsDtypeTest(testing.TestCase, parameterized.TestCase):
import jax # enable bfloat16 for numpy
Expand Down