Skip to content

Commit

Permalink
Merge pull request #15 from hmcezar/weighted-solvent
Browse files Browse the repository at this point in the history
Weighted solute-solvent contributions
  • Loading branch information
hmcezar authored Dec 26, 2024
2 parents fd0badb + 7a1dc77 commit 5378d56
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 65 deletions.
55 changes: 22 additions & 33 deletions clusttraj/distmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def build_distance_matrix(clust_opt: ClustOptions) -> np.ndarray:
itertools.repeat(clust_opt.trajfile),
itertools.repeat(clust_opt.no_hydrogen),
itertools.repeat(clust_opt.reorder_alg),
itertools.repeat(clust_opt.reorder_solvent_only),
itertools.repeat(clust_opt.solute_natoms),
itertools.repeat(clust_opt.weight_solute),
itertools.repeat(clust_opt.reorder_excl),
itertools.repeat(clust_opt.final_kabsch),
)
Expand All @@ -86,7 +88,9 @@ def compute_distmat_line(
reorder: Union[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray], None
],
reorder_solvent_only: bool,
nsatoms: int,
weight_solute: float,
reorderexcl: np.ndarray,
final_kabsch: bool,
) -> List[float]:
Expand Down Expand Up @@ -179,7 +183,7 @@ def compute_distmat_line(
P = np.dot(P, U)

# reorder solute atoms
if reorder:
if reorder and not reorder_solvent_only:
# find the solute atoms that are not excluded
soluexcl = np.where(reorderexcl < natoms)
soluteview = np.delete(np.arange(natoms), reorderexcl[soluexcl])
Expand Down Expand Up @@ -220,36 +224,6 @@ def compute_distmat_line(
# rotate the whole system with this rotation
P = np.dot(P, U)

# consider only the solvent atoms in the reorder (without exclusions)
# solvexcl = np.where(reorderexcl >= natoms)
# solvview = np.delete(np.arange(natoms, len(P)), reorderexcl[solvexcl])
# Pview = P[solvview]
# Paview = Pa[solvview]

# # reorder just these atoms
# prr = reorder(Qa[solvview], Paview, Q[solvview], Pview)
# Pview = Pview[prr]
# Paview = Paview[prr]

# # build the total molecule with the reordered atoms
# whereins = np.where(
# np.isin(np.arange(natoms, len(P)), reorderexcl[solvexcl]) == True
# )
# Psolv = np.insert(
# Pview,
# [x - whereins[0].tolist().index(x) for x in whereins[0]],
# P[reorderexcl[solvexcl]],
# axis=0,
# )
# Pasolv = np.insert(
# Paview,
# [x - whereins[0].tolist().index(x) for x in whereins[0]],
# Pa[reorderexcl[solvexcl]],
# axis=0,
# )

# Pr = np.concatenate((P[:natoms], Psolv))
# Pra = np.concatenate((Pa[:natoms], Pasolv))
else:
# Kabsch rotation
U = rmsd.kabsch(P, Q)
Expand Down Expand Up @@ -283,10 +257,25 @@ def compute_distmat_line(
else:
Pr = P

# compute the weights
if weight_solute:
W = np.zeros(Pr.shape[0])
W[:natoms] = weight_solute / natoms
W[natoms:] = (1.0 - weight_solute) / (Pr.shape[0] - natoms)

# for solute solvent alignement, compute RMSD without Kabsch
if nsatoms and reorder and not final_kabsch:
distmat.append(rmsd.rmsd(Pr, Q))
if weight_solute:
diff = Pr - Q
distmat.append(
np.sqrt(np.dot(W, np.sum(diff * diff, axis=1)) / Pr.shape[0])
)
else:
distmat.append(rmsd.rmsd(Pr, Q))
else:
distmat.append(rmsd.kabsch_rmsd(Pr, Q))
if weight_solute:
distmat.append(rmsd.kabsch_weighted_rmsd(Pr, Q, W))
else:
distmat.append(rmsd.kabsch_rmsd(Pr, Q))

return distmat
83 changes: 51 additions & 32 deletions clusttraj/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ClustOptions:
] = None
out_conf_fmt: str = None
reorder: bool = None
reorder_solvent_only: bool = None
exclusions: bool = None
no_hydrogen: bool = None
input_distmat: bool = None
Expand All @@ -53,6 +54,7 @@ class ClustOptions:
out_conf_name: str = None
summary_name: str = None
solute_natoms: int = None
weight_solute: float = None
reorder_excl: np.ndarray = None
optimal_cut: np.ndarray = None
verbose: bool = None
Expand Down Expand Up @@ -97,10 +99,16 @@ def __str__(self) -> str:
# reordering options
if self.reorder:
if self.solute_natoms:
return_str += "\nUsing solute-solvent reordering\n"
return_str += "\nUsing solute-solvent adjustment/reordering\n"
if self.final_kabsch:
return_str += "Using final Kabsch rotation before computing RMSD\n"
return_str += f"Number of solute atoms: {self.solute_natoms}\n"
if self.weight_solute:
return_str += f"Weight of the solute atoms: {self.weight_solute}\n"
else:
return_str += "Unweighted RMSD according to solute/solvent.\n"
if self.reorder_solvent_only:
return_str += "Reordering only solvent atoms\n"
else:
return_str += "\nReordering all atom at the same time\n"

Expand Down Expand Up @@ -274,6 +282,12 @@ def configure_runtime(args_in: List[str]) -> ClustOptions:
type=check_positive,
help="list of atoms that are ignored when reordering",
)
parser.add_argument(
"-rs",
"--reorder-solvent-only",
action="store_true",
help="reorder only the solvent atoms",
)
parser.add_argument(
"--reorder-alg",
action="store",
Expand All @@ -288,6 +302,13 @@ def configure_runtime(args_in: List[str]) -> ClustOptions:
type=check_positive,
help="number of solute atoms, to ignore these atoms in the reordering process",
)
parser.add_argument(
"-ws",
"--weight-solute",
metavar="WEIGHT SOLUTE",
type=float,
help="weight of the solute atoms in the RMSD calculation",
)
parser.add_argument(
"-odl",
"--optimal-ordering",
Expand Down Expand Up @@ -525,6 +546,20 @@ def configure_runtime(args_in: List[str]) -> ClustOptions:
"The list of atoms to exclude for reordering only makes sense if reordering is enabled. Ignoring the list."
)

if args.reorder_solvent_only and not args.natoms_solute:
parser.error(
"You must specify the number of solute atoms with -ns to use the -rs option."
)

if args.weight_solute and not args.natoms_solute:
parser.error(
"You must specify the number of solute atoms with -ns to use the -ws option."
)

if args.weight_solute:
if args.weight_solute < 0.0 or args.weight_solute > 1.0:
parser.error("The weight of the solute atoms must be between 0 and 1.")

return parse_args(args)


Expand All @@ -544,6 +579,7 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:

options_dict = {
"solute_natoms": args.natoms_solute,
"weight_solute": args.weight_solute,
"reorder_excl": (
np.asarray([x - 1 for x in args.reorder_exclusions], np.int32)
if args.reorder_exclusions
Expand All @@ -552,6 +588,7 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:
"exclusions": bool(args.reorder_exclusions),
"reorder_alg_name": args.reorder_alg,
"reorder_alg": None,
"reorder_solvent_only": bool(args.reorder_solvent_only),
"reorder": bool(args.reorder),
"input_distmat": bool(args.input),
"distmat_name": args.outputdistmat if not args.input else args.input,
Expand Down Expand Up @@ -612,7 +649,9 @@ def save_clusters_config(
reorder: Union[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray], None
],
reorder_solvent_only: bool,
nsatoms: int,
weight_solute: float,
outbasename: str,
outfmt: str,
reorderexcl: np.ndarray,
Expand Down Expand Up @@ -763,7 +802,7 @@ def save_clusters_config(
p_all = np.dot(p_all, U)

# reorder solute atoms
if reorder:
if reorder and not reorder_solvent_only:
# find the solute atoms that are not excluded
soluexcl = np.where(reorderexcl < natoms)
soluteview = np.delete(np.arange(natoms), reorderexcl[soluexcl])
Expand Down Expand Up @@ -807,35 +846,6 @@ def save_clusters_config(
P = np.dot(P, U)
p_all = np.dot(p_all, U)

# consider only the solvent atoms in the reorder (without exclusions)
# solvexcl = np.where(reorderexcl >= natoms)
# solvview = np.delete(np.arange(natoms, len(P)), reorderexcl[solvexcl])
# Pview = P[solvview]
# Paview = Pa[solvview]

# # reorder just these atoms
# prr = reorder(Qa[solvview], Paview, Q[solvview], Pview)
# Pview = Pview[prr]
# Paview = Paview[prr]
# # build the total molecule with the reordered atoms
# whereins = np.where(
# np.isin(np.arange(natoms, len(P)), reorderexcl[solvexcl]) == True
# )
# Psolv = np.insert(
# Pview,
# [x - whereins[0].tolist().index(x) for x in whereins[0]],
# P[reorderexcl[solvexcl]],
# axis=0,
# )
# Pasolv = np.insert(
# Paview,
# [x - whereins[0].tolist().index(x) for x in whereins[0]],
# Pa[reorderexcl[solvexcl]],
# axis=0,
# )

# Pr = np.concatenate((P[:natoms], Psolv))
# Pra = np.concatenate((Pa[:natoms], Pasolv))
else:
# Kabsch rotation
U = rmsd.kabsch(P, Q)
Expand Down Expand Up @@ -873,7 +883,16 @@ def save_clusters_config(
else:
Pr = P

if nsatoms and reorder and not final_kabsch:
# compute the weights
if weight_solute and final_kabsch:
W = np.zeros(Pr.shape[0])
W[:natoms] = weight_solute / natoms
W[natoms:] = (1.0 - weight_solute) / (Pr.shape[0] - natoms)

R, T, _ = rmsd.kabsch_weighted(Q, Pr, W)
p_all = np.dot(p_all, R.T) + T

elif nsatoms and reorder and final_kabsch:
# rotate whole configuration (considering hydrogens even with noh)
U = rmsd.kabsch(Pr, Q)
p_all = np.dot(p_all, U)
Expand Down
2 changes: 2 additions & 0 deletions clusttraj/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def main(args: List[str] = None) -> None:
distmat,
clust_opt.no_hydrogen,
clust_opt.reorder_alg,
clust_opt.reorder_solvent_only,
clust_opt.solute_natoms,
clust_opt.weight_solute,
clust_opt.out_conf_name,
clust_opt.out_conf_fmt,
clust_opt.reorder_excl,
Expand Down
2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def options_dict(tmp_path):
"reorder_alg_name": "hungarian",
"reorder_alg": None,
"reorder": False,
"reorder_solvent_only": False,
"input_distmat": False,
"distmat_name": "test/ref/test_distmat.npy",
"summary_name": os.path.join(tmp_path, "clusters.out"),
Expand All @@ -48,6 +49,7 @@ def options_dict(tmp_path):
"no_hydrogen": True,
"opt_order": False,
"solute_natoms": 17,
"weight_solute": None,
"overwrite": True,
"final_kabsch": False,
"silhouette_score": False,
Expand Down
31 changes: 31 additions & 0 deletions test/test_distmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def test_compute_distmat_line(options_dict, clust_opt, first_conf_traj, test_dis
clust_opt.trajfile,
clust_opt.no_hydrogen,
None,
clust_opt.reorder_solvent_only,
clust_opt.solute_natoms,
clust_opt.weight_solute,
clust_opt.reorder_excl,
clust_opt.final_kabsch,
)
Expand All @@ -35,6 +37,35 @@ def test_compute_distmat_line(options_dict, clust_opt, first_conf_traj, test_dis
assert line[1] == pytest.approx(test_distmat[1], abs=1e-8)


def test_compute_distmat_line_reorder_weight_solute(
options_dict, clust_opt, first_conf_traj
):
# test just the options
# TODO: need distmat to check
clust_opt.reorder_alg = "hungarian"
clust_opt.reorder_solvent_only = True
clust_opt.solute_natoms = 17
clust_opt.weight_solute = 0.8
clust_opt.reorder_excl = [1, 2, 3]
clust_opt.final_kabsch = True
clust_opt.no_hydrogen = False

line = compute_distmat_line(
0,
get_mol_info(first_conf_traj),
clust_opt.trajfile,
clust_opt.no_hydrogen,
None,
clust_opt.reorder_solvent_only,
clust_opt.solute_natoms,
clust_opt.weight_solute,
clust_opt.reorder_excl,
clust_opt.final_kabsch,
)

assert len(line) == 2


def test_build_distance_matrix(clust_opt, test_distmat):
distmat = build_distance_matrix(clust_opt)
assert len(distmat) == 3
Expand Down
8 changes: 8 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_ClustOptions(options_dict):
assert clust_opt.reorder_alg_name == "hungarian"
assert clust_opt.reorder_alg is None
assert clust_opt.reorder is False
assert clust_opt.reorder_solvent_only is False
assert clust_opt.input_distmat is False
assert clust_opt.distmat_name == "test/ref/test_distmat.npy"
assert os.path.basename(clust_opt.summary_name) == "clusters.out"
Expand All @@ -40,6 +41,7 @@ def test_ClustOptions(options_dict):
assert clust_opt.no_hydrogen is True
assert clust_opt.opt_order is False
assert clust_opt.solute_natoms == 17
assert clust_opt.weight_solute is None
assert clust_opt.overwrite is True
assert clust_opt.final_kabsch is False
assert clust_opt.silhouette_score is False
Expand All @@ -65,9 +67,11 @@ def test_extant_file():
def test_parse_args():
args = argparse.Namespace(
natoms_solute=10,
weight_solute=None,
reorder_exclusions=[1, 2, 3],
reorder_alg="hungarian",
reorder=False,
reorder_solvent_only=False,
input=True,
outputdistmat="distmat.npy",
outputclusters="clusters.dat",
Expand All @@ -92,9 +96,11 @@ def test_parse_args():

args = argparse.Namespace(
natoms_solute=10,
weight_solute=None,
reorder_exclusions=[1, 2, 3],
reorder_alg="hungarian",
reorder=True,
reorder_solvent_only=False,
input=True,
outputdistmat="distmat.npy",
outputclusters="clusters.dat",
Expand Down Expand Up @@ -167,7 +173,9 @@ def test_save_clusters_config(clust_opt, clusters_seq, test_distmat):
test_distmat,
clust_opt.no_hydrogen,
clust_opt.reorder_alg,
clust_opt.reorder_solvent_only,
clust_opt.solute_natoms,
clust_opt.weight_solute,
clust_opt.out_conf_name,
clust_opt.out_conf_fmt,
clust_opt.reorder_excl,
Expand Down

0 comments on commit 5378d56

Please sign in to comment.