Source code for flwr_datasets.utils

# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for FederatedDataset."""


import warnings
from typing import Optional, Union, cast

from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import IidPartitioner, Partitioner
from flwr_datasets.preprocessor import Preprocessor
from flwr_datasets.preprocessor.merger import Merger

tested_datasets = [
    "mnist",
    "ylecun/mnist",
    "cifar10",
    "uoft-cs/cifar10",
    "fashion_mnist",
    "zalando-datasets/fashion_mnist",
    "sasha/dog-food",
    "zh-plus/tiny-imagenet",
    "scikit-learn/adult-census-income",
    "cifar100",
    "uoft-cs/cifar100",
    "svhn",
    "ufldl-stanford/svhn",
    "sentiment140",
    "stanfordnlp/sentiment140",
    "speech_commands",
    "LIUM/tedlium",
    "flwrlabs/femnist",
    "flwrlabs/ucf101",
    "flwrlabs/ambient-acoustic-context",
    "jlh/uci-mushrooms",
    "Mike0307/MNIST-M",
    "flwrlabs/usps",
    "scikit-learn/iris",
    "flwrlabs/pacs",
    "flwrlabs/cinic10",
    "flwrlabs/caltech101",
    "flwrlabs/office-home",
    "flwrlabs/fed-isic2019",
]


def _instantiate_partitioners(
    partitioners: dict[str, Union[Partitioner, int]]
) -> dict[str, Partitioner]:
    """Transform the partitioners from the initial format to instantiated objects.

    Parameters
    ----------
    partitioners : Dict[str, Union[Partitioner, int]]
        Dataset split to the Partitioner or a number of IID partitions.

    Returns
    -------
    partitioners : Dict[str, Partitioner]
        Partitioners specified as split to Partitioner object.
    """
    instantiated_partitioners: dict[str, Partitioner] = {}
    if isinstance(partitioners, dict):
        for split, partitioner in partitioners.items():
            if isinstance(partitioner, Partitioner):
                instantiated_partitioners[split] = partitioner
            elif isinstance(partitioner, int):
                instantiated_partitioners[split] = IidPartitioner(
                    num_partitions=partitioner
                )
            else:
                raise ValueError(
                    f"Incorrect type of the 'partitioners' value encountered. "
                    f"Expected Partitioner or int. Given {type(partitioner)}"
                )
    else:
        raise ValueError(
            f"Incorrect type of the 'partitioners' encountered. "
            f"Expected Dict[str, Union[int, Partitioner]]. "
            f"Given {type(partitioners)}."
        )
    return instantiated_partitioners


def _instantiate_merger_if_needed(
    merger: Optional[Union[Preprocessor, dict[str, tuple[str, ...]]]]
) -> Optional[Preprocessor]:
    """Instantiate `Merger` if preprocessor is merge_config."""
    if merger and isinstance(merger, dict):
        merger = Merger(merge_config=merger)
    return cast(Optional[Preprocessor], merger)


def _check_if_dataset_tested(dataset: str) -> None:
    """Check if the dataset is in the narrowed down list of the tested datasets."""
    if dataset not in tested_datasets:
        warnings.warn(
            f"The currently tested dataset are {tested_datasets}. Given: {dataset}.",
            stacklevel=1,
        )


