Quickstart JAX#

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” Flower๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ธฐ์กด JAX ์›Œํฌ๋กœ๋“œ์˜ ์—ฐํ•ฉ ๋ฒ„์ „์„ ๊ตฌ์ถ•ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค. JAX๋ฅผ ์‚ฌ์šฉํ•ด scikit-learn ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ์„ ํ˜• ํšŒ๊ท€ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ์ œ๋Š” โ€˜ํŒŒ์ดํ† ์น˜ - Centralized์—์„œ Federated์œผ๋กœ <https://github.com/adap/flower/blob/main/examples/pytorch-from-centralized-to-federated>`_ ์›Œํฌ์Šค๋ฃจ์™€ ์œ ์‚ฌํ•˜๊ฒŒ ๊ตฌ์„ฑํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ €, JAX๋ฅผ ์‚ฌ์šฉํ•œ ์„ ํ˜• ํšŒ๊ท€ ํŠœํ† ๋ฆฌ์–ผ`์„ ๊ธฐ๋ฐ˜์œผ๋กœ centralized ํ•™์Šต ์ ‘๊ทผ ๋ฐฉ์‹์„ ๊ตฌ์ถ•ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ centralized ํŠธ๋ ˆ์ด๋‹ ์ฝ”๋“œ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ federated ๋ฐฉ์‹์œผ๋กœ ํŠธ๋ ˆ์ด๋‹์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.

JAX ์˜ˆ์ œ ๋นŒ๋“œ๋ฅผ ์‹œ์ž‘ํ•˜๊ธฐ ์ „์— jax, jaxlib, scikit-learn, flwr ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

$ pip install jax jaxlib scikit-learn flwr

JAX๋ฅผ ์‚ฌ์šฉํ•œ ์„ ํ˜• ํšŒ๊ท€#

๋จผ์ € ์„ ํ˜• ํšŒ๊ท€ ๋ชจ๋ธ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜๋Š” ์ค‘์•™ ์ง‘์ค‘์‹ ํ›ˆ๋ จ ์ฝ”๋“œ์— ๋Œ€ํ•œ ๊ฐ„๋žตํ•œ ์„ค๋ช…๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๋” ์ž์„ธํ•œ ์„ค๋ช…์„ ์›ํ•˜์‹œ๋ฉด ๊ณต์‹ `JAX ๋ฌธ์„œ <https://jax.readthedocs.io/>`_๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

์ „ํ†ต์ ์ธ(์ค‘์•™ ์ง‘์ค‘์‹) ์„ ํ˜• ํšŒ๊ท€ ํ›ˆ๋ จ์— ํ•„์š”ํ•œ ๋ชจ๋“  ๊ตฌ์„ฑ ์š”์†Œ๊ฐ€ ํฌํ•จ๋œ jax_training.py`๋ผ๋Š” ์ƒˆ ํŒŒ์ผ์„ ์ƒ์„ฑํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ €, JAX ํŒจํ‚ค์ง€์ธ :code:`jax`์™€ :code:`jaxlib`๋ฅผ ๊ฐ€์ ธ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— :code:`make_regression`์„ ์‚ฌ์šฉํ•˜๊ณ  ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ํ•™์Šต ๋ฐ ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋ถ„ํ• ํ•˜๊ธฐ ์œ„ํ•ด :code:`train_test_split`์„ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ :code:`sklearn`์„ ๊ฐ€์ ธ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์—ฐํ•ฉ ํ•™์Šต์„ ์œ„ํ•ด ์•„์ง :code:`flwr ํŒจํ‚ค์ง€๋ฅผ ๊ฐ€์ ธ์˜ค์ง€ ์•Š์€ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์ž‘์—…์€ ๋‚˜์ค‘์— ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.

from typing import Dict, List, Tuple, Callable
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

key = jax.random.PRNGKey(0)

code:load_data() ํ•จ์ˆ˜๋Š” ์•ž์„œ ์–ธ๊ธ‰ํ•œ ํŠธ๋ ˆ์ด๋‹ ๋ฐ ํ…Œ์ŠคํŠธ ์„ธํŠธ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.

def load_data() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
    # create our dataset and start with similar datasets for different clients
    X, y = make_regression(n_features=3, random_state=0)
    X, X_test, y, y_test = train_test_split(X, y)
    return X, y, X_test, y_test

๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜(๋งค์šฐ ๊ฐ„๋‹จํ•œ ์„ ํ˜• ํšŒ๊ท€ ๋ชจ๋ธ)๋Š” :code:`load_model()`์— ์ •์˜๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

