echostream_node.threading

  1from __future__ import annotations
  2
  3from concurrent.futures import Executor, ThreadPoolExecutor, as_completed, wait
  4from datetime import datetime, timezone
  5from functools import partial
  6from gzip import GzipFile, compress
  7from io import BytesIO
  8from os import environ
  9from queue import Empty, Queue
 10from signal import SIGTERM, signal
 11from threading import Event, RLock, Thread
 12from time import sleep
 13from typing import TYPE_CHECKING, Any, BinaryIO, Generator, Union
 14
 15import dynamic_function_loader
 16import simplejson as json
 17from aws_error_utils import catch_aws_error
 18from gql import Client as GqlClient
 19from gql.transport.requests import RequestsHTTPTransport
 20from httpx import Client as HttpxClient
 21from httpx_auth import AWS4Auth
 22from pycognito.utils import RequestsSrpAuth
 23
 24from .. import _GET_BULK_DATA_STORAGE_GQL, _GET_NODE_GQL, BatchItemFailures
 25from .. import BulkDataStorage as BaseBulkDataStorage
 26from .. import Edge, LambdaEvent, LambdaSqsRecords, Message, MessageType
 27from .. import Node as BaseNode
 28from .. import PresignedPost, getLogger
 29
 30if TYPE_CHECKING:
 31    from mypy_boto3_sqs.type_defs import (
 32        DeleteMessageBatchRequestEntryTypeDef,
 33        SendMessageBatchRequestEntryTypeDef,
 34    )
 35else:
 36    DeleteMessageBatchRequestEntryTypeDef = dict
 37    SendMessageBatchRequestEntryTypeDef = dict
 38
 39
 40class _Queue(Queue):
 41    def tasks_done(self, count: int) -> None:
 42        with self.all_tasks_done:
 43            unfinished = self.unfinished_tasks - count
 44            if unfinished < 0:
 45                raise ValueError("count larger than unfinished tasks")
 46            if unfinished == 0:
 47                self.all_tasks_done.notify_all()
 48            self.unfinished_tasks = unfinished
 49
 50
 51class _AuditRecordQueue(_Queue):
 52    def __init__(self, message_type: MessageType, node: Node) -> None:
 53        super().__init__()
 54
 55        def sender() -> None:
 56            with HttpxClient() as client:
 57                while True:
 58                    batch: list[dict] = list()
 59                    while len(batch) < 500:
 60                        try:
 61                            batch.append(self.get(timeout=node.timeout))
 62                        except Empty:
 63                            break
 64                    if not batch:
 65                        continue
 66                    credentials = (
 67                        node._session.get_credentials().get_frozen_credentials()
 68                    )
 69                    auth = AWS4Auth(
 70                        access_id=credentials.access_key,
 71                        region=node._session.region_name,
 72                        secret_key=credentials.secret_key,
 73                        service="lambda",
 74                        security_token=credentials.token,
 75                    )
 76                    url = node._audit_records_endpoint
 77                    post_args = dict(
 78                        auth=auth,
 79                        url=f"{url}{'' if url.endswith(
 80                            '/') else '/'}{node.name}",
 81                    )
 82                    body = dict(
 83                        messageType=message_type.name,
 84                        auditRecords=batch,
 85                    )
 86                    if len(batch) <= 10:
 87                        post_args["json"] = body
 88                    else:
 89                        post_args["content"] = compress(
 90                            json.dumps(body, separators=(",", ":")).encode()
 91                        )
 92                        post_args["headers"] = {
 93                            "Content-Encoding": "gzip",
 94                            "Content-Type": "application/json",
 95                        }
 96                    try:
 97                        response = client.post(**post_args)
 98                        response.raise_for_status()
 99                        response.close()
