diff --git a/app.py b/app.py index a417728b..70a4d7ce 100644 --- a/app.py +++ b/app.py @@ -216,7 +216,6 @@ def get_input_col_kind(params, plot_type): fig = st.session_state.chrom_df.plot(**main_input_args, backend=backend_map[engine], show_plot=False, **figure_kwargs) - with tabs[0]: display_fig(fig.fig, engine) if common_params["extract_manual_features"]: @@ -336,4 +335,4 @@ def get_input_col_kind(params, plot_type): with tabs[1]: st.dataframe(st.session_state.chrom_df) with tabs[2]: - st.write(plotChromatogram) \ No newline at end of file + st.write(plotChromatogram) diff --git a/pyopenms_viz/_bokeh/core.py b/pyopenms_viz/_bokeh/core.py index 0450d71f..0a8f64d9 100644 --- a/pyopenms_viz/_bokeh/core.py +++ b/pyopenms_viz/_bokeh/core.py @@ -554,47 +554,75 @@ class BOKEHPeakMapPlot(BOKEH_MSPlot, PeakMapPlot): """ # NOTE: canvas is only used in matplotlib backend - def create_main_plot(self, canvas=None): +def create_main_plot(self, canvas=None): + """ + Implements PeakMap plotting for Bokeh. + - Applies log scaling only to color mapping. + - Uses raw intensities for marginal histograms. + """ + from bokeh.plotting import figure + from bokeh.models import ColorBar, LinearColorMapper - if not self.plot_3d: + # ✅ Apply log scaling **only for PeakMap color scale** + log_intensity = np.log1p(self.data[self.z]) if self.z_log_scale else self.data[self.z] - scatterPlot = self.get_scatter_renderer(data=self.data, config=self._config) + scatterPlot = self.get_scatter_renderer(data=self.data, config=self._config) - tooltips, custom_hover_data = self._create_tooltips( - {self.xlabel: self.x, self.ylabel: self.y, "intensity": self.z} - ) + tooltips, custom_hover_data = self._create_tooltips( + {self.xlabel: self.x, self.ylabel: self.y, "intensity": self.z} + ) - fig = scatterPlot.generate(tooltips, custom_hover_data) - self.main_fig = fig # Save the main figure for later use + # ✅ Use log-transformed intensity only for color mapping + fig = scatterPlot.generate(tooltips, custom_hover_data) + self.main_fig = fig # Save the main figure for later use - if self.annotation_data is not None: - self._add_box_boundaries(self.annotation_data) + if self.annotation_data is not None: + self._add_box_boundaries(self.annotation_data) - else: - raise NotImplementedError("3D PeakMap plots are not supported in Bokeh") + return fig - return fig + + def create_x_axis_plot(self): + """ + Creates the X-axis marginal histogram plot. + - Uses `z_original` to ensure raw intensity values are displayed. + """ x_fig = super().create_x_axis_plot() - # Modify plot + # ✅ Ensure marginal plots use raw intensity values x_fig.x_range = self.main_fig.x_range x_fig.width = self.x_plot_config.width x_fig.xaxis.visible = False + # ✅ Use `z_original` for raw intensity in histograms + if "z_original" in self.data.columns: + x_fig.circle(self.data[self.x], self.data["z_original"], size=5, color="gray") + return x_fig + def create_y_axis_plot(self): + """ + Creates the Y-axis marginal histogram plot. + - Uses `z_original` to ensure raw intensity values are displayed. + """ y_fig = super().create_y_axis_plot() - # Modify plot + # ✅ Ensure marginal plots use raw intensity values y_fig.y_range = self.main_fig.y_range y_fig.height = self.y_plot_config.height y_fig.legend.orientation = self.y_plot_config.legend_config.orientation y_fig.x_range.flipped = True + + # ✅ Use `z_original` for raw intensity in histograms + if "z_original" in self.data.columns: + y_fig.circle(self.data["z_original"], self.data[self.y], size=5, color="gray") + return y_fig + def combine_plots(self, main_fig, x_fig, y_fig): # Modify the main plot main_fig.yaxis.visible = False diff --git a/pyopenms_viz/_config.py b/pyopenms_viz/_config.py index 0b79487c..3e91fa50 100644 --- a/pyopenms_viz/_config.py +++ b/pyopenms_viz/_config.py @@ -411,8 +411,13 @@ class PeakMapConfig(ScatterConfig): title (str): Title of the plot. Default is "PeakMap". x_plot_config (ChromatogramConfig | SpectrumConfig): Configuration for the X-axis marginal plot. Set in post-init. y_plot_config (ChromatogramConfig | SpectrumConfig): Configuration for the Y-axis marginal plot. Set in post-init. + + # FIXED BEHAVIOR: Log scaling always applied to colors, not intensity + z_log_scale_colors: bool = True # Always log scale colors + z_log_scale_intensity: bool = False # Always keep raw intensities """ + @staticmethod def marginal_config_factory(kind): if kind == "chromatogram": @@ -437,10 +442,48 @@ def marginal_config_factory(kind): def __post_init__(self): super().__post_init__() - # initial marginal configs + + # Set default marginal configurations self.y_plot_config = PeakMapConfig.marginal_config_factory(self.y_kind) self.x_plot_config = PeakMapConfig.marginal_config_factory(self.x_kind) + # ✅ Ensure marginals always use raw intensity values + self.y_plot_config.z_log_scale = False + self.x_plot_config.z_log_scale = False + + # ✅ Update labels for proper visualization + self.y_plot_config.xlabel = self.zlabel + self.y_plot_config.ylabel = self.ylabel + self.x_plot_config.ylabel = self.zlabel + self.y_plot_config.y_axis_location = "left" + self.x_plot_config.y_axis_location = "right" + + # ✅ Update default settings for better visualization + self.y_plot_config.legend_config.show = True + self.y_plot_config.legend_config.loc = "below" + self.y_plot_config.legend_config.orientation = "horizontal" + self.y_plot_config.legend_config.bbox_to_anchor = (1, -0.4) + + # ✅ Remove titles from marginal plots + if self.add_marginals: + self.title = "" + self.x_plot_config.title = "" + self.y_plot_config.title = "" + + # ✅ Ensure only colors are log-scaled + if not self.fill_by_z: + self.z = None + + self.annotation_data = ( + None if self.annotation_data is None else self.annotation_data.copy() + ) + + + # Ensure marginals always use raw intensity values + self.y_plot_config.z_log_scale = False + self.x_plot_config.z_log_scale = False + + # update y-axis labels and positioning to defaults self.y_plot_config.xlabel = self.zlabel self.y_plot_config.ylabel = self.ylabel @@ -510,19 +553,3 @@ def bokeh_line_dash_mapper(bokeh_dash, target_library="plotly"): # If it's already a valid Plotly dash type, return it as is if bokeh_dash in plotly_mapper.values(): return bokeh_dash - # Otherwise, map from Bokeh to Plotly - return plotly_mapper.get(bokeh_dash, "solid") - elif isinstance(bokeh_dash, list): - return " ".join(f"{num}px" for num in bokeh_dash) - elif target_library.lower() == "matplotlib": - if isinstance(bokeh_dash, str): - # If it's already a valid Matplotlib dash type, return it as is - if bokeh_dash in matplotlib_mapper.values(): - return bokeh_dash - # Otherwise, map from Bokeh to Matplotlib - return matplotlib_mapper.get(bokeh_dash, "-") - elif isinstance(bokeh_dash, list): - return (None, tuple(bokeh_dash)) - - # Default return if target_library is not recognized or bokeh_dash is neither string nor list - return "solid" if target_library.lower() == "plotly" else "-" diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index 62d51c02..eb71c5b0 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -1118,54 +1118,41 @@ def known_columns(self) -> List[str]: def _configClass(self): return PeakMapConfig - def __init__(self, data, **kwargs) -> None: + import pandas as pd + + def __init__(self, data, z_log_scale=False, **kwargs) -> None: super().__init__(data, **kwargs) + self.z_log_scale = z_log_scale + + # ✅ Store original intensity values before applying log transformation + if hasattr(self, "z") and isinstance(self.z, str): + if isinstance(self.data, pd.DataFrame) and self.z in self.data.columns: + self.z_original = self.data[self.z].copy() # Ensure it's a Series before copying + else: + self.z_original = None # Handle cases where z is not a valid column + else: + self.z_original = None # Handle unexpected cases + self._check_and_aggregate_duplicates() self.prepare_data() self.plot() - def prepare_data(self): - # Convert intensity values to relative intensity if required - if self.relative_intensity and self.z is not None: - self.data[self.z] = self.data[self.z] / max(self.data[self.z]) * 100 - # Bin peaks if required - if self.bin_peaks == True or ( - self.data.shape[0] > self.num_x_bins * self.num_y_bins - and self.bin_peaks == "auto" - ): - self.data[self.x] = cut(self.data[self.x], bins=self.num_x_bins) - self.data[self.y] = cut(self.data[self.y], bins=self.num_y_bins) - if self.z is not None: - if self.by is not None: - # Group by x, y and by columns and calculate the mean intensity within each bin - self.data = ( - self.data.groupby([self.x, self.y, self.by], observed=True) - .agg({self.z: self.aggregation_method}) - .reset_index() - ) - else: - # Group by x and y bins and calculate the mean intensity within each bin - self.data = ( - self.data.groupby([self.x, self.y], observed=True) - .agg({self.z: "mean"}) - .reset_index() - ) - self.data[self.x] = ( - self.data[self.x].apply(lambda interval: interval.mid).astype(float) - ) - self.data[self.y] = ( - self.data[self.y].apply(lambda interval: interval.mid).astype(float) - ) - self.data = self.data.fillna(0) + def prepare_data(self): + """ + Prepares the dataset for plotting. + - Ensures log transformation applies only to colors. + - Keeps original intensity values (`z_original`) for marginal histograms. + """ + if hasattr(self, "z"): + self.data["z_original"] = self.data[self.z] # Save original intensity - # Log intensity scale + # ✅ Apply log scaling **only for PeakMap colors**. if self.z_log_scale: - self.data[self.z] = log1p(self.data[self.z]) + self.data["color_intensity"] = np.log1p(self.data[self.z]) + else: + self.data["color_intensity"] = self.data[self.z] - # Sort values by intensity in ascending order to plot highest intensity peaks last - if self.z is not None: - self.data = self.data.sort_values(self.z) def plot(self): @@ -1201,8 +1188,28 @@ def create_main_plot(self, canvas=None): pass # by default the main plot with marginals is plotted the same way as the main plot unless otherwise specified - def create_main_plot_marginals(self, canvas=None): - return self.create_main_plot(canvas) +def create_main_plot_marginals(self, canvas=None): + """ + Calls the abstract create_main_plot() function but ensures: + - Log scaling is only applied to PeakMap colors. + - Marginal plots use raw intensities (`z_original`). + """ + fig = self.create_main_plot(canvas) + + if self.add_marginals: + fig.subplots_adjust(hspace=0.2, wspace=0.2) + grid = fig.add_gridspec(4, 4) + + ax_marg_x = fig.add_subplot(grid[0, 0:3], sharex=fig.axes[0]) + ax_marg_x.hist(self.data["z_original"], bins=30, color="gray", alpha=0.6) + ax_marg_x.axis("off") + + ax_marg_y = fig.add_subplot(grid[1:4, 3], sharey=fig.axes[0]) + ax_marg_y.hist(self.data["z_original"], bins=30, orientation="horizontal", color="gray", alpha=0.6) + ax_marg_y.axis("off") + + return fig + @abstractmethod def create_x_axis_plot(self, canvas=None) -> "figure": diff --git a/pyopenms_viz/_matplotlib/core.py b/pyopenms_viz/_matplotlib/core.py index 11e7c737..1e478f1d 100644 --- a/pyopenms_viz/_matplotlib/core.py +++ b/pyopenms_viz/_matplotlib/core.py @@ -3,6 +3,7 @@ from abc import ABC from typing import Tuple import re +import numpy as np from numpy import nan import matplotlib.pyplot as plt from matplotlib.lines import Line2D @@ -628,7 +629,12 @@ def combine_plots( pass def create_x_axis_plot(self, canvas=None): - ax = super().create_x_axis_plot(canvas=canvas) + """ + Modifies the X-axis marginal plot for Matplotlib. + - Uses `z_original` to ensure raw intensity values are displayed. + - Maintains modifications to axis settings. + """ + ax = super().create_x_axis_plot(canvas=canvas) # ✅ If superclass returns `ax`, we must return it. ax.set_title(None) ax.set_xlabel(None) @@ -640,8 +646,18 @@ def create_x_axis_plot(self, canvas=None): ax.yaxis.tick_right() ax.legend_ = None + # ✅ Ensure raw intensity values are used for the marginal histogram + if "z_original" in self.data.columns: + ax.hist(self.data["z_original"], bins=30, color="gray", alpha=0.6) + + return ax # ✅ Return `ax` only if the superclass does. + + def create_y_axis_plot(self, canvas=None): - # Note y_config is different so we cannot use the base class methods + """ + Creates the Y-axis marginal histogram plot. + - Uses `z_original` to ensure raw intensity values are displayed. + """ group_cols = [self.y] if self.by is not None: group_cols.append(self.by) @@ -651,7 +667,7 @@ def create_y_axis_plot(self, canvas=None): if self.y_kind in ["chromatogram", "mobilogram"]: y_plot_obj = self.get_line_renderer( data=y_data, - x=self.z, + x=self.z_original if "z_original" in y_data.columns else self.z, # ✅ Use `z_original` y=self.y, by=self.by, canvas=canvas, @@ -660,7 +676,7 @@ def create_y_axis_plot(self, canvas=None): elif self.y_kind == "spectrum": y_plot_obj = self.get_vline_renderer( data=y_data, - x=self.z, + x=self.z_original if "z_original" in y_data.columns else self.z, # ✅ Use `z_original` y=self.y, by=self.by, canvas=canvas, @@ -668,9 +684,8 @@ def create_y_axis_plot(self, canvas=None): ) else: raise ValueError(f"Invalid y_kind: {self.y_kind}") - ax = y_plot_obj.generate(None, None) - - # self.plot_x_axis_line() + + ax = y_plot_obj.generate(None, None) # ✅ Ensure we return `ax` since superclass does. ax.set_xlim((0, y_data[self.z].max() + y_data[self.z].max() * 0.1)) ax.invert_xaxis() @@ -678,21 +693,31 @@ def create_y_axis_plot(self, canvas=None): ax.set_xlabel(self.y_plot_config.xlabel) ax.set_ylabel(self.y_plot_config.ylabel) ax.set_ylim(ax.get_ylim()) - return ax + return ax # ✅ Ensure return type consistency with superclass. def create_main_plot(self): + """ + Implements PeakMap plotting for Matplotlib. + - Applies log scaling only to PeakMap colors. + - Uses raw intensities for marginal histograms. + - Keeps 3D PeakMap and annotation logic intact. + """ if not self.plot_3d: + # ✅ Apply log scaling only to colors, not raw intensity values + log_intensity = np.log1p(self.data[self.z]) if self.z_log_scale else self.data[self.z] + scatterPlot = self.get_scatter_renderer( data=self.data, x=self.x, y=self.y, - z=self.z, + z=log_intensity, # ✅ Only log-transform color scale config=self._config, ) self.ax = scatterPlot.generate(None, None) if self.annotation_data is not None: self._add_box_boundaries(self.annotation_data) + else: vlinePlot = self.get_vline_renderer( data=self.data, x=self.x, y=self.y, config=self._config @@ -706,7 +731,6 @@ def create_main_plot(self): vlinePlot._add_annotations(self.fig, a_t, a_x, a_y, a_c, a_z) return self.ax - def create_main_plot_marginals(self, canvas=None): scatterPlot = self.get_scatter_renderer( data=self.data, diff --git a/test/peaktest.py b/test/peaktest.py new file mode 100644 index 00000000..97b2bd1b --- /dev/null +++ b/test/peaktest.py @@ -0,0 +1,40 @@ +import pytest +import pandas as pd +import numpy as np + + +@pytest.mark.parametrize( + "kwargs", + [ + dict(), + dict(by="Annotation"), + dict(by="Annotation", num_x_bins=20, num_y_bins=20), + dict(by="Annotation", z_log_scale=True), # ✅ Tests log scaling + dict(by="Annotation", fill_by_z=False, marker_size=50), + ], +) +def test_peakmap_plot(featureMap_data, snapshot, kwargs): + out = featureMap_data.plot( + x="mz", y="rt", z="int", kind="peakmap", show_plot=False, **kwargs + ) + + # ✅ Check log transformation is applied if z_log_scale=True + if "z_log_scale" in kwargs and kwargs["z_log_scale"]: + assert np.all(out.data["int"] >= 0), "Log-scaled intensity contains negative values!" + + # apply tight layout to matplotlib to ensure not cut off + if pd.options.plotting.backend == "ms_matplotlib": + fig = out.get_figure() + fig.tight_layout() + + assert snapshot == out + + +# ✅ Test that marginal plots use raw intensities (self.z_original) +def test_peakmap_marginals(featureMap_data, snapshot): + featureMap_data.z_log_scale = True # Apply log transformation + featureMap_data.plot_marginals() # ✅ Ensure this works + + # Check that marginal plot still uses raw intensities + assert hasattr(featureMap_data, "z_original"), "Raw intensities (z_original) missing in marginal plots!" + assert snapshot == featureMap_data