def load_model(model_shape) -> Dict:
    # model weights
    params = {
        'b' : jax.random.uniform(key),
        'w' : jax.random.uniform(key, model_shape)
    }
    return params

์ด์ œ ํ›ˆ๋ จ ์ง‘ํ•ฉ์„ ๋ฐ˜๋ณตํ•˜๊ณ  ๊ฐ ํ›ˆ๋ จ ์˜ˆ์ œ ๋ฐฐ์น˜์— ๋Œ€ํ•ด ์†์‹ค์„ ์ธก์ •ํ•˜๋Š”(ํ•จ์ˆ˜ loss_fn()) ํ›ˆ๋ จ(ํ•จ์ˆ˜ train())์„ ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. JAX๋Š” grad() ํ•จ์ˆ˜(main() ํ•จ์ˆ˜์— ์ •์˜๋˜๊ณ  :code:`train()`์—์„œ ํ˜ธ์ถœ๋จ)๋กœ ํŒŒ์ƒ๋ฌผ์„ ์ทจํ•˜๋ฏ€๋กœ ์†์‹ค ํ•จ์ˆ˜๋Š” ๋ถ„๋ฆฌ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

def loss_fn(params, X, y) -> Callable:
    err = jnp.dot(X, params['w']) + params['b'] - y
    return jnp.mean(jnp.square(err))  # mse

def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]:
    num_examples = X.shape[0]
    for epochs in range(10):
        grads = grad_fn(params, X, y)
        params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
        loss = loss_fn(params,X, y)
        # if epochs % 10 == 9:
        #     print(f'For Epoch {epochs} loss {loss}')
    return params, loss, num_examples

๋ชจ๋ธ์˜ ํ‰๊ฐ€๋Š” evaluation() ํ•จ์ˆ˜์— ์ •์˜๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ํ•จ์ˆ˜๋Š” ๋ชจ๋“  ํ…Œ์ŠคํŠธ ์˜ˆ์ œ๋ฅผ ๊ฐ€์ ธ์™€ ์„ ํ˜• ํšŒ๊ท€ ๋ชจ๋ธ์˜ ์†์‹ค์„ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.

def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]:
    num_examples = X_test.shape[0]
    err_test = loss_fn(params, X_test, y_test)
    loss_test = jnp.mean(jnp.square(err_test))
    # print(f'Test loss {loss_test}')
    return loss_test, num_examples

๋ฐ์ดํ„ฐ ๋กœ๋”ฉ, ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜, ํ›ˆ๋ จ ๋ฐ ํ‰๊ฐ€๋ฅผ ์ •์˜ํ–ˆ์œผ๋ฏ€๋กœ ์ด์ œ ๋ชจ๋“  ๊ฒƒ์„ ์ข…ํ•ฉํ•˜์—ฌ JAX๋ฅผ ์‚ฌ์šฉ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฏธ ์–ธ๊ธ‰ํ–ˆ๋“ฏ์ด jax.grad() ํ•จ์ˆ˜๋Š” :code:`main()`์— ์ •์˜๋˜์–ด :code:`train()`์— ์ „๋‹ฌ๋ฉ๋‹ˆ๋‹ค.

def main():
    X, y, X_test, y_test = load_data()
    model_shape = X.shape[1:]
    grad_fn = jax.grad(loss_fn)
    print("Model Shape", model_shape)
    params = load_model(model_shape)
    params, loss, num_examples = train(params, grad_fn, X, y)
    evaluation(params, grad_fn, X_test, y_test)


if __name__ == "__main__":
    main()

์ด์ œ (์ค‘์•™ ์ง‘์ค‘์‹) JAX ์„ ํ˜• ํšŒ๊ท€ ์›Œํฌ๋กœ๋“œ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

python3 jax_training.py

์ง€๊ธˆ๊นŒ์ง€๋Š” JAX๋ฅผ ์‚ฌ์šฉํ•ด ๋ณธ ์ ์ด ์žˆ๋‹ค๋ฉด ์ด ๋ชจ๋“  ๊ฒƒ์ด ์ƒ๋‹นํžˆ ์ต์ˆ™ํ•ด ๋ณด์ผ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ๋„˜์–ด๊ฐ€์„œ ์šฐ๋ฆฌ๊ฐ€ ๊ตฌ์ถ•ํ•œ ๊ฒƒ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•˜๋‚˜์˜ ์„œ๋ฒ„์™€ ๋‘ ๊ฐœ์˜ ํด๋ผ์ด์–ธํŠธ๋กœ ๊ตฌ์„ฑ๋œ ๊ฐ„๋‹จํ•œ ์—ฐํ•ฉ ํ•™์Šต ์‹œ์Šคํ…œ์„ ๋งŒ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

