| # import numpy as np | |
| # import plotly.graph_objects as go | |
| # from scipy.interpolate import griddata | |
| # def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): | |
| # detectability = np.array(detectability_val) | |
| # distortion = np.array(distortion_val) | |
| # euclidean = np.array(euclidean_val) | |
| # # Find the closest point to the origin | |
| # distances_to_origin = np.linalg.norm(np.array([distortion, detectability, euclidean]).T, axis=1) | |
| # closest_point_index = np.argmin(distances_to_origin) | |
| # # Determine the closest points to each axis | |
| # closest_to_x_axis = np.argmin(distortion) | |
| # closest_to_y_axis = np.argmin(detectability) | |
| # closest_to_z_axis = np.argmin(euclidean) | |
| # # Use the detected closest point as the "sweet spot" | |
| # sweet_spot_detectability = detectability[closest_point_index] | |
| # sweet_spot_distortion = distortion[closest_point_index] | |
| # sweet_spot_euclidean = euclidean[closest_point_index] | |
| # # Create a meshgrid from the data | |
| # x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30), | |
| # np.linspace(min(distortion), max(distortion), 30)) | |
| # # Interpolate z values (Euclidean distances) to fit the grid | |
| # z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear') | |
| # if z_grid is None: | |
| # raise ValueError("griddata could not generate a valid interpolation. Check your input data.") | |
| # # Create the 3D contour plot with the Plasma color scale | |
| # fig = go.Figure(data=go.Surface( | |
| # z=z_grid, | |
| # x=x_grid, | |
| # y=y_grid, | |
| # contours={ | |
| # "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True} | |
| # }, | |
| # colorscale='Plasma' | |
| # )) | |
| # # Add a marker for the sweet spot | |
| # fig.add_trace(go.Scatter3d( | |
| # x=[sweet_spot_detectability], | |
| # y=[sweet_spot_distortion], | |
| # z=[sweet_spot_euclidean], | |
| # mode='markers+text', | |
| # marker=dict(size=10, color='red', symbol='circle'), | |
| # text=["Sweet Spot"], | |
| # textposition="top center" | |
| # )) | |
| # # Set axis labels | |
| # fig.update_layout( | |
| # scene=dict( | |
| # xaxis_title='Detectability Score', | |
| # yaxis_title='Distortion Score', | |
| # zaxis_title='Euclidean Distance' | |
| # ), | |
| # margin=dict(l=0, r=0, b=0, t=0) | |
| # ) | |
| # return fig | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from scipy.interpolate import griddata | |
| def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): | |
| detectability = np.array(detectability_val) | |
| distortion = np.array(distortion_val) | |
| euclidean = np.array(euclidean_val) | |
| # Normalize the values to range [0, 1] | |
| norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability)) | |
| norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion)) | |
| norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean)) | |
| # Composite score: maximize detectability, minimize distortion and Euclidean distance | |
| # We subtract distortion and euclidean as we want them minimized. | |
| composite_score = norm_detectability - (norm_distortion + norm_euclidean) | |
| # Find the index of the maximum score (sweet spot) | |
| sweet_spot_index = np.argmax(composite_score) | |
| # Sweet spot values | |
| sweet_spot_detectability = detectability[sweet_spot_index] | |
| sweet_spot_distortion = distortion[sweet_spot_index] | |
| sweet_spot_euclidean = euclidean[sweet_spot_index] | |
| # Create a meshgrid from the data | |
| x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30), | |
| np.linspace(min(distortion), max(distortion), 30)) | |
| # Interpolate z values (Euclidean distances) to fit the grid | |
| z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear') | |
| if z_grid is None: | |
| raise ValueError("griddata could not generate a valid interpolation. Check your input data.") | |
| # Create the 3D contour plot with the Plasma color scale | |
| fig = go.Figure(data=go.Surface( | |
| z=z_grid, | |
| x=x_grid, | |
| y=y_grid, | |
| contours={ | |
| "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True} | |
| }, | |
| colorscale='Plasma' | |
| )) | |
| # Add a marker for the sweet spot | |
| fig.add_trace(go.Scatter3d( | |
| x=[sweet_spot_detectability], | |
| y=[sweet_spot_distortion], | |
| z=[sweet_spot_euclidean], | |
| mode='markers+text', | |
| marker=dict(size=10, color='red', symbol='circle'), | |
| text=["Sweet Spot"], | |
| textposition="top center" | |
| )) | |
| # Set axis labels | |
| fig.update_layout( | |
| scene=dict( | |
| xaxis_title='Detectability Score', | |
| yaxis_title='Distortion Score', | |
| zaxis_title='Euclidean Distance' | |
| ), | |
| margin=dict(l=0, r=0, b=0, t=0) | |
| ) | |
| return fig | |