diff --git a/src/rpad/visualize_3d/plots.py b/src/rpad/visualize_3d/plots.py index fac7e75..466df96 100644 --- a/src/rpad/visualize_3d/plots.py +++ b/src/rpad/visualize_3d/plots.py @@ -8,7 +8,9 @@ from rpad.visualize_3d.primitives import pointcloud -def _3d_scene(data: npt.ArrayLike) -> Dict: +def _3d_scene( + data: npt.ArrayLike, domain_scale: float = 1.0, nogrid: bool = False +) -> Dict: """Create a plotly 3D scene dictionary that gives you a big cube, so aspect ratio is preserved""" # Create a 3D scene which is a cube w/ equal aspect ratio and fits all the data. data = np.array(data) @@ -19,13 +21,45 @@ def _3d_scene(data: npt.ArrayLike) -> Dict: max_x = np.abs(data[:, 0] - mean[0]).max() max_y = np.abs(data[:, 1] - mean[1]).max() max_z = np.abs(data[:, 2] - mean[2]).max() - all_max = max(max(max_x, max_y), max_z) - scene = dict( - xaxis=dict(nticks=10, range=[mean[0] - all_max, mean[0] + all_max]), - yaxis=dict(nticks=10, range=[mean[1] - all_max, mean[1] + all_max]), - zaxis=dict(nticks=10, range=[mean[2] - all_max, mean[2] + all_max]), - aspectratio=dict(x=1, y=1, z=1), - ) + all_max = max(max(max_x, max_y), max_z) * domain_scale + if nogrid: + scene = dict( + xaxis=dict( + nticks=10, + range=[mean[0] - all_max, mean[0] + all_max], + showgrid=False, + zeroline=False, + showline=False, + showticklabels=False, + visible=False, + ), + yaxis=dict( + nticks=10, + range=[mean[1] - all_max, mean[1] + all_max], + showgrid=False, + zeroline=False, + showline=False, + showticklabels=False, + visible=False, + ), + zaxis=dict( + nticks=10, + range=[mean[2] - all_max, mean[2] + all_max], + showgrid=False, + zeroline=False, + showline=False, + showticklabels=False, + visible=False, + ), + aspectratio=dict(x=1, y=1, z=1), + ) + else: + scene = dict( + xaxis=dict(nticks=10, range=[mean[0] - all_max, mean[0] + all_max]), + yaxis=dict(nticks=10, range=[mean[1] - all_max, mean[1] + all_max]), + zaxis=dict(nticks=10, range=[mean[2] - all_max, mean[2] + all_max]), + aspectratio=dict(x=1, y=1, z=1), + ) return scene @@ -122,7 +156,7 @@ def _flow_traces( mode="markers", marker={"size": 3, "color": "darkred"}, scene=scene, - showlegend=False, + showlegend=True, legendgroup=legendgroup, )