Source code for flwr_datasets.visualization.comparison_label_distribution

# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Comparison of label distribution plotting."""


from typing import Any, Literal, Optional, Union

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from flwr_datasets.common import EventType, event
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.visualization.constants import PLOT_TYPES
from flwr_datasets.visualization.label_distribution import plot_label_distributions


# pylint: disable=too-many-arguments,too-many-locals
# mypy: disable-error-code="call-overload"
[docs] def plot_comparison_label_distribution( partitioner_list: list[Partitioner], label_name: Union[str, list[str]], plot_type: Literal["bar", "heatmap"] = "bar", size_unit: Literal["percent", "absolute"] = "percent", max_num_partitions: Optional[int] = 30, partition_id_axis: Literal["x", "y"] = "y", figsize: Optional[tuple[float, float]] = None, subtitle: str = "Comparison of Per Partition Label Distribution", titles: Optional[list[str]] = None, cmap: Optional[Union[str, mcolors.Colormap]] = None, legend: bool = False, legend_title: Optional[str] = None, verbose_labels: bool = True, plot_kwargs_list: Optional[list[Optional[dict[str, Any]]]] = None, legend_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[Figure, list[Axes], list[pd.DataFrame]]: """Compare the label distribution across multiple partitioners. Parameters ---------- partitioner_list : List[Partitioner] List of partitioners to be compared. label_name : Union[str, List[str]] Column name or list of column names identifying labels for each partitioner. plot_type : Literal["bar", "heatmap"] Type of plot, either "bar" or "heatmap". size_unit : Literal["percent", "absolute"] "absolute" for raw counts, or "percent" to normalize values to 100%. max_num_partitions : Optional[int] Maximum number of partitions to include in the plot. If None, all partitions are included. partition_id_axis : Literal["x", "y"] Axis on which the partition IDs will be marked, either "x" or "y". figsize : Optional[Tuple[float, float]] Size of the figure. If None, a default size is calculated. subtitle : str Subtitle for the figure. Defaults to "Comparison of Per Partition Label Distribution" titles : Optional[List[str]] Titles for each subplot. If None, no titles are set. cmap : Optional[Union[str, mcolors.Colormap]] Colormap for determining the colorspace of the plot. legend : bool Whether to include a legend. If True, it will be included right-hand side after all the plots. legend_title : Optional[str] Title for the legend. If None, the defaults will be takes based on the type of plot. verbose_labels : bool Whether to use verbose versions of the labels. plot_kwargs_list: Optional[List[Optional[Dict[str, Any]]]] List of plot_kwargs. Any key value pair that can be passed to a plot function that are not supported directly. In case of the parameter doubling (e.g. specifying cmap here too) the chosen value will be taken from the explicit arguments (e.g. cmap specified as an argument to this function not the value in this dictionary). legend_kwargs: Optional[Dict[str, Any]] Any key value pair that can be passed to a figure.legend in case of bar plot or cbar_kws in case of heatmap that are not supported directly. In case of parameter doubling (e.g. specifying legend_title here too) the chosen value will be taken from the explicit arguments (e.g. legend_title specified as an argument to this function not the value in this dictionary). Returns ------- fig : Figure The figure object containing the plots. axes_list : List[Axes] List of Axes objects for the plots. dataframe_list : List[pd.DataFrame] List of DataFrames used for each plot. Examples -------- Compare the difference of using different alpha (concentration) parameters in DirichletPartitioner. >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.partitioner import DirichletPartitioner >>> from flwr_datasets.visualization import plot_comparison_label_distribution >>> >>> partitioner_list = [] >>> alpha_list = [10_000.0, 100.0, 1.0, 0.1, 0.01, 0.00001] >>> for alpha in alpha_list: >>> fds = FederatedDataset( >>> dataset="cifar10", >>> partitioners={ >>> "train": DirichletPartitioner( >>> num_partitions=20, >>> partition_by="label", >>> alpha=alpha, >>> min_partition_size=0, >>> ), >>> }, >>> ) >>> partitioner_list.append(fds.partitioners["train"]) >>> fig, axes, dataframe_list = plot_comparison_label_distribution( >>> partitioner_list=partitioner_list, >>> label_name="label", >>> titles=[f"Concentration = {alpha}" for alpha in alpha_list], >>> ) """ event( EventType.PLOT_COMPARISON_LABEL_DISTRIBUTION_CALLED, { "num_compare": len(partitioner_list), "plot_type": plot_type, }, ) num_partitioners = len(partitioner_list) if isinstance(label_name, str): label_name = [label_name] * num_partitioners elif isinstance(label_name, list): pass else: raise TypeError( f"Label name has to be of type List[str] or str but given " f"{type(label_name)}" ) figsize = _initialize_comparison_figsize(figsize, num_partitioners) axes_sharing = _initialize_axis_sharing(size_unit, plot_type, partition_id_axis) fig, axes = plt.subplots( nrows=1, ncols=num_partitioners, figsize=figsize, layout="constrained", **axes_sharing, ) if titles is None: titles = ["" for _ in range(num_partitioners)] if plot_kwargs_list is None: plot_kwargs_list = [None] * num_partitioners dataframe_list = [] for idx, (partitioner, single_label_name, plot_kwargs) in enumerate( zip(partitioner_list, label_name, plot_kwargs_list) ): if idx == (num_partitioners - 1): *_, dataframe = plot_label_distributions( partitioner=partitioner, label_name=single_label_name, plot_type=plot_type, size_unit=size_unit, partition_id_axis=partition_id_axis, axis=axes[idx], max_num_partitions=max_num_partitions, cmap=cmap, legend=legend, legend_title=legend_title, verbose_labels=verbose_labels, plot_kwargs=plot_kwargs, legend_kwargs=legend_kwargs, ) dataframe_list.append(dataframe) else: *_, dataframe = plot_label_distributions( partitioner=partitioner, label_name=single_label_name, plot_type=plot_type, size_unit=size_unit, partition_id_axis=partition_id_axis, axis=axes[idx], max_num_partitions=max_num_partitions, cmap=cmap, legend=False, plot_kwargs=plot_kwargs, ) dataframe_list.append(dataframe) # Do not use the xlabel and ylabel on each subplot plot # (instead use global = per figure xlabel and ylabel) for idx, axis in enumerate(axes): axis.set_xlabel("") axis.set_ylabel("") axis.set_title(titles[idx]) _set_tick_on_value_axes(axes, partition_id_axis, size_unit) # Set up figure xlabel and ylabel xlabel, ylabel = _initialize_comparison_xy_labels( plot_type, size_unit, partition_id_axis ) fig.supxlabel(xlabel) fig.supylabel(ylabel) fig.suptitle(subtitle) fig.tight_layout() return fig, axes, dataframe_list
def _initialize_comparison_figsize( figsize: Optional[tuple[float, float]], num_partitioners: int ) -> tuple[float, float]: if figsize is not None: return figsize x_value = 4 + (num_partitioners - 1) * 2 y_value = 4.8 figsize = (x_value, y_value) return figsize def _initialize_comparison_xy_labels( plot_type: Literal["bar", "heatmap"], size_unit: Literal["percent", "absolute"], partition_id_axis: Literal["x", "y"], ) -> tuple[str, str]: if plot_type == "bar": xlabel = "Partition ID" ylabel = "Class distribution" if size_unit == "percent" else "Class Count" elif plot_type == "heatmap": xlabel = "Partition ID" ylabel = "Label" else: raise ValueError( f"Invalid plot_type: {plot_type}. Must be one of {PLOT_TYPES}." ) if partition_id_axis == "y": xlabel, ylabel = ylabel, xlabel return xlabel, ylabel def _initialize_axis_sharing( size_unit: Literal["percent", "absolute"], plot_type: Literal["bar", "heatmap"], partition_id_axis: Literal["x", "y"], ) -> dict[str, bool]: # Do not intervene when the size_unit is percent and plot_type is heatmap if size_unit == "percent": return {} if plot_type == "heatmap": return {} if partition_id_axis == "x": return {"sharey": True} if partition_id_axis == "y": return {"sharex": True} return {"sharex": False, "sharey": False} def _set_tick_on_value_axes( axes: list[Axes], partition_id_axis: Literal["x", "y"], size_unit: Literal["percent", "absolute"], ) -> None: if partition_id_axis == "x" and size_unit == "absolute": # Exclude this case due to sharing of y-axis (and thus y-ticks) # They must remain set and the number are displayed only on the first plot pass else: for axis in axes[1:]: axis.set_yticks([])