Skip to content

Commit

Permalink
Merge pull request #113 from tlambert-forks/fft
Browse files Browse the repository at this point in the history
add reikna-backed fftn, ifftn, and fftshift
  • Loading branch information
haesleinhuepf authored Sep 3, 2021
2 parents db9009e + 432971c commit 484ac4a
Show file tree
Hide file tree
Showing 9 changed files with 456 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
python --version
conda install -y pyopencl
python -m pip install --upgrade pip
pip install setuptools wheel pytest pytest-cov pytest-benchmark dask
pip install setuptools wheel pytest pytest-cov pytest-benchmark scipy dask
pip install -e .
- name: Test
shell: bash -l {0}
Expand Down
1 change: 1 addition & 0 deletions pyclesperanto_prototype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from ._tier5 import *
from ._tier8 import *
from ._tier9 import *
from ._fft import *

__version__ = "0.10.0"
20 changes: 20 additions & 0 deletions pyclesperanto_prototype/_fft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from ._fft import fftn, ifftn
from ._fftshift import fftshift
from ._fftconvolve import fftconvolve

# clij2 aliases
convolve_fft = fftconvolve
inverse_fft = ifftn
forward_fft = fftn
fft = fftn

__all__ = [
"convolve_fft",
"fft",
"fftconvolve",
"fftn",
"fftshift",
"forward_fft",
"ifftn",
"inverse_fft",
]
164 changes: 164 additions & 0 deletions pyclesperanto_prototype/_fft/_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Optional, Tuple, Union

import numpy as np
import reikna.cluda as cluda
from pyopencl.array import Array
from reikna.core import Type
from reikna.fft import FFT
from reikna.transformations import broadcast_const, combine_complex

from .. import create, get_device, push

# FFT plan cache
_PLAN_CACHE = {}


def _get_fft_plan(arr, axes=None, fast_math=True):
"""Cache and return a reikna FFT plan suitable for `arr` type and shape."""
axes = _normalize_axes(arr.shape, axes)
plan_key = (arr.shape, arr.dtype, axes, fast_math)

if plan_key not in _PLAN_CACHE:
if np.iscomplexobj(arr):
plan = FFT(arr, axes=axes)
else:
plan = FFT(Type(cluda.dtypes.complex_for(arr.dtype), arr.shape), axes=axes)
# joins two real inputs into complex output
cc = combine_complex(plan.parameter.input)
# broadcasts 0 to the imaginary component
bc = broadcast_const(cc.imag, 0)
plan.parameter.input.connect(
cc, cc.output, real_input=cc.real, imag_input=cc.imag
)
plan.parameter.imag_input.connect(bc, bc.output)

thread = cluda.ocl_api().Thread(get_device().queue)
_PLAN_CACHE[plan_key] = plan.compile(thread, fast_math=fast_math)

return _PLAN_CACHE[plan_key]


def fftn(
input_arr: Union[np.ndarray, Array],
output_arr: Union[np.ndarray, Array] = None,
axes: Optional[Tuple[int, ...]] = None,
inplace: bool = False,
fast_math: bool = True,
_inverse: bool = False,
) -> Array:
"""Perform fast Fourier transformation on `input_array`.
Parameters
----------
input_arr : numpy or OCL array
A numpy or OCL array to transform. If an OCL array is provided, it must already
be of type `complex64`. If a numpy array is provided, it will be converted
to `float32` before the transformation is performed.
output_arr : numpy or OCL array, optional
An optional array/buffer to use for output, by default None
axes : tuple of int, optional
T tuple with axes over which to perform the transform.
If not given, the transform is performed over all the axes., by default None
inplace : bool, optional
Whether to place output data in the `input_arr` buffer, by default False
fast_math : bool, optional
Whether to enable fast (less precise) mathematical operations during
compilation, by default True
_inverse : bool, optional
Perform inverse FFT, by default False. (prefer using `ifftn`)
Returns
-------
OCLArray
result of transformation (still on GPU). Use `.get()` or `cle.pull`
to retrieve from GPU.
If `inplace` or `output_arr` where used, data will also be placed in
the corresponding buffer as a side effect.
Raises
------
TypeError
If OCL array is provided that is not of type complex64. Or if an unrecognized
array is provided.
ValueError
If inplace is used for numpy array, or both `output_arr` and `inplace` are used.
"""
_isnumpy = isinstance(input_arr, np.ndarray)
if isinstance(input_arr, Array):
if input_arr.dtype != np.complex64:
raise TypeError("OCLArray input_arr has to be of complex64 type")
elif _isnumpy:
if inplace:
raise ValueError("inplace FFT not supported for numpy arrays")
input_arr = input_arr.astype(np.float32, copy=False)
else:
raise TypeError("input_arr must be either OCLArray or np.ndarray")
if output_arr is not None and inplace:
raise ValueError("`output_arr` cannot be provided if `inplace` is True")

