Skip to content

Commit

Permalink
Improve performance of hamming loss
Browse files Browse the repository at this point in the history
  • Loading branch information
pablormier committed Sep 15, 2022
1 parent 415d0be commit 7819f1b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 68 deletions.
8 changes: 6 additions & 2 deletions corneto/backend/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def solve(
if len(p.weights) != len(p.objectives):
raise ValueError("Number of weights must match number of objectives")
# auto-convert to a weighted sum
# TODO: PICOS has issues with sum, change to matrix notation?
# type: ignore
o = sum(
p.weights[i] * p.objectives[i] if p.weights[i] != 0.0 else 0.0 # type: ignore
Expand Down Expand Up @@ -671,7 +670,12 @@ def _build_problem(self, other: ProblemDef) -> ProblemDef:
idx_one = np.where(x == 1)[0]
idx_zero = np.where(x == 0)[0]
P = ProblemDef()
P.add_objectives((sum(y[idx_zero] - x[idx_zero]) + sum(x[idx_one] - y[idx_one])) * self.penalty, inplace=True) # type: ignore
diff_zeros = y[idx_zero] - x[idx_zero]
diff_ones = x[idx_one] - y[idx_one]
hamming_dist = np.ones(diff_zeros.shape) @ diff_zeros + np.ones(diff_ones.shape) @ diff_ones
#P.add_objectives((sum(diff_zeros) + sum(diff_ones)) * self.penalty, inplace=True) # type: ignore
P.add_objectives(hamming_dist, weights=self.penalty, inplace=True) # type: ignore

return P


Expand Down
12 changes: 3 additions & 9 deletions corneto/methods/carnival.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def carnival_constraints(
# Constrain the product species of the reactions. They can be only up or
# down if at least one of the reactions that have the node as product
# carry some signal.
S_clipped = rn.stoichiometry.clip(0, 1)
p += N_act <= S_clipped @ R_act
p += N_inh <= S_clipped @ R_inh
incidence_matrix = rn.stoichiometry.clip(0, 1)
p += N_act <= incidence_matrix @ R_act
p += N_inh <= incidence_matrix @ R_inh
return p


Expand Down Expand Up @@ -331,18 +331,12 @@ def carnival_loss(
)
# TODO: Issues with sum and PICOS (https://gitlab.com/picos-api/picos/-/issues/330)
# override sum with picos.sum method
#losses.append(sum(Fi)) # type: ignore
#losses.append(Fi.T @ np.ones((1, Fi.shape[0])))
losses.append(np.ones(Fi.shape) @ Fi)
weights.append(l0_penalty_reaction)
if l1_penalty_reaction > 0:
#losses.append(sum(F)) # type: ignore
#losses.append(F.T @ np.ones((1, F.shape[0])))
losses.append(np.ones(F.shape) @ F)
weights.append(l1_penalty_reaction)
if l0_penalty_species > 0:
#losses.append(sum(N_act + N_inh))
#losses.append((N_act + N_inh).T @ np.ones((1, N_act.shape[0])))
losses.append(np.ones(N_act.shape) @ (N_act + N_inh))
weights.append(l0_penalty_species)
# Add objective and weights to p
Expand Down
114 changes: 57 additions & 57 deletions tests/notebooks/tutorial.ipynb

Large diffs are not rendered by default.

0 comments on commit 7819f1b

Please sign in to comment.