[docs] def divide_dataset( dataset: Dataset, division: Union[list[float], tuple[float, ...], dict[str, float]] ) -> Union[list[Dataset], DatasetDict]: """Divide the dataset according to the `division`. The division support varying number of splits, which you can name. The splits are created from the beginning of the dataset. Parameters ---------- dataset : Dataset Dataset to be divided. division: Union[List[float], Tuple[float, ...], Dict[str, float]] Configuration specifying how the dataset is divided. Each fraction has to be >0 and <=1. They have to sum up to at most 1 (smaller sum is possible). Returns ------- divided_dataset : Union[List[Dataset], DatasetDict] If `division` is `List` or `Tuple` then `List[Dataset]` is returned else if `division` is `Dict` then `DatasetDict` is returned. Examples -------- Use `divide_dataset` with division specified as a list. >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.utils import divide_dataset >>> >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) >>> partition = fds.load_partition(0) >>> division = [0.8, 0.2] >>> train, test = divide_dataset(dataset=partition, division=division) Use `divide_dataset` with division specified as a dict (this accomplishes the same goal as the example with a list above). >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.utils import divide_dataset >>> >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) >>> partition = fds.load_partition(0) >>> division = {"train": 0.8, "test": 0.2} >>> train_test = divide_dataset(dataset=partition, division=division) >>> train, test = train_test["train"], train_test["test"] """ _check_division_config_correctness(division) dataset_length = len(dataset) ranges = _create_division_indices_ranges(dataset_length, division) if isinstance(division, (list, tuple)): split_partition: list[Dataset] = [] for single_range in ranges: split_partition.append(dataset.select(single_range)) return split_partition if isinstance(division, dict): split_partition_dict: dict[str, Dataset] = {} for split_name, single_range in zip(division.keys(), ranges): split_partition_dict[split_name] = dataset.select(single_range) return DatasetDict(split_partition_dict) raise TypeError( f"The type of the `division` should be dict, " f"tuple or list but is {type(division)} instead." )
def _create_division_indices_ranges( dataset_length: int, division: Union[list[float], tuple[float, ...], dict[str, float]], ) -> list[range]: ranges = [] if isinstance(division, (list, tuple)): start_idx = 0 end_idx = 0 for fraction in division: end_idx += int(dataset_length * fraction) ranges.append(range(start_idx, end_idx)) start_idx = end_idx elif isinstance(division, dict): ranges = [] start_idx = 0 end_idx = 0 for fraction in division.values(): end_idx += int(dataset_length * fraction) ranges.append(range(start_idx, end_idx)) start_idx = end_idx else: raise TypeError( f"The type of the `division` should be dict, " f"tuple or list but is {type(division)} instead. " ) return ranges def _check_division_config_types_correctness( division: Union[list[float], tuple[float, ...], dict[str, float]] ) -> None: if isinstance(division, (list, tuple)): if not all(isinstance(x, float) for x in division): raise TypeError( "List or tuple values of `division` must contain only floats, " "other types are not allowed." ) elif isinstance(division, dict): if not all(isinstance(x, float) for x in division.values()): raise TypeError( "Dict values of `division` must be only floats, " "other types are not allowed." ) else: raise TypeError("`division` must be a list, tuple, or dict.") def _check_division_config_values_correctness( division: Union[list[float], tuple[float, ...], dict[str, float]] ) -> None: if isinstance(division, (list, tuple)): if not all(0 < x <= 1 for x in division): raise ValueError( "All fractions for the division must be greater than 0 and smaller or " "equal to 1." ) fraction_sum_from_list_tuple = sum(division) if fraction_sum_from_list_tuple > 1: raise ValueError("Sum of fractions for division must not exceed 1.") if fraction_sum_from_list_tuple < 1: warnings.warn( f"Sum of fractions for division is {sum(division)}, which is below 1. " f"Make sure that's the desired behavior. Some data will not be used " f"in the current specification.", stacklevel=1, ) elif isinstance(division, dict): values = list(division.values()) if not all(0 < x <= 1 for x in values): raise ValueError( "All fractions must be greater than 0 and smaller or equal to 1." ) if sum(values) > 1: raise ValueError("Sum of fractions must not exceed 1.") if sum(values) < 1: warnings.warn( f"Sum of fractions in `division` is {values}, which is below 1. " f"Make sure that's the desired behavior. Some data will not be used " f"in the current specification.", stacklevel=1, ) else: raise TypeError("`division` must be a list, tuple, or dict.") def _check_division_config_correctness( division: Union[list[float], tuple[float, ...], dict[str, float]] ) -> None: _check_division_config_types_correctness(division) _check_division_config_values_correctness(division)
[docs] def concatenate_divisions( partitioner: Partitioner, partition_division: Union[list[float], tuple[float, ...], dict[str, float]], division_id: Union[int, str], ) -> Dataset: """Create a dataset by concatenation of divisions from all partitions. The divisions are created based on the `partition_division` and accessed based on the `division_id`. This fuction can be used to create e.g. centralized dataset from federated on-edge test sets. Parameters ---------- partitioner : Partitioner Partitioner object with assigned dataset. partition_division : Union[List[float], Tuple[float, ...], Dict[str, float]] Fractions specifying the division of the partitions of a `partitioner`. You can think of this as on-edge division of the data into multiple divisions (e.g. into train and validation). E.g. [0.8, 0.2] or {"partition_train": 0.8, "partition_test": 0.2}. division_id : Union[int, str] The way to access the division (from a List or DatasetDict). If your `partition_division` is specified as a list, then `division_id` represents an index to an element in that list. If `partition_division` is passed as a `Dict`, then `division_id` is a key of such dictionary. Returns ------- concatenated_divisions : Dataset A dataset created as concatenation of the divisions from all partitions. Examples -------- Use `concatenate_divisions` with division specified as a list. >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.utils import concatenate_divisions >>> >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) >>> concatenated_divisions = concatenate_divisions( ... partitioner=fds.partitioners["train"], ... partition_division=[0.8, 0.2], ... division_id=1 ... ) >>> print(concatenated_divisions) Use `concatenate_divisions` with division specified as a dict. This accomplishes the same goal as the example with a list above. >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.utils import concatenate_divisions >>> >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) >>> concatenated_divisions = concatenate_divisions( ... partitioner=fds["train"], ... partition_division={"train": 0.8, "test": 0.2}, ... division_id="test" ... ) >>> print(concatenated_divisions) """ _check_division_config_correctness(partition_division) divisions = [] zero_len_divisions = 0 for partition_id in range(partitioner.num_partitions): partition = partitioner.load_partition(partition_id) if isinstance(partition_division, (list, tuple)): if not isinstance(division_id, int): raise TypeError( "The `division_id` needs to be an int in case of " "`partition_division` specification as List." ) partition = divide_dataset(partition, partition_division) division = partition[division_id] elif isinstance(partition_division, dict): partition = divide_dataset(partition, partition_division) division = partition[division_id] else: raise TypeError( "The type of partition needs to be List of DatasetDict in this " "context." ) if len(division) == 0: zero_len_divisions += 1 divisions.append(division) if zero_len_divisions == partitioner.num_partitions: raise ValueError( "The concatenated dataset is of length 0. Please change the " "`partition_division` parameter to change this behavior." ) if zero_len_divisions != 0: warnings.warn( f"{zero_len_divisions} division(s) have length zero.", stacklevel=1 ) return concatenate_datasets(divisions)