Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of NCH (Nearest Convex Hull) classifier #253

Merged
merged 47 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
624e942
Initial version of NearestConvexHull.
toncho11 Mar 4, 2024
cf3f6d9
Added script for testing.
toncho11 Mar 4, 2024
fd85f0a
First version that runs.
toncho11 Mar 4, 2024
4548b78
Improved code.
toncho11 Mar 4, 2024
0de3c40
Added support for parallel processing.
toncho11 Mar 5, 2024
37491eb
renamed
toncho11 Mar 5, 2024
1c1d17b
New version that uses a new class that implements a NCH classifier.
toncho11 Mar 5, 2024
dc5633e
small update
toncho11 Mar 5, 2024
1c4ae29
Merge branch 'pyRiemann:main' into main
toncho11 Mar 5, 2024
e07cd39
Updated to newest code - the new version of the distance function.
toncho11 Mar 5, 2024
a80ea8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
0f8136f
reinforce constraint on weights
gcattan Mar 5, 2024
f7cbe9f
- remove constraints on weights
gcattan Mar 5, 2024
60f5d58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
9aa2fb4
Added n_max_hull parameter. MOABB support tested.
toncho11 Mar 6, 2024
66aaca1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
f4f02bf
added multiple hulls.
toncho11 Mar 6, 2024
4465fc8
Multiple hull support. Stash and merge.
toncho11 Mar 7, 2024
c9357cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
8ab59af
Code cleanups.
toncho11 Mar 7, 2024
3ae1b1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
d7c6e1a
Improved code.
toncho11 Mar 8, 2024
4e8be61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
f4f836c
updated default parameters
toncho11 Mar 8, 2024
4d20109
General improvements.
toncho11 Mar 12, 2024
edc3561
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
1ffc5bc
removed commented code
toncho11 Mar 12, 2024
705164b
Small adjustments.
toncho11 Mar 12, 2024
c562d2a
Better class separation.
toncho11 Mar 12, 2024
b9d30fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
25bed43
Added support for n_samples_per_hull = -1 which takes all the samples…
toncho11 Mar 12, 2024
8eae303
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
f0d5153
Update pyriemann_qiskit/classification.py
toncho11 Mar 13, 2024
1d458c9
Update pyriemann_qiskit/classification.py
toncho11 Mar 13, 2024
a3533e4
Update pyriemann_qiskit/classification.py
toncho11 Mar 13, 2024
94e91b9
Update pyriemann_qiskit/classification.py
toncho11 Mar 13, 2024
a3cde26
Improvements proposed by Quentin.
toncho11 Mar 13, 2024
32bc8a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
3676e63
Added comment for the optimizer.
toncho11 Mar 13, 2024
a040bed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
b9837ae
Added some comments in classification.
toncho11 Mar 13, 2024
01d655c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
1eac251
Implemented min hull.
toncho11 Mar 15, 2024
67d3dc4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
fe9deb1
Reverted to previous version as requested by Gregoire.
toncho11 Mar 18, 2024
a7da2c5
fix lint issues
gcattan Mar 18, 2024
7047f46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions examples/ERP/classify_P300_nch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
====================================================================
Classification of P300 datasets from MOABB using NCH
====================================================================

Demonstrates classification with QunatumNCH.
Evaluation is done using MOABB.

If parameter "shots" is None then a classical SVM is used similar to the one
in scikit learn.
If "shots" is not None and IBM Qunatum token is provided with "q_account_token"
then a real Quantum computer will be used.
You also need to adjust the "n_components" in the PCA procedure to the number
of qubits supported by the real quantum computer you are going to use.
A list of real quantum computers is available in your IBM quantum account.

"""
# Author: Anton Andreev
# Modified from plot_classify_EEG_tangentspace.py of pyRiemann
# License: BSD (3-clause)

from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from sklearn.pipeline import make_pipeline
from matplotlib import pyplot as plt
import warnings
import seaborn as sns
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from moabb import set_log_level
from moabb.datasets import bi2012, BNCI2014009, bi2013a
from moabb.evaluations import WithinSessionEvaluation, CrossSubjectEvaluation
from moabb.paradigms import P300
from pyriemann_qiskit.pipelines import (
QuantumClassifierWithDefaultRiemannianPipeline,
)
from sklearn.decomposition import PCA
from pyriemann_qiskit.classification import QuanticNCH
from pyriemann.classification import MDM

print(__doc__)

##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

warnings.filterwarnings("ignore")

set_log_level("info")

##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.

##############################################################################
# We have to do this because the classes are called 'Target' and 'NonTarget'
# but the evaluation function uses a LabelEncoder, transforming them
# to 0 and 1
labels_dict = {"Target": 1, "NonTarget": 0}

paradigm = P300(resample=128)

datasets = [bi2013a()] # MOABB provides several other P300 datasets

# reduce the number of subjects, the Quantum pipeline takes a lot of time
# if executed on the entire dataset
n_subjects = 1
for dataset in datasets:
dataset.subject_list = dataset.subject_list[0:n_subjects]

overwrite = True # set to True if we want to overwrite cached results

pipelines = {}

pipelines["NCH"] = make_pipeline(
# applies XDawn and calculates the covariance matrix, output it matrices
XdawnCovariances(
nfilter=3,
classes=[labels_dict["Target"]],
estimator="lwf",
xdawn_estimator="scm",
),
QuanticNCH(n_hulls=3, n_samples_per_hull=15, n_jobs=12, quantum=False),
)

# this is a non quantum pipeline
pipelines["XD+MDM"] = make_pipeline(
XdawnCovariances(
nfilter=3,
classes=[labels_dict["Target"]],
estimator="lwf",
xdawn_estimator="scm",
),
MDM(),
)

print("Total pipelines to evaluate: ", len(pipelines))

evaluation = WithinSessionEvaluation(
paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite
)

results = evaluation.process(pipelines)

print("Averaging the session performance:")
print(results.groupby("pipeline").mean("score")[["score", "time"]])

##############################################################################
# Plot Results
# ----------------
#
# Here we plot the results to compare the two pipelines

fig, ax = plt.subplots(facecolor="white", figsize=[8, 4])

sns.stripplot(
data=results,
y="score",
x="pipeline",
ax=ax,
jitter=True,
alpha=0.5,
zorder=1,
palette="Set1",
)
sns.pointplot(data=results, y="score", x="pipeline", ax=ax, palette="Set1")

ax.set_ylabel("ROC AUC")
ax.set_ylim(0.3, 1)

plt.show()
Loading
Loading