compute_counts

compute_counts(partitioner: Partitioner, column_name: str, verbose_names: bool = False, max_num_partitions: int | None = None) DataFrame[source]

Compute the counts of unique values in a given column in the partitions.

Take into account all possible labels in dataset when computing count for each partition (assign 0 as the size when there are no values for a label in the partition).

Parameters:
  • partitioner (Partitioner) – Partitioner with an assigned dataset.

  • column_name (str) – Column name identifying label based on which the count will be calculated.

  • verbose_names (bool) – Whether to use verbose versions of the values in the column specified by column_name. The verbose values are possible to extract if the column is a feature of type ClassLabel.

  • max_num_partitions (Optional[int]) – The maximum number of partitions that will be used. If greater than the total number of partitions in a partitioner, it won’t have an effect. If left as None, then all partitions will be used.

Returns:

dataframe – DataFrame where the row index represent the partition id and the column index represent the unique values found in column specified by column_name (e.g. represeting the labels). The value of the dataframe.loc[i, j] represents the count of the label j, in the partition of index i.

Return type:

pd.DataFrame

Examples

Generate DataFrame with label counts resulting from DirichletPartitioner on cifar10

>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import DirichletPartitioner
>>> from flwr_datasets.metrics import compute_counts
>>>
>>> fds = FederatedDataset(
>>>     dataset="cifar10",
>>>     partitioners={
>>>         "train": DirichletPartitioner(
>>>             num_partitions=20,
>>>             partition_by="label",
>>>             alpha=0.3,
>>>             min_partition_size=0,
>>>         ),
>>>     },
>>> )
>>> partitioner = fds.partitioners["train"]
>>> counts_dataframe = compute_counts(
>>>     partitioner=partitioner,
>>>     column_name="label"
>>> )