echostream_node.threading

  1from __future__ import annotations
  2
  3from concurrent.futures import Executor, ThreadPoolExecutor, as_completed, wait
  4from datetime import datetime, timezone
  5from functools import partial
  6from gzip import GzipFile, compress
  7from io import BytesIO
  8from os import environ
  9from queue import Empty, Queue
 10from signal import SIGTERM, signal
 11from threading import Event, RLock, Thread
 12from time import sleep
 13from typing import TYPE_CHECKING, Any, BinaryIO, Generator, Union
 14
 15import dynamic_function_loader
 16import simplejson as json
 17from aws_error_utils import catch_aws_error
 18from gql import Client as GqlClient
 19from gql.transport.requests import RequestsHTTPTransport
 20from httpx import Client as HttpxClient
 21from httpx_auth import AWS4Auth
 22from pycognito.utils import RequestsSrpAuth
 23
 24from .. import _GET_BULK_DATA_STORAGE_GQL, _GET_NODE_GQL, BatchItemFailures
 25from .. import BulkDataStorage as BaseBulkDataStorage
 26from .. import Edge, LambdaEvent, LambdaSqsRecords, Message, MessageType
 27from .. import Node as BaseNode
 28from .. import PresignedPost, getLogger
 29
 30if TYPE_CHECKING:
 31    from mypy_boto3_sqs.type_defs import (
 32        DeleteMessageBatchRequestEntryTypeDef,
 33        SendMessageBatchRequestEntryTypeDef,
 34    )
 35else:
 36    DeleteMessageBatchRequestEntryTypeDef = dict
 37    SendMessageBatchRequestEntryTypeDef = dict
 38
 39
 40class _Queue(Queue):
 41    def tasks_done(self, count: int) -> None:
 42        with self.all_tasks_done:
 43            unfinished = self.unfinished_tasks - count
 44            if unfinished < 0:
 45                raise ValueError("count larger than unfinished tasks")
 46            if unfinished == 0:
 47                self.all_tasks_done.notify_all()
 48            self.unfinished_tasks = unfinished
 49
 50
 51class _AuditRecordQueue(_Queue):
 52    def __init__(self, message_type: MessageType, node: Node) -> None:
 53        super().__init__()
 54
 55        def sender() -> None:
 56            with HttpxClient() as client:
 57                while True:
 58                    batch: list[dict] = list()
 59                    while len(batch) < 500:
 60                        try:
 61                            batch.append(self.get(timeout=node.timeout))
 62                        except Empty:
 63                            break
 64                    if not batch:
 65                        continue
 66                    credentials = (
 67                        node._session.get_credentials().get_frozen_credentials()
 68                    )
 69                    auth = AWS4Auth(
 70                        access_id=credentials.access_key,
 71                        region=node._session.region_name,
 72                        secret_key=credentials.secret_key,
 73                        service="lambda",
 74                        security_token=credentials.token,
 75                    )
 76                    url = node._audit_records_endpoint
 77                    post_args = dict(
 78                        auth=auth,
 79                        url=f"{url}{'' if url.endswith(
 80                            '/') else '/'}{node.name}",
 81                    )
 82                    body = dict(
 83                        messageType=message_type.name,
 84                        auditRecords=batch,
 85                    )
 86                    if len(batch) <= 10:
 87                        post_args["json"] = body
 88                    else:
 89                        post_args["content"] = compress(
 90                            json.dumps(body, separators=(",", ":")).encode()
 91                        )
 92                        post_args["headers"] = {
 93                            "Content-Encoding": "gzip",
 94                            "Content-Type": "application/json",
 95                        }
 96                    try:
 97                        response = client.post(**post_args)
 98                        response.raise_for_status()
 99                        response.close()
