echostream_node.asyncio

  1from __future__ import annotations
  2
  3import asyncio
  4import threading
  5from concurrent.futures import ThreadPoolExecutor
  6from datetime import datetime, timezone
  7from functools import partial
  8from gzip import GzipFile, compress
  9from io import BytesIO
 10from os import environ
 11from signal import SIG_IGN, SIGTERM, signal
 12from typing import TYPE_CHECKING, Any, AsyncGenerator, BinaryIO, Union
 13
 14import dynamic_function_loader
 15import simplejson as json
 16from aws_error_utils import catch_aws_error
 17from gql import Client as GqlClient
 18from gql.transport.aiohttp import AIOHTTPTransport
 19from gql_appsync_cognito_authentication import AppSyncCognitoAuthentication
 20from httpx import AsyncClient as HttpxClient
 21from httpx_auth import AWS4Auth
 22
 23from .. import _GET_BULK_DATA_STORAGE_GQL, _GET_NODE_GQL, BatchItemFailures
 24from .. import BulkDataStorage as BaseBulkDataStorage
 25from .. import Edge, LambdaEvent, LambdaSqsRecords, Message, MessageType
 26from .. import Node as BaseNode
 27from .. import PresignedPost, getLogger
 28
 29if TYPE_CHECKING:
 30    from mypy_boto3_sqs.type_defs import (
 31        DeleteMessageBatchRequestEntryTypeDef,
 32        SendMessageBatchRequestEntryTypeDef,
 33    )
 34else:
 35    DeleteMessageBatchRequestEntryTypeDef = dict
 36    SendMessageBatchRequestEntryTypeDef = dict
 37
 38
 39class _AuditRecordQueue(asyncio.Queue):
 40    def __init__(self, message_type: MessageType, node: Node) -> None:
 41        super().__init__()
 42
 43        async def sender() -> None:
 44            cancelled = asyncio.Event()
 45
 46            async def batcher() -> AsyncGenerator[
 47                list[dict],
 48                None,
 49            ]:
 50                batch: list[dict] = list()
 51                while not (cancelled.is_set() and self.empty()):
 52                    try:
 53                        try:
 54                            batch.append(
 55                                await asyncio.wait_for(self.get(), timeout=node.timeout)
 56                            )
 57                        except asyncio.TimeoutError:
 58                            if batch:
 59                                yield batch
 60                                batch = list()
 61                        else:
 62                            if len(batch) == 500:
 63                                yield batch
 64                                batch = list()
 65                    except asyncio.CancelledError:
 66                        cancelled.set()
 67                        if batch:
 68                            yield batch
 69                            batch = list()
 70
 71            async with HttpxClient() as client:
 72                async for batch in batcher():
 73                    credentials = (
 74                        node._session.get_credentials().get_frozen_credentials()
 75                    )
 76                    auth = AWS4Auth(
 77                        access_id=credentials.access_key,
 78                        region=node._session.region_name,
 79                        secret_key=credentials.secret_key,
 80                        service="lambda",
 81                        security_token=credentials.token,
 82                    )
 83                    url = node._audit_records_endpoint
 84                    post_args = dict(
 85                        auth=auth,
 86                        url=f"{url}{'' if url.endswith('/') else '/'}{node.name}",
 87                    )
 88                    body = dict(
 89                        messageType=message_type.name,
 90                        auditRecords=batch,
 91                    )
 92                    if len(batch) <= 10:
 93                        post_args["json"] = body
 94                    else:
 95                        post_args["content"] = compress(
 96                            json.dumps(body, separators=(",", ":")).encode()
 97                        )
 98                        post_args["headers"] = {
 99                            "Content-Encoding": "gzip",
100                            "Content-Type": "application/json",
101                        }
102                    try:
103                        response = await client.post(**post_args)
104                        response.raise_for_status()
105                        await response.aclose()
106                    except asyncio.CancelledError:
107                        cancelled.set()
108                    except Exception:
109                        getLogger().exception("Error creating audit records")
110                    finally:
111                        for _ in range(len(batch)):
112                            self.task_done()
113
114        asyncio.create_task(sender(), name=f"AuditRecordsSender")
115
116    async def get(self) -> dict:
117        return await super().get()
118
119
120class _BulkDataStorage(BaseBulkDataStorage):
121    def __init__(
122        self,
123        bulk_data_storage: dict[str, Union[str, PresignedPost]],
124        client: HttpxClient,
125    ) -> None:
126        super().__init__(bulk_data_storage)
127        self.__client = client
128
129    async def handle_bulk_data(self, data: Union[bytearray, bytes, BinaryIO]) -> str:
130        if isinstance(data, BinaryIO):
131            data = data.read()
132        with BytesIO() as buffer:
133            with GzipFile(mode="wb", fileobj=buffer) as gzf:
134                gzf.write(data)
135            buffer.seek(0)
136            response = await self.__client.post(
137                self.presigned_post.url,
138                data=self.presigned_post.fields,
139                files=dict(file=("bulk_data", buffer)),
140            )
141            response.raise_for_status()
142            await response.aclose()
143        return self.presigned_get
144
145
146class _BulkDataStorageQueue(asyncio.Queue):
147    def __init__(self, node: Node) -> None:
148        super().__init__()
149        self.__fill = asyncio.Event()
150
151        async def filler() -> None:
152            async with HttpxClient() as client:
153                cancelled = asyncio.Event()
154                while not cancelled.is_set():
155                    bulk_data_storages: list[dict] = list()
156                    try:
157                        await self.__fill.wait()
158                        async with node._lock:
159                            async with node._gql_client as session:
160                                bulk_data_storages = (
161                                    await session.execute(
162                                        _GET_BULK_DATA_STORAGE_GQL,
163                                        variable_values={
164                                            "tenant": node.tenant,
165                                            "useAccelerationEndpoint": node.bulk_data_acceleration,
166                                        },
167                                    )
168                                )["GetBulkDataStorage"]
169                    except asyncio.CancelledError:
170                        cancelled.set()
171                    except Exception:
172                        getLogger().exception("Error getting bulk data storage")
173                    for bulk_data_storage in bulk_data_storages:
174                        self.put_nowait(_BulkDataStorage(bulk_data_storage, client))
175                    self.__fill.clear()
176
177        asyncio.create_task(filler(), name="BulkDataStorageQueueFiller")
178
179    async def get(self) -> _BulkDataStorage:
180        if self.qsize() < 20:
181            self.__fill.set()
182        bulk_data_storage: _BulkDataStorage = await super().get()
183        return bulk_data_storage if not bulk_data_storage.expired else await self.get()
184
185
186class _TargetMessageQueue(asyncio.Queue):
187    def __init__(self, node: Node, edge: Edge) -> None:
188        super().__init__()
189
190        async def sender() -> None:
191            cancelled = asyncio.Event()
192
193            async def batcher() -> (
194                AsyncGenerator[list[SendMessageBatchRequestEntryTypeDef], None]
195            ):
196                batch: list[SendMessageBatchRequestEntryTypeDef] = list()
197                batch_length = 0
198                id = 0
199                while not (cancelled.is_set() and self.empty()):
200                    try:
201                        try:
202                            message = await asyncio.wait_for(
203                                self.get(), timeout=node.timeout
204                            )
205                        except asyncio.TimeoutError:
206                            if batch:
207                                yield batch
208                            batch = list()
209                            batch_length = 0
210                            id = 0
211                        else:
212                            if batch_length + len(message) > 262144:
213                                yield batch
214                                batch = list()
215                                batch_length = 0
216                                id = 0
217                            batch.append(
218                                SendMessageBatchRequestEntryTypeDef(
219                                    Id=str(id), **message._sqs_message(node)
220                                )
221                            )
222                            if len(batch) == 10:
223                                yield batch
224                                batch = list()
225                                batch_length = 0
226                                id = 0
227                            id += 1
228                            batch_length += len(message)
229                    except asyncio.CancelledError:
230                        cancelled.set()
231                        if batch:
232                            yield batch
233                        batch = list()
234                        batch_length = 0
235                        id = 0
236
237            async for entries in batcher():
238                try:
239                    response = await asyncio.get_running_loop().run_in_executor(
240                        None,
241                        partial(
242                            node._sqs_client.send_message_batch,
243                            Entries=entries,
244                            QueueUrl=edge.queue,
245                        ),
246                    )
247                except asyncio.CancelledError:
248                    cancelled.set()
249                except Exception:
250                    getLogger().exception(f"Error sending messages to {edge.name}")
251                finally:
252                    for failed in response.get("Failed", list()):
253                        id = failed.pop("Id")
254                        getLogger().error(
255                            f"Unable to send message {entries[id]} to {edge.name}, reason {failed}"
256                        )
257                    for _ in range(len(entries)):
258                        self.task_done()
259
260        asyncio.create_task(sender(), name=f"TargetMessageSender({edge.name})")
261
262    async def get(self) -> Message:
263        return await super().get()
264
265
266class Node(BaseNode):
267    """
268    Base class for all implemented asyncio Nodes.
269
270    Nodes of this class must be instantiated outside of
271    the asyncio event loop.
272    """
273
274    def __init__(
275        self,
276        *,
277        appsync_endpoint: str = None,
278        bulk_data_acceleration: bool = False,
279        client_id: str = None,
280        name: str = None,
281        password: str = None,
282        tenant: str = None,
283        timeout: float = None,
284        user_pool_id: str = None,
285        username: str = None,
286    ) -> None:
287        super().__init__(
288            appsync_endpoint=appsync_endpoint,
289            bulk_data_acceleration=bulk_data_acceleration,
290            client_id=client_id,
291            name=name,
292            password=password,
293            tenant=tenant,
294            timeout=timeout,
295            user_pool_id=user_pool_id,
296            username=username,
297        )
298        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
299        self.__bulk_data_storage_queue: _BulkDataStorageQueue = None
300        self.__gql_client = GqlClient(
301            fetch_schema_from_transport=True,
302            transport=AIOHTTPTransport(
303                auth=AppSyncCognitoAuthentication(self.__cognito),
304                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
305            ),
306        )
307        self.__lock: asyncio.Lock = None
308        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
309
310    @property
311    def _gql_client(self) -> GqlClient:
312        return self.__gql_client
313
314    @property
315    def _lock(self) -> asyncio.Lock:
316        return self.__lock
317
318    def audit_message(
319        self,
320        /,
321        message: Message,
322        *,
323        extra_attributes: dict[str, Any] = None,
324        source: str = None,
325    ) -> None:
326        """
327        Audits the provided message. If extra_attibutes is
328        supplied, they will be added to the message's audit
329        dict. If source is provided, it will be recorded in
330        the audit.
331        """
332        if self.stopped:
333            raise RuntimeError(f"{self.name} is stopped")
334        if not self.audit:
335            return
336        extra_attributes = extra_attributes or dict()
337        message_type = message.message_type
338        record = dict(
339            datetime=datetime.now(timezone.utc).isoformat(),
340            previousTrackingIds=message.previous_tracking_ids,
341            sourceNode=source,
342            trackingId=message.tracking_id,
343        )
344        if attributes := (
345            message_type.auditor(message=message.body) | extra_attributes
346        ):
347            record["attributes"] = attributes
348        try:
349            self.__audit_records_queues[message_type.name].put_nowait(record)
350        except KeyError:
351            raise ValueError(f"Unrecognized message type {message_type.name}")
352
353    def audit_messages(
354        self,
355        /,
356        messages: list[Message],
357        *,
358        extra_attributes: list[dict[str, Any]] = None,
359        source: str = None,
360    ) -> None:
361        """
362        Audits the provided messages. If extra_attibutes is
363        supplied they will be added to the respective message's audit
364        dict and they must have the same count as messages.
365        If source is provided, it will be recorded in the audit.
366        """
367        if extra_attributes and len(extra_attributes) != len(messages):
368            raise ValueError(
369                "messages and extra_attributes must have the same number of items"
370            )
371        for message, attributes in zip(messages, extra_attributes):
372            self.audit_message(message, extra_attributes=attributes, source=source)
373
374    async def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
375        """
376        Posts data as bulk data and returns a GET URL for data retrieval.
377        Normally this returned URL will be used as a "ticket" in messages
378        that require bulk data.
379        """
380        return await (await self.__bulk_data_storage_queue.get()).handle_bulk_data(data)
381
382    async def handle_received_message(self, *, message: Message, source: str) -> None:
383        """
384        Callback called when a message is received. Subclasses that receive messages
385        should override this method.
386        """
387        pass
388
389    async def join(self) -> None:
390        """
391        Joins the calling thread with this Node. Will block until all
392        join conditions are satified.
393        """
394        await asyncio.gather(
395            *[
396                target_message_queue.join()
397                for target_message_queue in self.__target_message_queues.values()
398            ],
399            *[
400                audit_records_queue.join()
401                for audit_records_queue in self.__audit_records_queues.values()
402            ],
403        )
404
405    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
406        """
407        Send the message to the specified targets. If no targets are specified
408        the message will be sent to all targets.
409        """
410        self.send_messages([message], targets=targets)
411
412    def send_messages(
413        self, /, messages: list[Message], *, targets: set[Edge] = None
414    ) -> None:
415        """
416        Send the messages to the specified targets. If no targets are specified
417        the messages will be sent to all targets.
418        """
419        if self.stopped:
420            raise RuntimeError(f"{self.name} is stopped")
421        if messages:
422            for target in targets or self.targets:
423                if target_message_queue := self.__target_message_queues.get(
424                    target.name
425                ):
426                    for message in messages:
427                        target_message_queue.put_nowait(message)
428                else:
429                    getLogger().warning(f"Target {target.name} does not exist")
430
431    async def start(self) -> None:
432        """
433        Starts this Node. Must be called prior to any other usage.
434        """
435        getLogger().info(f"Starting Node {self.name}")
436        self.__lock = asyncio.Lock()
437        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
438        async with self.__lock:
439            async with self._gql_client as session:
440                data: dict[str, Union[str, dict]] = (
441                    await session.execute(
442                        _GET_NODE_GQL,
443                        variable_values=dict(name=self.name, tenant=self.tenant),
444                    )
445                )["GetNode"]
446        self._audit = data["tenant"].get("audit") or False
447        self.config = (
448            json.loads(data["tenant"].get("config") or "{}")
449            | json.loads((data.get("app") or dict()).get("config") or "{}")
450            | json.loads(data.get("config") or "{}")
451        )
452        self._stopped = data.get("stopped")
453        if receive_message_type := data.get("receiveMessageType"):
454            self._receive_message_type = MessageType(
455                auditor=dynamic_function_loader.load(receive_message_type["auditor"]),
456                name=receive_message_type["name"],
457            )
458            if not self.stopped and self.audit:
459                self.__audit_records_queues[receive_message_type["name"]] = (
460                    _AuditRecordQueue(self.receive_message_type, self)
461                )
462        if send_message_type := data.get("sendMessageType"):
463            self._send_message_type = MessageType(
464                auditor=dynamic_function_loader.load(send_message_type["auditor"]),
465                name=send_message_type["name"],
466            )
467            if not self.stopped and self.audit:
468                self.__audit_records_queues[send_message_type["name"]] = (
469                    _AuditRecordQueue(self.send_message_type, self)
470                )
471        if self.node_type == "AppChangeReceiverNode":
472            if edge := data.get("receiveEdge"):
473                self._sources = {Edge(name=edge["source"]["name"], queue=edge["queue"])}
474            else:
475                self._sources = set()
476        else:
477            self._sources = {
478                Edge(name=edge["source"]["name"], queue=edge["queue"])
479                for edge in (data.get("receiveEdges") or list())
480            }
481        self._targets = {
482            Edge(name=edge["target"]["name"], queue=edge["queue"])
483            for edge in (data.get("sendEdges") or list())
484        }
485        if not self.stopped:
486            self.__target_message_queues = {
487                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
488            }
489
490
491class _DeleteMessageQueue(asyncio.Queue):
492    def __init__(self, edge: Edge, node: AppNode) -> None:
493        super().__init__()
494
495        async def deleter() -> None:
496            cancelled = asyncio.Event()
497
498            async def batcher() -> AsyncGenerator[
499                list[str],
500                None,
501            ]:
502                batch: list[str] = list()
503                while not (cancelled.is_set() and self.empty()):
504                    try:
505                        try:
506                            batch.append(
507                                await asyncio.wait_for(self.get(), timeout=node.timeout)
508                            )
509                        except asyncio.TimeoutError:
510                            if batch:
511                                yield batch
512                                batch = list()
513                        else:
514                            if len(batch) == 10:
515                                yield batch
516                                batch = list()
517                    except asyncio.CancelledError:
518                        cancelled.set()
519                        if batch:
520                            yield batch
521                            batch = list()
522
523            async for receipt_handles in batcher():
524                try:
525                    response = await asyncio.get_running_loop().run_in_executor(
526                        None,
527                        partial(
528                            node._sqs_client.delete_message_batch,
529                            Entries=[
530                                DeleteMessageBatchRequestEntryTypeDef(
531                                    Id=str(id), ReceiptHandle=receipt_handle
532                                )
533                                for id, receipt_handle in enumerate(receipt_handles)
534                            ],
535                            QueueUrl=edge.queue,
536                        ),
537                    )
538                except asyncio.CancelledError:
539                    cancelled.set()
540                except Exception:
541                    getLogger().exception(f"Error deleting messages from {edge.name}")
542                finally:
543                    for failed in response.get("Failed", list()):
544                        id = failed.pop("Id")
545                        getLogger().error(
546                            f"Unable to delete message {receipt_handles[id]} from {edge.name}, reason {failed}"
547                        )
548                    for _ in range(len(receipt_handles)):
549                        self.task_done()
550
551        asyncio.create_task(deleter(), name=f"SourceMessageDeleter({edge.name})")
552
553    async def get(self) -> str:
554        return await super().get()
555
556
557class _SourceMessageReceiver:
558    def __init__(self, edge: Edge, node: AppNode) -> None:
559        self.__delete_message_queue = _DeleteMessageQueue(edge, node)
560
561        async def handle_received_message(
562            message: Message, receipt_handle: str
563        ) -> bool:
564            try:
565                await node.handle_received_message(message=message, source=edge.name)
566            except asyncio.CancelledError:
567                raise
568            except Exception:
569                getLogger().exception(
570                    f"Error handling recevied message for {edge.name}"
571                )
572                return False
573            else:
574                self.__delete_message_queue.put_nowait(receipt_handle)
575            return True
576
577        async def receive() -> None:
578            getLogger().info(f"Receiving messages from {edge.name}")
579            while True:
580                try:
581                    response = await asyncio.get_running_loop().run_in_executor(
582                        None,
583                        partial(
584                            node._sqs_client.receive_message,
585                            AttributeNames=["All"],
586                            MaxNumberOfMessages=10,
587                            MessageAttributeNames=["All"],
588                            QueueUrl=edge.queue,
589                            WaitTimeSeconds=20,
590                        ),
591                    )
592                except asyncio.CancelledError:
593                    raise
594                except catch_aws_error("AWS.SimpleQueueService.NonExistentQueue"):
595                    getLogger().warning(f"Queue {edge.queue} does not exist, exiting")
596                    break
597                except Exception:
598                    getLogger().exception(
599                        f"Error receiving messages from {edge.name}, retrying"
600                    )
601                    await asyncio.sleep(20)
602                else:
603                    if not (sqs_messages := response.get("Messages")):
604                        continue
605                    getLogger().info(f"Received {len(sqs_messages)} from {edge.name}")
606
607                    message_handlers = [
608                        handle_received_message(
609                            Message(
610                                body=sqs_message["Body"],
611                                group_id=sqs_message["Attributes"]["MessageGroupId"],
612                                message_type=node.receive_message_type,
613                                tracking_id=sqs_message["MessageAttributes"]
614                                .get("trackingId", {})
615                                .get("StringValue"),
616                                previous_tracking_ids=sqs_message["MessageAttributes"]
617                                .get("prevTrackingIds", {})
618                                .get("StringValue"),
619                            ),
620                            sqs_message["ReceiptHandle"],
621                        )
622                        for sqs_message in sqs_messages
623                    ]
624
625                    async def handle_received_messages() -> None:
626                        if node._concurrent_processing:
627                            await asyncio.gather(*message_handlers)
628                        else:
629                            for message_handler in message_handlers:
630                                if not await message_handler:
631                                    break
632
633                    asyncio.create_task(
634                        handle_received_messages(), name="handle_received_messages"
635                    )
636
637            getLogger().info(f"Stopping receiving messages from {edge.name}")
638
639        self.__task = asyncio.create_task(
640            receive(), name=f"SourceMessageReceiver({edge.name})"
641        )
642
643    async def join(self) -> None:
644        await asyncio.wait([self.__task])
645        await self.__delete_message_queue.join()
646
647
648class AppNode(Node):
649    def __init__(
650        self,
651        *,
652        appsync_endpoint: str = None,
653        bulk_data_acceleration: bool = False,
654        client_id: str = None,
655        concurrent_processing: bool = False,
656        name: str = None,
657        password: str = None,
658        tenant: str = None,
659        timeout: float = None,
660        user_pool_id: str = None,
661        username: str = None,
662    ) -> None:
663        super().__init__(
664            appsync_endpoint=appsync_endpoint,
665            bulk_data_acceleration=bulk_data_acceleration,
666            client_id=client_id,
667            name=name,
668            password=password,
669            tenant=tenant,
670            timeout=timeout,
671            user_pool_id=user_pool_id,
672            username=username,
673        )
674        self.__concurrent_processing = concurrent_processing
675        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
676
677    @property
678    def _concurrent_processing(self) -> bool:
679        return self.__concurrent_processing
680
681    async def join(self) -> None:
682        """
683        Joins the calling thread with this Node. Will block until all
684        join conditions are satified.
685        """
686        await asyncio.gather(
687            *[
688                source_message_receiver.join()
689                for source_message_receiver in self.__source_message_receivers
690            ]
691        )
692        await super().join()
693
694    async def start(self) -> None:
695        """
696        Starts this Node. Must be called prior to any other usage.
697        """
698        await super().start()
699        if not self.stopped:
700            self.__source_message_receivers = [
701                _SourceMessageReceiver(edge, self) for edge in self._sources
702            ]
703
704    async def start_and_run_forever(self) -> None:
705        """Will start this Node and run until the containing Task is cancelled"""
706        await self.start()
707        await self.join()
708
709
710class LambdaNode(Node):
711    def __init__(
712        self,
713        *,
714        appsync_endpoint: str = None,
715        bulk_data_acceleration: bool = False,
716        client_id: str = None,
717        concurrent_processing: bool = False,
718        name: str = None,
719        password: str = None,
720        report_batch_item_failures: bool = False,
721        tenant: str = None,
722        timeout: float = None,
723        user_pool_id: str = None,
724        username: str = None,
725    ) -> None:
726        super().__init__(
727            appsync_endpoint=appsync_endpoint,
728            bulk_data_acceleration=bulk_data_acceleration,
729            client_id=client_id,
730            name=name,
731            password=password,
732            tenant=tenant,
733            timeout=timeout or 0.01,
734            user_pool_id=user_pool_id,
735            username=username,
736        )
737        self.__concurrent_processing = concurrent_processing
738        self.__loop = self._create_event_loop()
739        self.__queue_name_to_source: dict[str, str] = None
740        self.__report_batch_item_failures = report_batch_item_failures
741
742        # Set up the asyncio loop
743        signal(SIGTERM, self._shutdown_handler)
744
745        self.__started = threading.Event()
746        self.__loop.create_task(self.start())
747
748        # Run the event loop in a seperate thread, or else we will block the main
749        # Lambda execution!
750        threading.Thread(name="event_loop", target=self.__run_event_loop).start()
751
752        # Wait until the started event is set before returning control to
753        # Lambda
754        self.__started.wait()
755
756    def __run_event_loop(self) -> None:
757        getLogger().info("Starting event loop")
758        asyncio.set_event_loop(self.__loop)
759
760        pending_exception_to_raise: Exception = None
761
762        def exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None:
763            nonlocal pending_exception_to_raise
764            pending_exception_to_raise = context.get("exception")
765            getLogger().error(
766                "Unhandled exception; stopping loop: %r",
767                context.get("message"),
768                exc_info=pending_exception_to_raise,
769            )
770            loop.stop()
771
772        self.__loop.set_exception_handler(exception_handler)
773        executor = ThreadPoolExecutor()
774        self.__loop.set_default_executor(executor)
775
776        self.__loop.run_forever()
777        getLogger().info("Entering shutdown phase")
778        getLogger().info("Cancelling pending tasks")
779        if tasks := asyncio.all_tasks(self.__loop):
780            for task in tasks:
781                getLogger().debug(f"Cancelling task: {task}")
782                task.cancel()
783            getLogger().info("Running pending tasks till complete")
784            self.__loop.run_until_complete(
785                asyncio.gather(*tasks, return_exceptions=True)
786            )
787        getLogger().info("Waiting for executor shutdown")
788        executor.shutdown(wait=True)
789        getLogger().info("Shutting down async generators")
790        self.__loop.run_until_complete(self.__loop.shutdown_asyncgens())
791        getLogger().info("Closing the loop.")
792        self.__loop.close()
793        getLogger().info("Loop is closed")
794        if pending_exception_to_raise:
795            getLogger().info("Reraising unhandled exception")
796            raise pending_exception_to_raise
797
798    def _create_event_loop(self) -> asyncio.AbstractEventLoop:
799        return asyncio.new_event_loop()
800
801    def _get_source(self, queue_arn: str) -> str:
802        return self.__queue_name_to_source[queue_arn.split(":")[-1:][0]]
803
804    async def _handle_event(self, event: LambdaEvent) -> BatchItemFailures:
805        """
806        Handles the AWS Lambda event passed into the containing
807        AWS Lambda function during invocation.
808
809        This is intended to be the only called method in your
810        containing AWS Lambda function.
811        """
812        records: LambdaSqsRecords = None
813        if not (records := event.get("Records")):
814            getLogger().warning(f"No Records found in event {event}")
815            return
816
817        source = self._get_source(records[0]["eventSourceARN"])
818        getLogger().info(f"Received {len(records)} messages from {source}")
819        batch_item_failures: list[str] = (
820            [record["messageId"] for record in records]
821            if self.__report_batch_item_failures
822            else None
823        )
824
825        async def handle_received_message(message: Message, message_id: str) -> None:
826            try:
827                await self.handle_received_message(message=message, source=source)
828            except asyncio.CancelledError:
829                raise
830            except Exception:
831                if not self.__report_batch_item_failures:
832                    raise
833                getLogger().exception(f"Error handling recevied message for {source}")
834            else:
835                if self.__report_batch_item_failures:
836                    batch_item_failures.remove(message_id)
837
838        message_handlers = [
839            handle_received_message(
840                Message(
841                    body=record["body"],
842                    group_id=record["attributes"]["MessageGroupId"],
843                    message_type=self.receive_message_type,
844                    previous_tracking_ids=record["messageAttributes"]
845                    .get("prevTrackingIds", {})
846                    .get("stringValue"),
847                    tracking_id=record["messageAttributes"]
848                    .get("trackingId", {})
849                    .get("stringValue"),
850                ),
851                record["messageId"],
852            )
853            for record in records
854        ]
855
856        if self.__concurrent_processing:
857            await asyncio.gather(*message_handlers)
858        else:
859            for message_handler in message_handlers:
860                await message_handler
861        await self.join()
862        if self.__report_batch_item_failures and batch_item_failures:
863            return dict(
864                batchItemFailures=[
865                    dict(itemIdentifier=message_id)
866                    for message_id in batch_item_failures
867                ]
868            )
869
870    def _shutdown_handler(self, signum: int, frame: object) -> None:
871        signal(SIGTERM, SIG_IGN)
872        getLogger().warning("Received SIGTERM, stopping the loop")
873        self.__loop.stop()
874
875    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
876        return asyncio.run_coroutine_threadsafe(
877            self._handle_event(event), self.__loop
878        ).result()
879
880    async def start(self) -> None:
881        await super().start()
882        self.__queue_name_to_source = {
883            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
884        }
885        self.__started.set()
class Node(echostream_node.Node):
267class Node(BaseNode):
268    """
269    Base class for all implemented asyncio Nodes.
270
271    Nodes of this class must be instantiated outside of
272    the asyncio event loop.
273    """
274
275    def __init__(
276        self,
277        *,
278        appsync_endpoint: str = None,
279        bulk_data_acceleration: bool = False,
280        client_id: str = None,
281        name: str = None,
282        password: str = None,
283        tenant: str = None,
284        timeout: float = None,
285        user_pool_id: str = None,
286        username: str = None,
287    ) -> None:
288        super().__init__(
289            appsync_endpoint=appsync_endpoint,
290            bulk_data_acceleration=bulk_data_acceleration,
291            client_id=client_id,
292            name=name,
293            password=password,
294            tenant=tenant,
295            timeout=timeout,
296            user_pool_id=user_pool_id,
297            username=username,
298        )
299        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
300        self.__bulk_data_storage_queue: _BulkDataStorageQueue = None
301        self.__gql_client = GqlClient(
302            fetch_schema_from_transport=True,
303            transport=AIOHTTPTransport(
304                auth=AppSyncCognitoAuthentication(self.__cognito),
305                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
306            ),
307        )
308        self.__lock: asyncio.Lock = None
309        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
310
311    @property
312    def _gql_client(self) -> GqlClient:
313        return self.__gql_client
314
315    @property
316    def _lock(self) -> asyncio.Lock:
317        return self.__lock
318
319    def audit_message(
320        self,
321        /,
322        message: Message,
323        *,
324        extra_attributes: dict[str, Any] = None,
325        source: str = None,
326    ) -> None:
327        """
328        Audits the provided message. If extra_attibutes is
329        supplied, they will be added to the message's audit
330        dict. If source is provided, it will be recorded in
331        the audit.
332        """
333        if self.stopped:
334            raise RuntimeError(f"{self.name} is stopped")
335        if not self.audit:
336            return
337        extra_attributes = extra_attributes or dict()
338        message_type = message.message_type
339        record = dict(
340            datetime=datetime.now(timezone.utc).isoformat(),
341            previousTrackingIds=message.previous_tracking_ids,
342            sourceNode=source,
343            trackingId=message.tracking_id,
344        )
345        if attributes := (
346            message_type.auditor(message=message.body) | extra_attributes
347        ):
348            record["attributes"] = attributes
349        try:
350            self.__audit_records_queues[message_type.name].put_nowait(record)
351        except KeyError:
352            raise ValueError(f"Unrecognized message type {message_type.name}")
353
354    def audit_messages(
355        self,
356        /,
357        messages: list[Message],
358        *,
359        extra_attributes: list[dict[str, Any]] = None,
360        source: str = None,
361    ) -> None:
362        """
363        Audits the provided messages. If extra_attibutes is
364        supplied they will be added to the respective message's audit
365        dict and they must have the same count as messages.
366        If source is provided, it will be recorded in the audit.
367        """
368        if extra_attributes and len(extra_attributes) != len(messages):
369            raise ValueError(
370                "messages and extra_attributes must have the same number of items"
371            )
372        for message, attributes in zip(messages, extra_attributes):
373            self.audit_message(message, extra_attributes=attributes, source=source)
374
375    async def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
376        """
377        Posts data as bulk data and returns a GET URL for data retrieval.
378        Normally this returned URL will be used as a "ticket" in messages
379        that require bulk data.
380        """
381        return await (await self.__bulk_data_storage_queue.get()).handle_bulk_data(data)
382
383    async def handle_received_message(self, *, message: Message, source: str) -> None:
384        """
385        Callback called when a message is received. Subclasses that receive messages
386        should override this method.
387        """
388        pass
389
390    async def join(self) -> None:
391        """
392        Joins the calling thread with this Node. Will block until all
393        join conditions are satified.
394        """
395        await asyncio.gather(
396            *[
397                target_message_queue.join()
398                for target_message_queue in self.__target_message_queues.values()
399            ],
400            *[
401                audit_records_queue.join()
402                for audit_records_queue in self.__audit_records_queues.values()
403            ],
404        )
405
406    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
407        """
408        Send the message to the specified targets. If no targets are specified
409        the message will be sent to all targets.
410        """
411        self.send_messages([message], targets=targets)
412
413    def send_messages(
414        self, /, messages: list[Message], *, targets: set[Edge] = None
415    ) -> None:
416        """
417        Send the messages to the specified targets. If no targets are specified
418        the messages will be sent to all targets.
419        """
420        if self.stopped:
421            raise RuntimeError(f"{self.name} is stopped")
422        if messages:
423            for target in targets or self.targets:
424                if target_message_queue := self.__target_message_queues.get(
425                    target.name
426                ):
427                    for message in messages:
428                        target_message_queue.put_nowait(message)
429                else:
430                    getLogger().warning(f"Target {target.name} does not exist")
431
432    async def start(self) -> None:
433        """
434        Starts this Node. Must be called prior to any other usage.
435        """
436        getLogger().info(f"Starting Node {self.name}")
437        self.__lock = asyncio.Lock()
438        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
439        async with self.__lock:
440            async with self._gql_client as session:
441                data: dict[str, Union[str, dict]] = (
442                    await session.execute(
443                        _GET_NODE_GQL,
444                        variable_values=dict(name=self.name, tenant=self.tenant),
445                    )
446                )["GetNode"]
447        self._audit = data["tenant"].get("audit") or False
448        self.config = (
449            json.loads(data["tenant"].get("config") or "{}")
450            | json.loads((data.get("app") or dict()).get("config") or "{}")
451            | json.loads(data.get("config") or "{}")
452        )
453        self._stopped = data.get("stopped")
454        if receive_message_type := data.get("receiveMessageType"):
455            self._receive_message_type = MessageType(
456                auditor=dynamic_function_loader.load(receive_message_type["auditor"]),
457                name=receive_message_type["name"],
458            )
459            if not self.stopped and self.audit:
460                self.__audit_records_queues[receive_message_type["name"]] = (
461                    _AuditRecordQueue(self.receive_message_type, self)
462                )
463        if send_message_type := data.get("sendMessageType"):
464            self._send_message_type = MessageType(
465                auditor=dynamic_function_loader.load(send_message_type["auditor"]),
466                name=send_message_type["name"],
467            )
468            if not self.stopped and self.audit:
469                self.__audit_records_queues[send_message_type["name"]] = (
470                    _AuditRecordQueue(self.send_message_type, self)
471                )
472        if self.node_type == "AppChangeReceiverNode":
473            if edge := data.get("receiveEdge"):
474                self._sources = {Edge(name=edge["source"]["name"], queue=edge["queue"])}
475            else:
476                self._sources = set()
477        else:
478            self._sources = {
479                Edge(name=edge["source"]["name"], queue=edge["queue"])
480                for edge in (data.get("receiveEdges") or list())
481            }
482        self._targets = {
483            Edge(name=edge["target"]["name"], queue=edge["queue"])
484            for edge in (data.get("sendEdges") or list())
485        }
486        if not self.stopped:
487            self.__target_message_queues = {
488                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
489            }

