Skip to content

Commit

Permalink
pert added to fitness plot
Browse files Browse the repository at this point in the history
  • Loading branch information
gabora committed Jun 17, 2024
1 parent f4732cb commit 0847cb4
Showing 1 changed file with 71 additions and 16 deletions.
87 changes: 71 additions & 16 deletions corneto/methods/signal/cellnopt_ilp.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ def plot_solution_network_active_edges(G, P, iexp):
)



def plot_fitness(G, exp_list, P, measured_only=False,**kwargs):
def plot_fitness(G, exp_list, P, measured_only=False, **kwargs):
"""Plot the fitness of the model simulation vs measurements.
PARAMETERS:
Expand All @@ -368,21 +367,52 @@ def plot_fitness(G, exp_list, P, measured_only=False,**kwargs):
import matplotlib.pyplot as plt

N_exps = len(exp_list)

# Ensure that all experiments have the input and output variables
for exp in exp_list.values():
if "input" not in exp:
raise ValueError("Input not found in experiment")
if "output" not in exp:
raise ValueError("Output not found in experiment")

# Collect the input and output variables
input_matrix, input_vars = collect_field_into_matrix(exp_list, "input")
output_matrix, output_vars = collect_field_into_matrix(exp_list, "output")

# Check if inhibition is present in any of the experiments
inhibition_present = any("inhibition" in exp for exp in exp_list.values())
if inhibition_present:
inhibition_matrix, inhibition_vars = collect_field_into_matrix(
exp_list, "inhibition"
)
perturbation_matrix = np.hstack((input_matrix, inhibition_matrix))
perturbation_vars = input_vars + inhibition_vars
else:
perturbation_matrix = input_matrix
perturbation_vars = input_vars

# Create the figure
# Set colors: input colors are blue, inhibition colors are red
perturbation_colors = ["blue"] * len(input_vars)
if inhibition_present:
perturbation_colors += ["red"] * len(inhibition_vars)

N_nodes = len(G.V)
output_names = list(
{key for exp in exp_list.values() for key in exp["output"].keys()}
)

# depending on the flag measured_only, we can plot only the measured nodes or all nodes
if measured_only:
fig, axs = plt.subplots(
N_exps - 1, len(output_names), squeeze=False, **kwargs
)
fig, axs = plt.subplots(N_exps - 1, len(output_names)+1, squeeze=False, **kwargs)
else:
fig, axs = plt.subplots(N_exps - 1, N_nodes, squeeze=False, **kwargs)
fig, axs = plt.subplots(N_exps - 1, N_nodes+1, squeeze=False, **kwargs)

fig.tight_layout(pad=0.0)

# Adjust the space between subplots
plt.subplots_adjust(wspace=0.1, hspace=0.1)

for exp, iexp in zip(exp_list, range(N_exps)):
if iexp == 0:
continue
Expand All @@ -399,7 +429,7 @@ def plot_fitness(G, exp_list, P, measured_only=False,**kwargs):
min(P.expr.vertex_value.value[imarker_inG, iexp], 1),
],
"bo-",
label=G.V[imarker_inG]
label=G.V[imarker_inG],
)

if G.V[imarker_inG] in exp_list[exp]["output"].keys():
Expand All @@ -413,9 +443,13 @@ def plot_fitness(G, exp_list, P, measured_only=False,**kwargs):
)
axs[iexp - 1, imarker].set_ylim([-0.05, 1.1])
if iexp == 1:
axs[iexp - 1, imarker].set_title(G.V[imarker_inG])
axs[iexp - 1, imarker].set_title(output_names[imarker])
if iexp != N_exps - 1:
axs[iexp - 1, imarker].set_xticks([])
if imarker == 0:
axs[iexp - 1, imarker].set_ylabel(f"Experiment {iexp}")
axs[iexp - 1, imarker].set_ylabel(f"Exp. {iexp}")
else:
axs[iexp - 1, imarker].set_yticks([])
else:
for imarker in range(N_nodes):
axs[iexp - 1, imarker].plot(
Expand All @@ -426,8 +460,8 @@ def plot_fitness(G, exp_list, P, measured_only=False,**kwargs):
],
"bo-",
label=G.V[imarker],
#color="blue",
#linestyle="o-",
# color="blue",
# linestyle="o-",
)

if G.V[imarker] in exp_list[exp]["output"].keys():
Expand All @@ -442,12 +476,36 @@ def plot_fitness(G, exp_list, P, measured_only=False,**kwargs):
axs[iexp - 1, imarker].set_ylim([-0.05, 1.1])
if iexp == 1:
axs[iexp - 1, imarker].set_title(G.V[imarker])
if iexp != N_exps - 1:
axs[iexp - 1, imarker].set_xticks([])
if imarker == 0:
axs[iexp - 1, imarker].set_ylabel(f"Experiment {iexp}")
axs[iexp - 1, imarker].set_ylabel(f"Exp. {iexp}")
else:
axs[iexp - 1, imarker].set_yticks([])
# Plot perturbation
axs[iexp - 1, len(output_vars)].bar(
range(len(perturbation_vars)),
perturbation_matrix[iexp],
color=perturbation_colors,
)
if iexp == N_exps - 1:
axs[iexp - 1, len(output_vars)].set_xticks(range(len(perturbation_vars)))
axs[iexp - 1, len(output_vars)].set_xticklabels(
perturbation_vars, rotation=45
)
else:
# No xtick label
axs[iexp - 1, len(output_vars)].set_xticks([])

axs[iexp - 1, len(output_vars)].set_ylim([-0.01, 1.1])
axs[iexp - 1, len(output_vars)].set_yticks([])
if iexp == 1:
axs[iexp - 1, len(output_vars)].set_title("Pert.")

plt.show()



def collect_field_into_matrix(experiments, field_name="input"):
"""Collects the field_name values into matrix.
Expand Down Expand Up @@ -541,10 +599,7 @@ def plot_data(exp_list):

axs[iexp - 1, imarker].plot(
[0, 10],
[
output_matrix[0, imarker],
output_matrix[iexp, imarker]
],
[output_matrix[0, imarker], output_matrix[iexp, imarker]],
"ro-",
)
axs[iexp - 1, imarker].set_ylim([-0.01, 1.1])
Expand Down

0 comments on commit 0847cb4

Please sign in to comment.