100                    except Exception:
101                        getLogger().exception("Error creating audit records")
102                    finally:
103                        self.tasks_done(len(batch))
104
105        Thread(daemon=True, name=f"AuditRecordsSender", target=sender).start()
106
107    def get(self, block: bool = True, timeout: float = None) -> dict:
108        return super().get(block=block, timeout=timeout)
109
110
111class _BulkDataStorage(BaseBulkDataStorage):
112    def __init__(
113        self,
114        bulk_data_storage: dict[str, Union[str, PresignedPost]],
115        client: HttpxClient,
116    ) -> None:
117        super().__init__(bulk_data_storage)
118        self.__client = client
119
120    def handle_bulk_data(self, data: Union[bytearray, bytes, BinaryIO]) -> str:
121        if isinstance(data, BinaryIO):
122            data = data.read()
123        with BytesIO() as buffer:
124            with GzipFile(mode="wb", fileobj=buffer) as gzf:
125                gzf.write(data)
126            buffer.seek(0)
127            response = self.__client.post(
128                self.presigned_post.url,
129                data=self.presigned_post.fields,
130                files=dict(file=("bulk_data", buffer)),
131            )
132            response.raise_for_status()
133            response.close()
134        return self.presigned_get
135
136
137class _BulkDataStorageQueue(Queue):
138    def __init__(self, node: Node) -> None:
139        super().__init__()
140        self.__fill = Event()
141
142        def filler() -> None:
143            with HttpxClient() as client:
144                while True:
145                    self.__fill.wait()
146                    try:
147                        with node._lock:
148                            with node._gql_client as session:
149                                bulk_data_storages: list[dict] = session.execute(
150                                    _GET_BULK_DATA_STORAGE_GQL,
151                                    variable_values={
152                                        "tenant": node.tenant,
153                                        "useAccelerationEndpoint": node.bulk_data_acceleration,
154                                    },
155                                )["GetBulkDataStorage"]
156                    except Exception:
157                        getLogger().exception("Error getting bulk data storage")
158                    else:
159                        for bulk_data_storage in bulk_data_storages:
160                            self.put_nowait(_BulkDataStorage(
161                                bulk_data_storage, client))
162                    self.__fill.clear()
163
164        Thread(daemon=True, name="BulkDataStorageQueueFiller",
165               target=filler).start()
166
167    def get(self, block: bool = True, timeout: float = None) -> _BulkDataStorage:
168        if self.qsize() < 20:
169            self.__fill.set()
170        bulk_data_storage: _BulkDataStorage = super().get(block=block, timeout=timeout)
171        return (
172            bulk_data_storage
173            if not bulk_data_storage.expired
174            else self.get(block=block, timeout=timeout)
175        )
176
177
178class _TargetMessageQueue(_Queue):
179    def __init__(self, node: Node, edge: Edge) -> None:
180        super().__init__()
181
182        def batcher() -> Generator[
183            list[SendMessageBatchRequestEntryTypeDef],
184            None,
185            list[SendMessageBatchRequestEntryTypeDef],
186        ]:
187            batch: list[SendMessageBatchRequestEntryTypeDef] = list()
188            batch_length = 0
189            id = 0
190            while True:
191                try:
192                    message = self.get(timeout=node.timeout)
193                    if batch_length + len(message) > 262144:
194                        yield batch
195                        batch = list()
196                        batch_length = 0
197                        id = 0
198                    batch.append(
199                        SendMessageBatchRequestEntryTypeDef(
200                            Id=str(id), **message._sqs_message(node)
201                        )
202                    )
203                    if len(batch) == 10:
204                        yield batch
205                        batch = list()
206                        batch_length = 0
207                        id = 0
208                    id += 1
209                    batch_length += len(message)
210                except Empty:
211                    if batch:
212                        yield batch
213                    batch = list()
214                    batch_length = 0
215                    id = 0
216
217        def sender() -> None:
218            for entries in batcher():
219                try:
220                    response = node._sqs_client.send_message_batch(
221                        Entries=entries, QueueUrl=edge.queue
222                    )
223                    for failed in response.get("Failed", list()):
224                        id = failed.pop("Id")
225                        getLogger().error(
226                            f"Unable to send message {entries[id]} to {
227                                edge.name}, reason {failed}"
228                        )
229                except Exception:
230                    getLogger().exception(
231                        f"Error sending messages to {edge.name}")
232                finally:
233                    self.tasks_done(len(entries))
234
235        Thread(
236            daemon=True, name=f"TargetMessageSender({edge.name})", target=sender
237        ).start()
238
239    def get(self, block: bool = True, timeout: float = None) -> Message:
240        return super().get(block=block, timeout=timeout)
241
242
243class Node(BaseNode):
244    """
245    Base class for all threading Nodes.
246    """
247
248    def __init__(
249        self,
250        *,
251        appsync_endpoint: str = None,
252        bulk_data_acceleration: bool = False,
253        client_id: str = None,
254        name: str = None,
255        password: str = None,
256        tenant: str = None,
257        timeout: float = None,
258        user_pool_id: str = None,
259        username: str = None,
260    ) -> None:
261        super().__init__(
262            appsync_endpoint=appsync_endpoint,
263            bulk_data_acceleration=bulk_data_acceleration,
264            client_id=client_id,
265            name=name,
266            password=password,
267            tenant=tenant,
268            timeout=timeout,
269            user_pool_id=user_pool_id,
270            username=username,
271        )
272        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
273        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
274        self.__gql_client = GqlClient(
275            fetch_schema_from_transport=True,
276            transport=RequestsHTTPTransport(
277                auth=RequestsSrpAuth(cognito=self.__cognito,
278                                     http_header_prefix=""),
279                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
280            ),
281        )
282        self.__lock = RLock()
283        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
284
285    @property
286    def _gql_client(self) -> GqlClient:
287        return self.__gql_client
288
289    @property
290    def _lock(self) -> RLock:
291        return self.__lock
292
293    def audit_message(
294        self,
295        /,
296        message: Message,
297        *,
298        extra_attributes: dict[str, Any] = None,
299        source: str = None,
300    ) -> None:
301        """
302        Audits the provided message. If extra_attibutes is
303        supplied, they will be added to the message's audit
304        dict. If source is provided, it will be recorded in
305        the audit.
306        """
307        if self.stopped:
308            raise RuntimeError(f"{self.name} is stopped")
309        if not self.audit:
310            return
311        extra_attributes = extra_attributes or dict()
312        message_type = message.message_type
313        record = dict(
314            datetime=datetime.now(timezone.utc).isoformat(),
315            previousTrackingIds=message.previous_tracking_ids,
316            sourceNode=source,
317            trackingId=message.tracking_id,
318        )
319        if attributes := (
320            message_type.auditor(message=message.body) | extra_attributes
321        ):
322            record["attributes"] = attributes
323        try:
324            self.__audit_records_queues[message_type.name].put_nowait(record)
325        except KeyError:
326            raise ValueError(f"Unrecognized message type {message_type.name}")
327
328    def audit_messages(
329        self,
330        /,
331        messages: list[Message],
332        *,
333        extra_attributes: list[dict[str, Any]] = None,
334        source: str = None,
335    ) -> None:
336        """
337        Audits the provided messages. If extra_attibutes is
338        supplied they will be added to the respective message's audit
339        dict and they must have the same count as messages.
340        If source is provided, it will be recorded in the audit.
341        """
342        if extra_attributes and len(extra_attributes) != len(messages):
343            raise ValueError(
344                "messages and extra_attributes must have the same number of items"
345            )
346        for message, attributes in zip(messages, extra_attributes):
347            self.audit_message(
348                message, extra_attributes=attributes, source=source)
349
350    def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
351        """
352        Posts data as bulk data and returns a GET URL for data retrieval.
353        Normally this returned URL will be used as a "ticket" in messages
354        that require bulk data.
355        """
356        return self.__bulk_data_storage_queue.get().handle_bulk_data(data)
357
358    def handle_received_message(self, *, message: Message, source: str) -> None:
359        """
360        Callback called when a message is received. Subclasses that receive messages
361        should override this method.
362        """
363        pass
364
365    def join(self) -> None:
366        """
367        Joins the calling thread with this Node. Will block until all
368        join conditions are satified.
369        """
370        for target_message_queue in self.__target_message_queues.values():
371            target_message_queue.join()
372        for audit_records_queue in self.__audit_records_queues.values():
373            audit_records_queue.join()
374
375    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
376        """
377        Send the message to the specified targets. If no targets are specified
378        the message will be sent to all targets.
379        """
380        self.send_messages([message], targets=targets)
381
382    def send_messages(
383        self, /, messages: list[Message], *, targets: set[Edge] = None
384    ) -> None:
385        """
386        Send the messages to the specified targets. If no targets are specified
387        the messages will be sent to all targets.
388        """
389        if self.stopped:
390            raise RuntimeError(f"{self.name} is stopped")
391        if messages:
392            for target in targets or self.targets:
393                if target_message_queue := self.__target_message_queues.get(
394                    target.name
395                ):
396                    for message in messages:
397                        target_message_queue.put_nowait(message)
398                else:
399                    getLogger().warning(f"Target {target.name} does not exist")
400
401    def start(self) -> None:
402        """
403        Starts this Node. Must be called prior to any other usage.
404        """
405        getLogger().info(f"Starting Node {self.name}")
406        with self._lock:
407            with self._gql_client as session:
408                data: dict[str, Union[str, dict]] = session.execute(
409                    _GET_NODE_GQL,
410                    variable_values=dict(name=self.name, tenant=self.tenant),
411                )["GetNode"]
412        self._audit = data["tenant"].get("audit") or False
413        self.config = (
414            json.loads(data["tenant"].get("config") or "{}")
415            | json.loads((data.get("app") or dict()).get("config") or "{}")
416            | json.loads(data.get("config") or "{}")
417        )
418        self._stopped = data.get("stopped")
419        if receive_message_type := data.get("receiveMessageType"):
420            self._receive_message_type = MessageType(
421                auditor=dynamic_function_loader.load(
422                    receive_message_type["auditor"]),
423                name=receive_message_type["name"],
424            )
425            if not self.stopped and self.audit:
426                self.__audit_records_queues[receive_message_type["name"]] = (
427                    _AuditRecordQueue(self._receive_message_type, self)
428                )
429        if send_message_type := data.get("sendMessageType"):
430            self._send_message_type = MessageType(
431                auditor=dynamic_function_loader.load(
432                    send_message_type["auditor"]),
433                name=send_message_type["name"],
434            )
435            if not self.stopped and self.audit:
436                self.__audit_records_queues[send_message_type["name"]] = (
437                    _AuditRecordQueue(self._send_message_type, self)
438                )
439        if self.node_type == "AppChangeReceiverNode":
440            if edge := data.get("receiveEdge"):
441                self._sources = {
442                    Edge(name=edge["source"]["name"], queue=edge["queue"])}
443            else:
444                self._sources = set()
445        else:
446            self._sources = {
447                Edge(name=edge["source"]["name"], queue=edge["queue"])
448                for edge in (data.get("receiveEdges") or list())
449            }
450        self._targets = {
451            Edge(name=edge["target"]["name"], queue=edge["queue"])
452            for edge in (data.get("sendEdges") or list())
453        }
454        if not self.stopped:
455            self.__target_message_queues = {
456                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
457            }
458
459    def stop(self) -> None:
460        """Stops the Node's processing."""
461        pass
462
463
464class _DeleteMessageQueue(_Queue):
465    def __init__(self, edge: Edge, node: AppNode) -> None:
466        super().__init__()
467
468        def deleter() -> None:
469            while True:
470                receipt_handles: list[str] = list()
471                while len(receipt_handles) < 10:
472                    try:
473                        receipt_handles.append(self.get(timeout=node.timeout))
474                    except Empty:
475                        break
476                if not receipt_handles:
477                    continue
478                try:
479                    response = node._sqs_client.delete_message_batch(
480                        Entries=[
481                            DeleteMessageBatchRequestEntryTypeDef(
482                                Id=str(id), ReceiptHandle=receipt_handle
483                            )
484                            for id, receipt_handle in enumerate(receipt_handles)
485                        ],
486                        QueueUrl=edge.queue,
487                    )
488                    for failed in response.get("Failed", list()):
489                        id = failed.pop("Id")
490                        getLogger().error(
491                            f"Unable to delete message {receipt_handles[id]} from {
492                                edge.name}, reason {failed}"
493                        )
494                except Exception:
495                    getLogger().exception(
496                        f"Error deleting messages from {edge.name}")
497                finally:
498                    self.tasks_done(len(receipt_handles))
499
500        Thread(
501            daemon=True, name=f"SourceMessageDeleter({edge.name})", target=deleter
502        ).start()
503
504    def get(self, block: bool = True, timeout: float = None) -> str:
505        return super().get(block=block, timeout=timeout)
506
507
508class _SourceMessageReceiver(Thread):
509    def __init__(self, edge: Edge, node: AppNode) -> None:
510        self.__continue = Event()
511        self.__continue.set()
512        self.__delete_message_queue = _DeleteMessageQueue(edge, node)
513
514        def handle_received_message(message: Message, receipt_handle: str) -> bool:
515            try:
516                node.handle_received_message(message=message, source=edge.name)
517            except Exception:
518                getLogger().exception(
519                    f"Error handling recevied message for {edge.name}"
520                )
521                return False
522            else:
523                self.__delete_message_queue.put_nowait(receipt_handle)
524            return True
525
526        def receive() -> None:
527            self.__continue.wait()
528            getLogger().info(f"Receiving messages from {edge.name}")
529            while self.__continue.is_set():
530                try:
531                    response = node._sqs_client.receive_message(
532                        AttributeNames=["All"],
533                        MaxNumberOfMessages=10,
534                        MessageAttributeNames=["All"],
535                        QueueUrl=edge.queue,
536                        WaitTimeSeconds=20,
537                    )
538                except catch_aws_error("AWS.SimpleQueueService.NonExistentQueue"):
539                    getLogger().warning(
540                        f"Queue {edge.queue} does not exist, exiting")
541                    break
542                except Exception:
543                    getLogger().exception(
544                        f"Error receiving messages from {edge.name}, retrying"
545                    )
546                    sleep(20)
547                else:
548                    if not (sqs_messages := response.get("Messages")):
549                        continue
550                    getLogger().info(
551                        f"Received {len(sqs_messages)} from {edge.name}")
552
553                    message_handlers = [
554                        partial(
555                            handle_received_message,
556                            Message(
557                                body=sqs_message["Body"],
558                                group_id=sqs_message["Attributes"]["MessageGroupId"],
559                                message_type=node.receive_message_type,
560                                tracking_id=sqs_message["MessageAttributes"]
561                                .get("trackingId", {})
562                                .get("StringValue"),
563                                previous_tracking_ids=sqs_message["MessageAttributes"]
564                                .get("prevTrackingIds", {})
565                                .get("StringValue"),
566                            ),
567                            sqs_message["ReceiptHandle"],
568                        )
569                        for sqs_message in sqs_messages
570                    ]
571
572                    def handle_received_messages() -> None:
573                        if executor := node._executor:
574                            wait(
575                                [
576                                    executor.submit(message_handler)
577                                    for message_handler in message_handlers
578                                ]
579                            )
580                        else:
581                            for message_handler in message_handlers:
582                                if not message_handler():
583                                    break
584
585                    Thread(
586                        name="handle_received_messages",
587                        target=handle_received_messages,
588                    ).start()
589
590            getLogger().info(f"Stopping receiving messages from {edge.name}")
591
592        super().__init__(
593            name=f"SourceMessageReceiver({edge.name})", target=receive)
594        self.start()
595
596    def join(self) -> None:
597        super().join()
598        self.__delete_message_queue.join()
599
600    def stop(self) -> None:
601        self.__continue.clear()
602
603
604class AppNode(Node):
605    """
606    A daemon Node intended to be used as either a stand-alone application
607    or as a part of a larger application.
608    """
609
610    def __init__(
611        self,
612        *,
613        appsync_endpoint: str = None,
614        bulk_data_acceleration: bool = False,
615        client_id: str = None,
616        executor: Executor = None,
617        name: str = None,
618        password: str = None,
619        tenant: str = None,
620        timeout: float = None,
621        user_pool_id: str = None,
622        username: str = None,
623    ) -> None:
624        super().__init__(
625            appsync_endpoint=appsync_endpoint,
626            bulk_data_acceleration=bulk_data_acceleration,
627            client_id=client_id,
628            name=name,
629            password=password,
630            tenant=tenant,
631            timeout=timeout,
632            user_pool_id=user_pool_id,
633            username=username,
634        )
635        self.__executor = executor
636        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
637        self.__stop = Event()
638
639    @property
640    def _executor(self) -> Executor:
641        return self.__executor
642
643    def join(self) -> None:
644        """
645        Method to join all the app node receivers so that main thread can wait for their execution to complete.
646        """
647        self.__stop.wait()
648        for app_node_receiver in self.__source_message_receivers:
649            app_node_receiver.join()
650        super().join()
651
652    def start(self) -> None:
653        """
654        Calls start of Node class
655        """
656        super().start()
657        self.__stop.clear()
658        if not self.stopped:
659            self.__source_message_receivers = [
660                _SourceMessageReceiver(edge, self) for edge in self._sources
661            ]
662
663    def start_and_run_forever(self) -> None:
664        """Will start this Node and run until stop is called"""
665        self.start()
666        self.join()
667
668    def stop(self) -> None:
669        """
670        Stops the Node gracefully
671        """
672        self.__stop.set()
673        for app_node_receiver in self.__source_message_receivers:
674            app_node_receiver.stop()
675
676
677class LambdaNode(Node):
678    """
679    A Node class intended to be implemented in an AWS Lambda function.
680    Nodes that inherit from this class are automatically started on
681    creation.
682    """
683
684    def __init__(
685        self,
686        *,
687        appsync_endpoint: str = None,
688        bulk_data_acceleration: bool = False,
689        client_id: str = None,
690        concurrent_processing: bool = False,
691        name: str = None,
692        password: str = None,
693        report_batch_item_failures: bool = False,
694        tenant: str = None,
695        timeout: float = None,
696        user_pool_id: str = None,
697        username: str = None,
698    ) -> None:
699        super().__init__(
700            appsync_endpoint=appsync_endpoint,
701            bulk_data_acceleration=bulk_data_acceleration,
702            client_id=client_id,
703            name=name,
704            password=password,
705            tenant=tenant,
706            timeout=timeout or 0.01,
707            user_pool_id=user_pool_id,
708            username=username,
709        )
710        self.start()
711        signal(SIGTERM, self._shutdown_handler)
712        self.__executor: Executor = (
713            ThreadPoolExecutor() if concurrent_processing else None
714        )
715        self.__queue_name_to_source = {
716            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
717        }
718        self.__report_batch_item_failures = report_batch_item_failures
719
720    def _get_source(self, queue_arn: str) -> str:
721        return self.__queue_name_to_source[queue_arn.split(":")[-1:][0]]
722
723    def _shutdown_handler(self, signum: int, frame: object) -> None:
724        getLogger().info("Received SIGTERM, shutting down")
725        self.join()
726        getLogger().info("Shutdown complete")
727
728    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
729        """
730        Handles the AWS Lambda event passed into the containing
731        AWS Lambda function during invocation.
732
733        This is intended to be the only called method in your
734        containing AWS Lambda function.
735        """
736        records: LambdaSqsRecords = None
737        if not (records := event.get("Records")):
738            getLogger().warning(f"No Records found in event {event}")
739            return
740
741        source = self._get_source(records[0]["eventSourceARN"])
742        getLogger().info(f"Received {len(records)} messages from {source}")
743        batch_item_failures: list[str] = (
744            [record["messageId"] for record in records]
745            if self.__report_batch_item_failures
746            else None
747        )
748
749        def handle_received_message(message: Message, message_id: str) -> None:
750            try:
751                self.handle_received_message(message=message, source=source)
752            except Exception:
753                if not self.__report_batch_item_failures:
754                    raise
755                getLogger().exception(
756                    f"Error handling recevied message for {source}")
757            else:
758                if self.__report_batch_item_failures:
759                    batch_item_failures.remove(message_id)
760
761        message_handlers = [
762            partial(
763                handle_received_message,
764                Message(
765                    body=record["body"],
766                    group_id=record["attributes"]["MessageGroupId"],
767                    message_type=self.receive_message_type,
768                    previous_tracking_ids=record["messageAttributes"]
769                    .get("prevTrackingIds", {})
770                    .get("stringValue"),
771                    tracking_id=record["messageAttributes"]
772                    .get("trackingId", {})
773                    .get("stringValue"),
774                ),
775                record["messageId"],
776            )
777            for record in records
778        ]
779
780        if executor := self.__executor:
781            for future in as_completed(
782                [
783                    executor.submit(message_handler)
784                    for message_handler in message_handlers
785                ]
786            ):
787                if exception := future.exception():
788                    raise exception
789        else:
790            for message_handler in message_handlers:
791                message_handler()
792        self.join()
793        if self.__report_batch_item_failures and batch_item_failures:
794            return dict(
795                batchItemFailures=[
796                    dict(itemIdentifier=message_id)
797                    for message_id in batch_item_failures
798                ]
799            )
class Node(echostream_node.Node):
244class Node(BaseNode):
245    """
246    Base class for all threading Nodes.
247    """
248
249    def __init__(
250        self,
251        *,
252        appsync_endpoint: str = None,
253        bulk_data_acceleration: bool = False,
254        client_id: str = None,
255        name: str = None,
256        password: str = None,
257        tenant: str = None,
258        timeout: float = None,
259        user_pool_id: str = None,
260        username: str = None,
261    ) -> None:
262        super().__init__(
263            appsync_endpoint=appsync_endpoint,
264            bulk_data_acceleration=bulk_data_acceleration,
265            client_id=client_id,
266            name=name,
267            password=password,
268            tenant=tenant,
269            timeout=timeout,
270            user_pool_id=user_pool_id,
271            username=username,
272        )
273        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
274        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
275        self.__gql_client = GqlClient(
276            fetch_schema_from_transport=True,
277            transport=RequestsHTTPTransport(
278                auth=RequestsSrpAuth(cognito=self.__cognito,
279                                     http_header_prefix=""),
280                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
281            ),
282        )
283        self.__lock = RLock()
284        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
285
286    @property
287    def _gql_client(self) -> GqlClient:
288        return self.__gql_client
289
290    @property
291    def _lock(self) -> RLock:
292        return self.__lock
293
294    def audit_message(
295        self,
296        /,
297        message: Message,
298        *,
299        extra_attributes: dict[str, Any] = None,
300        source: str = None,
301    ) -> None:
302        """
303        Audits the provided message. If extra_attibutes is
304        supplied, they will be added to the message's audit
305        dict. If source is provided, it will be recorded in
306        the audit.
307        """
308        if self.stopped:
309            raise RuntimeError(f"{self.name} is stopped")
310        if not self.audit:
311            return
312        extra_attributes = extra_attributes or dict()
313        message_type = message.message_type
314        record = dict(
315            datetime=datetime.now(timezone.utc).isoformat(),
316            previousTrackingIds=message.previous_tracking_ids,
317            sourceNode=source,
318            trackingId=message.tracking_id,
319        )
320        if attributes := (
321            message_type.auditor(message=message.body) | extra_attributes
322        ):
323            record["attributes"] = attributes
324        try:
325            self.__audit_records_queues[message_type.name].put_nowait(record)
326        except KeyError:
327            raise ValueError(f"Unrecognized message type {message_type.name}")
328
329    def audit_messages(
330        self,
331        /,
332        messages: list[Message],
333        *,
334        extra_attributes: list[dict[str, Any]] = None,
335        source: str = None,
336    ) -> None:
337        """
338        Audits the provided messages. If extra_attibutes is
339        supplied they will be added to the respective message's audit
340        dict and they must have the same count as messages.
341        If source is provided, it will be recorded in the audit.
342        """
343        if extra_attributes and len(extra_attributes) != len(messages):
344            raise ValueError(
345                "messages and extra_attributes must have the same number of items"
346            )
347        for message, attributes in zip(messages, extra_attributes):
348            self.audit_message(
349                message, extra_attributes=attributes, source=source)
350
351    def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
352        """
353        Posts data as bulk data and returns a GET URL for data retrieval.
354        Normally this returned URL will be used as a "ticket" in messages
355        that require bulk data.
356        """
357        return self.__bulk_data_storage_queue.get().handle_bulk_data(data)
358
359    def handle_received_message(self, *, message: Message, source: str) -> None:
360        """
361        Callback called when a message is received. Subclasses that receive messages
362        should override this method.
363        """
364        pass
365
366    def join(self) -> None:
367        """
368        Joins the calling thread with this Node. Will block until all
369        join conditions are satified.
370        """
371        for target_message_queue in self.__target_message_queues.values():
372            target_message_queue.join()
373        for audit_records_queue in self.__audit_records_queues.values():
374            audit_records_queue.join()
375
376    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
377        """
378        Send the message to the specified targets. If no targets are specified
379        the message will be sent to all targets.
380        """
381        self.send_messages([message], targets=targets)
382
383    def send_messages(
384        self, /, messages: list[Message], *, targets: set[Edge] = None
385    ) -> None:
386        """
387        Send the messages to the specified targets. If no targets are specified
388        the messages will be sent to all targets.
389        """
390        if self.stopped:
391            raise RuntimeError(f"{self.name} is stopped")
392        if messages:
393            for target in targets or self.targets:
394                if target_message_queue := self.__target_message_queues.get(
395                    target.name
396                ):
397                    for message in messages:
398                        target_message_queue.put_nowait(message)
399                else:
400                    getLogger().warning(f"Target {target.name} does not exist")
401
402    def start(self) -> None:
403        """
404        Starts this Node. Must be called prior to any other usage.
405        """
406        getLogger().info(f"Starting Node {self.name}")
407        with self._lock:
408            with self._gql_client as session:
409                data: dict[str, Union[str, dict]] = session.execute(
410                    _GET_NODE_GQL,
411                    variable_values=dict(name=self.name, tenant=self.tenant),
412                )["GetNode"]
413        self._audit = data["tenant"].get("audit") or False
414        self.config = (
415            json.loads(data["tenant"].get("config") or "{}")
416            | json.loads((data.get("app") or dict()).get("config") or "{}")
417            | json.loads(data.get("config") or "{}")
418        )
419        self._stopped = data.get("stopped")
420        if receive_message_type := data.get("receiveMessageType"):
421            self._receive_message_type = MessageType(
422                auditor=dynamic_function_loader.load(
423                    receive_message_type["auditor"]),
424                name=receive_message_type["name"],
425            )
426            if not self.stopped and self.audit:
427                self.__audit_records_queues[receive_message_type["name"]] = (
428                    _AuditRecordQueue(self._receive_message_type, self)
429                )
430        if send_message_type := data.get("sendMessageType"):
431            self._send_message_type = MessageType(
432                auditor=dynamic_function_loader.load(
433                    send_message_type["auditor"]),
434                name=send_message_type["name"],
435            )
436            if not self.stopped and self.audit:
437                self.__audit_records_queues[send_message_type["name"]] = (
438                    _AuditRecordQueue(self._send_message_type, self)
439                )
440        if self.node_type == "AppChangeReceiverNode":
441            if edge := data.get("receiveEdge"):
442                self._sources = {
443                    Edge(name=edge["source"]["name"], queue=edge["queue"])}
444            else:
445                self._sources = set()
446        else:
447            self._sources = {
448                Edge(name=edge["source"]["name"], queue=edge["queue"])
449                for edge in (data.get("receiveEdges") or list())
450            }
451        self._targets = {
452            Edge(name=edge["target"]["name"], queue=edge["queue"])
453            for edge in (data.get("sendEdges") or list())
454        }
455        if not self.stopped:
456            self.__target_message_queues = {
457                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
458            }
459
460    def stop(self) -> None:
461        """Stops the Node's processing."""
462        pass