100                    except Exception:
101                        getLogger().exception("Error creating audit records")
102                    finally:
103                        self.tasks_done(len(batch))
104
105        Thread(daemon=True, name=f"AuditRecordsSender", target=sender).start()
106
107    def get(self, block: bool = True, timeout: float = None) -> dict:
108        return super().get(block=block, timeout=timeout)
109
110
111class _BulkDataStorage(BaseBulkDataStorage):
112    def __init__(
113        self,
114        bulk_data_storage: dict[str, Union[str, PresignedPost]],
115        client: HttpxClient,
116    ) -> None:
117        super().__init__(bulk_data_storage)
118        self.__client = client
119
120    def handle_bulk_data(self, data: Union[bytearray, bytes, BinaryIO]) -> str:
121        if isinstance(data, BinaryIO):
122            data = data.read()
123        with BytesIO() as buffer:
124            with GzipFile(mode="wb", fileobj=buffer) as gzf:
125                gzf.write(data)
126            buffer.seek(0)
127            response = self.__client.post(
128                self.presigned_post.url,
129                data=self.presigned_post.fields,
130                files=dict(file=("bulk_data", buffer)),
131            )
132            response.raise_for_status()
133            response.close()
134        return self.presigned_get
135
136
137class _BulkDataStorageQueue(Queue):
138    def __init__(self, node: Node) -> None:
139        super().__init__()
140        self.__fill = Event()
141
142        def filler() -> None:
143            with HttpxClient() as client:
144                while True:
145                    self.__fill.wait()
146                    try:
147                        with node._lock:
148                            with node._gql_client as session:
149                                bulk_data_storages: list[dict] = session.execute(
150                                    _GET_BULK_DATA_STORAGE_GQL,
151                                    variable_values={
152                                        "tenant": node.tenant,
153                                        "useAccelerationEndpoint": node.bulk_data_acceleration,
154                                    },
155                                )["GetBulkDataStorage"]
156                    except Exception:
157                        getLogger().exception("Error getting bulk data storage")
158                    else:
159                        for bulk_data_storage in bulk_data_storages:
160                            self.put_nowait(_BulkDataStorage(
161                                bulk_data_storage, client))
162                    self.__fill.clear()
163
164        Thread(daemon=True, name="BulkDataStorageQueueFiller",
165               target=filler).start()
166
167    def get(self, block: bool = True, timeout: float = None) -> _BulkDataStorage:
168        if self.qsize() < 20:
169            self.__fill.set()
170        bulk_data_storage: _BulkDataStorage = super().get(block=block, timeout=timeout)
171        return (
172            bulk_data_storage
173            if not bulk_data_storage.expired
174            else self.get(block=block, timeout=timeout)
175        )
176
177
178class _TargetMessageQueue(_Queue):
179    def __init__(self, node: Node, edge: Edge) -> None:
180        super().__init__()
181
182        def batcher() -> Generator[
183            list[SendMessageBatchRequestEntryTypeDef],
184            None,
185            list[SendMessageBatchRequestEntryTypeDef],
186        ]:
187            batch: list[SendMessageBatchRequestEntryTypeDef] = list()
188            batch_length = 0
189            id = 0
190            while True:
191                try:
192                    message = self.get(timeout=node.timeout)
193                    if batch_length + len(message) > 262144:
194                        yield batch
195                        batch = list()
196                        batch_length = 0
197                        id = 0
198                    batch.append(
199                        SendMessageBatchRequestEntryTypeDef(
200                            Id=str(id), **message._sqs_message(node)
201                        )
202                    )
203                    if len(batch) == 10:
204                        yield batch
205                        batch = list()
206                        batch_length = 0
207                        id = 0
208                    id += 1
209                    batch_length += len(message)
210                except Empty:
211                    if batch:
212                        yield batch
213                    batch = list()
214                    batch_length = 0
215                    id = 0
216
217        def sender() -> None:
218            for entries in batcher():
219                try:
220                    response = node._sqs_client.send_message_batch(
221                        Entries=entries, QueueUrl=edge.queue
222                    )
223                    for failed in response.get("Failed", list()):
224                        id = failed.pop("Id")
225                        getLogger().error(
226                            f"Unable to send message {entries[id]} to {
227                                edge.name}, reason {failed}"
228                        )
229                except Exception:
230                    getLogger().exception(
231                        f"Error sending messages to {edge.name}")
232                finally:
233                    self.tasks_done(len(entries))
234
235        Thread(
236            daemon=True, name=f"TargetMessageSender({edge.name})", target=sender
237        ).start()
238
239    def get(self, block: bool = True, timeout: float = None) -> Message:
240        return super().get(block=block, timeout=timeout)
241
242
243class Node(BaseNode):
244    """
245    Base class for all threading Nodes.
246    """
247
248    def __init__(
249        self,
250        *,
251        appsync_endpoint: str = None,
252        bulk_data_acceleration: bool = False,
253        client_id: str = None,
254        name: str = None,
255        password: str = None,
256        tenant: str = None,
257        timeout: float = None,
258        user_pool_id: str = None,
259        username: str = None,
260    ) -> None:
261        super().__init__(
262            appsync_endpoint=appsync_endpoint,
263            bulk_data_acceleration=bulk_data_acceleration,
264            client_id=client_id,
265            name=name,
266            password=password,
267            tenant=tenant,
268            timeout=timeout,
269            user_pool_id=user_pool_id,
270            username=username,
271        )
272        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
273        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
274        self.__gql_client = GqlClient(
275            fetch_schema_from_transport=True,
276            transport=RequestsHTTPTransport(
277                auth=RequestsSrpAuth(cognito=self.__cognito,
278                                     http_header_prefix=""),
279                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
280            ),
281        )
282        self.__lock = RLock()
283        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
284
285    @property
286    def _gql_client(self) -> GqlClient:
287        return self.__gql_client
288
289    @property
290    def _lock(self) -> RLock:
291        return self.__lock
292
293    def audit_message(
294        self,
295        /,
296        message: Message,
297        *,
298        extra_attributes: dict[str, Any] = None,
299        source: str = None,
300    ) -> None:
301        """
302        Audits the provided message. If extra_attibutes is
303        supplied, they will be added to the message's audit
304        dict. If source is provided, it will be recorded in
305        the audit.
306        """
307        if self.stopped:
308            raise RuntimeError(f"{self.name} is stopped")
309        if not self.audit:
310            return
311        extra_attributes = extra_attributes or dict()
312        message_type = message.message_type
313        record = dict(
314            datetime=datetime.now(timezone.utc).isoformat(),
315            previousTrackingIds=message.previous_tracking_ids,
316            sourceNode=source,
317            trackingId=message.tracking_id,
318        )
319        if attributes := (
320            message_type.auditor(message=message.body) | extra_attributes
321        ):
322            record["attributes"] = attributes
323        try:
324            self.__audit_records_queues[message_type.name].put_nowait(record)
325        except KeyError:
326            raise ValueError(f"Unrecognized message type {message_type.name}")
327
328    def audit_messages(
329        self,
330        /,
331        messages: list[Message],
332        *,
333        extra_attributes: list[dict[str, Any]] = None,
334        source: str = None,
335    ) -> None:
336        """
337        Audits the provided messages. If extra_attibutes is
338        supplied they will be added to the respective message's audit
339        dict and they must have the same count as messages.
340        If source is provided, it will be recorded in the audit.
341        """
342        if extra_attributes:
343            if len(extra_attributes) != len(messages):
344                raise ValueError(
345                    "messages and extra_attributes must have the same number of items"
346                )
347        else:
348            extra_attributes = [dict()] * len(messages)
349        for message, attributes in zip(messages, extra_attributes):
350            self.audit_message(
351                message, extra_attributes=attributes, source=source)
352
353    def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
354        """
355        Posts data as bulk data and returns a GET URL for data retrieval.
356        Normally this returned URL will be used as a "ticket" in messages
357        that require bulk data.
358        """
359        return self.__bulk_data_storage_queue.get().handle_bulk_data(data)
360
361    def handle_received_message(self, *, message: Message, source: str) -> None:
362        """
363        Callback called when a message is received. Subclasses that receive messages
364        should override this method.
365        """
366        pass
367
368    def join(self) -> None:
369        """
370        Joins the calling thread with this Node. Will block until all
371        join conditions are satified.
372        """
373        for target_message_queue in self.__target_message_queues.values():
374            target_message_queue.join()
375        for audit_records_queue in self.__audit_records_queues.values():
376            audit_records_queue.join()
377
378    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
379        """
380        Send the message to the specified targets. If no targets are specified
381        the message will be sent to all targets.
382        """
383        self.send_messages([message], targets=targets)
384
385    def send_messages(
386        self, /, messages: list[Message], *, targets: set[Edge] = None
387    ) -> None:
388        """
389        Send the messages to the specified targets. If no targets are specified
390        the messages will be sent to all targets.
391        """
392        if self.stopped:
393            raise RuntimeError(f"{self.name} is stopped")
394        if messages:
395            for target in targets or self.targets:
396                if target_message_queue := self.__target_message_queues.get(
397                    target.name
398                ):
399                    for message in messages:
400                        target_message_queue.put_nowait(message)
401                else:
402                    getLogger().warning(f"Target {target.name} does not exist")
403
404    def start(self) -> None:
405        """
406        Starts this Node. Must be called prior to any other usage.
407        """
408        getLogger().info(f"Starting Node {self.name}")
409        with self._lock:
410            with self._gql_client as session:
411                data: dict[str, Union[str, dict]] = session.execute(
412                    _GET_NODE_GQL,
413                    variable_values=dict(name=self.name, tenant=self.tenant),
414                )["GetNode"]
415        self._audit = data["tenant"].get("audit") or False
416        self.config = (
417            json.loads(data["tenant"].get("config") or "{}")
418            | json.loads((data.get("app") or dict()).get("config") or "{}")
419            | json.loads(data.get("config") or "{}")
420        )
421        self._stopped = data.get("stopped")
422        if receive_message_type := data.get("receiveMessageType"):
423            self._receive_message_type = MessageType(
424                auditor=dynamic_function_loader.load(
425                    receive_message_type["auditor"]),
426                name=receive_message_type["name"],
427            )
428            if not self.stopped and self.audit:
429                self.__audit_records_queues[receive_message_type["name"]] = (
430                    _AuditRecordQueue(self._receive_message_type, self)
431                )
432        if send_message_type := data.get("sendMessageType"):
433            self._send_message_type = MessageType(
434                auditor=dynamic_function_loader.load(
435                    send_message_type["auditor"]),
436                name=send_message_type["name"],
437            )
438            if not self.stopped and self.audit:
439                self.__audit_records_queues[send_message_type["name"]] = (
440                    _AuditRecordQueue(self._send_message_type, self)
441                )
442        if self.node_type == "AppChangeReceiverNode":
443            if edge := data.get("receiveEdge"):
444                self._sources = {
445                    Edge(name=edge["source"]["name"], queue=edge["queue"])}
446            else:
447                self._sources = set()
448        else:
449            self._sources = {
450                Edge(name=edge["source"]["name"], queue=edge["queue"])
451                for edge in (data.get("receiveEdges") or list())
452            }
453        self._targets = {
454            Edge(name=edge["target"]["name"], queue=edge["queue"])
455            for edge in (data.get("sendEdges") or list())
456        }
457        if not self.stopped:
458            self.__target_message_queues = {
459                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
460            }
461
462    def stop(self) -> None:
463        """Stops the Node's processing."""
464        pass
465
466
467class _DeleteMessageQueue(_Queue):
468    def __init__(self, edge: Edge, node: AppNode) -> None:
469        super().__init__()
470
471        def deleter() -> None:
472            while True:
473                receipt_handles: list[str] = list()
474                while len(receipt_handles) < 10:
475                    try:
476                        receipt_handles.append(self.get(timeout=node.timeout))
477                    except Empty:
478                        break
479                if not receipt_handles:
480                    continue
481                try:
482                    response = node._sqs_client.delete_message_batch(
483                        Entries=[
484                            DeleteMessageBatchRequestEntryTypeDef(
485                                Id=str(id), ReceiptHandle=receipt_handle
486                            )
487                            for id, receipt_handle in enumerate(receipt_handles)
488                        ],
489                        QueueUrl=edge.queue,
490                    )
491                    for failed in response.get("Failed", list()):
492                        id = failed.pop("Id")
493                        getLogger().error(
494                            f"Unable to delete message {receipt_handles[id]} from {
495                                edge.name}, reason {failed}"
496                        )
497                except Exception:
498                    getLogger().exception(
499                        f"Error deleting messages from {edge.name}")
500                finally:
501                    self.tasks_done(len(receipt_handles))
502
503        Thread(
504            daemon=True, name=f"SourceMessageDeleter({edge.name})", target=deleter
505        ).start()
506
507    def get(self, block: bool = True, timeout: float = None) -> str:
508        return super().get(block=block, timeout=timeout)
509
510
511class _SourceMessageReceiver(Thread):
512    def __init__(self, edge: Edge, node: AppNode) -> None:
513        self.__continue = Event()
514        self.__continue.set()
515        self.__delete_message_queue = _DeleteMessageQueue(edge, node)
516
517        def handle_received_message(message: Message, receipt_handle: str) -> bool:
518            try:
519                node.handle_received_message(message=message, source=edge.name)
520            except Exception:
521                getLogger().exception(
522                    f"Error handling recevied message for {edge.name}"
523                )
524                return False
525            else:
526                self.__delete_message_queue.put_nowait(receipt_handle)
527            return True
528
529        def receive() -> None:
530            self.__continue.wait()
531            getLogger().info(f"Receiving messages from {edge.name}")
532            while self.__continue.is_set():
533                try:
534                    response = node._sqs_client.receive_message(
535                        AttributeNames=["All"],
536                        MaxNumberOfMessages=10,
537                        MessageAttributeNames=["All"],
538                        QueueUrl=edge.queue,
539                        WaitTimeSeconds=20,
540                    )
541                except catch_aws_error("AWS.SimpleQueueService.NonExistentQueue"):
542                    getLogger().warning(
543                        f"Queue {edge.queue} does not exist, exiting")
544                    break
545                except Exception:
546                    getLogger().exception(
547                        f"Error receiving messages from {edge.name}, retrying"
548                    )
549                    sleep(20)
550                else:
551                    if not (sqs_messages := response.get("Messages")):
552                        continue
553                    getLogger().info(
554                        f"Received {len(sqs_messages)} from {edge.name}")
555
556                    message_handlers = [
557                        partial(
558                            handle_received_message,
559                            Message(
560                                body=sqs_message["Body"],
561                                group_id=sqs_message["Attributes"]["MessageGroupId"],
562                                message_type=node.receive_message_type,
563                                tracking_id=sqs_message["MessageAttributes"]
564                                .get("trackingId", {})
565                                .get("StringValue"),
566                                previous_tracking_ids=sqs_message["MessageAttributes"]
567                                .get("prevTrackingIds", {})
568                                .get("StringValue"),
569                            ),
570                            sqs_message["ReceiptHandle"],
571                        )
572                        for sqs_message in sqs_messages
573                    ]
574
575                    def handle_received_messages() -> None:
576                        if executor := node._executor:
577                            wait(
578                                [
579                                    executor.submit(message_handler)
580                                    for message_handler in message_handlers
581                                ]
582                            )
583                        else:
584                            for message_handler in message_handlers:
585                                if not message_handler():
586                                    break
587
588                    Thread(
589                        name="handle_received_messages",
590                        target=handle_received_messages,
591                    ).start()
592
593            getLogger().info(f"Stopping receiving messages from {edge.name}")
594
595        super().__init__(
596            name=f"SourceMessageReceiver({edge.name})", target=receive)
597        self.start()
598
599    def join(self) -> None:
600        super().join()
601        self.__delete_message_queue.join()
602
603    def stop(self) -> None:
604        self.__continue.clear()
605
606
607class AppNode(Node):
608    """
609    A daemon Node intended to be used as either a stand-alone application
610    or as a part of a larger application.
611    """
612
613    def __init__(
614        self,
615        *,
616        appsync_endpoint: str = None,
617        bulk_data_acceleration: bool = False,
618        client_id: str = None,
619        executor: Executor = None,
620        name: str = None,
621        password: str = None,
622        tenant: str = None,
623        timeout: float = None,
624        user_pool_id: str = None,
625        username: str = None,
626    ) -> None:
627        super().__init__(
628            appsync_endpoint=appsync_endpoint,
629            bulk_data_acceleration=bulk_data_acceleration,
630            client_id=client_id,
631            name=name,
632            password=password,
633            tenant=tenant,
634            timeout=timeout,
635            user_pool_id=user_pool_id,
636            username=username,
637        )
638        self.__executor = executor
639        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
640        self.__stop = Event()
641
642    @property
643    def _executor(self) -> Executor:
644        return self.__executor
645
646    def join(self) -> None:
647        """
648        Method to join all the app node receivers so that main thread can wait for their execution to complete.
649        """
650        self.__stop.wait()
651        for app_node_receiver in self.__source_message_receivers:
652            app_node_receiver.join()
653        super().join()
654
655    def start(self) -> None:
656        """
657        Calls start of Node class
658        """
659        super().start()
660        self.__stop.clear()
661        if not self.stopped:
662            self.__source_message_receivers = [
663                _SourceMessageReceiver(edge, self) for edge in self._sources
664            ]
665
666    def start_and_run_forever(self) -> None:
667        """Will start this Node and run until stop is called"""
668        self.start()
669        self.join()
670
671    def stop(self) -> None:
672        """
673        Stops the Node gracefully
674        """
675        self.__stop.set()
676        for app_node_receiver in self.__source_message_receivers:
677            app_node_receiver.stop()
678
679
680class LambdaNode(Node):
681    """
682    A Node class intended to be implemented in an AWS Lambda function.
683    Nodes that inherit from this class are automatically started on
684    creation.
685    """
686
687    def __init__(
688        self,
689        *,
690        appsync_endpoint: str = None,
691        bulk_data_acceleration: bool = False,
692        client_id: str = None,
693        concurrent_processing: bool = False,
694        name: str = None,
695        password: str = None,
696        report_batch_item_failures: bool = False,
697        tenant: str = None,
698        timeout: float = None,
699        user_pool_id: str = None,
700        username: str = None,
701    ) -> None:
702        super().__init__(
703            appsync_endpoint=appsync_endpoint,
704            bulk_data_acceleration=bulk_data_acceleration,
705            client_id=client_id,
706            name=name,
707            password=password,
708            tenant=tenant,
709            timeout=timeout or 0.01,
710            user_pool_id=user_pool_id,
711            username=username,
712        )
713        self.start()
714        signal(SIGTERM, self._shutdown_handler)
715        self.__executor: Executor = (
716            ThreadPoolExecutor() if concurrent_processing else None
717        )
718        self.__queue_name_to_source = {
719            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
720        }
721        self.__report_batch_item_failures = report_batch_item_failures
722
723    def _get_source(self, queue_arn: str) -> str:
724        return self.__queue_name_to_source[queue_arn.split(":")[-1:][0]]
725
726    def _shutdown_handler(self, signum: int, frame: object) -> None:
727        getLogger().info("Received SIGTERM, shutting down")
728        self.join()
729        getLogger().info("Shutdown complete")
730
731    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
732        """
733        Handles the AWS Lambda event passed into the containing
734        AWS Lambda function during invocation.
735
736        This is intended to be the only called method in your
737        containing AWS Lambda function.
738        """
739        records: LambdaSqsRecords = None
740        if not (records := event.get("Records")):
741            getLogger().warning(f"No Records found in event {event}")
742            return
743
744        source = self._get_source(records[0]["eventSourceARN"])
745        getLogger().info(f"Received {len(records)} messages from {source}")
746        batch_item_failures: list[str] = (
747            [record["messageId"] for record in records]
748            if self.__report_batch_item_failures
749            else None
750        )
751
752        def handle_received_message(message: Message, message_id: str) -> None:
753            try:
754                self.handle_received_message(message=message, source=source)
755            except Exception:
756                if not self.__report_batch_item_failures:
757                    raise
758                getLogger().exception(
759                    f"Error handling recevied message for {source}")
760            else:
761                if self.__report_batch_item_failures:
762                    batch_item_failures.remove(message_id)
763
764        message_handlers = [
765            partial(
766                handle_received_message,
767                Message(
768                    body=record["body"],
769                    group_id=record["attributes"]["MessageGroupId"],
770                    message_type=self.receive_message_type,
771                    previous_tracking_ids=record["messageAttributes"]
772                    .get("prevTrackingIds", {})
773                    .get("stringValue"),
774                    tracking_id=record["messageAttributes"]
775                    .get("trackingId", {})
776                    .get("stringValue"),
777                ),
778                record["messageId"],
779            )
780            for record in records
781        ]
782
783        if executor := self.__executor:
784            for future in as_completed(
785                [
786                    executor.submit(message_handler)
787                    for message_handler in message_handlers
788                ]
789            ):
790                if exception := future.exception():
791                    raise exception
792        else:
793            for message_handler in message_handlers:
794                message_handler()
795        self.join()
796        if self.__report_batch_item_failures and batch_item_failures:
797            return dict(
798                batchItemFailures=[
799                    dict(itemIdentifier=message_id)
800                    for message_id in batch_item_failures
801                ]
802            )
class Node(echostream_node.Node):
244class Node(BaseNode):
245    """
246    Base class for all threading Nodes.
247    """
248
249    def __init__(
250        self,
251        *,
252        appsync_endpoint: str = None,
253        bulk_data_acceleration: bool = False,
254        client_id: str = None,
255        name: str = None,
256        password: str = None,
257        tenant: str = None,
258        timeout: float = None,
259        user_pool_id: str = None,
260        username: str = None,
261    ) -> None:
262        super().__init__(
263            appsync_endpoint=appsync_endpoint,
264            bulk_data_acceleration=bulk_data_acceleration,
265            client_id=client_id,
266            name=name,
267            password=password,
268            tenant=tenant,
269            timeout=timeout,
270            user_pool_id=user_pool_id,
271            username=username,
272        )
273        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
274        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
275        self.__gql_client = GqlClient(
276            fetch_schema_from_transport=True,
277            transport=RequestsHTTPTransport(
278                auth=RequestsSrpAuth(cognito=self.__cognito,
279                                     http_header_prefix=""),
280                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
281            ),
282        )
283        self.__lock = RLock()
284        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
285
286    @property
287    def _gql_client(self) -> GqlClient:
288        return self.__gql_client
289
290    @property
291    def _lock(self) -> RLock:
292        return self.__lock
293
294    def audit_message(
295        self,
296        /,
297        message: Message,
298        *,
299        extra_attributes: dict[str, Any] = None,
300        source: str = None,
301    ) -> None:
302        """
303        Audits the provided message. If extra_attibutes is
304        supplied, they will be added to the message's audit
305        dict. If source is provided, it will be recorded in
306        the audit.
307        """
308        if self.stopped:
309            raise RuntimeError(f"{self.name} is stopped")
310        if not self.audit:
311            return
312        extra_attributes = extra_attributes or dict()
313        message_type = message.message_type
314        record = dict(
315            datetime=datetime.now(timezone.utc).isoformat(),
316            previousTrackingIds=message.previous_tracking_ids,
317            sourceNode=source,
318            trackingId=message.tracking_id,
319        )
320        if attributes := (
321            message_type.auditor(message=message.body) | extra_attributes
322        ):
323            record["attributes"] = attributes
324        try:
325            self.__audit_records_queues[message_type.name].put_nowait(record)
326        except KeyError:
327            raise ValueError(f"Unrecognized message type {message_type.name}")
328
329    def audit_messages(
330        self,
331        /,
332        messages: list[Message],
333        *,
334        extra_attributes: list[dict[str, Any]] = None,
335        source: str = None,
336    ) -> None:
337        """
338        Audits the provided messages. If extra_attibutes is
339        supplied they will be added to the respective message's audit
340        dict and they must have the same count as messages.
341        If source is provided, it will be recorded in the audit.
342        """
343        if extra_attributes:
344            if len(extra_attributes) != len(messages):
345                raise ValueError(
346                    "messages and extra_attributes must have the same number of items"
347                )
348        else:
349            extra_attributes = [dict()] * len(messages)
350        for message, attributes in zip(messages, extra_attributes):
351            self.audit_message(
352                message, extra_attributes=attributes, source=source)
353
354    def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
355        """
356        Posts data as bulk data and returns a GET URL for data retrieval.
357        Normally this returned URL will be used as a "ticket" in messages
358        that require bulk data.
359        """
360        return self.__bulk_data_storage_queue.get().handle_bulk_data(data)
361
362    def handle_received_message(self, *, message: Message, source: str) -> None:
363        """
364        Callback called when a message is received. Subclasses that receive messages
365        should override this method.
366        """
367        pass
368
369    def join(self) -> None:
370        """
371        Joins the calling thread with this Node. Will block until all
372        join conditions are satified.
373        """
374        for target_message_queue in self.__target_message_queues.values():
375            target_message_queue.join()
376        for audit_records_queue in self.__audit_records_queues.values():
377            audit_records_queue.join()
378
379    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
380        """
381        Send the message to the specified targets. If no targets are specified
382        the message will be sent to all targets.
383        """
384        self.send_messages([message], targets=targets)
385
386    def send_messages(
387        self, /, messages: list[Message], *, targets: set[Edge] = None
388    ) -> None:
389        """
390        Send the messages to the specified targets. If no targets are specified
391        the messages will be sent to all targets.
392        """
393        if self.stopped:
394            raise RuntimeError(f"{self.name} is stopped")
395        if messages:
396            for target in targets or self.targets:
397                if target_message_queue := self.__target_message_queues.get(
398                    target.name
399                ):
400                    for message in messages:
401                        target_message_queue.put_nowait(message)
402                else:
403                    getLogger().warning(f"Target {target.name} does not exist")
404
405    def start(self) -> None:
406        """
407        Starts this Node. Must be called prior to any other usage.
408        """
409        getLogger().info(f"Starting Node {self.name}")
410        with self._lock:
411            with self._gql_client as session:
412                data: dict[str, Union[str, dict]] = session.execute(
413                    _GET_NODE_GQL,
414                    variable_values=dict(name=self.name, tenant=self.tenant),
415                )["GetNode"]
416        self._audit = data["tenant"].get("audit") or False
417        self.config = (
418            json.loads(data["tenant"].get("config") or "{}")
419            | json.loads((data.get("app") or dict()).get("config") or "{}")
420            | json.loads(data.get("config") or "{}")
421        )
422        self._stopped = data.get("stopped")
423        if receive_message_type := data.get("receiveMessageType"):
424            self._receive_message_type = MessageType(
425                auditor=dynamic_function_loader.load(
426                    receive_message_type["auditor"]),
427                name=receive_message_type["name"],
428            )
429            if not self.stopped and self.audit:
430                self.__audit_records_queues[receive_message_type["name"]] = (
431                    _AuditRecordQueue(self._receive_message_type, self)
432                )
433        if send_message_type := data.get("sendMessageType"):
434            self._send_message_type = MessageType(
435                auditor=dynamic_function_loader.load(
436                    send_message_type["auditor"]),
437                name=send_message_type["name"],
438            )
439            if not self.stopped and self.audit:
440                self.__audit_records_queues[send_message_type["name"]] = (
441                    _AuditRecordQueue(self._send_message_type, self)
442                )
443        if self.node_type == "AppChangeReceiverNode":
444            if edge := data.get("receiveEdge"):
445                self._sources = {
446                    Edge(name=edge["source"]["name"], queue=edge["queue"])}
447            else:
448                self._sources = set()
449        else:
450            self._sources = {
451                Edge(name=edge["source"]["name"], queue=edge["queue"])
452                for edge in (data.get("receiveEdges") or list())
453            }
454        self._targets = {
455            Edge(name=edge["target"]["name"], queue=edge["queue"])
456            for edge in (data.get("sendEdges") or list())
457        }
458        if not self.stopped:
459            self.__target_message_queues = {
460                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
461            }
462
463    def stop(self) -> None:
464        """Stops the Node's processing."""
465        pass