plan = _get_fft_plan(input_arr, axes=axes, fast_math=fast_math)

if _isnumpy:
arr_dev = push(input_arr)
res_dev = create(arr_dev, np.complex64)
else:
arr_dev = input_arr
if inplace:
res_dev = arr_dev
else:
res_dev = (
create(arr_dev, np.complex64) if output_arr is None else output_arr
)

plan(res_dev, arr_dev, inverse=_inverse)

if _isnumpy and output_arr is not None:
output_arr[:] = res_dev.get()
return res_dev


def _normalize_axes(dshape, axes):
"""Convert possibly negative axes to positive axes."""
if axes is None:
return None
try:
return tuple(np.arange(len(dshape))[list(axes)])
except Exception as e:
raise TypeError(f"Cannot normalize axes {axes}: {e}")


def ifftn(
input_arr,
output_arr=None,
axes=None,
inplace=False,
fast_math=True,
):
"""Perform inverse Fourier transformation on `input_array`.
Parameters
----------
input_arr : numpy or OCL array
A numpy or OCL array to transform. If an OCL array is provided, it must already
be of type `complex64`. If a numpy array is provided, it will be converted
to `float32` before the transformation is performed.
output_arr : numpy or OCL array, optional
An optional array/buffer to use for output, by default None.
axes : tuple of int, optional
T tuple with axes over which to perform the transform.
If not given, the transform is performed over all the axes., by default None.
inplace : bool, optional
Whether to place output data in the `input_arr` buffer, by default False.
fast_math : bool, optional
Whether to enable fast (less precise) mathematical operations during
compilation, by default True.
Returns
-------
OCLArray
result of transformation (still on GPU). Use `.get()` or `cle.pull`
to retrieve from GPU.
If `inplace` or `output_arr` where used, data will also be placed in
the corresponding buffer as a side effect.
"""
return fftn(input_arr, output_arr, axes, inplace, fast_math, True)
106 changes: 106 additions & 0 deletions pyclesperanto_prototype/_fft/_fftconvolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from pyopencl.elementwise import ElementwiseKernel
from .. import get_device
from ._fft import fftn, ifftn
from .._tier0._pycl import OCLArray
import numpy as np
from .. import push, crop

_mult_complex = ElementwiseKernel(
get_device().context,
"cfloat_t *a, cfloat_t * b",
"a[i] = cfloat_mul(a[i], b[i])",
"mult",
)


def _fix_shape(arr, shape):
if arr.shape == shape:
return arr
if isinstance(arr, OCLArray):
# TODO
raise NotImplementedError("Cannot not resize/convert array to complex type..")
# result = OCLArray.empty(shape, arr.dtype)
# set(result, 0)
# paste(arr, result)
if isinstance(arr, np.ndarray):
result = np.zeros(shape, dtype=arr.dtype)
result[tuple(slice(i) for i in arr.shape)] = arr
return result


def fftconvolve(
data,
kernel,
mode="same",
axes=None,
output_arr=None,
inplace=False,
kernel_is_fft=False,
):
if mode not in {"valid", "same", "full"}:
raise ValueError("acceptable mode flags are 'valid', 'same', or 'full'")
if data.ndim != kernel.ndim:
raise ValueError("data and kernel should have the same dimensionality")