Base class for all threading Nodes.

Node( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, name: str = None, password: str = None, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
249    def __init__(
250        self,
251        *,
252        appsync_endpoint: str = None,
253        bulk_data_acceleration: bool = False,
254        client_id: str = None,
255        name: str = None,
256        password: str = None,
257        tenant: str = None,
258        timeout: float = None,
259        user_pool_id: str = None,
260        username: str = None,
261    ) -> None:
262        super().__init__(
263            appsync_endpoint=appsync_endpoint,
264            bulk_data_acceleration=bulk_data_acceleration,
265            client_id=client_id,
266            name=name,
267            password=password,
268            tenant=tenant,
269            timeout=timeout,
270            user_pool_id=user_pool_id,
271            username=username,
272        )
273        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
274        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
275        self.__gql_client = GqlClient(
276            fetch_schema_from_transport=True,
277            transport=RequestsHTTPTransport(
278                auth=RequestsSrpAuth(cognito=self.__cognito,
279                                     http_header_prefix=""),
280                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
281            ),
282        )
283        self.__lock = RLock()
284        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
def audit_message( self, /, message: echostream_node.Message, *, extra_attributes: dict[str, typing.Any] = None, source: str = None) -> None:
294    def audit_message(
295        self,
296        /,
297        message: Message,
298        *,
299        extra_attributes: dict[str, Any] = None,
300        source: str = None,
301    ) -> None:
302        """
303        Audits the provided message. If extra_attibutes is
304        supplied, they will be added to the message's audit
305        dict. If source is provided, it will be recorded in
306        the audit.
307        """
308        if self.stopped:
309            raise RuntimeError(f"{self.name} is stopped")
310        if not self.audit:
311            return
312        extra_attributes = extra_attributes or dict()
313        message_type = message.message_type
314        record = dict(
315            datetime=datetime.now(timezone.utc).isoformat(),
316            previousTrackingIds=message.previous_tracking_ids,
317            sourceNode=source,
318            trackingId=message.tracking_id,
319        )
320        if attributes := (
321            message_type.auditor(message=message.body) | extra_attributes
322        ):
323            record["attributes"] = attributes
324        try:
325            self.__audit_records_queues[message_type.name].put_nowait(record)
326        except KeyError:
327            raise ValueError(f"Unrecognized message type {message_type.name}")

Audits the provided message. If extra_attibutes is supplied, they will be added to the message's audit dict. If source is provided, it will be recorded in the audit.

def audit_messages( self, /, messages: list[echostream_node.Message], *, extra_attributes: list[dict[str, typing.Any]] = None, source: str = None) -> None:
329    def audit_messages(
330        self,
331        /,
332        messages: list[Message],
333        *,
334        extra_attributes: list[dict[str, Any]] = None,
335        source: str = None,
336    ) -> None:
337        """
338        Audits the provided messages. If extra_attibutes is
339        supplied they will be added to the respective message's audit
340        dict and they must have the same count as messages.
341        If source is provided, it will be recorded in the audit.
342        """
343        if extra_attributes and len(extra_attributes) != len(messages):
344            raise ValueError(
345                "messages and extra_attributes must have the same number of items"
346            )
347        for message, attributes in zip(messages, extra_attributes):
348            self.audit_message(
349                message, extra_attributes=attributes, source=source)

Audits the provided messages. If extra_attibutes is supplied they will be added to the respective message's audit dict and they must have the same count as messages. If source is provided, it will be recorded in the audit.

def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
351    def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
352        """
353        Posts data as bulk data and returns a GET URL for data retrieval.
354        Normally this returned URL will be used as a "ticket" in messages
355        that require bulk data.
356        """
357        return self.__bulk_data_storage_queue.get().handle_bulk_data(data)

Posts data as bulk data and returns a GET URL for data retrieval. Normally this returned URL will be used as a "ticket" in messages that require bulk data.

def handle_received_message(self, *, message: echostream_node.Message, source: str) -> None:
359    def handle_received_message(self, *, message: Message, source: str) -> None:
360        """
361        Callback called when a message is received. Subclasses that receive messages
362        should override this method.
363        """
364        pass

Callback called when a message is received. Subclasses that receive messages should override this method.

def join(self) -> None:
366    def join(self) -> None:
367        """
368        Joins the calling thread with this Node. Will block until all
369        join conditions are satified.
370        """
371        for target_message_queue in self.__target_message_queues.values():
372            target_message_queue.join()
373        for audit_records_queue in self.__audit_records_queues.values():
374            audit_records_queue.join()

Joins the calling thread with this Node. Will block until all join conditions are satified.

def send_message( self, /, message: echostream_node.Message, *, targets: set[echostream_node.Edge] = None) -> None:
376    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
377        """
378        Send the message to the specified targets. If no targets are specified
379        the message will be sent to all targets.
380        """
381        self.send_messages([message], targets=targets)

Send the message to the specified targets. If no targets are specified the message will be sent to all targets.

def send_messages( self, /, messages: list[echostream_node.Message], *, targets: set[echostream_node.Edge] = None) -> None:
383    def send_messages(
384        self, /, messages: list[Message], *, targets: set[Edge] = None
385    ) -> None:
386        """
387        Send the messages to the specified targets. If no targets are specified
388        the messages will be sent to all targets.
389        """
390        if self.stopped:
391            raise RuntimeError(f"{self.name} is stopped")
392        if messages:
393            for target in targets or self.targets:
394                if target_message_queue := self.__target_message_queues.get(
395                    target.name
396                ):
397                    for message in messages:
398                        target_message_queue.put_nowait(message)
399                else:
400                    getLogger().warning(f"Target {target.name} does not exist")

Send the messages to the specified targets. If no targets are specified the messages will be sent to all targets.

def start(self) -> None:
402    def start(self) -> None:
403        """
404        Starts this Node. Must be called prior to any other usage.
405        """
406        getLogger().info(f"Starting Node {self.name}")
407        with self._lock:
408            with self._gql_client as session:
409                data: dict[str, Union[str, dict]] = session.execute(
410                    _GET_NODE_GQL,
411                    variable_values=dict(name=self.name, tenant=self.tenant),
412                )["GetNode"]
413        self._audit = data["tenant"].get("audit") or False
414        self.config = (
415            json.loads(data["tenant"].get("config") or "{}")
416            | json.loads((data.get("app") or dict()).get("config") or "{}")
417            | json.loads(data.get("config") or "{}")
418        )
419        self._stopped = data.get("stopped")
420        if receive_message_type := data.get("receiveMessageType"):
421            self._receive_message_type = MessageType(
422                auditor=dynamic_function_loader.load(
423                    receive_message_type["auditor"]),
424                name=receive_message_type["name"],
425            )
426            if not self.stopped and self.audit:
427                self.__audit_records_queues[receive_message_type["name"]] = (
428                    _AuditRecordQueue(self._receive_message_type, self)
429                )
430        if send_message_type := data.get("sendMessageType"):
431            self._send_message_type = MessageType(
432                auditor=dynamic_function_loader.load(
433                    send_message_type["auditor"]),
434                name=send_message_type["name"],
435            )
436            if not self.stopped and self.audit:
437                self.__audit_records_queues[send_message_type["name"]] = (
438                    _AuditRecordQueue(self._send_message_type, self)
439                )
440        if self.node_type == "AppChangeReceiverNode":
441            if edge := data.get("receiveEdge"):
442                self._sources = {
443                    Edge(name=edge["source"]["name"], queue=edge["queue"])}
444            else:
445                self._sources = set()
446        else:
447            self._sources = {
448                Edge(name=edge["source"]["name"], queue=edge["queue"])
449                for edge in (data.get("receiveEdges") or list())
450            }
451        self._targets = {
452            Edge(name=edge["target"]["name"], queue=edge["queue"])
453            for edge in (data.get("sendEdges") or list())
454        }
455        if not self.stopped:
456            self.__target_message_queues = {
457                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
458            }

Starts this Node. Must be called prior to any other usage.

def stop(self) -> None:
460    def stop(self) -> None:
461        """Stops the Node's processing."""
462        pass

Stops the Node's processing.

class AppNode(Node):
605class AppNode(Node):
606    """
607    A daemon Node intended to be used as either a stand-alone application
608    or as a part of a larger application.
609    """
610
611    def __init__(
612        self,
613        *,
614        appsync_endpoint: str = None,
615        bulk_data_acceleration: bool = False,
616        client_id: str = None,
617        executor: Executor = None,
618        name: str = None,
619        password: str = None,
620        tenant: str = None,
621        timeout: float = None,
622        user_pool_id: str = None,
623        username: str = None,
624    ) -> None:
625        super().__init__(
626            appsync_endpoint=appsync_endpoint,
627            bulk_data_acceleration=bulk_data_acceleration,
628            client_id=client_id,
629            name=name,
630            password=password,
631            tenant=tenant,
632            timeout=timeout,
633            user_pool_id=user_pool_id,
634            username=username,
635        )
636        self.__executor = executor
637        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
638        self.__stop = Event()
639
640    @property
641    def _executor(self) -> Executor:
642        return self.__executor
643
644    def join(self) -> None:
645        """
646        Method to join all the app node receivers so that main thread can wait for their execution to complete.
647        """
648        self.__stop.wait()
649        for app_node_receiver in self.__source_message_receivers:
650            app_node_receiver.join()
651        super().join()
652
653    def start(self) -> None:
654        """
655        Calls start of Node class
656        """
657        super().start()
658        self.__stop.clear()
659        if not self.stopped:
660            self.__source_message_receivers = [
661                _SourceMessageReceiver(edge, self) for edge in self._sources
662            ]
663
664    def start_and_run_forever(self) -> None:
665        """Will start this Node and run until stop is called"""
666        self.start()
667        self.join()
668
669    def stop(self) -> None:
670        """
671        Stops the Node gracefully
672        """
673        self.__stop.set()
674        for app_node_receiver in self.__source_message_receivers:
675            app_node_receiver.stop()

A daemon Node intended to be used as either a stand-alone application or as a part of a larger application.

AppNode( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, executor: concurrent.futures._base.Executor = None, name: str = None, password: str = None, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
611    def __init__(
612        self,
613        *,
614        appsync_endpoint: str = None,
615        bulk_data_acceleration: bool = False,
616        client_id: str = None,
617        executor: Executor = None,
618        name: str = None,
619        password: str = None,
620        tenant: str = None,
621        timeout: float = None,
622        user_pool_id: str = None,
623        username: str = None,
624    ) -> None:
625        super().__init__(
626            appsync_endpoint=appsync_endpoint,
627            bulk_data_acceleration=bulk_data_acceleration,
628            client_id=client_id,
629            name=name,
630            password=password,
631            tenant=tenant,
632            timeout=timeout,
633            user_pool_id=user_pool_id,
634            username=username,
635        )
636        self.__executor = executor
637        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
638        self.__stop = Event()
def join(self) -> None:
644    def join(self) -> None:
645        """
646        Method to join all the app node receivers so that main thread can wait for their execution to complete.
647        """
648        self.__stop.wait()
649        for app_node_receiver in self.__source_message_receivers:
650            app_node_receiver.join()
651        super().join()

Method to join all the app node receivers so that main thread can wait for their execution to complete.

def start(self) -> None:
653    def start(self) -> None:
654        """
655        Calls start of Node class
656        """
657        super().start()
658        self.__stop.clear()
659        if not self.stopped:
660            self.__source_message_receivers = [
661                _SourceMessageReceiver(edge, self) for edge in self._sources
662            ]

Calls start of Node class

def start_and_run_forever(self) -> None:
664    def start_and_run_forever(self) -> None:
665        """Will start this Node and run until stop is called"""
666        self.start()
667        self.join()

Will start this Node and run until stop is called

def stop(self) -> None:
669    def stop(self) -> None:
670        """
671        Stops the Node gracefully
672        """
673        self.__stop.set()
674        for app_node_receiver in self.__source_message_receivers:
675            app_node_receiver.stop()

Stops the Node gracefully

class LambdaNode(Node):
678class LambdaNode(Node):
679    """
680    A Node class intended to be implemented in an AWS Lambda function.
681    Nodes that inherit from this class are automatically started on
682    creation.
683    """
684
685    def __init__(
686        self,
687        *,
688        appsync_endpoint: str = None,
689        bulk_data_acceleration: bool = False,
690        client_id: str = None,
691        concurrent_processing: bool = False,
692        name: str = None,
693        password: str = None,
694        report_batch_item_failures: bool = False,
695        tenant: str = None,
696        timeout: float = None,
697        user_pool_id: str = None,
698        username: str = None,
699    ) -> None:
700        super().__init__(
701            appsync_endpoint=appsync_endpoint,
702            bulk_data_acceleration=bulk_data_acceleration,
703            client_id=client_id,
704            name=name,
705            password=password,
706            tenant=tenant,
707            timeout=timeout or 0.01,
708            user_pool_id=user_pool_id,
709            username=username,
710        )
711        self.start()
712        signal(SIGTERM, self._shutdown_handler)
713        self.__executor: Executor = (
714            ThreadPoolExecutor() if concurrent_processing else None
715        )
716        self.__queue_name_to_source = {
717            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
718        }
719        self.__report_batch_item_failures = report_batch_item_failures
720
721    def _get_source(self, queue_arn: str) -> str:
722        return self.__queue_name_to_source[queue_arn.split(":")[-1:][0]]
723
724    def _shutdown_handler(self, signum: int, frame: object) -> None:
725        getLogger().info("Received SIGTERM, shutting down")
726        self.join()
727        getLogger().info("Shutdown complete")
728
729    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
730        """
731        Handles the AWS Lambda event passed into the containing
732        AWS Lambda function during invocation.
733
734        This is intended to be the only called method in your
735        containing AWS Lambda function.
736        """
737        records: LambdaSqsRecords = None
738        if not (records := event.get("Records")):
739            getLogger().warning(f"No Records found in event {event}")
740            return
741
742        source = self._get_source(records[0]["eventSourceARN"])
743        getLogger().info(f"Received {len(records)} messages from {source}")
744        batch_item_failures: list[str] = (
745            [record["messageId"] for record in records]
746            if self.__report_batch_item_failures
747            else None
748        )
749
750        def handle_received_message(message: Message, message_id: str) -> None:
751            try:
752                self.handle_received_message(message=message, source=source)
753            except Exception:
754                if not self.__report_batch_item_failures:
755                    raise
756                getLogger().exception(
757                    f"Error handling recevied message for {source}")
758            else:
759                if self.__report_batch_item_failures:
760                    batch_item_failures.remove(message_id)
761
762        message_handlers = [
763            partial(
764                handle_received_message,
765                Message(
766                    body=record["body"],
767                    group_id=record["attributes"]["MessageGroupId"],
768                    message_type=self.receive_message_type,
769                    previous_tracking_ids=record["messageAttributes"]
770                    .get("prevTrackingIds", {})
771                    .get("stringValue"),
772                    tracking_id=record["messageAttributes"]
773                    .get("trackingId", {})
774                    .get("stringValue"),
775                ),
776                record["messageId"],
777            )
778            for record in records
779        ]
780
781        if executor := self.__executor:
782            for future in as_completed(
783                [
784                    executor.submit(message_handler)
785                    for message_handler in message_handlers
786                ]
787            ):
788                if exception := future.exception():
789                    raise exception
790        else:
791            for message_handler in message_handlers:
792                message_handler()
793        self.join()
794        if self.__report_batch_item_failures and batch_item_failures:
795            return dict(
796                batchItemFailures=[
797                    dict(itemIdentifier=message_id)
798                    for message_id in batch_item_failures
799                ]
800            )

A Node class intended to be implemented in an AWS Lambda function. Nodes that inherit from this class are automatically started on creation.

LambdaNode( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, concurrent_processing: bool = False, name: str = None, password: str = None, report_batch_item_failures: bool = False, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
685    def __init__(
686        self,
687        *,
688        appsync_endpoint: str = None,
689        bulk_data_acceleration: bool = False,
690        client_id: str = None,
691        concurrent_processing: bool = False,
692        name: str = None,
693        password: str = None,
694        report_batch_item_failures: bool = False,
695        tenant: str = None,
696        timeout: float = None,
697        user_pool_id: str = None,
698        username: str = None,
699    ) -> None:
700        super().__init__(
701            appsync_endpoint=appsync_endpoint,
702            bulk_data_acceleration=bulk_data_acceleration,
703            client_id=client_id,
704            name=name,
705            password=password,
706            tenant=tenant,
707            timeout=timeout or 0.01,
708            user_pool_id=user_pool_id,
709            username=username,
710        )
711        self.start()
712        signal(SIGTERM, self._shutdown_handler)
713        self.__executor: Executor = (
714            ThreadPoolExecutor() if concurrent_processing else None
715        )
716        self.__queue_name_to_source = {
717            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
718        }
719        self.__report_batch_item_failures = report_batch_item_failures
def handle_event( self, event: Union[bool, dict, float, int, list, str, tuple, NoneType]) -> dict[str, list[dict[str, str]]]:
729    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
730        """
731        Handles the AWS Lambda event passed into the containing
732        AWS Lambda function during invocation.
733
734        This is intended to be the only called method in your
735        containing AWS Lambda function.
736        """
737        records: LambdaSqsRecords = None
738        if not (records := event.get("Records")):
739            getLogger().warning(f"No Records found in event {event}")
740            return
741
742        source = self._get_source(records[0]["eventSourceARN"])
743        getLogger().info(f"Received {len(records)} messages from {source}")
744        batch_item_failures: list[str] = (
745            [record["messageId"] for record in records]
746            if self.__report_batch_item_failures
747            else None
748        )
749
750        def handle_received_message(message: Message, message_id: str) -> None:
751            try:
752                self.handle_received_message(message=message, source=source)
753            except Exception:
754                if not self.__report_batch_item_failures:
755                    raise
756                getLogger().exception(
757                    f"Error handling recevied message for {source}")
758            else:
759                if self.__report_batch_item_failures:
760                    batch_item_failures.remove(message_id)
761
762        message_handlers = [
763            partial(
764                handle_received_message,
765                Message(
766                    body=record["body"],
767                    group_id=record["attributes"]["MessageGroupId"],
768                    message_type=self.receive_message_type,
769                    previous_tracking_ids=record["messageAttributes"]
770                    .get("prevTrackingIds", {})
771                    .get("stringValue"),
772                    tracking_id=record["messageAttributes"]
773                    .get("trackingId", {})
774                    .get("stringValue"),
775                ),
776                record["messageId"],
777            )
778            for record in records
779        ]
780
781        if executor := self.__executor:
782            for future in as_completed(
783                [
784                    executor.submit(message_handler)
785                    for message_handler in message_handlers
786                ]
787            ):
788                if exception := future.exception():
789                    raise exception
790        else:
791            for message_handler in message_handlers:
792                message_handler()
793        self.join()
794        if self.__report_batch_item_failures and batch_item_failures:
795            return dict(
796                batchItemFailures=[
797                    dict(itemIdentifier=message_id)
798                    for message_id in batch_item_failures
799                ]
800            )

Handles the AWS Lambda event passed into the containing AWS Lambda function during invocation.

This is intended to be the only called method in your containing AWS Lambda function.