From c76a5ff24377b4b7c3448a3e4fe8bb8ae7166e5a Mon Sep 17 00:00:00 2001 From: NicolaCourtier <45851982+NicolaCourtier@users.noreply.github.com> Date: Fri, 10 Jan 2025 13:07:48 +0000 Subject: [PATCH] Update problem plot for the MultiFittingProblem (#602) * Add domain to MultiFittingProblem * Update for early termination * Update CHANGELOG.md * Use markers only for MultiFittingProblems * Update CHANGELOG.md --------- Co-authored-by: Brady Planden <55357039+BradyPlanden@users.noreply.github.com> --- CHANGELOG.md | 2 + pybop/plot/problem.py | 56 ++++++++++++------------- pybop/problems/multi_fitting_problem.py | 35 +++++++++++++--- 3 files changed, 59 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 564deceb..b47d1675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ ## Bug Fixes +- [#602](https://github.com/pybop-team/PyBOP/pull/602) - Aligns the standard quick plot of `MultiFittingProblem` outputs. + ## Breaking Changes # [v24.12](https://github.com/pybop-team/PyBOP/tree/v24.12) - 2024-12-21 diff --git a/pybop/plot/problem.py b/pybop/plot/problem.py index fc869822..a14b82f0 100644 --- a/pybop/plot/problem.py +++ b/pybop/plot/problem.py @@ -1,7 +1,7 @@ import jax.numpy as jnp import numpy as np -from pybop import DesignProblem, FittingProblem +from pybop import DesignProblem, FittingProblem, MultiFittingProblem from pybop.parameters.parameter import Inputs from pybop.plot.standard_plots import StandardPlot @@ -37,7 +37,8 @@ def quick(problem, problem_inputs: Inputs = None, show=True, **layout_kwargs): problem_inputs = problem.parameters.verify(problem_inputs) # Extract the time data and evaluate the model's output and target values - xaxis_data = problem.domain_data + domain = problem.domain + domain_data = problem.domain_data model_output = problem.evaluate(problem_inputs) target_output = problem.get_target() @@ -49,47 +50,46 @@ def quick(problem, problem_inputs: Inputs = None, show=True, **layout_kwargs): # Create a plot for each output figure_list = [] - for i in problem.signal: - default_layout_options = dict( - title="Scatter Plot", - xaxis_title="Time / s", - yaxis_title=StandardPlot.remove_brackets(i), - ) - + for signal in problem.signal: # Create a plot dictionary - if isinstance(problem, DesignProblem): - trace_name = "Optimised" - opt_domain_data = model_output["Time [s]"] - else: - trace_name = "Model" - opt_domain_data = xaxis_data - plot_dict = StandardPlot( - x=opt_domain_data, - y=model_output[i], - layout_options=default_layout_options, - trace_names=trace_name, + layout_options=dict( + title="Scatter Plot", + xaxis_title="Time / s", + yaxis_title=StandardPlot.remove_brackets(signal), + ) + ) + + model_trace = plot_dict.create_trace( + x=model_output[domain] + if domain in model_output.keys() + else domain_data[: len(model_output[signal])], + y=model_output[signal], + name="Optimised" if isinstance(problem, DesignProblem) else "Model", + mode="markers" if isinstance(problem, MultiFittingProblem) else "lines", + showlegend=True, ) + plot_dict.traces.append(model_trace) target_trace = plot_dict.create_trace( - x=xaxis_data, - y=target_output[i], + x=domain_data, + y=target_output[signal], name="Reference", mode="markers", showlegend=True, ) plot_dict.traces.append(target_trace) - if isinstance(problem, FittingProblem) and len(model_output[i]) == len( - target_output[i] + if isinstance(problem, FittingProblem) and len(model_output[signal]) == len( + target_output[signal] ): # Compute the standard deviation as proxy for uncertainty - plot_dict.sigma = np.std(model_output[i] - target_output[i]) + plot_dict.sigma = np.std(model_output[signal] - target_output[signal]) # Convert x and upper and lower limits into lists to create a filled trace - x = xaxis_data.tolist() - y_upper = (model_output[i] + plot_dict.sigma).tolist() - y_lower = (model_output[i] - plot_dict.sigma).tolist() + x = domain_data.tolist() + y_upper = (model_output[signal] + plot_dict.sigma).tolist() + y_lower = (model_output[signal] - plot_dict.sigma).tolist() fill_trace = plot_dict.create_trace( x=x + x[::-1], diff --git a/pybop/problems/multi_fitting_problem.py b/pybop/problems/multi_fitting_problem.py index 72dd70c3..e2dd38c2 100644 --- a/pybop/problems/multi_fitting_problem.py +++ b/pybop/problems/multi_fitting_problem.py @@ -37,9 +37,11 @@ def __init__(self, *args): combined_parameters.join(problem.parameters) # Combine the target datasets + domain = self.problems[0].domain combined_domain_data = [] combined_signal = [] for problem in self.problems: + domain = problem.domain if problem.domain == domain else "Mixed domain" for signal in problem.signal: combined_domain_data.extend(problem.domain_data) combined_signal.extend(problem.target[signal]) @@ -50,6 +52,7 @@ def __init__(self, *args): signal=["Combined signal"], ) + self.domain = domain combined_dataset = Dataset( { self.domain: np.asarray(combined_domain_data), @@ -93,17 +96,27 @@ def evaluate(self, inputs: Inputs): inputs = self.parameters.verify(inputs) self.parameters.update(values=list(inputs.values())) + combined_domain = [] combined_signal = [] for problem in self.problems: problem_inputs = problem.parameters.as_dict() - signal_values = problem.evaluate(problem_inputs) + problem_output = problem.evaluate(problem_inputs) + domain_data = ( + problem_output[problem.domain] + if problem.domain in problem_output.keys() + else problem.domain_data[: len(problem_output[problem.signal[0]])] + ) # Collect signals for signal in problem.signal: - combined_signal.extend(signal_values[signal]) + combined_domain.extend(domain_data) + combined_signal.extend(problem_output[signal]) - return {"Combined signal": np.asarray(combined_signal)} + return { + self.domain: np.asarray(combined_domain), + "Combined signal": np.asarray(combined_signal), + } def evaluateS1(self, inputs: Inputs): """ @@ -123,19 +136,29 @@ def evaluateS1(self, inputs: Inputs): inputs = self.parameters.verify(inputs) self.parameters.update(values=list(inputs.values())) + combined_domain = [] combined_signal = [] all_derivatives = [] for problem in self.problems: problem_inputs = problem.parameters.as_dict() - signal_values, dyi = problem.evaluateS1(problem_inputs) + problem_output, dyi = problem.evaluateS1(problem_inputs) + domain_data = ( + problem_output[problem.domain] + if problem.domain in problem_output.keys() + else problem.domain_data[: len(problem_output[problem.signal[0]])] + ) # Collect signals and derivatives for signal in problem.signal: - combined_signal.extend(signal_values[signal]) + combined_domain.extend(domain_data) + combined_signal.extend(problem_output[signal]) all_derivatives.append(dyi) - y = {"Combined signal": np.asarray(combined_signal)} + y = { + self.domain: np.asarray(combined_domain), + "Combined signal": np.asarray(combined_signal), + } dy = np.concatenate(all_derivatives) if all_derivatives else None return (y, dy)