diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c15e7b5..93101dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index a5788bf..7dda0a0 100644 --- a/README.md +++ b/README.md @@ -16,23 +16,25 @@ The following libraries are required: - [NumPy](http://www.numpy.org/) - [OpenBabel](http://openbabel.org/) - [RMSD](https://github.com/charnley/rmsd) -- [QML](https://github.com/qmlcode/qml) - [SciPy](https://www.scipy.org/) - [scikit-learn](http://scikit-learn.org/stable/index.html) - [matplotlib](https://matplotlib.org/) -We use the development branch of `QML` as one of the reordering algorithms. -Since the development of `QML` has been quite slow, we provide our [own branch](https://github.com/hmcezar/qml/tree/develop) in which installation using modern versions of `numpy` is possible. +We also have [qmllib](https://github.com/qmlcode/qmllib) as an optional dependency as one of the reordering algorithms. For `openbabel`, we use the `pip` package `openbabel-wheel` which provides pre-built `openbabel` packages for Linux and MacOS. More details can be seen in the [projects' GitHub page](https://github.com/njzjz/openbabel-wheel). - You can install the package using `pip` ```bash pip install clusttraj ``` +If you want to use the `qmllib` reordering algorithm, you can install it with: +```bash +pip install clusttraj[qml] +``` + ## Usage To see all the options run the script with the `-h` command option: ```bash diff --git a/clusttraj/distmat.py b/clusttraj/distmat.py index 200fbae..5e475e0 100644 --- a/clusttraj/distmat.py +++ b/clusttraj/distmat.py @@ -195,7 +195,9 @@ def compute_distmat_line( # whereins = np.where( # np.isin(np.arange(natoms), reorderexcl[soluexcl]) is True # ) - whereins = np.where(np.atleast_1d(np.isin(np.arange(natoms), reorderexcl[soluexcl]))) + whereins = np.where( + np.atleast_1d(np.isin(np.arange(natoms), reorderexcl[soluexcl])) + ) Psolu = np.insert( Pview, [x - whereins[0].tolist().index(x) for x in whereins[0]], @@ -255,8 +257,14 @@ def compute_distmat_line( # reorder the solvent atoms separately if reorder: + # if the solute is specified, reorder just the solvent atoms in this step + if nsatoms: + exclusions = np.unique(np.concatenate((np.arange(natoms), reorderexcl))) + else: + exclusions = reorderexcl + # get the view without the excluded atoms - view = np.delete(np.arange(len(P)), reorderexcl) + view = np.delete(np.arange(len(P)), exclusions) Pview = P[view] Paview = Pa[view] @@ -264,12 +272,11 @@ def compute_distmat_line( Pview = Pview[prr] # build the total molecule with the reordered atoms - # whereins = np.where(np.isin(np.arange(len(P)), reorderexcl) is True) - whereins = np.where(np.atleast_1d(np.isin(np.arange(len(P)), reorderexcl))) + whereins = np.where(np.atleast_1d(np.isin(np.arange(len(P)), exclusions))) Pr = np.insert( Pview, [x - whereins[0].tolist().index(x) for x in whereins[0]], - P[reorderexcl], + P[exclusions], axis=0, ) diff --git a/clusttraj/io.py b/clusttraj/io.py index e1063f6..7c7c51e 100644 --- a/clusttraj/io.py +++ b/clusttraj/io.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from .utils import get_mol_info -if importlib.util.find_spec("qml"): +if importlib.util.find_spec("qmllib"): has_qml = True else: has_qml = False @@ -377,7 +377,7 @@ def configure_runtime(args_in: List[str]) -> ClustOptions: if (args.reorder_alg == "qml") and (not has_qml): parser.error( - "You must have the development branch of qml installed in order to use it as a reorder method." + "You must have the optional dependency `qmllib` installed in order to use it as a reorder method." ) if args.clusters_configurations: @@ -544,9 +544,11 @@ def parse_args(args: argparse.Namespace) -> ClustOptions: options_dict = { "solute_natoms": args.natoms_solute, - "reorder_excl": np.asarray([x - 1 for x in args.reorder_exclusions], np.int32) - if args.reorder_exclusions - else np.asarray([], np.int32), + "reorder_excl": ( + np.asarray([x - 1 for x in args.reorder_exclusions], np.int32) + if args.reorder_exclusions + else np.asarray([], np.int32) + ), "exclusions": bool(args.reorder_exclusions), "reorder_alg_name": args.reorder_alg, "reorder_alg": None, @@ -556,12 +558,12 @@ def parse_args(args: argparse.Namespace) -> ClustOptions: "out_clust_name": args.outputclusters, "summary_name": basenameout + ".out", "save_confs": bool(args.clusters_configurations), - "out_conf_name": basenameout + "_confs" - if args.clusters_configurations - else None, - "out_conf_fmt": args.clusters_configurations - if args.clusters_configurations - else None, + "out_conf_name": ( + basenameout + "_confs" if args.clusters_configurations else None + ), + "out_conf_fmt": ( + args.clusters_configurations if args.clusters_configurations else None + ), "plot": bool(args.plot), "evo_name": basenameout + "_evo.pdf" if args.plot else None, "dendrogram_name": basenameout + "_dendrogram.pdf" if args.plot else None, @@ -777,7 +779,9 @@ def save_clusters_config( # whereins = np.where( # np.isin(np.arange(natoms), reorderexcl[soluexcl]) is True # ) - whereins = np.where(np.atleast_1d(np.isin(np.arange(natoms), reorderexcl))) + whereins = np.where( + np.atleast_1d(np.isin(np.arange(natoms), reorderexcl)) + ) Psolu = np.insert( Pview, [x - whereins[0].tolist().index(x) for x in whereins[0]], @@ -840,8 +844,16 @@ def save_clusters_config( # reorder the solvent atoms separately if reorder: + # if the solute is specified, reorder just the solvent atoms in this step + if nsatoms: + exclusions = np.unique( + np.concatenate((np.arange(natoms), reorderexcl)) + ) + else: + exclusions = reorderexcl + # get the view without the excluded atoms - view = np.delete(np.arange(len(P)), reorderexcl) + view = np.delete(np.arange(len(P)), exclusions) Pview = P[view] Paview = Pa[view] @@ -849,12 +861,13 @@ def save_clusters_config( Pview = Pview[prr] # build the total molecule with the reordered atoms - # whereins = np.where(np.isin(np.arange(len(P)), reorderexcl) is True) - whereins = np.where(np.atleast_1d(np.isin(np.arange(len(P)), reorderexcl))) + whereins = np.where( + np.atleast_1d(np.isin(np.arange(len(P)), exclusions)) + ) Pr = np.insert( Pview, [x - whereins[0].tolist().index(x) for x in whereins[0]], - P[reorderexcl], + P[exclusions], axis=0, ) else: @@ -883,4 +896,6 @@ def save_clusters_config( # closes the file for the cnum cluster outfile.close() - # type: ignore + + +# type: ignore diff --git a/clusttraj/metrics.py b/clusttraj/metrics.py index 909eb54..2dd4a73 100644 --- a/clusttraj/metrics.py +++ b/clusttraj/metrics.py @@ -9,11 +9,9 @@ from scipy.cluster.hierarchy import cophenet from typing import Tuple import numpy as np -# from .io import ClustOptions def compute_metrics( - # clust_opt: ClustOptions, distmat: np.ndarray, z_matrix: np.ndarray, clusters: np.ndarray, @@ -21,7 +19,6 @@ def compute_metrics( """Compute metrics to assess the performance of the clustering procedure. Args: - # clust_opt (ClustOptions): The clustering options. distmat: The distance matrix. z_matrix (np.ndarray): The Z-matrix from hierarchical clustering procedure. clusters (np.ndarray): The cluster classifications for each sample. diff --git a/clusttraj/plot.py b/clusttraj/plot.py index 67c9f7c..9e15429 100644 --- a/clusttraj/plot.py +++ b/clusttraj/plot.py @@ -12,10 +12,7 @@ from .io import ClustOptions -def plot_clust_evo( - clust_opt: ClustOptions, - clusters: np.ndarray -) -> None: +def plot_clust_evo(clust_opt: ClustOptions, clusters: np.ndarray) -> None: """Plot the evolution of cluster classification over the given samples. Args: @@ -25,30 +22,34 @@ def plot_clust_evo( Returns: None """ - + # Define a color for the lines line_color = (0, 0, 0, 0.5) # plot evolution with o cluster in trajectory plt.figure(figsize=(10, 6)) - + # Set the y-axis to only show integers plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True)) # Increase tick size and font size - plt.tick_params(axis='both', which='major', direction='in', labelsize=12) + plt.tick_params(axis="both", which="major", direction="in", labelsize=12) plt.plot(range(1, len(clusters) + 1), clusters, markersize=4, color=line_color) - plt.scatter(range(1, len(clusters) + 1), clusters, marker="o", c=clusters, cmap=plt.cm.nipy_spectral) + plt.scatter( + range(1, len(clusters) + 1), + clusters, + marker="o", + c=clusters, + cmap=plt.cm.nipy_spectral, + ) plt.xlabel("Sample Index", fontsize=14) plt.ylabel("Cluster classification", fontsize=14) plt.savefig(clust_opt.evo_name, bbox_inches="tight") def plot_dendrogram( - clust_opt: ClustOptions, - clusters: np.ndarray, - Z: np.ndarray + clust_opt: ClustOptions, clusters: np.ndarray, Z: np.ndarray ) -> None: """Plot a dendrogram based on hierarchical clustering. @@ -65,7 +66,7 @@ def plot_dendrogram( plt.title("Hierarchical Clustering Dendrogram", fontsize=20) # plt.xlabel("Sample Index", fontsize=14) plt.ylabel(r"RMSD ($\AA$)", fontsize=18) - plt.tick_params(axis='y', labelsize=18) + plt.tick_params(axis="y", labelsize=18) # Define a color for the dashed and non-cluster lines line_color = (0, 0, 0, 0.5) @@ -73,10 +74,14 @@ def plot_dendrogram( # Add a horizontal line at the minimum RMSD value and set the threshold if clust_opt.silhouette_score: if isinstance(clust_opt.optimal_cut, (np.ndarray, list)): - plt.axhline(clust_opt.optimal_cut[0], linestyle="--", linewidth=2, color=line_color) + plt.axhline( + clust_opt.optimal_cut[0], linestyle="--", linewidth=2, color=line_color + ) threshold = clust_opt.optimal_cut[0] elif isinstance(clust_opt.optimal_cut, (float, np.float32, np.float64)): - plt.axhline(clust_opt.optimal_cut, linestyle="--", linewidth=2, color=line_color) + plt.axhline( + clust_opt.optimal_cut, linestyle="--", linewidth=2, color=line_color + ) threshold = clust_opt.optimal_cut else: raise ValueError("optimal_cut must be a float or np.ndarray") @@ -86,9 +91,9 @@ def plot_dendrogram( # Use the 'nipy_spectral' cmap to color the dendrogram unique_clusters = np.unique(clusters) - cmap = cm.get_cmap('nipy_spectral', len(unique_clusters)) + cmap = cm.get_cmap("nipy_spectral", len(unique_clusters)) colors = [to_hex(cmap(i)) for i in range(cmap.N)] - + hierarchy.set_link_color_palette(colors) # Plot the dendrogram @@ -98,18 +103,14 @@ def plot_dendrogram( # leaf_font_size=8.0, # Font size for the x axis labels no_labels=True, color_threshold=threshold, - above_threshold_color=line_color + above_threshold_color=line_color, ) # Save the dendrogram to a file plt.savefig(clust_opt.dendrogram_name, bbox_inches="tight") -def plot_mds( - clust_opt: ClustOptions, - clusters: np.ndarray, - distmat: np.ndarray -) -> None: +def plot_mds(clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray) -> None: """Plot the multidimensional scaling (MDS) of the distance matrix. Args: @@ -139,7 +140,7 @@ def plot_mds( coords = mds.fit_transform(squareform(distmat)) # Set the figure size - plt.figure(figsize=(6,6)) + plt.figure(figsize=(6, 6)) # Configure tick parameters plt.tick_params( @@ -165,9 +166,7 @@ def plot_mds( def plot_tsne( - clust_opt: ClustOptions, - clusters: np.ndarray, - distmat: np.ndarray + clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray ) -> None: """Plot the t-distributed Stochastic Neighbor Embedding 2D plot of the clustering. @@ -194,11 +193,11 @@ def plot_tsne( # Define a list of unique colors for each cluster unique_clusters = np.unique(clusters) - cmap = cm.get_cmap('nipy_spectral', len(unique_clusters)) + cmap = cm.get_cmap("nipy_spectral", len(unique_clusters)) colors = [cmap(i) for i in range(len(unique_clusters))] # Set the figure size - plt.figure(figsize=(6,6)) + plt.figure(figsize=(6, 6)) # Configure tick parameters plt.tick_params( diff --git a/pyproject.toml b/pyproject.toml index c60c430..ca978c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,10 @@ classifiers = [ "Intended Audience :: Science/Research", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Environment :: Console", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", @@ -43,8 +43,8 @@ clusttraj = "clusttraj.main:main" [project.optional-dependencies] test = ["pytest", "pytest-cov[all]"] docs = ["sphinx", "sphinx_rtd_theme"] -lint = ["ruff", "black"] -qml = ["qml"] # you can install from git+https://github.com/hmcezar/qml@develop +lint = ["black"] +qml = ["qmllib"] all = ["clusttraj[test,docs,lint,qml]"] [tool.setuptools] @@ -52,7 +52,3 @@ packages = ["clusttraj"] [tool.black] line-length = 89 - -[tool.ruff] -line-length = 89 -ignore = ["E501"] diff --git a/requirements.txt b/requirements.txt index 0f73374..19a7d0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -rmsd +rmsd>=1.6.0 scipy scikit-learn matplotlib diff --git a/test/test_plot.py b/test/test_plot.py index 8e66f9d..3198b16 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -8,8 +8,8 @@ def test_plot_clust_evo(clust_opt, clusters_seq): assert os.path.exists(clust_opt.evo_name) -def test_plot_dendrogram(clust_opt, Z_matrix): - assert plot_dendrogram(clust_opt, Z_matrix) is None +def test_plot_dendrogram(clust_opt, clusters_seq, Z_matrix): + assert plot_dendrogram(clust_opt, clusters_seq, Z_matrix) is None assert os.path.exists(clust_opt.dendrogram_name)