-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #113 from tlambert-forks/fft
add reikna-backed fftn, ifftn, and fftshift
- Loading branch information
Showing
9 changed files
with
456 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,5 +6,6 @@ | |
from ._tier5 import * | ||
from ._tier8 import * | ||
from ._tier9 import * | ||
from ._fft import * | ||
|
||
__version__ = "0.10.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from ._fft import fftn, ifftn | ||
from ._fftshift import fftshift | ||
from ._fftconvolve import fftconvolve | ||
|
||
# clij2 aliases | ||
convolve_fft = fftconvolve | ||
inverse_fft = ifftn | ||
forward_fft = fftn | ||
fft = fftn | ||
|
||
__all__ = [ | ||
"convolve_fft", | ||
"fft", | ||
"fftconvolve", | ||
"fftn", | ||
"fftshift", | ||
"forward_fft", | ||
"ifftn", | ||
"inverse_fft", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
from typing import Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import reikna.cluda as cluda | ||
from pyopencl.array import Array | ||
from reikna.core import Type | ||
from reikna.fft import FFT | ||
from reikna.transformations import broadcast_const, combine_complex | ||
|
||
from .. import create, get_device, push | ||
|
||
# FFT plan cache | ||
_PLAN_CACHE = {} | ||
|
||
|
||
def _get_fft_plan(arr, axes=None, fast_math=True): | ||
"""Cache and return a reikna FFT plan suitable for `arr` type and shape.""" | ||
axes = _normalize_axes(arr.shape, axes) | ||
plan_key = (arr.shape, arr.dtype, axes, fast_math) | ||
|
||
if plan_key not in _PLAN_CACHE: | ||
if np.iscomplexobj(arr): | ||
plan = FFT(arr, axes=axes) | ||
else: | ||
plan = FFT(Type(cluda.dtypes.complex_for(arr.dtype), arr.shape), axes=axes) | ||
# joins two real inputs into complex output | ||
cc = combine_complex(plan.parameter.input) | ||
# broadcasts 0 to the imaginary component | ||
bc = broadcast_const(cc.imag, 0) | ||
plan.parameter.input.connect( | ||
cc, cc.output, real_input=cc.real, imag_input=cc.imag | ||
) | ||
plan.parameter.imag_input.connect(bc, bc.output) | ||
|
||
thread = cluda.ocl_api().Thread(get_device().queue) | ||
_PLAN_CACHE[plan_key] = plan.compile(thread, fast_math=fast_math) | ||
|
||
return _PLAN_CACHE[plan_key] | ||
|
||
|
||
def fftn( | ||
input_arr: Union[np.ndarray, Array], | ||
output_arr: Union[np.ndarray, Array] = None, | ||
axes: Optional[Tuple[int, ...]] = None, | ||
inplace: bool = False, | ||
fast_math: bool = True, | ||
_inverse: bool = False, | ||
) -> Array: | ||
"""Perform fast Fourier transformation on `input_array`. | ||
Parameters | ||
---------- | ||
input_arr : numpy or OCL array | ||
A numpy or OCL array to transform. If an OCL array is provided, it must already | ||
be of type `complex64`. If a numpy array is provided, it will be converted | ||
to `float32` before the transformation is performed. | ||
output_arr : numpy or OCL array, optional | ||
An optional array/buffer to use for output, by default None | ||
axes : tuple of int, optional | ||
T tuple with axes over which to perform the transform. | ||
If not given, the transform is performed over all the axes., by default None | ||
inplace : bool, optional | ||
Whether to place output data in the `input_arr` buffer, by default False | ||
fast_math : bool, optional | ||
Whether to enable fast (less precise) mathematical operations during | ||
compilation, by default True | ||
_inverse : bool, optional | ||
Perform inverse FFT, by default False. (prefer using `ifftn`) | ||
Returns | ||
------- | ||
OCLArray | ||
result of transformation (still on GPU). Use `.get()` or `cle.pull` | ||
to retrieve from GPU. | ||
If `inplace` or `output_arr` where used, data will also be placed in | ||
the corresponding buffer as a side effect. | ||
Raises | ||
------ | ||
TypeError | ||
If OCL array is provided that is not of type complex64. Or if an unrecognized | ||
array is provided. | ||
ValueError | ||
If inplace is used for numpy array, or both `output_arr` and `inplace` are used. | ||
""" | ||
_isnumpy = isinstance(input_arr, np.ndarray) | ||
if isinstance(input_arr, Array): | ||
if input_arr.dtype != np.complex64: | ||
raise TypeError("OCLArray input_arr has to be of complex64 type") | ||
elif _isnumpy: | ||
if inplace: | ||
raise ValueError("inplace FFT not supported for numpy arrays") | ||
input_arr = input_arr.astype(np.float32, copy=False) | ||
else: | ||
raise TypeError("input_arr must be either OCLArray or np.ndarray") | ||
if output_arr is not None and inplace: | ||
raise ValueError("`output_arr` cannot be provided if `inplace` is True") | ||
|
||
plan = _get_fft_plan(input_arr, axes=axes, fast_math=fast_math) | ||
|
||
if _isnumpy: | ||
arr_dev = push(input_arr) | ||
res_dev = create(arr_dev, np.complex64) | ||
else: | ||
arr_dev = input_arr | ||
if inplace: | ||
res_dev = arr_dev | ||
else: | ||
res_dev = ( | ||
create(arr_dev, np.complex64) if output_arr is None else output_arr | ||
) | ||
|
||
plan(res_dev, arr_dev, inverse=_inverse) | ||
|
||
if _isnumpy and output_arr is not None: | ||
output_arr[:] = res_dev.get() | ||
return res_dev | ||
|
||
|
||
def _normalize_axes(dshape, axes): | ||
"""Convert possibly negative axes to positive axes.""" | ||
if axes is None: | ||
return None | ||
try: | ||
return tuple(np.arange(len(dshape))[list(axes)]) | ||
except Exception as e: | ||
raise TypeError(f"Cannot normalize axes {axes}: {e}") | ||
|
||
|
||
def ifftn( | ||
input_arr, | ||
output_arr=None, | ||
axes=None, | ||
inplace=False, | ||
fast_math=True, | ||
): | ||
"""Perform inverse Fourier transformation on `input_array`. | ||
Parameters | ||
---------- | ||
input_arr : numpy or OCL array | ||
A numpy or OCL array to transform. If an OCL array is provided, it must already | ||
be of type `complex64`. If a numpy array is provided, it will be converted | ||
to `float32` before the transformation is performed. | ||
output_arr : numpy or OCL array, optional | ||
An optional array/buffer to use for output, by default None. | ||
axes : tuple of int, optional | ||
T tuple with axes over which to perform the transform. | ||
If not given, the transform is performed over all the axes., by default None. | ||
inplace : bool, optional | ||
Whether to place output data in the `input_arr` buffer, by default False. | ||
fast_math : bool, optional | ||
Whether to enable fast (less precise) mathematical operations during | ||
compilation, by default True. | ||
Returns | ||
------- | ||
OCLArray | ||
result of transformation (still on GPU). Use `.get()` or `cle.pull` | ||
to retrieve from GPU. | ||
If `inplace` or `output_arr` where used, data will also be placed in | ||
the corresponding buffer as a side effect. | ||
""" | ||
return fftn(input_arr, output_arr, axes, inplace, fast_math, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from pyopencl.elementwise import ElementwiseKernel | ||
from .. import get_device | ||
from ._fft import fftn, ifftn | ||
from .._tier0._pycl import OCLArray | ||
import numpy as np | ||
from .. import push, crop | ||
|
||
_mult_complex = ElementwiseKernel( | ||
get_device().context, | ||
"cfloat_t *a, cfloat_t * b", | ||
"a[i] = cfloat_mul(a[i], b[i])", | ||
"mult", | ||
) | ||
|
||
|
||
def _fix_shape(arr, shape): | ||
if arr.shape == shape: | ||
return arr | ||
if isinstance(arr, OCLArray): | ||
# TODO | ||
raise NotImplementedError("Cannot not resize/convert array to complex type..") | ||
# result = OCLArray.empty(shape, arr.dtype) | ||
# set(result, 0) | ||
# paste(arr, result) | ||
if isinstance(arr, np.ndarray): | ||
result = np.zeros(shape, dtype=arr.dtype) | ||
result[tuple(slice(i) for i in arr.shape)] = arr | ||
return result | ||
|
||
|
||
def fftconvolve( | ||
data, | ||
kernel, | ||
mode="same", | ||
axes=None, | ||
output_arr=None, | ||
inplace=False, | ||
kernel_is_fft=False, | ||
): | ||
if mode not in {"valid", "same", "full"}: | ||
raise ValueError("acceptable mode flags are 'valid', 'same', or 'full'") | ||
if data.ndim != kernel.ndim: | ||
raise ValueError("data and kernel should have the same dimensionality") | ||
|
||
# expand arrays | ||
s1 = data.shape | ||
s2 = kernel.shape | ||
axes = tuple(range(len(s1))) if axes is None else tuple(axes) | ||
shape = [ | ||
max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 | ||
for i in range(data.ndim) | ||
] | ||
data = _fix_shape(data, shape) | ||
kernel = _fix_shape(kernel, shape) | ||
|
||
if data.shape != kernel.shape: | ||
raise ValueError("in1 and in2 must have the same shape") | ||
|
||
data_g = push(data, np.complex64) | ||
kernel_g = push(kernel, np.complex64) | ||
|
||
if inplace: | ||
output_arr = data_g | ||
else: | ||
if output_arr is None: | ||
output_arr = OCLArray.empty(data_g.shape, data_g.dtype) | ||
output_arr.copy_buffer(data_g) | ||
|
||
if not kernel_is_fft: | ||
kern_g = OCLArray.empty(kernel_g.shape, kernel_g.dtype) | ||
kern_g.copy_buffer(kernel_g) | ||
fftn(kern_g, inplace=True) | ||
else: | ||
kern_g = kernel_g | ||
|
||
fftn(output_arr, inplace=True, axes=axes) | ||
_mult_complex(output_arr, kern_g) | ||
ifftn(output_arr, inplace=True, axes=axes) | ||
|
||
_out = output_arr.real if np.isrealobj(data) else output_arr | ||
|
||
if mode == "full": | ||
return _out | ||
elif mode == "same": | ||
return _crop_centered(_out, s1) | ||
else: | ||
shape_valid = [ | ||
_out.shape[a] if a not in axes else s1[a] - s2[a] + 1 | ||
for a in range(_out.ndim) | ||
] | ||
return _crop_centered(_out, shape_valid) | ||
return _out | ||
|
||
|
||
def _crop_centered(arr, newshape): | ||
# Return the center newshape portion of the array. | ||
newshape = np.asarray(newshape) | ||
currshape = np.array(arr.shape) | ||
startind = (currshape - newshape) // 2 | ||
return crop( | ||
arr, | ||
start_y=startind[0], | ||
start_x=startind[1], | ||
height=newshape[0], | ||
width=newshape[1], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import numpy as np | ||
import reikna.cluda as cluda | ||
from reikna.fft import FFTShift | ||
|
||
from .. import create, get_device, push | ||
|
||
|
||
def fftshift(arr, axes=None, output_arr=None, inplace=False): | ||
shift = FFTShift(arr, axes=axes) | ||
thread = cluda.ocl_api().Thread(get_device().queue) | ||
shiftc = shift.compile(thread) | ||
|
||
arr_dev = push(arr) if isinstance(arr, np.ndarray) else arr | ||
if inplace: | ||
res_dev = arr_dev | ||
else: | ||
res_dev = create(arr_dev) if output_arr is None else output_arr | ||
shiftc(res_dev, arr_dev) | ||
return res_dev |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.