JAX์™€ Flower์˜ ๋งŒ๋‚จ#

๊ธฐ์กด ์›Œํฌ๋กœ๋“œ๋ฅผ ์—ฐํ•ฉํ•˜๋Š” ๊ฐœ๋…์€ ํ•ญ์ƒ ๋™์ผํ•˜๊ณ  ์ดํ•ดํ•˜๊ธฐ ์‰ฝ์Šต๋‹ˆ๋‹ค. ์„œ๋ฒ„*๋ฅผ ์‹œ์ž‘ํ•œ ๋‹ค์Œ *์„œ๋ฒ„*์— ์—ฐ๊ฒฐ๋œ *ํด๋ผ์ด์–ธํŠธ*์— ๋Œ€ํ•ด :code:`jax_training.py`์˜ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. *์„œ๋ฒ„*๋Š” ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํด๋ผ์ด์–ธํŠธ๋กœ ์ „์†กํ•ฉ๋‹ˆ๋‹ค. ํด๋ผ์ด์–ธํŠธ๋Š” ํ•™์Šต์„ ์‹คํ–‰ํ•˜๊ณ  ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค. ์—…๋ฐ์ดํŠธ๋œ ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” *์„œ๋ฒ„*๋กœ ๋‹ค์‹œ ์ „์†ก๋˜๋ฉฐ, ์ˆ˜์‹ ๋œ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ์˜ ํ‰๊ท ์„ ๊ตฌํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์—ฐํ•ฉ ํ•™์Šต ํ”„๋กœ์„ธ์Šค์˜ ํ•œ ๋ผ์šด๋“œ๋ฅผ ์„ค๋ช…ํ•˜๋ฉฐ, ์ด ๊ณผ์ •์„ ์—ฌ๋Ÿฌ ๋ผ์šด๋“œ์— ๊ฑธ์ณ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.

์ด ์˜ˆ์ œ๋Š” ํ•˜๋‚˜์˜ *์„œ๋ฒ„*์™€ ๋‘ ๊ฐœ์˜ *ํด๋ผ์ด์–ธํŠธ*๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค. ๋จผ์ € server.py`๋ฅผ ์„ค์ •ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. *server*๋Š” Flower ํŒจํ‚ค์ง€ :code:`flwr`๋ฅผ ๊ฐ€์ ธ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์œผ๋กœ, :code:`start_server ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์„œ๋ฒ„๋ฅผ ์‹œ์ž‘ํ•˜๊ณ  ์„ธ ์ฐจ๋ก€์˜ ์—ฐํ•ฉ ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•˜๋„๋ก ์ง€์‹œํ•ฉ๋‹ˆ๋‹ค.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))

์ด๋ฏธ *์„œ๋ฒ„*๋ฅผ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

python3 server.py

๋งˆ์ง€๋ง‰์œผ๋กœ, client.py`์—์„œ *client* ๋กœ์ง์„ ์ •์˜ํ•˜๊ณ  :code:`jax_training.py`์—์„œ ์ด์ „์— ์ •์˜ํ•œ JAX ๊ต์œก์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋นŒ๋“œํ•ฉ๋‹ˆ๋‹ค. *ํด๋ผ์ด์–ธํŠธ*๋Š” :code:`flwr`์„ ๊ฐ€์ ธ์™€์•ผ ํ•˜๋ฉฐ, JAX ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•˜๊ธฐ ์œ„ํ•ด :code:`jax ๋ฐ :code:`jaxlib`๋„ ๊ฐ€์ ธ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค:

from typing import Dict, List, Callable, Tuple

import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp

import jax_training

