Skip to content

Commit

Permalink
Added newer version of compare_essential_dynamics.py with ipca
Browse files Browse the repository at this point in the history
  • Loading branch information
oliserand committed Oct 19, 2023
1 parent 2af0436 commit 61aa6c4
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/compare_essential_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import mdtraj as md
from sklearn.decomposition import PCA
from sklearn.decomposition import IncrementalPCA
from sklearn.cluster import KMeans
import matplotlib
import seaborn as sns
Expand All @@ -20,7 +21,7 @@
__version__ = "1.2"
__maintainer__ = "Olivier Sheik Amamuddy"
__email__ = "[email protected]"
__date__ = "24th May 2021"
__date__ = "16th Oct 2021"

def parse_args():
"""
Expand Down Expand Up @@ -50,6 +51,8 @@ def parse_args():
parser.add_argument('--n_clusters', type=int, default=3,
help="The expected number of protein clusters to\
extract.")
parser.add_argument('--ipca', action="store_true",
help="Use incremental PCA. Reduces memory usage (default=off)")
parser.add_argument('--ignn', type=int, default=0,
help="The number of N-terminus residues to ignore in \
PCA calculations (default=0)")
Expand Down Expand Up @@ -175,13 +178,26 @@ def plot_graphs(pcs, outbasename, traj, title="Essential dynamics plot",
plt.tight_layout()
plt.savefig("{}.png".format(outbasename))

def write_pcs(outfilename, pcs_matrix, explained_variance_array):
"""Save PCs csv file"""
columns = []
for i, j in enumerate(explained_variance_array):
columns.append("PC{}:{}".format(i+1, np.round(j,3)))
columns = ",".join(columns)
np.savetxt(outfilename, pcs_matrix, delimiter=",",
comments="", header=columns)
print("INFO: Wrote PCs in {}".format(outfilename))

def main(args):
"""
Program main
"""
sns.set_style("whitegrid")
plt.rcParams["font.family"] = 'serif'
pca = PCA()
if args.ipca:
pca = IncrementalPCA()
else:
pca = PCA()
ntrajectories = len(args.trajectories)
num_frames = []
all_frames = []
Expand Down Expand Up @@ -209,7 +225,7 @@ def main(args):
check_compatibility(num_frames)
xyz_concat = np.concatenate(all_frames)
pca.fit(xyz_concat)
percent_variance = pca.explained_variance_ratio_[:2]*100
percent_variance = pca.explained_variance_ratio_*100
pcs = pca.transform(xyz_concat)
# Setting plot limits
xmin = np.min(pcs[:, 0])
Expand All @@ -230,7 +246,9 @@ def main(args):
# Plotting
plot_graphs(pcs=pcs_currtraj, traj=traj, outbasename=outbasename,
xlim=(xmin, xmax), ylim=(ymin, ymax),
percent_variance=percent_variance, n_clusters=n_clusters)
percent_variance=percent_variance[:2], n_clusters=n_clusters)
# Save PCs
write_pcs("{}_pcs.csv".format(outbasename), pcs_currtraj, percent_variance)

if __name__ == "__main__":
ARGS = parse_args()
Expand Down

0 comments on commit 61aa6c4

Please sign in to comment.