# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Message."""
from __future__ import annotations
from logging import WARNING
from typing import Any, Optional, cast, overload
from flwr.common.date import now
from flwr.common.logger import warn_deprecated_feature
from .constant import MESSAGE_TTL_TOLERANCE, MessageType, MessageTypeLegacy
from .logger import log
from .record import RecordDict
DEFAULT_TTL = 43200 # This is 12 hours
MESSAGE_INIT_ERROR_MESSAGE = (
"Invalid arguments for Message. Expected one of the documented "
"signatures: Message(content: RecordDict, dst_node_id: int, message_type: str,"
" *, [ttl: float, group_id: str]) or Message(content: RecordDict | error: Error,"
" *, reply_to: Message, [ttl: float])."
)
class MessageInitializationError(TypeError):
"""Error raised when initializing a message with invalid arguments."""
def __init__(self, message: str | None = None) -> None:
super().__init__(message or MESSAGE_INIT_ERROR_MESSAGE)
[๋ฌธ์]
class Error:
"""The class storing information about an error that occurred.
Parameters
----------
code : int
An identifier for the error.
reason : Optional[str]
A reason for why the error arose (e.g. an exception stack-trace)
"""
def __init__(self, code: int, reason: str | None = None) -> None:
var_dict = {
"_code": code,
"_reason": reason,
}
self.__dict__.update(var_dict)
@property
def code(self) -> int:
"""Error code."""
return cast(int, self.__dict__["_code"])
@property
def reason(self) -> str | None:
"""Reason reported about the error."""
return cast(Optional[str], self.__dict__["_reason"])
def __repr__(self) -> str:
"""Return a string representation of this instance."""
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
return f"{self.__class__.__qualname__}({view})"
def __eq__(self, other: object) -> bool:
"""Compare two instances of the class."""
if not isinstance(other, self.__class__):
raise NotImplementedError
return self.__dict__ == other.__dict__
[๋ฌธ์]
class Message:
"""Represents a message exchanged between ClientApp and ServerApp.
This class encapsulates the payload and metadata necessary for communication
between a ClientApp and a ServerApp.
Parameters
----------
content : Optional[RecordDict] (default: None)
Holds records either sent by another entity (e.g. sent by the server-side
logic to a client, or vice-versa) or that will be sent to it.
error : Optional[Error] (default: None)
A dataclass that captures information about an error that took place
when processing another message.
dst_node_id : Optional[int] (default: None)
An identifier for the node receiving this message.
message_type : Optional[str] (default: None)
A string that encodes the action to be executed on
the receiving end.
ttl : Optional[float] (default: None)
Time-to-live (TTL) for this message in seconds. If `None` (default),
the TTL is set to 43,200 seconds (12 hours).
group_id : Optional[str] (default: None)
An identifier for grouping messages. In some settings, this is used as
the FL round.
reply_to : Optional[Message] (default: None)
The instruction message to which this message is a reply. This message does
not retain the original message's content but derives its metadata from it.
"""
@overload
def __init__( # pylint: disable=too-many-arguments # noqa: E704
self,
content: RecordDict,
dst_node_id: int,
message_type: str,
*,
ttl: float | None = None,
group_id: str | None = None,
) -> None: ...
@overload
def __init__( # noqa: E704
self, content: RecordDict, *, reply_to: Message, ttl: float | None = None
) -> None: ...
@overload
def __init__( # noqa: E704
self, error: Error, *, reply_to: Message, ttl: float | None = None
) -> None: ...
def __init__( # pylint: disable=too-many-arguments
self,
*args: Any,
dst_node_id: int | None = None,
message_type: str | None = None,
content: RecordDict | None = None,
error: Error | None = None,
ttl: float | None = None,
group_id: str | None = None,
reply_to: Message | None = None,
metadata: Metadata | None = None,
) -> None:
# Set positional arguments
content, error, dst_node_id, message_type = _extract_positional_args(
*args,
content=content,
error=error,
dst_node_id=dst_node_id,
message_type=message_type,
)
_check_arg_types(
dst_node_id=dst_node_id,
message_type=message_type,
content=content,
error=error,
ttl=ttl,
group_id=group_id,
reply_to=reply_to,
metadata=metadata,
)
# Set metadata directly (This is for internal use only)
if metadata is not None:
# When metadata is set, all other arguments must be None,
# except `content`, `error`, or `content_or_error`
if any(
x is not None
for x in [dst_node_id, message_type, ttl, group_id, reply_to]
):
raise MessageInitializationError(
f"Invalid arguments for {Message.__qualname__}. "
"Expected only `metadata` to be set when creating a message "
"with provided metadata."
)
# Create metadata for an instruction message
elif reply_to is None:
# Check arguments
# `content`, `dst_node_id` and `message_type` must be set
if not (
isinstance(content, RecordDict)
and isinstance(dst_node_id, int)
and isinstance(message_type, str)
):
raise MessageInitializationError()
# Set metadata
metadata = Metadata(
run_id=0, # Will be set before pushed
message_id="", # Will be set by the SuperLink
src_node_id=0, # Will be set before pushed
dst_node_id=dst_node_id,
# Instruction messages do not reply to any message
reply_to_message_id="",
group_id=group_id or "",
created_at=now().timestamp(),
ttl=ttl or DEFAULT_TTL,
message_type=message_type,
)
# Create metadata for a reply message
else:
# Check arguments
# `dst_node_id`, `message_type` and `group_id` must not be set
if any(x is not None for x in [dst_node_id, message_type, group_id]):
raise MessageInitializationError()
# Set metadata
current = now().timestamp()
metadata = Metadata(
run_id=reply_to.metadata.run_id,
message_id="", # Will be set by the SuperLink
src_node_id=reply_to.metadata.dst_node_id,
dst_node_id=reply_to.metadata.src_node_id,
reply_to_message_id=reply_to.metadata.message_id,
group_id=reply_to.metadata.group_id,
created_at=current,
ttl=_limit_reply_ttl(current, ttl, reply_to),
message_type=reply_to.metadata.message_type,
)
metadata.delivered_at = "" # Backward compatibility
var_dict = {
"_metadata": metadata,
"_content": content,
"_error": error,
}
self.__dict__.update(var_dict)
@property
def metadata(self) -> Metadata:
"""A dataclass including information about the message to be executed."""
return cast(Metadata, self.__dict__["_metadata"])
@property
def content(self) -> RecordDict:
"""The content of this message."""
if self.__dict__["_content"] is None:
raise ValueError(
"Message content is None. Use <message>.has_content() "
"to check if a message has content."
)
return cast(RecordDict, self.__dict__["_content"])
@content.setter
def content(self, value: RecordDict) -> None:
"""Set content."""
if self.__dict__["_error"] is None:
self.__dict__["_content"] = value
else:
raise ValueError("A message with an error set cannot have content.")
@property
def error(self) -> Error:
"""Error captured by this message."""
if self.__dict__["_error"] is None:
raise ValueError(
"Message error is None. Use <message>.has_error() "
"to check first if a message carries an error."
)
return cast(Error, self.__dict__["_error"])
@error.setter
def error(self, value: Error) -> None:
"""Set error."""
if self.has_content():
raise ValueError("A message with content set cannot carry an error.")
self.__dict__["_error"] = value
[๋ฌธ์]
def has_content(self) -> bool:
"""Return True if message has content, else False."""
return self.__dict__["_content"] is not None
[๋ฌธ์]
def has_error(self) -> bool:
"""Return True if message has an error, else False."""
return self.__dict__["_error"] is not None
[๋ฌธ์]
def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
"""Construct a reply message indicating an error happened.
Parameters
----------
error : Error
The error that was encountered.
ttl : Optional[float] (default: None)
Time-to-live for this message in seconds. If unset, it will be set based
on the remaining time for the received message before it expires. This
follows the equation:
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
Returns
-------
message : Message
A Message containing only the relevant error and metadata.
"""
warn_deprecated_feature(
"`Message.create_error_reply` is deprecated. "
"Instead of calling `some_message.create_error_reply(some_error, ttl=...)`"
", use `Message(some_error, reply_to=some_message, ttl=...)`."
)
if ttl is not None:
return Message(error, reply_to=self, ttl=ttl)
return Message(error, reply_to=self)
[๋ฌธ์]
def create_reply(self, content: RecordDict, ttl: float | None = None) -> Message:
"""Create a reply to this message with specified content and TTL.
The method generates a new `Message` as a reply to this message.
It inherits 'run_id', 'src_node_id', 'dst_node_id', and 'message_type' from
this message and sets 'reply_to_message_id' to the ID of this message.
Parameters
----------
content : RecordDict
The content for the reply message.
ttl : Optional[float] (default: None)
Time-to-live for this message in seconds. If unset, it will be set based
on the remaining time for the received message before it expires. This
follows the equation:
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
Returns
-------
Message
A new `Message` instance representing the reply.
"""
warn_deprecated_feature(
"`Message.create_reply` is deprecated. "
"Instead of calling `some_message.create_reply(some_content, ttl=...)`"
", use `Message(some_content, reply_to=some_message, ttl=...)`."
)
if ttl is not None:
return Message(content, reply_to=self, ttl=ttl)
return Message(content, reply_to=self)
def __repr__(self) -> str:
"""Return a string representation of this instance."""
view = ", ".join(
[
f"{k.lstrip('_')}={v!r}"
for k, v in self.__dict__.items()
if v is not None
]
)
return f"{self.__class__.__qualname__}({view})"
def make_message(
metadata: Metadata, content: RecordDict | None = None, error: Error | None = None
) -> Message:
"""Create a message with the provided metadata, content, and error."""
return Message(metadata=metadata, content=content, error=error) # type: ignore
def _limit_reply_ttl(
current: float, reply_ttl: float | None, reply_to: Message
) -> float:
"""Limit the TTL of a reply message such that it does exceed the expiration time of
the message it replies to."""
# Calculate the maximum allowed TTL
max_allowed_ttl = reply_to.metadata.created_at + reply_to.metadata.ttl - current
if reply_ttl is not None and reply_ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE:
log(
WARNING,
"The reply TTL of %.2f seconds exceeded the "
"allowed maximum of %.2f seconds. "
"The TTL has been updated to the allowed maximum.",
reply_ttl,
max_allowed_ttl,
)
return max_allowed_ttl
return reply_ttl or max_allowed_ttl
def _extract_positional_args(
*args: Any,
content: RecordDict | None,
error: Error | None,
dst_node_id: int | None,
message_type: str | None,
) -> tuple[RecordDict | None, Error | None, int | None, str | None]:
"""Extract positional arguments for the `Message` constructor."""
content_or_error = args[0] if args else None
if len(args) > 1:
if dst_node_id is not None:
raise MessageInitializationError()
dst_node_id = args[1]
if len(args) > 2:
if message_type is not None:
raise MessageInitializationError()
message_type = args[2]
if len(args) > 3:
raise MessageInitializationError()
# One and only one of `content_or_error`, `content` and `error` must be set
if sum(x is not None for x in [content_or_error, content, error]) != 1:
raise MessageInitializationError()
# Set `content` or `error` based on `content_or_error`
if content_or_error is not None: # This means `content` and `error` are None
if isinstance(content_or_error, RecordDict):
content = content_or_error
elif isinstance(content_or_error, Error):
error = content_or_error
else:
raise MessageInitializationError()
return content, error, dst_node_id, message_type
def _check_arg_types( # pylint: disable=too-many-arguments, R0917
dst_node_id: int | None = None,
message_type: str | None = None,
content: RecordDict | None = None,
error: Error | None = None,
ttl: float | None = None,
group_id: str | None = None,
reply_to: Message | None = None,
metadata: Metadata | None = None,
) -> None:
"""Check argument types for the `Message` constructor."""
# pylint: disable=too-many-boolean-expressions
if (
(dst_node_id is None or isinstance(dst_node_id, int))
and (message_type is None or isinstance(message_type, str))
and (content is None or isinstance(content, RecordDict))
and (error is None or isinstance(error, Error))
and (ttl is None or isinstance(ttl, (int, float)))
and (group_id is None or isinstance(group_id, str))
and (reply_to is None or isinstance(reply_to, Message))
and (metadata is None or isinstance(metadata, Metadata))
):
return
raise MessageInitializationError()
def validate_message_type(message_type: str) -> bool:
"""Validate if the message type is valid.
A valid message type format must be one of the following:
- "<category>"
- "<category>.<action>"
where `category` must be one of "train", "evaluate", or "query",
and `action` must be a valid Python identifier.
"""
# Check if conforming to the format "<category>"
valid_types = {
MessageType.TRAIN,
MessageType.EVALUATE,
MessageType.QUERY,
MessageType.SYSTEM,
}
if message_type in valid_types:
return True
# Check if conforming to the format "<category>.<action>"
if message_type.count(".") != 1:
return False
category, action = message_type.split(".")
if category in valid_types and action.isidentifier():
return True
return False
def validate_legacy_message_type(message_type: str) -> bool:
"""Validate if the legacy message type is valid."""
# Backward compatibility for legacy message types
if message_type in (
MessageTypeLegacy.GET_PARAMETERS,
MessageTypeLegacy.GET_PROPERTIES,
"reconnect",
):
return True
return False