Base class for all implemented asyncio Nodes.

Nodes of this class must be instantiated outside of the asyncio event loop.

Node( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, name: str = None, password: str = None, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
275    def __init__(
276        self,
277        *,
278        appsync_endpoint: str = None,
279        bulk_data_acceleration: bool = False,
280        client_id: str = None,
281        name: str = None,
282        password: str = None,
283        tenant: str = None,
284        timeout: float = None,
285        user_pool_id: str = None,
286        username: str = None,
287    ) -> None:
288        super().__init__(
289            appsync_endpoint=appsync_endpoint,
290            bulk_data_acceleration=bulk_data_acceleration,
291            client_id=client_id,
292            name=name,
293            password=password,
294            tenant=tenant,
295            timeout=timeout,
296            user_pool_id=user_pool_id,
297            username=username,
298        )
299        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
300        self.__bulk_data_storage_queue: _BulkDataStorageQueue = None
301        self.__gql_client = GqlClient(
302            fetch_schema_from_transport=True,
303            transport=AIOHTTPTransport(
304                auth=AppSyncCognitoAuthentication(self.__cognito),
305                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
306            ),
307        )
308        self.__lock: asyncio.Lock = None
309        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
def audit_message( self, /, message: echostream_node.Message, *, extra_attributes: dict[str, typing.Any] = None, source: str = None) -> None:
319    def audit_message(
320        self,
321        /,
322        message: Message,
323        *,
324        extra_attributes: dict[str, Any] = None,
325        source: str = None,
326    ) -> None:
327        """
328        Audits the provided message. If extra_attibutes is
329        supplied, they will be added to the message's audit
330        dict. If source is provided, it will be recorded in
331        the audit.
332        """
333        if self.stopped:
334            raise RuntimeError(f"{self.name} is stopped")
335        if not self.audit:
336            return
337        extra_attributes = extra_attributes or dict()
338        message_type = message.message_type
339        record = dict(
340            datetime=datetime.now(timezone.utc).isoformat(),
341            previousTrackingIds=message.previous_tracking_ids,
342            sourceNode=source,
343            trackingId=message.tracking_id,
344        )
345        if attributes := (
346            message_type.auditor(message=message.body) | extra_attributes
347        ):
348            record["attributes"] = attributes
349        try:
350            self.__audit_records_queues[message_type.name].put_nowait(record)
351        except KeyError:
352            raise ValueError(f"Unrecognized message type {message_type.name}")

Audits the provided message. If extra_attibutes is supplied, they will be added to the message's audit dict. If source is provided, it will be recorded in the audit.

def audit_messages( self, /, messages: list[echostream_node.Message], *, extra_attributes: list[dict[str, typing.Any]] = None, source: str = None) -> None:
354    def audit_messages(
355        self,
356        /,
357        messages: list[Message],
358        *,
359        extra_attributes: list[dict[str, Any]] = None,
360        source: str = None,
361    ) -> None:
362        """
363        Audits the provided messages. If extra_attibutes is
364        supplied they will be added to the respective message's audit
365        dict and they must have the same count as messages.
366        If source is provided, it will be recorded in the audit.
367        """
368        if extra_attributes and len(extra_attributes) != len(messages):
369            raise ValueError(
370                "messages and extra_attributes must have the same number of items"
371            )
372        for message, attributes in zip(messages, extra_attributes):
373            self.audit_message(message, extra_attributes=attributes, source=source)

Audits the provided messages. If extra_attibutes is supplied they will be added to the respective message's audit dict and they must have the same count as messages. If source is provided, it will be recorded in the audit.

async def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
375    async def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
376        """
377        Posts data as bulk data and returns a GET URL for data retrieval.
378        Normally this returned URL will be used as a "ticket" in messages
379        that require bulk data.
380        """
381        return await (await self.__bulk_data_storage_queue.get()).handle_bulk_data(data)

Posts data as bulk data and returns a GET URL for data retrieval. Normally this returned URL will be used as a "ticket" in messages that require bulk data.

async def handle_received_message(self, *, message: echostream_node.Message, source: str) -> None:
383    async def handle_received_message(self, *, message: Message, source: str) -> None:
384        """
385        Callback called when a message is received. Subclasses that receive messages
386        should override this method.
387        """
388        pass

Callback called when a message is received. Subclasses that receive messages should override this method.

async def join(self) -> None:
390    async def join(self) -> None:
391        """
392        Joins the calling thread with this Node. Will block until all
393        join conditions are satified.
394        """
395        await asyncio.gather(
396            *[
397                target_message_queue.join()
398                for target_message_queue in self.__target_message_queues.values()
399            ],
400            *[
401                audit_records_queue.join()
402                for audit_records_queue in self.__audit_records_queues.values()
403            ],
404        )

Joins the calling thread with this Node. Will block until all join conditions are satified.

def send_message( self, /, message: echostream_node.Message, *, targets: set[echostream_node.Edge] = None) -> None:
406    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
407        """
408        Send the message to the specified targets. If no targets are specified
409        the message will be sent to all targets.
410        """
411        self.send_messages([message], targets=targets)

Send the message to the specified targets. If no targets are specified the message will be sent to all targets.

def send_messages( self, /, messages: list[echostream_node.Message], *, targets: set[echostream_node.Edge] = None) -> None:
413    def send_messages(
414        self, /, messages: list[Message], *, targets: set[Edge] = None
415    ) -> None:
416        """
417        Send the messages to the specified targets. If no targets are specified
418        the messages will be sent to all targets.
419        """
420        if self.stopped:
421            raise RuntimeError(f"{self.name} is stopped")
422        if messages:
423            for target in targets or self.targets:
424                if target_message_queue := self.__target_message_queues.get(
425                    target.name
426                ):
427                    for message in messages:
428                        target_message_queue.put_nowait(message)
429                else:
430                    getLogger().warning(f"Target {target.name} does not exist")

Send the messages to the specified targets. If no targets are specified the messages will be sent to all targets.

async def start(self) -> None:
432    async def start(self) -> None:
433        """
434        Starts this Node. Must be called prior to any other usage.
435        """
436        getLogger().info(f"Starting Node {self.name}")
437        self.__lock = asyncio.Lock()
438        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
439        async with self.__lock:
440            async with self._gql_client as session:
441                data: dict[str, Union[str, dict]] = (
442                    await session.execute(
443                        _GET_NODE_GQL,
444                        variable_values=dict(name=self.name, tenant=self.tenant),
445                    )
446                )["GetNode"]
447        self._audit = data["tenant"].get("audit") or False
448        self.config = (
449            json.loads(data["tenant"].get("config") or "{}")
450            | json.loads((data.get("app") or dict()).get("config") or "{}")
451            | json.loads(data.get("config") or "{}")
452        )
453        self._stopped = data.get("stopped")
454        if receive_message_type := data.get("receiveMessageType"):
455            self._receive_message_type = MessageType(
456                auditor=dynamic_function_loader.load(receive_message_type["auditor"]),
457                name=receive_message_type["name"],
458            )
459            if not self.stopped and self.audit:
460                self.__audit_records_queues[receive_message_type["name"]] = (
461                    _AuditRecordQueue(self.receive_message_type, self)
462                )
463        if send_message_type := data.get("sendMessageType"):
464            self._send_message_type = MessageType(
465                auditor=dynamic_function_loader.load(send_message_type["auditor"]),
466                name=send_message_type["name"],
467            )
468            if not self.stopped and self.audit:
469                self.__audit_records_queues[send_message_type["name"]] = (
470                    _AuditRecordQueue(self.send_message_type, self)
471                )
472        if self.node_type == "AppChangeReceiverNode":
473            if edge := data.get("receiveEdge"):
474                self._sources = {Edge(name=edge["source"]["name"], queue=edge["queue"])}
475            else:
476                self._sources = set()
477        else:
478            self._sources = {
479                Edge(name=edge["source"]["name"], queue=edge["queue"])
480                for edge in (data.get("receiveEdges") or list())
481            }
482        self._targets = {
483            Edge(name=edge["target"]["name"], queue=edge["queue"])
484            for edge in (data.get("sendEdges") or list())
485        }
486        if not self.stopped:
487            self.__target_message_queues = {
488                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
489            }

Starts this Node. Must be called prior to any other usage.

class AppNode(Node):
649class AppNode(Node):
650    def __init__(
651        self,
652        *,
653        appsync_endpoint: str = None,
654        bulk_data_acceleration: bool = False,
655        client_id: str = None,
656        concurrent_processing: bool = False,
657        name: str = None,
658        password: str = None,
659        tenant: str = None,
660        timeout: float = None,
661        user_pool_id: str = None,
662        username: str = None,
663    ) -> None:
664        super().__init__(
665            appsync_endpoint=appsync_endpoint,
666            bulk_data_acceleration=bulk_data_acceleration,
667            client_id=client_id,
668            name=name,
669            password=password,
670            tenant=tenant,
671            timeout=timeout,
672            user_pool_id=user_pool_id,
673            username=username,
674        )
675        self.__concurrent_processing = concurrent_processing
676        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
677
678    @property
679    def _concurrent_processing(self) -> bool:
680        return self.__concurrent_processing
681
682    async def join(self) -> None:
683        """
684        Joins the calling thread with this Node. Will block until all
685        join conditions are satified.
686        """
687        await asyncio.gather(
688            *[
689                source_message_receiver.join()
690                for source_message_receiver in self.__source_message_receivers
691            ]
692        )
693        await super().join()
694
695    async def start(self) -> None:
696        """
697        Starts this Node. Must be called prior to any other usage.
698        """
699        await super().start()
700        if not self.stopped:
701            self.__source_message_receivers = [
702                _SourceMessageReceiver(edge, self) for edge in self._sources
703            ]
704
705    async def start_and_run_forever(self) -> None:
706        """Will start this Node and run until the containing Task is cancelled"""
707        await self.start()
708        await self.join()

Base class for all implemented asyncio Nodes.

Nodes of this class must be instantiated outside of the asyncio event loop.

AppNode( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, concurrent_processing: bool = False, name: str = None, password: str = None, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
650    def __init__(
651        self,
652        *,
653        appsync_endpoint: str = None,
654        bulk_data_acceleration: bool = False,
655        client_id: str = None,
656        concurrent_processing: bool = False,
657        name: str = None,
658        password: str = None,
659        tenant: str = None,
660        timeout: float = None,
661        user_pool_id: str = None,
662        username: str = None,
663    ) -> None:
664        super().__init__(
665            appsync_endpoint=appsync_endpoint,
666            bulk_data_acceleration=bulk_data_acceleration,
667            client_id=client_id,
668            name=name,
669            password=password,
670            tenant=tenant,
671            timeout=timeout,
672            user_pool_id=user_pool_id,
673            username=username,
674        )
675        self.__concurrent_processing = concurrent_processing
676        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
async def join(self) -> None:
682    async def join(self) -> None:
683        """
684        Joins the calling thread with this Node. Will block until all
685        join conditions are satified.
686        """
687        await asyncio.gather(
688            *[
689                source_message_receiver.join()
690                for source_message_receiver in self.__source_message_receivers
691            ]
692        )
693        await super().join()

Joins the calling thread with this Node. Will block until all join conditions are satified.

async def start(self) -> None:
695    async def start(self) -> None:
696        """
697        Starts this Node. Must be called prior to any other usage.
698        """
699        await super().start()
700        if not self.stopped:
701            self.__source_message_receivers = [
702                _SourceMessageReceiver(edge, self) for edge in self._sources
703            ]

Starts this Node. Must be called prior to any other usage.

async def start_and_run_forever(self) -> None:
705    async def start_and_run_forever(self) -> None:
706        """Will start this Node and run until the containing Task is cancelled"""
707        await self.start()
708        await self.join()

Will start this Node and run until the containing Task is cancelled

class LambdaNode(Node):
711class LambdaNode(Node):
712    def __init__(
713        self,
714        *,
715        appsync_endpoint: str = None,
716        bulk_data_acceleration: bool = False,
717        client_id: str = None,
718        concurrent_processing: bool = False,
719        name: str = None,
720        password: str = None,
721        report_batch_item_failures: bool = False,
722        tenant: str = None,
723        timeout: float = None,
724        user_pool_id: str = None,
725        username: str = None,
726    ) -> None:
727        super().__init__(
728            appsync_endpoint=appsync_endpoint,
729            bulk_data_acceleration=bulk_data_acceleration,
730            client_id=client_id,
731            name=name,
732            password=password,
733            tenant=tenant,
734            timeout=timeout or 0.01,
735            user_pool_id=user_pool_id,
736            username=username,
737        )
738        self.__concurrent_processing = concurrent_processing
739        self.__loop = self._create_event_loop()
740        self.__queue_name_to_source: dict[str, str] = None
741        self.__report_batch_item_failures = report_batch_item_failures
742
743        # Set up the asyncio loop
744        signal(SIGTERM, self._shutdown_handler)
745
746        self.__started = threading.Event()
747        self.__loop.create_task(self.start())
748
749        # Run the event loop in a seperate thread, or else we will block the main
750        # Lambda execution!
751        threading.Thread(name="event_loop", target=self.__run_event_loop).start()
752
753        # Wait until the started event is set before returning control to
754        # Lambda
755        self.__started.wait()
756
757    def __run_event_loop(self) -> None:
758        getLogger().info("Starting event loop")
759        asyncio.set_event_loop(self.__loop)
760
761        pending_exception_to_raise: Exception = None
762
763        def exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None:
764            nonlocal pending_exception_to_raise
765            pending_exception_to_raise = context.get("exception")
766            getLogger().error(
767                "Unhandled exception; stopping loop: %r",
768                context.get("message"),
769                exc_info=pending_exception_to_raise,
770            )
771            loop.stop()
772
773        self.__loop.set_exception_handler(exception_handler)
774        executor = ThreadPoolExecutor()
775        self.__loop.set_default_executor(executor)
776
777        self.__loop.run_forever()
778        getLogger().info("Entering shutdown phase")
779        getLogger().info("Cancelling pending tasks")
780        if tasks := asyncio.all_tasks(self.__loop):
781            for task in tasks:
782                getLogger().debug(f"Cancelling task: {task}")
783                task.cancel()
784            getLogger().info("Running pending tasks till complete")
785            self.__loop.run_until_complete(
786                asyncio.gather(*tasks, return_exceptions=True)
787            )
788        getLogger().info("Waiting for executor shutdown")
789        executor.shutdown(wait=True)
790        getLogger().info("Shutting down async generators")
791        self.__loop.run_until_complete(self.__loop.shutdown_asyncgens())
792        getLogger().info("Closing the loop.")
793        self.__loop.close()
794        getLogger().info("Loop is closed")
795        if pending_exception_to_raise:
796            getLogger().info("Reraising unhandled exception")
797            raise pending_exception_to_raise
798
799    def _create_event_loop(self) -> asyncio.AbstractEventLoop:
800        return asyncio.new_event_loop()
801
802    def _get_source(self, queue_arn: str) -> str:
803        return self.__queue_name_to_source[queue_arn.split(":")[-1:][0]]
804
805    async def _handle_event(self, event: LambdaEvent) -> BatchItemFailures:
806        """
807        Handles the AWS Lambda event passed into the containing
808        AWS Lambda function during invocation.
809
810        This is intended to be the only called method in your
811        containing AWS Lambda function.
812        """
813        records: LambdaSqsRecords = None
814        if not (records := event.get("Records")):
815            getLogger().warning(f"No Records found in event {event}")
816            return
817
818        source = self._get_source(records[0]["eventSourceARN"])
819        getLogger().info(f"Received {len(records)} messages from {source}")
820        batch_item_failures: list[str] = (
821            [record["messageId"] for record in records]
822            if self.__report_batch_item_failures
823            else None
824        )
825
826        async def handle_received_message(message: Message, message_id: str) -> None:
827            try:
828                await self.handle_received_message(message=message, source=source)
829            except asyncio.CancelledError:
830                raise
831            except Exception:
832                if not self.__report_batch_item_failures:
833                    raise
834                getLogger().exception(f"Error handling recevied message for {source}")
835            else:
836                if self.__report_batch_item_failures:
837                    batch_item_failures.remove(message_id)
838
839        message_handlers = [
840            handle_received_message(
841                Message(
842                    body=record["body"],
843                    group_id=record["attributes"]["MessageGroupId"],
844                    message_type=self.receive_message_type,
845                    previous_tracking_ids=record["messageAttributes"]
846                    .get("prevTrackingIds", {})
847                    .get("stringValue"),
848                    tracking_id=record["messageAttributes"]
849                    .get("trackingId", {})
850                    .get("stringValue"),
851                ),
852                record["messageId"],
853            )
854            for record in records
855        ]
856
857        if self.__concurrent_processing:
858            await asyncio.gather(*message_handlers)
859        else:
860            for message_handler in message_handlers:
861                await message_handler
862        await self.join()
863        if self.__report_batch_item_failures and batch_item_failures:
864            return dict(
865                batchItemFailures=[
866                    dict(itemIdentifier=message_id)
867                    for message_id in batch_item_failures
868                ]
869            )
870
871    def _shutdown_handler(self, signum: int, frame: object) -> None:
872        signal(SIGTERM, SIG_IGN)
873        getLogger().warning("Received SIGTERM, stopping the loop")
874        self.__loop.stop()
875
876    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
877        return asyncio.run_coroutine_threadsafe(
878            self._handle_event(event), self.__loop
879        ).result()
880
881    async def start(self) -> None:
882        await super().start()
883        self.__queue_name_to_source = {
884            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
885        }
886        self.__started.set()

Base class for all implemented asyncio Nodes.

Nodes of this class must be instantiated outside of the asyncio event loop.

LambdaNode( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, concurrent_processing: bool = False, name: str = None, password: str = None, report_batch_item_failures: bool = False, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
712    def __init__(
713        self,
714        *,
715        appsync_endpoint: str = None,
716        bulk_data_acceleration: bool = False,
717        client_id: str = None,
718        concurrent_processing: bool = False,
719        name: str = None,
720        password: str = None,
721        report_batch_item_failures: bool = False,
722        tenant: str = None,
723        timeout: float = None,
724        user_pool_id: str = None,
725        username: str = None,
726    ) -> None:
727        super().__init__(
728            appsync_endpoint=appsync_endpoint,
729            bulk_data_acceleration=bulk_data_acceleration,
730            client_id=client_id,
731            name=name,
732            password=password,
733            tenant=tenant,
734            timeout=timeout or 0.01,
735            user_pool_id=user_pool_id,
736            username=username,
737        )
738        self.__concurrent_processing = concurrent_processing
739        self.__loop = self._create_event_loop()
740        self.__queue_name_to_source: dict[str, str] = None
741        self.__report_batch_item_failures = report_batch_item_failures
742
743        # Set up the asyncio loop
744        signal(SIGTERM, self._shutdown_handler)
745
746        self.__started = threading.Event()
747        self.__loop.create_task(self.start())
748
749        # Run the event loop in a seperate thread, or else we will block the main
750        # Lambda execution!
751        threading.Thread(name="event_loop", target=self.__run_event_loop).start()
752
753        # Wait until the started event is set before returning control to
754        # Lambda
755        self.__started.wait()
def handle_event( self, event: Union[bool, dict, float, int, list, str, tuple, NoneType]) -> dict[str, list[dict[str, str]]]:
876    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
877        return asyncio.run_coroutine_threadsafe(
878            self._handle_event(event), self.__loop
879        ).result()
async def start(self) -> None:
881    async def start(self) -> None:
882        await super().start()
883        self.__queue_name_to_source = {
884            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
885        }
886        self.__started.set()

Starts this Node. Must be called prior to any other usage.