# expand arrays
s1 = data.shape
s2 = kernel.shape
axes = tuple(range(len(s1))) if axes is None else tuple(axes)
shape = [
max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1
for i in range(data.ndim)
]
data = _fix_shape(data, shape)
kernel = _fix_shape(kernel, shape)

if data.shape != kernel.shape:
raise ValueError("in1 and in2 must have the same shape")

data_g = push(data, np.complex64)
kernel_g = push(kernel, np.complex64)

if inplace:
output_arr = data_g
else:
if output_arr is None:
output_arr = OCLArray.empty(data_g.shape, data_g.dtype)
output_arr.copy_buffer(data_g)

if not kernel_is_fft:
kern_g = OCLArray.empty(kernel_g.shape, kernel_g.dtype)
kern_g.copy_buffer(kernel_g)
fftn(kern_g, inplace=True)
else:
kern_g = kernel_g

fftn(output_arr, inplace=True, axes=axes)
_mult_complex(output_arr, kern_g)
ifftn(output_arr, inplace=True, axes=axes)

_out = output_arr.real if np.isrealobj(data) else output_arr

if mode == "full":
return _out
elif mode == "same":
return _crop_centered(_out, s1)
else:
shape_valid = [
_out.shape[a] if a not in axes else s1[a] - s2[a] + 1
for a in range(_out.ndim)
]
return _crop_centered(_out, shape_valid)
return _out


def _crop_centered(arr, newshape):
# Return the center newshape portion of the array.
newshape = np.asarray(newshape)
currshape = np.array(arr.shape)
startind = (currshape - newshape) // 2
return crop(
arr,
start_y=startind[0],
start_x=startind[1],
height=newshape[0],
width=newshape[1],
)
19 changes: 19 additions & 0 deletions pyclesperanto_prototype/_fft/_fftshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
import reikna.cluda as cluda
from reikna.fft import FFTShift

from .. import create, get_device, push


def fftshift(arr, axes=None, output_arr=None, inplace=False):
shift = FFTShift(arr, axes=axes)
thread = cluda.ocl_api().Thread(get_device().queue)
shiftc = shift.compile(thread)

arr_dev = push(arr) if isinstance(arr, np.ndarray) else arr
if inplace:
res_dev = arr_dev
else:
res_dev = create(arr_dev) if output_arr is None else output_arr
shiftc(res_dev, arr_dev)
return res_dev
19 changes: 10 additions & 9 deletions pyclesperanto_prototype/_tier0/_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,48 @@
from ._pycl import OCLArray


def push(any_array):
def push(any_array, dtype=None):
"""Copies an image to GPU memory and returns its handle
.. deprecated:: 0.6.0
`push` behaviour will be changed pyclesperanto_prototype 0.7.0 to do the same as
`push_zyx` because it's faster and having both doing different things is confusing.
Parameters
----------
image : numpy array
dtype : np.dtype, optional
Returns
-------
OCLArray
Examples
--------
>>> import pyclesperanto_prototype as cle
>>> cle.push(image)
References
----------
.. [1] https://clij.github.io/clij2-docs/reference_push
"""
if isinstance(any_array, OCLArray):
return any_array

if isinstance(any_array, list) or isinstance(any_array, tuple):
if isinstance(any_array, (list, tuple)):
any_array = np.asarray(any_array)

if hasattr(any_array, 'shape') and hasattr(any_array, 'dtype') and hasattr(any_array, 'get'):
any_array = np.asarray(any_array.get())

float_arr = any_array.astype(np.float32)
float_arr = any_array.astype(np.float32 if dtype is None else dtype)
return OCLArray.from_array(float_arr)


def push_zyx(any_array):
import warnings

warnings.warn(
"Deprecated: `push_zyx()` is now deprecated as it does the same as `push()`.",
DeprecationWarning
DeprecationWarning,
)
return push(any_array)
return push(any_array)
Loading

0 comments on commit 484ac4a

Please sign in to comment.