Base class for all threading Nodes.

Node( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, name: str = None, password: str = None, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
249    def __init__(
250        self,
251        *,
252        appsync_endpoint: str = None,
253        bulk_data_acceleration: bool = False,
254        client_id: str = None,
255        name: str = None,
256        password: str = None,
257        tenant: str = None,
258        timeout: float = None,
259        user_pool_id: str = None,
260        username: str = None,
261    ) -> None:
262        super().__init__(
263            appsync_endpoint=appsync_endpoint,
264            bulk_data_acceleration=bulk_data_acceleration,
265            client_id=client_id,
266            name=name,
267            password=password,
268            tenant=tenant,
269            timeout=timeout,
270            user_pool_id=user_pool_id,
271            username=username,
272        )
273        self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict()
274        self.__bulk_data_storage_queue = _BulkDataStorageQueue(self)
275        self.__gql_client = GqlClient(
276            fetch_schema_from_transport=True,
277            transport=RequestsHTTPTransport(
278                auth=RequestsSrpAuth(cognito=self.__cognito,
279                                     http_header_prefix=""),
280                url=appsync_endpoint or environ["APPSYNC_ENDPOINT"],
281            ),
282        )
283        self.__lock = RLock()
284        self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
def audit_message( self, /, message: echostream_node.Message, *, extra_attributes: dict[str, typing.Any] = None, source: str = None) -> None:
294    def audit_message(
295        self,
296        /,
297        message: Message,
298        *,
299        extra_attributes: dict[str, Any] = None,
300        source: str = None,
301    ) -> None:
302        """
303        Audits the provided message. If extra_attibutes is
304        supplied, they will be added to the message's audit
305        dict. If source is provided, it will be recorded in
306        the audit.
307        """
308        if self.stopped:
309            raise RuntimeError(f"{self.name} is stopped")
310        if not self.audit:
311            return
312        extra_attributes = extra_attributes or dict()
313        message_type = message.message_type
314        record = dict(
315            datetime=datetime.now(timezone.utc).isoformat(),
316            previousTrackingIds=message.previous_tracking_ids,
317            sourceNode=source,
318            trackingId=message.tracking_id,
319        )
320        if attributes := (
321            message_type.auditor(message=message.body) | extra_attributes
322        ):
323            record["attributes"] = attributes
324        try:
325            self.__audit_records_queues[message_type.name].put_nowait(record)
326        except KeyError:
327            raise ValueError(f"Unrecognized message type {message_type.name}")

Audits the provided message. If extra_attibutes is supplied, they will be added to the message's audit dict. If source is provided, it will be recorded in the audit.

def audit_messages( self, /, messages: list[echostream_node.Message], *, extra_attributes: list[dict[str, typing.Any]] = None, source: str = None) -> None:
329    def audit_messages(
330        self,
331        /,
332        messages: list[Message],
333        *,
334        extra_attributes: list[dict[str, Any]] = None,
335        source: str = None,
336    ) -> None:
337        """
338        Audits the provided messages. If extra_attibutes is
339        supplied they will be added to the respective message's audit
340        dict and they must have the same count as messages.
341        If source is provided, it will be recorded in the audit.
342        """
343        if extra_attributes:
344            if len(extra_attributes) != len(messages):
345                raise ValueError(
346                    "messages and extra_attributes must have the same number of items"
347                )
348        else:
349            extra_attributes = [dict()] * len(messages)
350        for message, attributes in zip(messages, extra_attributes):
351            self.audit_message(
352                message, extra_attributes=attributes, source=source)

Audits the provided messages. If extra_attibutes is supplied they will be added to the respective message's audit dict and they must have the same count as messages. If source is provided, it will be recorded in the audit.

def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
354    def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str:
355        """
356        Posts data as bulk data and returns a GET URL for data retrieval.
357        Normally this returned URL will be used as a "ticket" in messages
358        that require bulk data.
359        """
360        return self.__bulk_data_storage_queue.get().handle_bulk_data(data)

Posts data as bulk data and returns a GET URL for data retrieval. Normally this returned URL will be used as a "ticket" in messages that require bulk data.

def handle_received_message(self, *, message: echostream_node.Message, source: str) -> None:
362    def handle_received_message(self, *, message: Message, source: str) -> None:
363        """
364        Callback called when a message is received. Subclasses that receive messages
365        should override this method.
366        """
367        pass

Callback called when a message is received. Subclasses that receive messages should override this method.

def join(self) -> None:
369    def join(self) -> None:
370        """
371        Joins the calling thread with this Node. Will block until all
372        join conditions are satified.
373        """
374        for target_message_queue in self.__target_message_queues.values():
375            target_message_queue.join()
376        for audit_records_queue in self.__audit_records_queues.values():
377            audit_records_queue.join()

Joins the calling thread with this Node. Will block until all join conditions are satified.

def send_message( self, /, message: echostream_node.Message, *, targets: set[echostream_node.Edge] = None) -> None:
379    def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None:
380        """
381        Send the message to the specified targets. If no targets are specified
382        the message will be sent to all targets.
383        """
384        self.send_messages([message], targets=targets)

Send the message to the specified targets. If no targets are specified the message will be sent to all targets.

def send_messages( self, /, messages: list[echostream_node.Message], *, targets: set[echostream_node.Edge] = None) -> None:
386    def send_messages(
387        self, /, messages: list[Message], *, targets: set[Edge] = None
388    ) -> None:
389        """
390        Send the messages to the specified targets. If no targets are specified
391        the messages will be sent to all targets.
392        """
393        if self.stopped:
394            raise RuntimeError(f"{self.name} is stopped")
395        if messages:
396            for target in targets or self.targets:
397                if target_message_queue := self.__target_message_queues.get(
398                    target.name
399                ):
400                    for message in messages:
401                        target_message_queue.put_nowait(message)
402                else:
403                    getLogger().warning(f"Target {target.name} does not exist")

Send the messages to the specified targets. If no targets are specified the messages will be sent to all targets.

def start(self) -> None:
405    def start(self) -> None:
406        """
407        Starts this Node. Must be called prior to any other usage.
408        """
409        getLogger().info(f"Starting Node {self.name}")
410        with self._lock:
411            with self._gql_client as session:
412                data: dict[str, Union[str, dict]] = session.execute(
413                    _GET_NODE_GQL,
414                    variable_values=dict(name=self.name, tenant=self.tenant),
415                )["GetNode"]
416        self._audit = data["tenant"].get("audit") or False
417        self.config = (
418            json.loads(data["tenant"].get("config") or "{}")
419            | json.loads((data.get("app") or dict()).get("config") or "{}")
420            | json.loads(data.get("config") or "{}")
421        )
422        self._stopped = data.get("stopped")
423        if receive_message_type := data.get("receiveMessageType"):
424            self._receive_message_type = MessageType(
425                auditor=dynamic_function_loader.load(
426                    receive_message_type["auditor"]),
427                name=receive_message_type["name"],
428            )
429            if not self.stopped and self.audit:
430                self.__audit_records_queues[receive_message_type["name"]] = (
431                    _AuditRecordQueue(self._receive_message_type, self)
432                )
433        if send_message_type := data.get("sendMessageType"):
434            self._send_message_type = MessageType(
435                auditor=dynamic_function_loader.load(
436                    send_message_type["auditor"]),
437                name=send_message_type["name"],
438            )
439            if not self.stopped and self.audit:
440                self.__audit_records_queues[send_message_type["name"]] = (
441                    _AuditRecordQueue(self._send_message_type, self)
442                )
443        if self.node_type == "AppChangeReceiverNode":
444            if edge := data.get("receiveEdge"):
445                self._sources = {
446                    Edge(name=edge["source"]["name"], queue=edge["queue"])}
447            else:
448                self._sources = set()
449        else:
450            self._sources = {
451                Edge(name=edge["source"]["name"], queue=edge["queue"])
452                for edge in (data.get("receiveEdges") or list())
453            }
454        self._targets = {
455            Edge(name=edge["target"]["name"], queue=edge["queue"])
456            for edge in (data.get("sendEdges") or list())
457        }
458        if not self.stopped:
459            self.__target_message_queues = {
460                edge.name: _TargetMessageQueue(self, edge) for edge in self._targets
461            }

Starts this Node. Must be called prior to any other usage.

def stop(self) -> None:
463    def stop(self) -> None:
464        """Stops the Node's processing."""
465        pass

Stops the Node's processing.

class AppNode(Node):
608class AppNode(Node):
609    """
610    A daemon Node intended to be used as either a stand-alone application
611    or as a part of a larger application.
612    """
613
614    def __init__(
615        self,
616        *,
617        appsync_endpoint: str = None,
618        bulk_data_acceleration: bool = False,
619        client_id: str = None,
620        executor: Executor = None,
621        name: str = None,
622        password: str = None,
623        tenant: str = None,
624        timeout: float = None,
625        user_pool_id: str = None,
626        username: str = None,
627    ) -> None:
628        super().__init__(
629            appsync_endpoint=appsync_endpoint,
630            bulk_data_acceleration=bulk_data_acceleration,
631            client_id=client_id,
632            name=name,
633            password=password,
634            tenant=tenant,
635            timeout=timeout,
636            user_pool_id=user_pool_id,
637            username=username,
638        )
639        self.__executor = executor
640        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
641        self.__stop = Event()
642
643    @property
644    def _executor(self) -> Executor:
645        return self.__executor
646
647    def join(self) -> None:
648        """
649        Method to join all the app node receivers so that main thread can wait for their execution to complete.
650        """
651        self.__stop.wait()
652        for app_node_receiver in self.__source_message_receivers:
653            app_node_receiver.join()
654        super().join()
655
656    def start(self) -> None:
657        """
658        Calls start of Node class
659        """
660        super().start()
661        self.__stop.clear()
662        if not self.stopped:
663            self.__source_message_receivers = [
664                _SourceMessageReceiver(edge, self) for edge in self._sources
665            ]
666
667    def start_and_run_forever(self) -> None:
668        """Will start this Node and run until stop is called"""
669        self.start()
670        self.join()
671
672    def stop(self) -> None:
673        """
674        Stops the Node gracefully
675        """
676        self.__stop.set()
677        for app_node_receiver in self.__source_message_receivers:
678            app_node_receiver.stop()

A daemon Node intended to be used as either a stand-alone application or as a part of a larger application.

AppNode( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, executor: concurrent.futures._base.Executor = None, name: str = None, password: str = None, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
614    def __init__(
615        self,
616        *,
617        appsync_endpoint: str = None,
618        bulk_data_acceleration: bool = False,
619        client_id: str = None,
620        executor: Executor = None,
621        name: str = None,
622        password: str = None,
623        tenant: str = None,
624        timeout: float = None,
625        user_pool_id: str = None,
626        username: str = None,
627    ) -> None:
628        super().__init__(
629            appsync_endpoint=appsync_endpoint,
630            bulk_data_acceleration=bulk_data_acceleration,
631            client_id=client_id,
632            name=name,
633            password=password,
634            tenant=tenant,
635            timeout=timeout,
636            user_pool_id=user_pool_id,
637            username=username,
638        )
639        self.__executor = executor
640        self.__source_message_receivers: list[_SourceMessageReceiver] = list()
641        self.__stop = Event()
def join(self) -> None:
647    def join(self) -> None:
648        """
649        Method to join all the app node receivers so that main thread can wait for their execution to complete.
650        """
651        self.__stop.wait()
652        for app_node_receiver in self.__source_message_receivers:
653            app_node_receiver.join()
654        super().join()

Method to join all the app node receivers so that main thread can wait for their execution to complete.

def start(self) -> None:
656    def start(self) -> None:
657        """
658        Calls start of Node class
659        """
660        super().start()
661        self.__stop.clear()
662        if not self.stopped:
663            self.__source_message_receivers = [
664                _SourceMessageReceiver(edge, self) for edge in self._sources
665            ]

Calls start of Node class

def start_and_run_forever(self) -> None:
667    def start_and_run_forever(self) -> None:
668        """Will start this Node and run until stop is called"""
669        self.start()
670        self.join()

Will start this Node and run until stop is called

def stop(self) -> None:
672    def stop(self) -> None:
673        """
674        Stops the Node gracefully
675        """
676        self.__stop.set()
677        for app_node_receiver in self.__source_message_receivers:
678            app_node_receiver.stop()

Stops the Node gracefully

class LambdaNode(Node):
681class LambdaNode(Node):
682    """
683    A Node class intended to be implemented in an AWS Lambda function.
684    Nodes that inherit from this class are automatically started on
685    creation.
686    """
687
688    def __init__(
689        self,
690        *,
691        appsync_endpoint: str = None,
692        bulk_data_acceleration: bool = False,
693        client_id: str = None,
694        concurrent_processing: bool = False,
695        name: str = None,
696        password: str = None,
697        report_batch_item_failures: bool = False,
698        tenant: str = None,
699        timeout: float = None,
700        user_pool_id: str = None,
701        username: str = None,
702    ) -> None:
703        super().__init__(
704            appsync_endpoint=appsync_endpoint,
705            bulk_data_acceleration=bulk_data_acceleration,
706            client_id=client_id,
707            name=name,
708            password=password,
709            tenant=tenant,
710            timeout=timeout or 0.01,
711            user_pool_id=user_pool_id,
712            username=username,
713        )
714        self.start()
715        signal(SIGTERM, self._shutdown_handler)
716        self.__executor: Executor = (
717            ThreadPoolExecutor() if concurrent_processing else None
718        )
719        self.__queue_name_to_source = {
720            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
721        }
722        self.__report_batch_item_failures = report_batch_item_failures
723
724    def _get_source(self, queue_arn: str) -> str:
725        return self.__queue_name_to_source[queue_arn.split(":")[-1:][0]]
726
727    def _shutdown_handler(self, signum: int, frame: object) -> None:
728        getLogger().info("Received SIGTERM, shutting down")
729        self.join()
730        getLogger().info("Shutdown complete")
731
732    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
733        """
734        Handles the AWS Lambda event passed into the containing
735        AWS Lambda function during invocation.
736
737        This is intended to be the only called method in your
738        containing AWS Lambda function.
739        """
740        records: LambdaSqsRecords = None
741        if not (records := event.get("Records")):
742            getLogger().warning(f"No Records found in event {event}")
743            return
744
745        source = self._get_source(records[0]["eventSourceARN"])
746        getLogger().info(f"Received {len(records)} messages from {source}")
747        batch_item_failures: list[str] = (
748            [record["messageId"] for record in records]
749            if self.__report_batch_item_failures
750            else None
751        )
752
753        def handle_received_message(message: Message, message_id: str) -> None:
754            try:
755                self.handle_received_message(message=message, source=source)
756            except Exception:
757                if not self.__report_batch_item_failures:
758                    raise
759                getLogger().exception(
760                    f"Error handling recevied message for {source}")
761            else:
762                if self.__report_batch_item_failures:
763                    batch_item_failures.remove(message_id)
764
765        message_handlers = [
766            partial(
767                handle_received_message,
768                Message(
769                    body=record["body"],
770                    group_id=record["attributes"]["MessageGroupId"],
771                    message_type=self.receive_message_type,
772                    previous_tracking_ids=record["messageAttributes"]
773                    .get("prevTrackingIds", {})
774                    .get("stringValue"),
775                    tracking_id=record["messageAttributes"]
776                    .get("trackingId", {})
777                    .get("stringValue"),
778                ),
779                record["messageId"],
780            )
781            for record in records
782        ]
783
784        if executor := self.__executor:
785            for future in as_completed(
786                [
787                    executor.submit(message_handler)
788                    for message_handler in message_handlers
789                ]
790            ):
791                if exception := future.exception():
792                    raise exception
793        else:
794            for message_handler in message_handlers:
795                message_handler()
796        self.join()
797        if self.__report_batch_item_failures and batch_item_failures:
798            return dict(
799                batchItemFailures=[
800                    dict(itemIdentifier=message_id)
801                    for message_id in batch_item_failures
802                ]
803            )

A Node class intended to be implemented in an AWS Lambda function. Nodes that inherit from this class are automatically started on creation.

LambdaNode( *, appsync_endpoint: str = None, bulk_data_acceleration: bool = False, client_id: str = None, concurrent_processing: bool = False, name: str = None, password: str = None, report_batch_item_failures: bool = False, tenant: str = None, timeout: float = None, user_pool_id: str = None, username: str = None)
688    def __init__(
689        self,
690        *,
691        appsync_endpoint: str = None,
692        bulk_data_acceleration: bool = False,
693        client_id: str = None,
694        concurrent_processing: bool = False,
695        name: str = None,
696        password: str = None,
697        report_batch_item_failures: bool = False,
698        tenant: str = None,
699        timeout: float = None,
700        user_pool_id: str = None,
701        username: str = None,
702    ) -> None:
703        super().__init__(
704            appsync_endpoint=appsync_endpoint,
705            bulk_data_acceleration=bulk_data_acceleration,
706            client_id=client_id,
707            name=name,
708            password=password,
709            tenant=tenant,
710            timeout=timeout or 0.01,
711            user_pool_id=user_pool_id,
712            username=username,
713        )
714        self.start()
715        signal(SIGTERM, self._shutdown_handler)
716        self.__executor: Executor = (
717            ThreadPoolExecutor() if concurrent_processing else None
718        )
719        self.__queue_name_to_source = {
720            edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources
721        }
722        self.__report_batch_item_failures = report_batch_item_failures
def handle_event( self, event: Union[bool, dict, float, int, list, str, tuple, NoneType]) -> dict[str, list[dict[str, str]]]:
732    def handle_event(self, event: LambdaEvent) -> BatchItemFailures:
733        """
734        Handles the AWS Lambda event passed into the containing
735        AWS Lambda function during invocation.
736
737        This is intended to be the only called method in your
738        containing AWS Lambda function.
739        """
740        records: LambdaSqsRecords = None
741        if not (records := event.get("Records")):
742            getLogger().warning(f"No Records found in event {event}")
743            return
744
745        source = self._get_source(records[0]["eventSourceARN"])
746        getLogger().info(f"Received {len(records)} messages from {source}")
747        batch_item_failures: list[str] = (
748            [record["messageId"] for record in records]
749            if self.__report_batch_item_failures
750            else None
751        )
752
753        def handle_received_message(message: Message, message_id: str) -> None:
754            try:
755                self.handle_received_message(message=message, source=source)
756            except Exception:
757                if not self.__report_batch_item_failures:
758                    raise
759                getLogger().exception(
760                    f"Error handling recevied message for {source}")
761            else:
762                if self.__report_batch_item_failures:
763                    batch_item_failures.remove(message_id)
764
765        message_handlers = [
766            partial(
767                handle_received_message,
768                Message(
769                    body=record["body"],
770                    group_id=record["attributes"]["MessageGroupId"],
771                    message_type=self.receive_message_type,
772                    previous_tracking_ids=record["messageAttributes"]
773                    .get("prevTrackingIds", {})
774                    .get("stringValue"),
775                    tracking_id=record["messageAttributes"]
776                    .get("trackingId", {})
777                    .get("stringValue"),
778                ),
779                record["messageId"],
780            )
781            for record in records
782        ]
783
784        if executor := self.__executor:
785            for future in as_completed(
786                [
787                    executor.submit(message_handler)
788                    for message_handler in message_handlers
789                ]
790            ):
791                if exception := future.exception():
792                    raise exception
793        else:
794            for message_handler in message_handlers:
795                message_handler()
796        self.join()
797        if self.__report_batch_item_failures and batch_item_failures:
798            return dict(
799                batchItemFailures=[
800                    dict(itemIdentifier=message_id)
801                    for message_id in batch_item_failures
802                ]
803            )

Handles the AWS Lambda event passed into the containing AWS Lambda function during invocation.

This is intended to be the only called method in your containing AWS Lambda function.