# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Message."""
from __future__ import annotations
import time
from logging import WARNING
from typing import Optional, cast
from .constant import MESSAGE_TTL_TOLERANCE
from .logger import log
from .record import RecordSet
DEFAULT_TTL = 3600
[문서]
class Error:
"""A dataclass that stores information about an error that occurred.
Parameters
----------
code : int
An identifier for the error.
reason : Optional[str]
A reason for why the error arose (e.g. an exception stack-trace)
"""
def __init__(self, code: int, reason: str | None = None) -> None:
var_dict = {
"_code": code,
"_reason": reason,
}
self.__dict__.update(var_dict)
@property
def code(self) -> int:
"""Error code."""
return cast(int, self.__dict__["_code"])
@property
def reason(self) -> str | None:
"""Reason reported about the error."""
return cast(Optional[str], self.__dict__["_reason"])
def __repr__(self) -> str:
"""Return a string representation of this instance."""
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
return f"{self.__class__.__qualname__}({view})"
def __eq__(self, other: object) -> bool:
"""Compare two instances of the class."""
if not isinstance(other, self.__class__):
raise NotImplementedError
return self.__dict__ == other.__dict__
[문서]
class Message:
"""State of your application from the viewpoint of the entity using it.
Parameters
----------
metadata : Metadata
A dataclass including information about the message to be executed.
content : Optional[RecordSet]
Holds records either sent by another entity (e.g. sent by the server-side
logic to a client, or vice-versa) or that will be sent to it.
error : Optional[Error]
A dataclass that captures information about an error that took place
when processing another message.
"""
def __init__(
self,
metadata: Metadata,
content: RecordSet | None = None,
error: Error | None = None,
) -> None:
if not (content is None) ^ (error is None):
raise ValueError("Either `content` or `error` must be set, but not both.")
metadata.created_at = time.time() # Set the message creation timestamp
var_dict = {
"_metadata": metadata,
"_content": content,
"_error": error,
}
self.__dict__.update(var_dict)
@property
def metadata(self) -> Metadata:
"""A dataclass including information about the message to be executed."""
return cast(Metadata, self.__dict__["_metadata"])
@property
def content(self) -> RecordSet:
"""The content of this message."""
if self.__dict__["_content"] is None:
raise ValueError(
"Message content is None. Use <message>.has_content() "
"to check if a message has content."
)
return cast(RecordSet, self.__dict__["_content"])
@content.setter
def content(self, value: RecordSet) -> None:
"""Set content."""
if self.__dict__["_error"] is None:
self.__dict__["_content"] = value
else:
raise ValueError("A message with an error set cannot have content.")
@property
def error(self) -> Error:
"""Error captured by this message."""
if self.__dict__["_error"] is None:
raise ValueError(
"Message error is None. Use <message>.has_error() "
"to check first if a message carries an error."
)
return cast(Error, self.__dict__["_error"])
@error.setter
def error(self, value: Error) -> None:
"""Set error."""
if self.has_content():
raise ValueError("A message with content set cannot carry an error.")
self.__dict__["_error"] = value
[문서]
def has_content(self) -> bool:
"""Return True if message has content, else False."""
return self.__dict__["_content"] is not None
[문서]
def has_error(self) -> bool:
"""Return True if message has an error, else False."""
return self.__dict__["_error"] is not None
[문서]
def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
"""Construct a reply message indicating an error happened.
Parameters
----------
error : Error
The error that was encountered.
ttl : Optional[float] (default: None)
Time-to-live for this message in seconds. If unset, it will be set based
on the remaining time for the received message before it expires. This
follows the equation:
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
Returns
-------
message : Message
A Message containing only the relevant error and metadata.
"""
# If no TTL passed, use default for message creation (will update after
# message creation)
ttl_ = DEFAULT_TTL if ttl is None else ttl
# Create reply with error
message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
if ttl is None:
# Set TTL equal to the remaining time for the received message to expire
ttl = self.metadata.ttl - (
message.metadata.created_at - self.metadata.created_at
)
message.metadata.ttl = ttl
self._limit_task_res_ttl(message)
return message
[문서]
def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
"""Create a reply to this message with specified content and TTL.
The method generates a new `Message` as a reply to this message.
It inherits 'run_id', 'src_node_id', 'dst_node_id', and 'message_type' from
this message and sets 'reply_to_message' to the ID of this message.
Parameters
----------
content : RecordSet
The content for the reply message.
ttl : Optional[float] (default: None)
Time-to-live for this message in seconds. If unset, it will be set based
on the remaining time for the received message before it expires. This
follows the equation:
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
Returns
-------
Message
A new `Message` instance representing the reply.
"""
# If no TTL passed, use default for message creation (will update after
# message creation)
ttl_ = DEFAULT_TTL if ttl is None else ttl
message = Message(
metadata=_create_reply_metadata(self, ttl_),
content=content,
)
if ttl is None:
# Set TTL equal to the remaining time for the received message to expire
ttl = self.metadata.ttl - (
message.metadata.created_at - self.metadata.created_at
)
message.metadata.ttl = ttl
self._limit_task_res_ttl(message)
return message
def __repr__(self) -> str:
"""Return a string representation of this instance."""
view = ", ".join(
[
f"{k.lstrip('_')}={v!r}"
for k, v in self.__dict__.items()
if v is not None
]
)
return f"{self.__class__.__qualname__}({view})"
def _limit_task_res_ttl(self, message: Message) -> None:
"""Limit the TaskRes TTL to not exceed the expiration time of the TaskIns it
replies to.
Parameters
----------
message : Message
The message to which the TaskRes is replying.
"""
# Calculate the maximum allowed TTL
max_allowed_ttl = (
self.metadata.created_at + self.metadata.ttl - message.metadata.created_at
)
if message.metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE:
log(
WARNING,
"The reply TTL of %.2f seconds exceeded the "
"allowed maximum of %.2f seconds. "
"The TTL has been updated to the allowed maximum.",
message.metadata.ttl,
max_allowed_ttl,
)
message.metadata.ttl = max_allowed_ttl
def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
"""Construct metadata for a reply message."""
return Metadata(
run_id=msg.metadata.run_id,
message_id="",
src_node_id=msg.metadata.dst_node_id,
dst_node_id=msg.metadata.src_node_id,
reply_to_message=msg.metadata.message_id,
group_id=msg.metadata.group_id,
ttl=ttl,
message_type=msg.metadata.message_type,
)