Flower *ํด๋ผ์ด์–ธํŠธ*๋ฅผ ๊ตฌํ˜„ํ•œ๋‹ค๋Š” ๊ฒƒ์€ ๊ธฐ๋ณธ์ ์œผ๋กœ flwr.client.Client ๋˜๋Š” :code:`flwr.client.NumPyClient`์˜ ์„œ๋ธŒํด๋ž˜์Šค๋ฅผ ๊ตฌํ˜„ํ•˜๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. ๊ตฌํ˜„์€ :code:`flwr.client.NumPyClient`๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜๋ฉฐ, ์ด๋ฅผ :code:`FlowerClient`๋ผ๊ณ  ๋ถ€๋ฅผ ๊ฒƒ์ž…๋‹ˆ๋‹ค. :code:`NumPyClient`๋Š” ํ•„์š”ํ•œ ์ผ๋ถ€ ๋ณด์ผ๋Ÿฌํ”Œ๋ ˆ์ด๋ฅผ ํ”ผํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— NumPy ์ƒํ˜ธ ์šด์šฉ์„ฑ์ด ์ข‹์€ ํ”„๋ ˆ์ž„์›Œํฌ(์˜ˆ: JAX)๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ :code:`Client`๋ณด๋‹ค ๊ตฌํ˜„ํ•˜๊ธฐ๊ฐ€ ์•ฝ๊ฐ„ ๋” ์‰ฝ์Šต๋‹ˆ๋‹ค. code:`FlowerClient`๋Š” ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ ธ์˜ค๊ฑฐ๋‚˜ ์„ค์ •ํ•˜๋Š” ๋ฉ”์„œ๋“œ 2๊ฐœ, ๋ชจ๋ธ ํ•™์Šต์„ ์œ„ํ•œ ๋ฉ”์„œ๋“œ 1๊ฐœ, ๋ชจ๋ธ ํ…Œ์ŠคํŠธ๋ฅผ ์œ„ํ•œ ๋ฉ”์„œ๋“œ 1๊ฐœ ๋“ฑ ์ด 4๊ฐœ์˜ ๋ฉ”์„œ๋“œ๋ฅผ ๊ตฌํ˜„ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

  1. set_parameters (์„ ํƒ์‚ฌํ•ญ)
    • ์„œ๋ฒ„์—์„œ ์ˆ˜์‹ ํ•œ ๋กœ์ปฌ ๋ชจ๋ธ์˜ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค

    • ๋งค๊ฐœ ๋ณ€์ˆ˜๋ฅผ NumPy :code:`ndarray`๋กœ ๋ณ€ํ™˜

    • (์‹ ๊ฒฝ๋ง ๋ ˆ์ด์–ด ๋ชฉ๋ก์œผ๋กœ ์ƒ๊ฐํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค) NumPy :code:`ndarray`๋กœ ๋ฐ›์€ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ๋ชฉ๋ก์— ๋Œ€ํ•ด ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค

  2. get_parameters
    • ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ ธ์™€์„œ NumPy :code:`ndarray`์˜ ๋ชฉ๋ก์œผ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค(์ด๋Š” :code:`flwr.client.NumPyClient`๊ฐ€ ๊ธฐ๋Œ€ํ•˜๋Š” ๋ฐ”์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค)

  3. fit
    • ์„œ๋ฒ„์—์„œ ๋ฐ›์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๋กœ์ปฌ ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค

    • ๋กœ์ปฌ ํ›ˆ๋ จ ์„ธํŠธ์—์„œ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค

    • ์—…๋ฐ์ดํŠธ๋œ ๋กœ์ปฌ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐ€์ ธ์™€ ์„œ๋ฒ„๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค

  4. evaluate
    • ์„œ๋ฒ„์—์„œ ๋ฐ›์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๋กœ์ปฌ ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค

    • ๋กœ์ปฌ ํ…Œ์ŠคํŠธ ์„ธํŠธ์—์„œ ์—…๋ฐ์ดํŠธ๋œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค

    • ๋กœ์ปฌ ์†์‹ค์„ ์„œ๋ฒ„๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค

์–ด๋ ค์šด ๋ถ€๋ถ„์€ JAX ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ :code:`DeviceArray`์—์„œ :code:`NumPy ndarray`๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ `NumPyClient`์™€ ํ˜ธํ™˜๋˜๋„๋ก ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋‘ ๊ฐœ์˜ NumPyClient ๋ฉ”์„œ๋“œ์ธ fit`๊ณผ :code:`evaluate`๋Š” ์ด์ „์— :code:`jax_training.py`์— ์ •์˜๋œ ํ•จ์ˆ˜ :code:`train()`๊ณผ :code:`evaluate()`๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์—ฌ๊ธฐ์„œ ์šฐ๋ฆฌ๊ฐ€ ์‹ค์ œ๋กœ ํ•˜๋Š” ์ผ์€ ์ด๋ฏธ ์ •์˜๋œ ํ•จ์ˆ˜ ์ค‘ ํ›ˆ๋ จ๊ณผ ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด ํ˜ธ์ถœํ•  ํ•จ์ˆ˜๋ฅผ :code:`NumPyClient ์„œ๋ธŒํด๋ž˜์Šค๋ฅผ ํ†ตํ•ด Flower์—๊ฒŒ ์•Œ๋ ค์ฃผ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ „๋‹ฌ๋˜๋Š” ๋ฐ์ดํ„ฐ ์œ ํ˜•์„ ๋” ์ž˜ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก ์œ ํ˜• type annotation์„ ํฌํ•จํ–ˆ์Šต๋‹ˆ๋‹ค.

