Skip to content

Commit

Permalink
Update problem plot for the MultiFittingProblem (#602)
Browse files Browse the repository at this point in the history
* Add domain to MultiFittingProblem

* Update for early termination

* Update CHANGELOG.md

* Use markers only for MultiFittingProblems

* Update CHANGELOG.md

---------

Co-authored-by: Brady Planden <[email protected]>
  • Loading branch information
NicolaCourtier and BradyPlanden authored Jan 10, 2025
1 parent bfd4c7f commit c76a5ff
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

## Bug Fixes

- [#602](https://github.com/pybop-team/PyBOP/pull/602) - Aligns the standard quick plot of `MultiFittingProblem` outputs.

## Breaking Changes

# [v24.12](https://github.com/pybop-team/PyBOP/tree/v24.12) - 2024-12-21
Expand Down
56 changes: 28 additions & 28 deletions pybop/plot/problem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp
import numpy as np

from pybop import DesignProblem, FittingProblem
from pybop import DesignProblem, FittingProblem, MultiFittingProblem
from pybop.parameters.parameter import Inputs
from pybop.plot.standard_plots import StandardPlot

Expand Down Expand Up @@ -37,7 +37,8 @@ def quick(problem, problem_inputs: Inputs = None, show=True, **layout_kwargs):
problem_inputs = problem.parameters.verify(problem_inputs)

# Extract the time data and evaluate the model's output and target values
xaxis_data = problem.domain_data
domain = problem.domain
domain_data = problem.domain_data
model_output = problem.evaluate(problem_inputs)
target_output = problem.get_target()

Expand All @@ -49,47 +50,46 @@ def quick(problem, problem_inputs: Inputs = None, show=True, **layout_kwargs):

# Create a plot for each output
figure_list = []
for i in problem.signal:
default_layout_options = dict(
title="Scatter Plot",
xaxis_title="Time / s",
yaxis_title=StandardPlot.remove_brackets(i),
)

for signal in problem.signal:
# Create a plot dictionary
if isinstance(problem, DesignProblem):
trace_name = "Optimised"
opt_domain_data = model_output["Time [s]"]
else:
trace_name = "Model"
opt_domain_data = xaxis_data

plot_dict = StandardPlot(
x=opt_domain_data,
y=model_output[i],
layout_options=default_layout_options,
trace_names=trace_name,
layout_options=dict(
title="Scatter Plot",
xaxis_title="Time / s",
yaxis_title=StandardPlot.remove_brackets(signal),
)
)

model_trace = plot_dict.create_trace(
x=model_output[domain]
if domain in model_output.keys()
else domain_data[: len(model_output[signal])],
y=model_output[signal],
name="Optimised" if isinstance(problem, DesignProblem) else "Model",
mode="markers" if isinstance(problem, MultiFittingProblem) else "lines",
showlegend=True,
)
plot_dict.traces.append(model_trace)

target_trace = plot_dict.create_trace(
x=xaxis_data,
y=target_output[i],
x=domain_data,
y=target_output[signal],
name="Reference",
mode="markers",
showlegend=True,
)
plot_dict.traces.append(target_trace)

if isinstance(problem, FittingProblem) and len(model_output[i]) == len(
target_output[i]
if isinstance(problem, FittingProblem) and len(model_output[signal]) == len(
target_output[signal]
):
# Compute the standard deviation as proxy for uncertainty
plot_dict.sigma = np.std(model_output[i] - target_output[i])
plot_dict.sigma = np.std(model_output[signal] - target_output[signal])

# Convert x and upper and lower limits into lists to create a filled trace
x = xaxis_data.tolist()
y_upper = (model_output[i] + plot_dict.sigma).tolist()
y_lower = (model_output[i] - plot_dict.sigma).tolist()
x = domain_data.tolist()
y_upper = (model_output[signal] + plot_dict.sigma).tolist()
y_lower = (model_output[signal] - plot_dict.sigma).tolist()

fill_trace = plot_dict.create_trace(
x=x + x[::-1],
Expand Down
35 changes: 29 additions & 6 deletions pybop/problems/multi_fitting_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ def __init__(self, *args):
combined_parameters.join(problem.parameters)

# Combine the target datasets
domain = self.problems[0].domain
combined_domain_data = []
combined_signal = []
for problem in self.problems:
domain = problem.domain if problem.domain == domain else "Mixed domain"
for signal in problem.signal:
combined_domain_data.extend(problem.domain_data)
combined_signal.extend(problem.target[signal])
Expand All @@ -50,6 +52,7 @@ def __init__(self, *args):
signal=["Combined signal"],
)

self.domain = domain
combined_dataset = Dataset(
{
self.domain: np.asarray(combined_domain_data),
Expand Down Expand Up @@ -93,17 +96,27 @@ def evaluate(self, inputs: Inputs):
inputs = self.parameters.verify(inputs)
self.parameters.update(values=list(inputs.values()))

combined_domain = []
combined_signal = []

for problem in self.problems:
problem_inputs = problem.parameters.as_dict()
signal_values = problem.evaluate(problem_inputs)
problem_output = problem.evaluate(problem_inputs)
domain_data = (
problem_output[problem.domain]
if problem.domain in problem_output.keys()
else problem.domain_data[: len(problem_output[problem.signal[0]])]
)

# Collect signals
for signal in problem.signal:
combined_signal.extend(signal_values[signal])
combined_domain.extend(domain_data)
combined_signal.extend(problem_output[signal])

return {"Combined signal": np.asarray(combined_signal)}
return {
self.domain: np.asarray(combined_domain),
"Combined signal": np.asarray(combined_signal),
}

def evaluateS1(self, inputs: Inputs):
"""
Expand All @@ -123,19 +136,29 @@ def evaluateS1(self, inputs: Inputs):
inputs = self.parameters.verify(inputs)
self.parameters.update(values=list(inputs.values()))

combined_domain = []
combined_signal = []
all_derivatives = []

for problem in self.problems:
problem_inputs = problem.parameters.as_dict()
signal_values, dyi = problem.evaluateS1(problem_inputs)
problem_output, dyi = problem.evaluateS1(problem_inputs)
domain_data = (
problem_output[problem.domain]
if problem.domain in problem_output.keys()
else problem.domain_data[: len(problem_output[problem.signal[0]])]
)

# Collect signals and derivatives
for signal in problem.signal:
combined_signal.extend(signal_values[signal])
combined_domain.extend(domain_data)
combined_signal.extend(problem_output[signal])
all_derivatives.append(dyi)

y = {"Combined signal": np.asarray(combined_signal)}
y = {
self.domain: np.asarray(combined_domain),
"Combined signal": np.asarray(combined_signal),
}
dy = np.concatenate(all_derivatives) if all_derivatives else None

return (y, dy)

0 comments on commit c76a5ff

Please sign in to comment.