Quickstart JAX#
์ด ํํ ๋ฆฌ์ผ์์๋ Flower๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ์กด JAX ์ํฌ๋ก๋์ ์ฐํฉ ๋ฒ์ ์ ๊ตฌ์ถํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ๋๋ฆฝ๋๋ค. JAX๋ฅผ ์ฌ์ฉํด scikit-learn ๋ฐ์ดํฐ ์ธํธ์์ ์ ํ ํ๊ท ๋ชจ๋ธ์ ํ๋ จํ๊ณ ์์ต๋๋ค. ์์ ๋ โํ์ดํ ์น - Centralized์์ Federated์ผ๋ก <https://github.com/adap/flower/blob/main/examples/pytorch-from-centralized-to-federated>`_ ์ํฌ์ค๋ฃจ์ ์ ์ฌํ๊ฒ ๊ตฌ์ฑํ๊ฒ ์ต๋๋ค. ๋จผ์ , JAX๋ฅผ ์ฌ์ฉํ ์ ํ ํ๊ท ํํ ๋ฆฌ์ผ`์ ๊ธฐ๋ฐ์ผ๋ก centralized ํ์ต ์ ๊ทผ ๋ฐฉ์์ ๊ตฌ์ถํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ centralized ํธ๋ ์ด๋ ์ฝ๋๋ฅผ ๊ธฐ๋ฐ์ผ๋ก federated ๋ฐฉ์์ผ๋ก ํธ๋ ์ด๋์ ์คํํฉ๋๋ค.
JAX ์์ ๋น๋๋ฅผ ์์ํ๊ธฐ ์ ์ jax
, jaxlib
, scikit-learn
, flwr
ํจํค์ง๋ฅผ ์ค์นํด์ผ ํฉ๋๋ค:
$ pip install jax jaxlib scikit-learn flwr
JAX๋ฅผ ์ฌ์ฉํ ์ ํ ํ๊ท#
๋จผ์ ์ ํ ํ๊ท
๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ ์ค์ ์ง์ค์ ํ๋ จ ์ฝ๋์ ๋ํ ๊ฐ๋ตํ ์ค๋ช
๋ถํฐ ์์ํ๊ฒ ์ต๋๋ค. ๋ ์์ธํ ์ค๋ช
์ ์ํ์๋ฉด ๊ณต์ `JAX ๋ฌธ์ <https://jax.readthedocs.io/>`_๋ฅผ ์ฐธ์กฐํ์ธ์.
์ ํต์ ์ธ(์ค์ ์ง์ค์) ์ ํ ํ๊ท ํ๋ จ์ ํ์ํ ๋ชจ๋ ๊ตฌ์ฑ ์์๊ฐ ํฌํจ๋ jax_training.py`๋ผ๋ ์ ํ์ผ์ ์์ฑํด ๋ณด๊ฒ ์ต๋๋ค. ๋จผ์ , JAX ํจํค์ง์ธ :code:`jax`์ :code:`jaxlib`๋ฅผ ๊ฐ์ ธ์์ผ ํฉ๋๋ค. ๋ํ ๋ฐ์ดํฐ ์ธํธ์ :code:`make_regression`์ ์ฌ์ฉํ๊ณ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ํ์ต ๋ฐ ํ
์คํธ ์ธํธ๋ก ๋ถํ ํ๊ธฐ ์ํด :code:`train_test_split`์ ์ฌ์ฉํ๋ฏ๋ก :code:`sklearn`์ ๊ฐ์ ธ์์ผ ํฉ๋๋ค. ์ฐํฉ ํ์ต์ ์ํด ์์ง :code:`flwr
ํจํค์ง๋ฅผ ๊ฐ์ ธ์ค์ง ์์ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ์ด ์์
์ ๋์ค์ ์ํ๋ฉ๋๋ค.
from typing import Dict, List, Tuple, Callable
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
key = jax.random.PRNGKey(0)
code:load_data() ํจ์๋ ์์ ์ธ๊ธํ ํธ๋ ์ด๋ ๋ฐ ํ ์คํธ ์ธํธ๋ฅผ ๋ก๋ํฉ๋๋ค.
def load_data() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
# create our dataset and start with similar datasets for different clients
X, y = make_regression(n_features=3, random_state=0)
X, X_test, y, y_test = train_test_split(X, y)
return X, y, X_test, y_test
๋ชจ๋ธ ์ํคํ
์ฒ(๋งค์ฐ ๊ฐ๋จํ ์ ํ ํ๊ท
๋ชจ๋ธ)๋ :code:`load_model()`์ ์ ์๋์ด ์์ต๋๋ค.
def load_model(model_shape) -> Dict:
# model weights
params = {
'b' : jax.random.uniform(key),
'w' : jax.random.uniform(key, model_shape)
}
return params
์ด์ ํ๋ จ ์งํฉ์ ๋ฐ๋ณตํ๊ณ ๊ฐ ํ๋ จ ์์ ๋ฐฐ์น์ ๋ํด ์์ค์ ์ธก์ ํ๋(ํจ์ loss_fn()
) ํ๋ จ(ํจ์ train()
)์ ์ ์ํด์ผ ํฉ๋๋ค. JAX๋ grad()
ํจ์(main()
ํจ์์ ์ ์๋๊ณ :code:`train()`์์ ํธ์ถ๋จ)๋ก ํ์๋ฌผ์ ์ทจํ๋ฏ๋ก ์์ค ํจ์๋ ๋ถ๋ฆฌ๋์ด ์์ต๋๋ค.
def loss_fn(params, X, y) -> Callable:
err = jnp.dot(X, params['w']) + params['b'] - y
return jnp.mean(jnp.square(err)) # mse
def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]:
num_examples = X.shape[0]
for epochs in range(10):
grads = grad_fn(params, X, y)
params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
loss = loss_fn(params,X, y)
# if epochs % 10 == 9:
# print(f'For Epoch {epochs} loss {loss}')
return params, loss, num_examples
๋ชจ๋ธ์ ํ๊ฐ๋ evaluation()
ํจ์์ ์ ์๋์ด ์์ต๋๋ค. ์ด ํจ์๋ ๋ชจ๋ ํ
์คํธ ์์ ๋ฅผ ๊ฐ์ ธ์ ์ ํ ํ๊ท ๋ชจ๋ธ์ ์์ค์ ์ธก์ ํฉ๋๋ค.
def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]:
num_examples = X_test.shape[0]
err_test = loss_fn(params, X_test, y_test)
loss_test = jnp.mean(jnp.square(err_test))
# print(f'Test loss {loss_test}')
return loss_test, num_examples
๋ฐ์ดํฐ ๋ก๋ฉ, ๋ชจ๋ธ ์ํคํ
์ฒ, ํ๋ จ ๋ฐ ํ๊ฐ๋ฅผ ์ ์ํ์ผ๋ฏ๋ก ์ด์ ๋ชจ๋ ๊ฒ์ ์ข
ํฉํ์ฌ JAX๋ฅผ ์ฌ์ฉ ๋ชจ๋ธ์ ํ๋ จํ ์ ์์ต๋๋ค. ์ด๋ฏธ ์ธ๊ธํ๋ฏ์ด jax.grad()
ํจ์๋ :code:`main()`์ ์ ์๋์ด :code:`train()`์ ์ ๋ฌ๋ฉ๋๋ค.
def main():
X, y, X_test, y_test = load_data()
model_shape = X.shape[1:]
grad_fn = jax.grad(loss_fn)
print("Model Shape", model_shape)
params = load_model(model_shape)
params, loss, num_examples = train(params, grad_fn, X, y)
evaluation(params, grad_fn, X_test, y_test)
if __name__ == "__main__":
main()
์ด์ (์ค์ ์ง์ค์) JAX ์ ํ ํ๊ท ์ํฌ๋ก๋๋ฅผ ์คํํ ์ ์์ต๋๋ค:
python3 jax_training.py
์ง๊ธ๊น์ง๋ JAX๋ฅผ ์ฌ์ฉํด ๋ณธ ์ ์ด ์๋ค๋ฉด ์ด ๋ชจ๋ ๊ฒ์ด ์๋นํ ์ต์ํด ๋ณด์ผ ๊ฒ์ ๋๋ค. ๋ค์ ๋จ๊ณ๋ก ๋์ด๊ฐ์ ์ฐ๋ฆฌ๊ฐ ๊ตฌ์ถํ ๊ฒ์ ์ฌ์ฉํ์ฌ ํ๋์ ์๋ฒ์ ๋ ๊ฐ์ ํด๋ผ์ด์ธํธ๋ก ๊ตฌ์ฑ๋ ๊ฐ๋จํ ์ฐํฉ ํ์ต ์์คํ ์ ๋ง๋ค์ด ๋ณด๊ฒ ์ต๋๋ค.
JAX์ Flower์ ๋ง๋จ#
๊ธฐ์กด ์ํฌ๋ก๋๋ฅผ ์ฐํฉํ๋ ๊ฐ๋ ์ ํญ์ ๋์ผํ๊ณ ์ดํดํ๊ธฐ ์ฝ์ต๋๋ค. ์๋ฒ*๋ฅผ ์์ํ ๋ค์ *์๋ฒ*์ ์ฐ๊ฒฐ๋ *ํด๋ผ์ด์ธํธ*์ ๋ํด :code:`jax_training.py`์ ์ฝ๋๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. *์๋ฒ*๋ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ ํด๋ผ์ด์ธํธ๋ก ์ ์กํฉ๋๋ค. ํด๋ผ์ด์ธํธ๋ ํ์ต์ ์คํํ๊ณ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค. ์ ๋ฐ์ดํธ๋ ํ๋ผ๋ฏธํฐ๋ *์๋ฒ*๋ก ๋ค์ ์ ์ก๋๋ฉฐ, ์์ ๋ ๋ชจ๋ ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ์ ํ๊ท ์ ๊ตฌํฉ๋๋ค. ์ด๋ ์ฐํฉ ํ์ต ํ๋ก์ธ์ค์ ํ ๋ผ์ด๋๋ฅผ ์ค๋ช ํ๋ฉฐ, ์ด ๊ณผ์ ์ ์ฌ๋ฌ ๋ผ์ด๋์ ๊ฑธ์ณ ๋ฐ๋ณตํฉ๋๋ค.
์ด ์์ ๋ ํ๋์ *์๋ฒ*์ ๋ ๊ฐ์ *ํด๋ผ์ด์ธํธ*๋ก ๊ตฌ์ฑ๋ฉ๋๋ค. ๋จผ์ server.py`๋ฅผ ์ค์ ํด ๋ณด๊ฒ ์ต๋๋ค. *server*๋ Flower ํจํค์ง :code:`flwr`๋ฅผ ๊ฐ์ ธ์์ผ ํฉ๋๋ค. ๋ค์์ผ๋ก, :code:`start_server
ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์๋ฒ๋ฅผ ์์ํ๊ณ ์ธ ์ฐจ๋ก์ ์ฐํฉ ํ์ต์ ์ํํ๋๋ก ์ง์ํฉ๋๋ค.
import flwr as fl
if __name__ == "__main__":
fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))
์ด๋ฏธ *์๋ฒ*๋ฅผ ์์ํ ์ ์์ต๋๋ค:
python3 server.py
๋ง์ง๋ง์ผ๋ก, client.py`์์ *client* ๋ก์ง์ ์ ์ํ๊ณ :code:`jax_training.py`์์ ์ด์ ์ ์ ์ํ JAX ๊ต์ก์ ๊ธฐ๋ฐ์ผ๋ก ๋น๋ํฉ๋๋ค. *ํด๋ผ์ด์ธํธ*๋ :code:`flwr`์ ๊ฐ์ ธ์์ผ ํ๋ฉฐ, JAX ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์
๋ฐ์ดํธํ๊ธฐ ์ํด :code:`jax
๋ฐ :code:`jaxlib`๋ ๊ฐ์ ธ์์ผ ํฉ๋๋ค:
from typing import Dict, List, Callable, Tuple
import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp
import jax_training
Flower *ํด๋ผ์ด์ธํธ*๋ฅผ ๊ตฌํํ๋ค๋ ๊ฒ์ ๊ธฐ๋ณธ์ ์ผ๋ก flwr.client.Client
๋๋ :code:`flwr.client.NumPyClient`์ ์๋ธํด๋์ค๋ฅผ ๊ตฌํํ๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. ๊ตฌํ์ :code:`flwr.client.NumPyClient`๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋ฉฐ, ์ด๋ฅผ :code:`FlowerClient`๋ผ๊ณ ๋ถ๋ฅผ ๊ฒ์
๋๋ค. :code:`NumPyClient`๋ ํ์ํ ์ผ๋ถ ๋ณด์ผ๋ฌํ๋ ์ด๋ฅผ ํผํ ์ ์๊ธฐ ๋๋ฌธ์ NumPy ์ํธ ์ด์ฉ์ฑ์ด ์ข์ ํ๋ ์์ํฌ(์: JAX)๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ :code:`Client`๋ณด๋ค ๊ตฌํํ๊ธฐ๊ฐ ์ฝ๊ฐ ๋ ์ฝ์ต๋๋ค. code:`FlowerClient`๋ ๋ชจ๋ธ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ ธ์ค๊ฑฐ๋ ์ค์ ํ๋ ๋ฉ์๋ 2๊ฐ, ๋ชจ๋ธ ํ์ต์ ์ํ ๋ฉ์๋ 1๊ฐ, ๋ชจ๋ธ ํ
์คํธ๋ฅผ ์ํ ๋ฉ์๋ 1๊ฐ ๋ฑ ์ด 4๊ฐ์ ๋ฉ์๋๋ฅผ ๊ตฌํํด์ผ ํฉ๋๋ค:
set_parameters (์ ํ์ฌํญ)
fit
์๋ฒ์์ ๋ฐ์ ํ๋ผ๋ฏธํฐ๋ก ๋ก์ปฌ ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค
๋ก์ปฌ ํ๋ จ ์ธํธ์์ ๋ชจ๋ธ์ ํ๋ จํฉ๋๋ค
์ ๋ฐ์ดํธ๋ ๋ก์ปฌ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ ธ์ ์๋ฒ๋ก ๋ฐํํฉ๋๋ค
evaluate
์๋ฒ์์ ๋ฐ์ ํ๋ผ๋ฏธํฐ๋ก ๋ก์ปฌ ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค
๋ก์ปฌ ํ ์คํธ ์ธํธ์์ ์ ๋ฐ์ดํธ๋ ๋ชจ๋ธ์ ํ๊ฐํฉ๋๋ค
๋ก์ปฌ ์์ค์ ์๋ฒ๋ก ๋ฐํํฉ๋๋ค
์ด๋ ค์ด ๋ถ๋ถ์ JAX ๋ชจ๋ธ ๋งค๊ฐ๋ณ์๋ฅผ :code:`DeviceArray`์์ :code:`NumPy ndarray`๋ก ๋ณํํ์ฌ `NumPyClient`์ ํธํ๋๋๋ก ํ๋ ๊ฒ์ ๋๋ค.
๋ ๊ฐ์ NumPyClient
๋ฉ์๋์ธ fit`๊ณผ :code:`evaluate`๋ ์ด์ ์ :code:`jax_training.py`์ ์ ์๋ ํจ์ :code:`train()`๊ณผ :code:`evaluate()`๋ฅผ ์ฌ์ฉํฉ๋๋ค. ๋ฐ๋ผ์ ์ฌ๊ธฐ์ ์ฐ๋ฆฌ๊ฐ ์ค์ ๋ก ํ๋ ์ผ์ ์ด๋ฏธ ์ ์๋ ํจ์ ์ค ํ๋ จ๊ณผ ํ๊ฐ๋ฅผ ์ํด ํธ์ถํ ํจ์๋ฅผ :code:`NumPyClient
์๋ธํด๋์ค๋ฅผ ํตํด Flower์๊ฒ ์๋ ค์ฃผ๋ ๊ฒ์
๋๋ค. ์ ๋ฌ๋๋ ๋ฐ์ดํฐ ์ ํ์ ๋ ์ ์ดํดํ ์ ์๋๋ก ์ ํ type annotation์ ํฌํจํ์ต๋๋ค.
class FlowerClient(fl.client.NumPyClient):
"""Flower client implementing using linear regression and JAX."""
def __init__(
self,
params: Dict,
grad_fn: Callable,
train_x: List[np.ndarray],
train_y: List[np.ndarray],
test_x: List[np.ndarray],
test_y: List[np.ndarray],
) -> None:
self.params= params
self.grad_fn = grad_fn
self.train_x = train_x
self.train_y = train_y
self.test_x = test_x
self.test_y = test_y
def get_parameters(self, config) -> Dict:
# Return model parameters as a list of NumPy ndarrays
parameter_value = []
for _, val in self.params.items():
parameter_value.append(np.array(val))
return parameter_value
def set_parameters(self, parameters: List[np.ndarray]) -> Dict:
# Collect model parameters and update the parameters of the local model
value=jnp.ndarray
params_item = list(zip(self.params.keys(),parameters))
for item in params_item:
key = item[0]
value = item[1]
self.params[key] = value
return self.params
def fit(
self, parameters: List[np.ndarray], config: Dict
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
print("Start local training")
self.params = self.set_parameters(parameters)
self.params, loss, num_examples = jax_training.train(self.params, self.grad_fn, self.train_x, self.train_y)
results = {"loss": float(loss)}
print("Training results", results)
return self.get_parameters(config={}), num_examples, results
def evaluate(
self, parameters: List[np.ndarray], config: Dict
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate the model on a local test dataset, return result
print("Start evaluation")
self.params = self.set_parameters(parameters)
loss, num_examples = jax_training.evaluation(self.params,self.grad_fn, self.test_x, self.test_y)
print("Evaluation accuracy & loss", loss)
return (
float(loss),
num_examples,
{"loss": float(loss)},
)
์ฐํฉ ํ๋ก์ธ์ค๋ฅผ ์ ์ํ์ผ๋ฉด ์ด์ ์คํํ ์ ์์ต๋๋ค.
def main() -> None:
"""Load data, start MNISTClient."""
# Load data
train_x, train_y, test_x, test_y = jax_training.load_data()
grad_fn = jax.grad(jax_training.loss_fn)
# Load model (from centralized training) and initialize parameters
model_shape = train_x.shape[1:]
params = jax_training.load_model(model_shape)
# Start Flower client
client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
if __name__ == "__main__":
main()
์ฌ๊ธฐ๊น์ง์ ๋๋ค. ์ด์ ๋ ๊ฐ์ ํฐ๋ฏธ๋ ์ฐฝ์ ์ถ๊ฐ๋ก ์ด๊ณ ๋ค์์ ์คํํ ์ ์์ต๋๋ค
python3 client.py
๋ฅผ ์ ๋ ฅํ๊ณ (๊ทธ ์ ์ ์๋ฒ๊ฐ ๊ณ์ ์คํ ์ค์ธ์ง ํ์ธํ์ธ์) ๋ ํด๋ผ์ด์ธํธ์์ ์ฐํฉ ํ์ต์ ์คํํ๋ JAX ํ๋ก์ ํธ๋ฅผ ํ์ธํฉ๋๋ค. ์ถํํฉ๋๋ค!
๋ค์ ๋จ๊ณ#
์ด ์์ ์ ์์ค ์ฝ๋๋ ์๊ฐ์ด ์ง๋จ์ ๋ฐ๋ผ ๊ฐ์ ๋์์ผ๋ฉฐ ์ฌ๊ธฐ์์ ํ์ธํ ์ ์์ต๋๋ค: โQuickstart JAX <https://github.com/adap/flower/blob/main/examples/quickstart-jax>`_. ๋ ํด๋ผ์ด์ธํธ๊ฐ ๋์ผํ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ก๋ํ๊ธฐ ๋๋ฌธ์ ์ด ์์ ๋ ๋ค์ ๋จ์ํ๋์ด ์์ต๋๋ค.
์ด์ ์ด ์ฃผ์ ๋ฅผ ๋ ์์ธํ ์ดํด๋ณผ ์ค๋น๊ฐ ๋์์ต๋๋ค. ๋ ์ ๊ตํ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ฑฐ๋ ๋ค๋ฅธ ๋ฐ์ดํฐ ์งํฉ์ ์ฌ์ฉํด ๋ณด๋ ๊ฒ์ ์ด๋จ๊น์? ํด๋ผ์ด์ธํธ๋ฅผ ๋ ์ถ๊ฐํ๋ ๊ฒ์ ์ด๋จ๊น์?