class FlowerClient(fl.client.NumPyClient):
    """Flower client implementing using linear regression and JAX."""

    def __init__(
        self,
        params: Dict,
        grad_fn: Callable,
        train_x: List[np.ndarray],
        train_y: List[np.ndarray],
        test_x: List[np.ndarray],
        test_y: List[np.ndarray],
    ) -> None:
        self.params= params
        self.grad_fn = grad_fn
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y

    def get_parameters(self, config) -> Dict:
        # Return model parameters as a list of NumPy ndarrays
        parameter_value = []
        for _, val in self.params.items():
            parameter_value.append(np.array(val))
        return parameter_value

    def set_parameters(self, parameters: List[np.ndarray]) -> Dict:
        # Collect model parameters and update the parameters of the local model
        value=jnp.ndarray
        params_item = list(zip(self.params.keys(),parameters))
        for item in params_item:
            key = item[0]
            value = item[1]
            self.params[key] = value
        return self.params


    def fit(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        print("Start local training")
        self.params = self.set_parameters(parameters)
        self.params, loss, num_examples = jax_training.train(self.params, self.grad_fn, self.train_x, self.train_y)
        results = {"loss": float(loss)}
        print("Training results", results)
        return self.get_parameters(config={}), num_examples, results

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate the model on a local test dataset, return result
        print("Start evaluation")
        self.params = self.set_parameters(parameters)
        loss, num_examples = jax_training.evaluation(self.params,self.grad_fn, self.test_x, self.test_y)
        print("Evaluation accuracy & loss", loss)
        return (
            float(loss),
            num_examples,
            {"loss": float(loss)},
        )

์—ฐํ•ฉ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ •์˜ํ–ˆ์œผ๋ฉด ์ด์ œ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def main() -> None:
    """Load data, start MNISTClient."""

    # Load data
    train_x, train_y, test_x, test_y = jax_training.load_data()
    grad_fn = jax.grad(jax_training.loss_fn)

    # Load model (from centralized training) and initialize parameters
    model_shape = train_x.shape[1:]
    params = jax_training.load_model(model_shape)

    # Start Flower client
    client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
    fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

if __name__ == "__main__":
    main()

์—ฌ๊ธฐ๊นŒ์ง€์ž…๋‹ˆ๋‹ค. ์ด์ œ ๋‘ ๊ฐœ์˜ ํ„ฐ๋ฏธ๋„ ์ฐฝ์„ ์ถ”๊ฐ€๋กœ ์—ด๊ณ  ๋‹ค์Œ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค

python3 client.py

๋ฅผ ์ž…๋ ฅํ•˜๊ณ (๊ทธ ์ „์— ์„œ๋ฒ„๊ฐ€ ๊ณ„์† ์‹คํ–‰ ์ค‘์ธ์ง€ ํ™•์ธํ•˜์„ธ์š”) ๋‘ ํด๋ผ์ด์–ธํŠธ์—์„œ ์—ฐํ•ฉ ํ•™์Šต์„ ์‹คํ–‰ํ•˜๋Š” JAX ํ”„๋กœ์ ํŠธ๋ฅผ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค. ์ถ•ํ•˜ํ•ฉ๋‹ˆ๋‹ค!

๋‹ค์Œ ๋‹จ๊ณ„#

์ด ์˜ˆ์ œ์˜ ์†Œ์Šค ์ฝ”๋“œ๋Š” ์‹œ๊ฐ„์ด ์ง€๋‚จ์— ๋”ฐ๋ผ ๊ฐœ์„ ๋˜์—ˆ์œผ๋ฉฐ ์—ฌ๊ธฐ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค: โ€˜Quickstart JAX <https://github.com/adap/flower/blob/main/examples/quickstart-jax>`_. ๋‘ ํด๋ผ์ด์–ธํŠธ๊ฐ€ ๋™์ผํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋กœ๋“œํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด ์˜ˆ์ œ๋Š” ๋‹ค์†Œ ๋‹จ์ˆœํ™”๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

์ด์ œ ์ด ์ฃผ์ œ๋ฅผ ๋” ์ž์„ธํžˆ ์‚ดํŽด๋ณผ ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋” ์ •๊ตํ•œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ์ง‘ํ•ฉ์„ ์‚ฌ์šฉํ•ด ๋ณด๋Š” ๊ฒƒ์€ ์–ด๋–จ๊นŒ์š”? ํด๋ผ์ด์–ธํŠธ๋ฅผ ๋” ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์€ ์–ด๋–จ๊นŒ์š”?