plot_comparison_label_distribution#

plot_comparison_label_distribution(partitioner_list: List[Partitioner], label_name: str | List[str], plot_type: str = 'bar', size_unit: str = 'percent', max_num_partitions: int | None = 30, partition_id_axis: str = 'y', figsize: Tuple[float, float] | None = None, subtitle: str = 'Comparison of Per Partition Label Distribution', titles: List[str] | None = None, cmap: str | Colormap | None = None, legend: bool = False, legend_title: str | None = None, verbose_labels: bool = True, plot_kwargs_list: List[Dict[str, Any] | None] | None = None, legend_kwargs: Dict[str, Any] | None = None) Tuple[Figure, List[Axes], List[DataFrame]][source]#

Compare the label distribution across multiple partitioners.

Parameters:
  • partitioner_list (List[Partitioner]) – List of partitioners to be compared.

  • label_name (Union[str, List[str]]) – Column name or list of column names identifying labels for each partitioner.

  • plot_type (str) – Type of plot, either “bar” or “heatmap”.

  • size_unit (str) – “absolute” for raw counts, or “percent” to normalize values to 100%.

  • max_num_partitions (Optional[int]) – Maximum number of partitions to include in the plot. If None, all partitions are included.

  • partition_id_axis (str) – Axis on which the partition IDs will be marked, either “x” or “y”.

  • figsize (Optional[Tuple[float, float]]) – Size of the figure. If None, a default size is calculated.

  • subtitle (str) – Subtitle for the figure. Defaults to “Comparison of Per Partition Label Distribution”

  • titles (Optional[List[str]]) – Titles for each subplot. If None, no titles are set.

  • cmap (Optional[Union[str, mcolors.Colormap]]) – Colormap for determining the colorspace of the plot.

  • legend (bool) – Whether to include a legend. If True, it will be included right-hand side after all the plots.

  • legend_title (Optional[str]) – Title for the legend. If None, the defaults will be takes based on the type of plot.

  • verbose_labels (bool) – Whether to use verbose versions of the labels.

  • plot_kwargs_list (Optional[List[Optional[Dict[str, Any]]]]) – List of plot_kwargs. Any key value pair that can be passed to a plot function that are not supported directly. In case of the parameter doubling (e.g. specifying cmap here too) the chosen value will be taken from the explicit arguments (e.g. cmap specified as an argument to this function not the value in this dictionary).

  • legend_kwargs (Optional[Dict[str, Any]]) – Any key value pair that can be passed to a figure.legend in case of bar plot or cbar_kws in case of heatmap that are not supported directly. In case of parameter doubling (e.g. specifying legend_title here too) the chosen value will be taken from the explicit arguments (e.g. legend_title specified as an argument to this function not the value in this dictionary).

Returns:

  • fig (Figure) – The figure object containing the plots.

  • axes_list (List[Axes]) – List of Axes objects for the plots.

  • dataframe_list (List[pd.DataFrame]) – List of DataFrames used for each plot.

Examples

Compare the difference of using different alpha (concentration) parameters in DirichletPartitioner.

>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import DirichletPartitioner
>>> from flwr_datasets.visualization import plot_comparison_label_distribution
>>>
>>> partitioner_list = []
>>> alpha_list = [10_000.0, 100.0, 1.0, 0.1, 0.01, 0.00001]
>>> for alpha in alpha_list:
>>>     fds = FederatedDataset(
>>>         dataset="cifar10",
>>>         partitioners={
>>>             "train": DirichletPartitioner(
>>>                 num_partitions=20,
>>>                 partition_by="label",
>>>                 alpha=alpha,
>>>                 min_partition_size=0,
>>>             ),
>>>         },
>>>     )
>>>     partitioner_list.append(fds.partitioners["train"])
>>> fig, axes, dataframe_list = plot_comparison_label_distribution(
>>>     partitioner_list=partitioner_list,
>>>     label_name="label",
>>>     titles=[f"Concentration = {alpha}" for alpha in alpha_list],
>>> )