Skip to content

Commit

Permalink
Merge pull request #69 from arnaudvl/sr
Browse files Browse the repository at this point in the history
SR method
  • Loading branch information
arnaudvl authored Nov 25, 2019
2 parents 7c6ae6d + 5f1a522 commit e39bc70
Show file tree
Hide file tree
Showing 14 changed files with 1,038 additions and 24 deletions.
40 changes: 24 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ This will install `alibi-detect` with all its dependencies:

### Outlier Detection

The following table shows the advised use cases for each algorithm. The column *Feature Level* indicates whether the outlier scoring and detection can be done and returned at the feature level, e.g. per pixel for an image:

| Detector | Tabular | Image | Time Series | Text | Categorical Features | Online | Feature Level |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Isolation Forest ||||||||
| Mahalanobis Distance ||||||||
| VAE ||||||||
| AEGMM ||||||||
| VAEGMM ||||||||
| Prophet ||||||||
| Spectral Residual ||||||||


- Isolation Forest ([FT Liu et al., 2008](https://cs.nju.edu.cn/zhouzh/zhouzh.files/publication/icdm08b.pdf))
- [Documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/iforest.html)
- Examples:
Expand Down Expand Up @@ -62,33 +75,28 @@ This will install `alibi-detect` with all its dependencies:
- [Documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/prophet.html)
- Examples:
[Weather Forecast](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/od_prophet_weather.html)

The following table shows the advised use cases for each algorithm. The column *Feature Level* indicates whether the outlier scoring and detection can be done and returned at the feature level, e.g. per pixel for an image:

| Detector | Tabular | Image | Time Series | Text | Categorical Features | Online | Feature Level |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Isolation Forest ||||||||
| Mahalanobis Distance ||||||||
| VAE ||||||||
| AEGMM ||||||||
| VAEGMM ||||||||
| Prophet ||||||||

- Spectral Residual Time Series Outlier Detector ([Ren et al., 2019](https://arxiv.org/abs/1906.03821))
- [Documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/sr.html)
- Examples:
[Synthetic Dataset](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/od_sr_synth.html)


### Adversarial Detection

- Adversarial Variational Auto-Encoder (paper coming soon)
- [Documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/adversarialvae.html)
- Examples:
[MNIST](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/ad_advvae_mnist.html)

Advised use cases:

| Detector | Tabular | Image | Time Series | Text | Categorical Features | Online | Feature Level |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Adversarial VAE ||||||||


- Adversarial Variational Auto-Encoder (paper coming soon)
- [Documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/adversarialvae.html)
- Examples:
[MNIST](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/ad_advvae_mnist.html)


## Integrations

The integrations folder contains various wrapper tools to allow the alibi-detect algorithms to be used in production machine learning systems with [examples](https://github.com/SeldonIO/alibi-detect/tree/master/integrations/samples/kfserving) on how to deploy outlier and adversarial detectors with [KFServing](https://www.kubeflow.org/docs/components/serving/kfserving/).
4 changes: 3 additions & 1 deletion alibi_detect/od/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from .vae import OutlierVAE
from .vaegmm import OutlierVAEGMM
from .prophet import OutlierProphet
from .sr import SpectralResidual

__all__ = [
"OutlierAEGMM",
"IForest",
"Mahalanobis",
"OutlierVAE",
"OutlierVAEGMM",
"OutlierProphet"
"OutlierProphet",
"SpectralResidual"
]
215 changes: 215 additions & 0 deletions alibi_detect/od/sr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import logging
import numpy as np
from typing import Dict
from alibi_detect.base import BaseDetector, ThresholdMixin, outlier_prediction_dict

logger = logging.getLogger(__name__)

EPSILON = 1e-8


class SpectralResidual(BaseDetector, ThresholdMixin):

def __init__(self,
threshold: float = None,
window_amp: int = None,
window_local: int = None,
n_est_points: int = None,
n_grad_points: int = 5,
) -> None:
"""
Outlier detector for time-series data using the spectral residual algorithm.
Based on "Time-Series Anomaly Detection Service at Microsoft" (Ren et al., 2019)
https://arxiv.org/abs/1906.03821
Parameters
----------
threshold
Threshold used to classify outliers. Relative saliency map distance from the moving average.
window_amp
Window for the average log amplitude.
window_local
Window for the local average of the saliency map.
n_est_points
Number of estimated points padded to the end of the sequence.
n_grad_points
Number of points used for the gradient estimation of the additional points padded
to the end of the sequence.
"""
super().__init__()

if threshold is None:
logger.warning('No threshold level set. Need to infer threshold using `infer_threshold`.')

self.threshold = threshold
self.window_amp = window_amp
self.window_local = window_local
self.conv_amp = np.ones((1, window_amp)).reshape(-1,) / window_amp
self.conv_local = np.ones((1, window_local)).reshape(-1,) / window_local
self.n_est_points = n_est_points
self.n_grad_points = n_grad_points

# set metadata
self.meta['detector_type'] = 'online'
self.meta['data_type'] = 'time-series'

def infer_threshold(self,
X: np.ndarray,
t: np.ndarray = None,
threshold_perc: float = 95.
) -> None:
"""
Update threshold by a value inferred from the percentage of instances considered to be
outliers in a sample of the dataset.
Parameters
----------
X
Batch of instances.
threshold_perc
Percentage of X considered to be normal based on the outlier score.
"""
if t is None:
t = np.arange(X.shape[0])

# compute outlier scores
iscore = self.score(X, t)

# update threshold
self.threshold = np.percentile(iscore, threshold_perc)

def saliency_map(self, X: np.ndarray) -> np.ndarray:
"""
Compute saliency map.
Parameters
----------
X
Time series of instances.
Returns
-------
Array with saliency map values.
"""
fft = np.fft.fft(X)
amp = np.abs(fft)
log_amp = np.log(amp)
phase = np.angle(fft)
ma_log_amp = np.convolve(log_amp, self.conv_amp, 'same')
res_amp = log_amp - ma_log_amp
sr = np.abs(np.fft.ifft(np.exp(res_amp + 1j * phase)))
return sr

def compute_grads(self, X: np.ndarray, t: np.ndarray) -> np.ndarray:
"""
Slope of the straight line between different points of the time series
multiplied by the average time step size.
Parameters
----------
X
Time series of instances.
t
Time steps.
Returns
-------
Array with slope values.
"""
dX = X[-1] - X[-self.n_grad_points-1:-1]
dt = t[-1] - t[-self.n_grad_points-1:-1]
mean_grads = np.mean(dX / dt) * np.mean(dt)
return mean_grads

def add_est_points(self, X: np.ndarray, t: np.ndarray) -> np.ndarray:
"""
Pad the time series with additional points since the method works better if the anomaly point
is towards the center of the sliding window.
Parameters
----------
X
Time series of instances.
t
Time steps.
Returns
-------
Padded version of X.
"""
grads = self.compute_grads(X, t)
X_add = X[-self.n_grad_points] + grads
X_pad = np.concatenate([X, np.tile(X_add, self.n_est_points)])
return X_pad

def score(self, X: np.ndarray, t: np.ndarray = None) -> np.ndarray:
"""
Compute outlier scores.
Parameters
----------
X
Time series of instances.
t
Time steps.
Returns
-------
Array with outlier scores for each instance in the batch.
"""
if t is None:
t = np.arange(X.shape[0])

if len(X.shape) == 2:
n_samples, n_dim = X.shape
X = X.reshape(-1,)
if X.shape[0] != n_samples:
raise ValueError('Only univariate time series allowed for SR method. Number of features '
'of time series equals {}.'.format(n_dim))

X_pad = self.add_est_points(X, t) # add padding
sr = self.saliency_map(X_pad) # compute saliency map
sr = sr[:-self.n_est_points] # remove padding again
ma_sr = np.convolve(sr, self.conv_local, 'same')
iscore = (sr - ma_sr) / (ma_sr + EPSILON)
return iscore

def predict(self,
X: np.ndarray,
t: np.ndarray = None,
return_instance_score: bool = True) \
-> Dict[Dict[str, str], Dict[np.ndarray, np.ndarray]]:
"""
Compute outlier scores and transform into outlier predictions.
Parameters
----------
X
Time series of instances.
t
Time steps.
return_instance_score
Whether to return instance level outlier scores.
Returns
-------
Dictionary containing 'meta' and 'data' dictionaries.
'meta' has the model's metadata.
'data' contains the outlier predictions and instance level outlier scores.
"""
if t is None:
t = np.arange(X.shape[0])

# compute outlier scores
iscore = self.score(X.reshape(-1, ), t)

# values above threshold are outliers
outlier_pred = (iscore > self.threshold).astype(int)

# populate output dict
od = outlier_prediction_dict()
od['meta'] = self.meta
od['data']['is_outlier'] = outlier_pred
if return_instance_score:
od['data']['instance_score'] = iscore
return od
55 changes: 55 additions & 0 deletions alibi_detect/od/tests/test_sr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from itertools import product
import numpy as np
import pytest
from alibi_detect.od import SpectralResidual

# create normal time series and one with perturbations
t = np.linspace(0, 0.5, 1000)
X = np.sin(40 * 2 * np.pi * t) + 0.5 * np.sin(90 * 2 * np.pi * t)
idx_pert = np.random.randint(0, 1000, 10)
X_pert = X.copy()
X_pert[idx_pert] = 50

window_amp = [10, 20]
window_local = [20, 30]
n_est_points = [10, 20]
return_instance_score = [True, False]

tests = list(product(window_amp, window_local, n_est_points, return_instance_score))
n_tests = len(tests)


@pytest.fixture
def sr_params(request):
return tests[request.param]


@pytest.mark.parametrize('sr_params', list(range(n_tests)), indirect=True)
def test_sr(sr_params):
window_amp, window_local, n_est_points, return_instance_score = sr_params

threshold = 2.5
od = SpectralResidual(threshold=threshold,
window_amp=window_amp,
window_local=window_local,
n_est_points=n_est_points)

assert od.threshold == threshold
assert od.meta == {'name': 'SpectralResidual',
'detector_type': 'online',
'data_type': 'time-series'}
preds_in = od.predict(X, t, return_instance_score=return_instance_score)
assert preds_in['data']['is_outlier'].sum() <= 2.
if return_instance_score:
assert preds_in['data']['is_outlier'].sum() == (preds_in['data']['instance_score']
> od.threshold).astype(int).sum()
else:
assert preds_in['data']['instance_score'] is None
preds_out = od.predict(X_pert, t, return_instance_score=return_instance_score)
assert preds_out['data']['is_outlier'].sum() >= idx_pert.shape[0] - 2
if return_instance_score:
assert preds_out['data']['is_outlier'].sum() == (preds_out['data']['instance_score']
> od.threshold).astype(int).sum()
else:
assert preds_out['data']['instance_score'] is None
assert preds_out['meta'] == od.meta
5 changes: 5 additions & 0 deletions alibi_detect/utils/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def inject_outlier_ts(X: np.ndarray,
-------
Bunch object with the perturbed time series and the outlier labels.
"""
n_dim = len(X.shape)
if n_dim == 1:
X = X.reshape(-1, 1)
n_samples, n_ts = X.shape
X_outlier = X.copy()
is_outlier = np.zeros(n_samples)
Expand All @@ -141,4 +144,6 @@ def inject_outlier_ts(X: np.ndarray,
rnd = np.random.normal(size=n_outlier)
X_outlier[outlier_idx, s] += np.sign(rnd) * np.maximum(np.abs(rnd * n_std), min_std) * stdev
is_outlier[outlier_idx] = 1
if n_dim == 1:
X_outlier = X_outlier.reshape(n_samples,)
return Bunch(data=X_outlier, target=is_outlier, target_names=['normal', 'outlier'])
Loading

0 comments on commit e39bc70

Please sign in to comment.