echostream_node.asyncio
1from __future__ import annotations 2 3import asyncio 4import threading 5from concurrent.futures import ThreadPoolExecutor 6from datetime import datetime, timezone 7from functools import partial 8from gzip import GzipFile, compress 9from io import BytesIO 10from os import environ 11from signal import SIG_IGN, SIGTERM, signal 12from typing import TYPE_CHECKING, Any, AsyncGenerator, BinaryIO, Union 13 14import dynamic_function_loader 15import simplejson as json 16from aws_error_utils import catch_aws_error 17from gql import Client as GqlClient 18from gql.transport.aiohttp import AIOHTTPTransport 19from gql_appsync_cognito_authentication import AppSyncCognitoAuthentication 20from httpx import AsyncClient as HttpxClient 21from httpx_auth import AWS4Auth 22 23from .. import _GET_BULK_DATA_STORAGE_GQL, _GET_NODE_GQL, BatchItemFailures 24from .. import BulkDataStorage as BaseBulkDataStorage 25from .. import Edge, LambdaEvent, LambdaSqsRecords, Message, MessageType 26from .. import Node as BaseNode 27from .. import PresignedPost, getLogger 28 29if TYPE_CHECKING: 30 from mypy_boto3_sqs.type_defs import ( 31 DeleteMessageBatchRequestEntryTypeDef, 32 SendMessageBatchRequestEntryTypeDef, 33 ) 34else: 35 DeleteMessageBatchRequestEntryTypeDef = dict 36 SendMessageBatchRequestEntryTypeDef = dict 37 38 39class _AuditRecordQueue(asyncio.Queue): 40 def __init__(self, message_type: MessageType, node: Node) -> None: 41 super().__init__() 42 43 async def sender() -> None: 44 cancelled = asyncio.Event() 45 46 async def batcher() -> AsyncGenerator[ 47 list[dict], 48 None, 49 ]: 50 batch: list[dict] = list() 51 while not (cancelled.is_set() and self.empty()): 52 try: 53 try: 54 batch.append( 55 await asyncio.wait_for(self.get(), timeout=node.timeout) 56 ) 57 except asyncio.TimeoutError: 58 if batch: 59 yield batch 60 batch = list() 61 else: 62 if len(batch) == 500: 63 yield batch 64 batch = list() 65 except asyncio.CancelledError: 66 cancelled.set() 67 if batch: 68 yield batch 69 batch = list() 70 71 async with HttpxClient() as client: 72 async for batch in batcher(): 73 credentials = ( 74 node._session.get_credentials().get_frozen_credentials() 75 ) 76 auth = AWS4Auth( 77 access_id=credentials.access_key, 78 region=node._session.region_name, 79 secret_key=credentials.secret_key, 80 service="lambda", 81 security_token=credentials.token, 82 ) 83 url = node._audit_records_endpoint 84 post_args = dict( 85 auth=auth, 86 url=f"{url}{'' if url.endswith('/') else '/'}{node.name}", 87 ) 88 body = dict( 89 messageType=message_type.name, 90 auditRecords=batch, 91 ) 92 if len(batch) <= 10: 93 post_args["json"] = body 94 else: 95 post_args["content"] = compress( 96 json.dumps(body, separators=(",", ":")).encode() 97 ) 98 post_args["headers"] = { 99 "Content-Encoding": "gzip", 100 "Content-Type": "application/json", 101 } 102 try: 103 response = await client.post(**post_args) 104 response.raise_for_status() 105 await response.aclose() 106 except asyncio.CancelledError: 107 cancelled.set() 108 except Exception: 109 getLogger().exception("Error creating audit records") 110 finally: 111 for _ in range(len(batch)): 112 self.task_done() 113 114 asyncio.create_task(sender(), name=f"AuditRecordsSender") 115 116 async def get(self) -> dict: 117 return await super().get() 118 119 120class _BulkDataStorage(BaseBulkDataStorage): 121 def __init__( 122 self, 123 bulk_data_storage: dict[str, Union[str, PresignedPost]], 124 client: HttpxClient, 125 ) -> None: 126 super().__init__(bulk_data_storage) 127 self.__client = client 128 129 async def handle_bulk_data(self, data: Union[bytearray, bytes, BinaryIO]) -> str: 130 if isinstance(data, BinaryIO): 131 data = data.read() 132 with BytesIO() as buffer: 133 with GzipFile(mode="wb", fileobj=buffer) as gzf: 134 gzf.write(data) 135 buffer.seek(0) 136 response = await self.__client.post( 137 self.presigned_post.url, 138 data=self.presigned_post.fields, 139 files=dict(file=("bulk_data", buffer)), 140 ) 141 response.raise_for_status() 142 await response.aclose() 143 return self.presigned_get 144 145 146class _BulkDataStorageQueue(asyncio.Queue): 147 def __init__(self, node: Node) -> None: 148 super().__init__() 149 self.__fill = asyncio.Event() 150 151 async def filler() -> None: 152 async with HttpxClient() as client: 153 cancelled = asyncio.Event() 154 while not cancelled.is_set(): 155 bulk_data_storages: list[dict] = list() 156 try: 157 await self.__fill.wait() 158 async with node._lock: 159 async with node._gql_client as session: 160 bulk_data_storages = ( 161 await session.execute( 162 _GET_BULK_DATA_STORAGE_GQL, 163 variable_values={ 164 "tenant": node.tenant, 165 "useAccelerationEndpoint": node.bulk_data_acceleration, 166 }, 167 ) 168 )["GetBulkDataStorage"] 169 except asyncio.CancelledError: 170 cancelled.set() 171 except Exception: 172 getLogger().exception("Error getting bulk data storage") 173 for bulk_data_storage in bulk_data_storages: 174 self.put_nowait(_BulkDataStorage(bulk_data_storage, client)) 175 self.__fill.clear() 176 177 asyncio.create_task(filler(), name="BulkDataStorageQueueFiller") 178 179 async def get(self) -> _BulkDataStorage: 180 if self.qsize() < 20: 181 self.__fill.set() 182 bulk_data_storage: _BulkDataStorage = await super().get() 183 return bulk_data_storage if not bulk_data_storage.expired else await self.get() 184 185 186class _TargetMessageQueue(asyncio.Queue): 187 def __init__(self, node: Node, edge: Edge) -> None: 188 super().__init__() 189 190 async def sender() -> None: 191 cancelled = asyncio.Event() 192 193 async def batcher() -> ( 194 AsyncGenerator[list[SendMessageBatchRequestEntryTypeDef], None] 195 ): 196 batch: list[SendMessageBatchRequestEntryTypeDef] = list() 197 batch_length = 0 198 id = 0 199 while not (cancelled.is_set() and self.empty()): 200 try: 201 try: 202 message = await asyncio.wait_for( 203 self.get(), timeout=node.timeout 204 ) 205 except asyncio.TimeoutError: 206 if batch: 207 yield batch 208 batch = list() 209 batch_length = 0 210 id = 0 211 else: 212 if batch_length + len(message) > 262144: 213 yield batch 214 batch = list() 215 batch_length = 0 216 id = 0 217 batch.append( 218 SendMessageBatchRequestEntryTypeDef( 219 Id=str(id), **message._sqs_message(node) 220 ) 221 ) 222 if len(batch) == 10: 223 yield batch 224 batch = list() 225 batch_length = 0 226 id = 0 227 id += 1 228 batch_length += len(message) 229 except asyncio.CancelledError: 230 cancelled.set() 231 if batch: 232 yield batch 233 batch = list() 234 batch_length = 0 235 id = 0 236 237 async for entries in batcher(): 238 try: 239 response = await asyncio.get_running_loop().run_in_executor( 240 None, 241 partial( 242 node._sqs_client.send_message_batch, 243 Entries=entries, 244 QueueUrl=edge.queue, 245 ), 246 ) 247 except asyncio.CancelledError: 248 cancelled.set() 249 except Exception: 250 getLogger().exception(f"Error sending messages to {edge.name}") 251 finally: 252 for failed in response.get("Failed", list()): 253 id = failed.pop("Id") 254 getLogger().error( 255 f"Unable to send message {entries[id]} to {edge.name}, reason {failed}" 256 ) 257 for _ in range(len(entries)): 258 self.task_done() 259 260 asyncio.create_task(sender(), name=f"TargetMessageSender({edge.name})") 261 262 async def get(self) -> Message: 263 return await super().get() 264 265 266class Node(BaseNode): 267 """ 268 Base class for all implemented asyncio Nodes. 269 270 Nodes of this class must be instantiated outside of 271 the asyncio event loop. 272 """ 273 274 def __init__( 275 self, 276 *, 277 appsync_endpoint: str = None, 278 bulk_data_acceleration: bool = False, 279 client_id: str = None, 280 name: str = None, 281 password: str = None, 282 tenant: str = None, 283 timeout: float = None, 284 user_pool_id: str = None, 285 username: str = None, 286 ) -> None: 287 super().__init__( 288 appsync_endpoint=appsync_endpoint, 289 bulk_data_acceleration=bulk_data_acceleration, 290 client_id=client_id, 291 name=name, 292 password=password, 293 tenant=tenant, 294 timeout=timeout, 295 user_pool_id=user_pool_id, 296 username=username, 297 ) 298 self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict() 299 self.__bulk_data_storage_queue: _BulkDataStorageQueue = None 300 self.__gql_client = GqlClient( 301 fetch_schema_from_transport=True, 302 transport=AIOHTTPTransport( 303 auth=AppSyncCognitoAuthentication(self.__cognito), 304 url=appsync_endpoint or environ["APPSYNC_ENDPOINT"], 305 ), 306 ) 307 self.__lock: asyncio.Lock = None 308 self.__target_message_queues: dict[str, _TargetMessageQueue] = dict() 309 310 @property 311 def _gql_client(self) -> GqlClient: 312 return self.__gql_client 313 314 @property 315 def _lock(self) -> asyncio.Lock: 316 return self.__lock 317 318 def audit_message( 319 self, 320 /, 321 message: Message, 322 *, 323 extra_attributes: dict[str, Any] = None, 324 source: str = None, 325 ) -> None: 326 """ 327 Audits the provided message. If extra_attibutes is 328 supplied, they will be added to the message's audit 329 dict. If source is provided, it will be recorded in 330 the audit. 331 """ 332 if self.stopped: 333 raise RuntimeError(f"{self.name} is stopped") 334 if not self.audit: 335 return 336 extra_attributes = extra_attributes or dict() 337 message_type = message.message_type 338 record = dict( 339 datetime=datetime.now(timezone.utc).isoformat(), 340 previousTrackingIds=message.previous_tracking_ids, 341 sourceNode=source, 342 trackingId=message.tracking_id, 343 ) 344 if attributes := ( 345 message_type.auditor(message=message.body) | extra_attributes 346 ): 347 record["attributes"] = attributes 348 try: 349 self.__audit_records_queues[message_type.name].put_nowait(record) 350 except KeyError: 351 raise ValueError(f"Unrecognized message type {message_type.name}") 352 353 def audit_messages( 354 self, 355 /, 356 messages: list[Message], 357 *, 358 extra_attributes: list[dict[str, Any]] = None, 359 source: str = None, 360 ) -> None: 361 """ 362 Audits the provided messages. If extra_attibutes is 363 supplied they will be added to the respective message's audit 364 dict and they must have the same count as messages. 365 If source is provided, it will be recorded in the audit. 366 """ 367 if extra_attributes and len(extra_attributes) != len(messages): 368 raise ValueError( 369 "messages and extra_attributes must have the same number of items" 370 ) 371 for message, attributes in zip(messages, extra_attributes): 372 self.audit_message(message, extra_attributes=attributes, source=source) 373 374 async def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str: 375 """ 376 Posts data as bulk data and returns a GET URL for data retrieval. 377 Normally this returned URL will be used as a "ticket" in messages 378 that require bulk data. 379 """ 380 return await (await self.__bulk_data_storage_queue.get()).handle_bulk_data(data) 381 382 async def handle_received_message(self, *, message: Message, source: str) -> None: 383 """ 384 Callback called when a message is received. Subclasses that receive messages 385 should override this method. 386 """ 387 pass 388 389 async def join(self) -> None: 390 """ 391 Joins the calling thread with this Node. Will block until all 392 join conditions are satified. 393 """ 394 await asyncio.gather( 395 *[ 396 target_message_queue.join() 397 for target_message_queue in self.__target_message_queues.values() 398 ], 399 *[ 400 audit_records_queue.join() 401 for audit_records_queue in self.__audit_records_queues.values() 402 ], 403 ) 404 405 def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None: 406 """ 407 Send the message to the specified targets. If no targets are specified 408 the message will be sent to all targets. 409 """ 410 self.send_messages([message], targets=targets) 411 412 def send_messages( 413 self, /, messages: list[Message], *, targets: set[Edge] = None 414 ) -> None: 415 """ 416 Send the messages to the specified targets. If no targets are specified 417 the messages will be sent to all targets. 418 """ 419 if self.stopped: 420 raise RuntimeError(f"{self.name} is stopped") 421 if messages: 422 for target in targets or self.targets: 423 if target_message_queue := self.__target_message_queues.get( 424 target.name 425 ): 426 for message in messages: 427 target_message_queue.put_nowait(message) 428 else: 429 getLogger().warning(f"Target {target.name} does not exist") 430 431 async def start(self) -> None: 432 """ 433 Starts this Node. Must be called prior to any other usage. 434 """ 435 getLogger().info(f"Starting Node {self.name}") 436 self.__lock = asyncio.Lock() 437 self.__bulk_data_storage_queue = _BulkDataStorageQueue(self) 438 async with self.__lock: 439 async with self._gql_client as session: 440 data: dict[str, Union[str, dict]] = ( 441 await session.execute( 442 _GET_NODE_GQL, 443 variable_values=dict(name=self.name, tenant=self.tenant), 444 ) 445 )["GetNode"] 446 self._audit = data["tenant"].get("audit") or False 447 self.config = ( 448 json.loads(data["tenant"].get("config") or "{}") 449 | json.loads((data.get("app") or dict()).get("config") or "{}") 450 | json.loads(data.get("config") or "{}") 451 ) 452 self._stopped = data.get("stopped") 453 if receive_message_type := data.get("receiveMessageType"): 454 self._receive_message_type = MessageType( 455 auditor=dynamic_function_loader.load(receive_message_type["auditor"]), 456 name=receive_message_type["name"], 457 ) 458 if not self.stopped and self.audit: 459 self.__audit_records_queues[receive_message_type["name"]] = ( 460 _AuditRecordQueue(self.receive_message_type, self) 461 ) 462 if send_message_type := data.get("sendMessageType"): 463 self._send_message_type = MessageType( 464 auditor=dynamic_function_loader.load(send_message_type["auditor"]), 465 name=send_message_type["name"], 466 ) 467 if not self.stopped and self.audit: 468 self.__audit_records_queues[send_message_type["name"]] = ( 469 _AuditRecordQueue(self.send_message_type, self) 470 ) 471 if self.node_type == "AppChangeReceiverNode": 472 if edge := data.get("receiveEdge"): 473 self._sources = {Edge(name=edge["source"]["name"], queue=edge["queue"])} 474 else: 475 self._sources = set() 476 else: 477 self._sources = { 478 Edge(name=edge["source"]["name"], queue=edge["queue"]) 479 for edge in (data.get("receiveEdges") or list()) 480 } 481 self._targets = { 482 Edge(name=edge["target"]["name"], queue=edge["queue"]) 483 for edge in (data.get("sendEdges") or list()) 484 } 485 if not self.stopped: 486 self.__target_message_queues = { 487 edge.name: _TargetMessageQueue(self, edge) for edge in self._targets 488 } 489 490 491class _DeleteMessageQueue(asyncio.Queue): 492 def __init__(self, edge: Edge, node: AppNode) -> None: 493 super().__init__() 494 495 async def deleter() -> None: 496 cancelled = asyncio.Event() 497 498 async def batcher() -> AsyncGenerator[ 499 list[str], 500 None, 501 ]: 502 batch: list[str] = list() 503 while not (cancelled.is_set() and self.empty()): 504 try: 505 try: 506 batch.append( 507 await asyncio.wait_for(self.get(), timeout=node.timeout) 508 ) 509 except asyncio.TimeoutError: 510 if batch: 511 yield batch 512 batch = list() 513 else: 514 if len(batch) == 10: 515 yield batch 516 batch = list() 517 except asyncio.CancelledError: 518 cancelled.set() 519 if batch: 520 yield batch 521 batch = list() 522 523 async for receipt_handles in batcher(): 524 try: 525 response = await asyncio.get_running_loop().run_in_executor( 526 None, 527 partial( 528 node._sqs_client.delete_message_batch, 529 Entries=[ 530 DeleteMessageBatchRequestEntryTypeDef( 531 Id=str(id), ReceiptHandle=receipt_handle 532 ) 533 for id, receipt_handle in enumerate(receipt_handles) 534 ], 535 QueueUrl=edge.queue, 536 ), 537 ) 538 except asyncio.CancelledError: 539 cancelled.set() 540 except Exception: 541 getLogger().exception(f"Error deleting messages from {edge.name}") 542 finally: 543 for failed in response.get("Failed", list()): 544 id = failed.pop("Id") 545 getLogger().error( 546 f"Unable to delete message {receipt_handles[id]} from {edge.name}, reason {failed}" 547 ) 548 for _ in range(len(receipt_handles)): 549 self.task_done() 550 551 asyncio.create_task(deleter(), name=f"SourceMessageDeleter({edge.name})") 552 553 async def get(self) -> str: 554 return await super().get() 555 556 557class _SourceMessageReceiver: 558 def __init__(self, edge: Edge, node: AppNode) -> None: 559 self.__delete_message_queue = _DeleteMessageQueue(edge, node) 560 561 async def handle_received_message( 562 message: Message, receipt_handle: str 563 ) -> bool: 564 try: 565 await node.handle_received_message(message=message, source=edge.name) 566 except asyncio.CancelledError: 567 raise 568 except Exception: 569 getLogger().exception( 570 f"Error handling recevied message for {edge.name}" 571 ) 572 return False 573 else: 574 self.__delete_message_queue.put_nowait(receipt_handle) 575 return True 576 577 async def receive() -> None: 578 getLogger().info(f"Receiving messages from {edge.name}") 579 while True: 580 try: 581 response = await asyncio.get_running_loop().run_in_executor( 582 None, 583 partial( 584 node._sqs_client.receive_message, 585 AttributeNames=["All"], 586 MaxNumberOfMessages=10, 587 MessageAttributeNames=["All"], 588 QueueUrl=edge.queue, 589 WaitTimeSeconds=20, 590 ), 591 ) 592 except asyncio.CancelledError: 593 raise 594 except catch_aws_error("AWS.SimpleQueueService.NonExistentQueue"): 595 getLogger().warning(f"Queue {edge.queue} does not exist, exiting") 596 break 597 except Exception: 598 getLogger().exception( 599 f"Error receiving messages from {edge.name}, retrying" 600 ) 601 await asyncio.sleep(20) 602 else: 603 if not (sqs_messages := response.get("Messages")): 604 continue 605 getLogger().info(f"Received {len(sqs_messages)} from {edge.name}") 606 607 message_handlers = [ 608 handle_received_message( 609 Message( 610 body=sqs_message["Body"], 611 group_id=sqs_message["Attributes"]["MessageGroupId"], 612 message_type=node.receive_message_type, 613 tracking_id=sqs_message["MessageAttributes"] 614 .get("trackingId", {}) 615 .get("StringValue"), 616 previous_tracking_ids=sqs_message["MessageAttributes"] 617 .get("prevTrackingIds", {}) 618 .get("StringValue"), 619 ), 620 sqs_message["ReceiptHandle"], 621 ) 622 for sqs_message in sqs_messages 623 ] 624 625 async def handle_received_messages() -> None: 626 if node._concurrent_processing: 627 await asyncio.gather(*message_handlers) 628 else: 629 for message_handler in message_handlers: 630 if not await message_handler: 631 break 632 633 asyncio.create_task( 634 handle_received_messages(), name="handle_received_messages" 635 ) 636 637 getLogger().info(f"Stopping receiving messages from {edge.name}") 638 639 self.__task = asyncio.create_task( 640 receive(), name=f"SourceMessageReceiver({edge.name})" 641 ) 642 643 async def join(self) -> None: 644 await asyncio.wait([self.__task]) 645 await self.__delete_message_queue.join() 646 647 648class AppNode(Node): 649 def __init__( 650 self, 651 *, 652 appsync_endpoint: str = None, 653 bulk_data_acceleration: bool = False, 654 client_id: str = None, 655 concurrent_processing: bool = False, 656 name: str = None, 657 password: str = None, 658 tenant: str = None, 659 timeout: float = None, 660 user_pool_id: str = None, 661 username: str = None, 662 ) -> None: 663 super().__init__( 664 appsync_endpoint=appsync_endpoint, 665 bulk_data_acceleration=bulk_data_acceleration, 666 client_id=client_id, 667 name=name, 668 password=password, 669 tenant=tenant, 670 timeout=timeout, 671 user_pool_id=user_pool_id, 672 username=username, 673 ) 674 self.__concurrent_processing = concurrent_processing 675 self.__source_message_receivers: list[_SourceMessageReceiver] = list() 676 677 @property 678 def _concurrent_processing(self) -> bool: 679 return self.__concurrent_processing 680 681 async def join(self) -> None: 682 """ 683 Joins the calling thread with this Node. Will block until all 684 join conditions are satified. 685 """ 686 await asyncio.gather( 687 *[ 688 source_message_receiver.join() 689 for source_message_receiver in self.__source_message_receivers 690 ] 691 ) 692 await super().join() 693 694 async def start(self) -> None: 695 """ 696 Starts this Node. Must be called prior to any other usage. 697 """ 698 await super().start() 699 if not self.stopped: 700 self.__source_message_receivers = [ 701 _SourceMessageReceiver(edge, self) for edge in self._sources 702 ] 703 704 async def start_and_run_forever(self) -> None: 705 """Will start this Node and run until the containing Task is cancelled""" 706 await self.start() 707 await self.join() 708 709 710class LambdaNode(Node): 711 def __init__( 712 self, 713 *, 714 appsync_endpoint: str = None, 715 bulk_data_acceleration: bool = False, 716 client_id: str = None, 717 concurrent_processing: bool = False, 718 name: str = None, 719 password: str = None, 720 report_batch_item_failures: bool = False, 721 tenant: str = None, 722 timeout: float = None, 723 user_pool_id: str = None, 724 username: str = None, 725 ) -> None: 726 super().__init__( 727 appsync_endpoint=appsync_endpoint, 728 bulk_data_acceleration=bulk_data_acceleration, 729 client_id=client_id, 730 name=name, 731 password=password, 732 tenant=tenant, 733 timeout=timeout or 0.01, 734 user_pool_id=user_pool_id, 735 username=username, 736 ) 737 self.__concurrent_processing = concurrent_processing 738 self.__loop = self._create_event_loop() 739 self.__queue_name_to_source: dict[str, str] = None 740 self.__report_batch_item_failures = report_batch_item_failures 741 742 # Set up the asyncio loop 743 signal(SIGTERM, self._shutdown_handler) 744 745 self.__started = threading.Event() 746 self.__loop.create_task(self.start()) 747 748 # Run the event loop in a seperate thread, or else we will block the main 749 # Lambda execution! 750 threading.Thread(name="event_loop", target=self.__run_event_loop).start() 751 752 # Wait until the started event is set before returning control to 753 # Lambda 754 self.__started.wait() 755 756 def __run_event_loop(self) -> None: 757 getLogger().info("Starting event loop") 758 asyncio.set_event_loop(self.__loop) 759 760 pending_exception_to_raise: Exception = None 761 762 def exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None: 763 nonlocal pending_exception_to_raise 764 pending_exception_to_raise = context.get("exception") 765 getLogger().error( 766 "Unhandled exception; stopping loop: %r", 767 context.get("message"), 768 exc_info=pending_exception_to_raise, 769 ) 770 loop.stop() 771 772 self.__loop.set_exception_handler(exception_handler) 773 executor = ThreadPoolExecutor() 774 self.__loop.set_default_executor(executor) 775 776 self.__loop.run_forever() 777 getLogger().info("Entering shutdown phase") 778 getLogger().info("Cancelling pending tasks") 779 if tasks := asyncio.all_tasks(self.__loop): 780 for task in tasks: 781 getLogger().debug(f"Cancelling task: {task}") 782 task.cancel() 783 getLogger().info("Running pending tasks till complete") 784 self.__loop.run_until_complete( 785 asyncio.gather(*tasks, return_exceptions=True) 786 ) 787 getLogger().info("Waiting for executor shutdown") 788 executor.shutdown(wait=True) 789 getLogger().info("Shutting down async generators") 790 self.__loop.run_until_complete(self.__loop.shutdown_asyncgens()) 791 getLogger().info("Closing the loop.") 792 self.__loop.close() 793 getLogger().info("Loop is closed") 794 if pending_exception_to_raise: 795 getLogger().info("Reraising unhandled exception") 796 raise pending_exception_to_raise 797 798 def _create_event_loop(self) -> asyncio.AbstractEventLoop: 799 return asyncio.new_event_loop() 800 801 def _get_source(self, queue_arn: str) -> str: 802 return self.__queue_name_to_source[queue_arn.split(":")[-1:][0]] 803 804 async def _handle_event(self, event: LambdaEvent) -> BatchItemFailures: 805 """ 806 Handles the AWS Lambda event passed into the containing 807 AWS Lambda function during invocation. 808 809 This is intended to be the only called method in your 810 containing AWS Lambda function. 811 """ 812 records: LambdaSqsRecords = None 813 if not (records := event.get("Records")): 814 getLogger().warning(f"No Records found in event {event}") 815 return 816 817 source = self._get_source(records[0]["eventSourceARN"]) 818 getLogger().info(f"Received {len(records)} messages from {source}") 819 batch_item_failures: list[str] = ( 820 [record["messageId"] for record in records] 821 if self.__report_batch_item_failures 822 else None 823 ) 824 825 async def handle_received_message(message: Message, message_id: str) -> None: 826 try: 827 await self.handle_received_message(message=message, source=source) 828 except asyncio.CancelledError: 829 raise 830 except Exception: 831 if not self.__report_batch_item_failures: 832 raise 833 getLogger().exception(f"Error handling recevied message for {source}") 834 else: 835 if self.__report_batch_item_failures: 836 batch_item_failures.remove(message_id) 837 838 message_handlers = [ 839 handle_received_message( 840 Message( 841 body=record["body"], 842 group_id=record["attributes"]["MessageGroupId"], 843 message_type=self.receive_message_type, 844 previous_tracking_ids=record["messageAttributes"] 845 .get("prevTrackingIds", {}) 846 .get("stringValue"), 847 tracking_id=record["messageAttributes"] 848 .get("trackingId", {}) 849 .get("stringValue"), 850 ), 851 record["messageId"], 852 ) 853 for record in records 854 ] 855 856 if self.__concurrent_processing: 857 await asyncio.gather(*message_handlers) 858 else: 859 for message_handler in message_handlers: 860 await message_handler 861 await self.join() 862 if self.__report_batch_item_failures and batch_item_failures: 863 return dict( 864 batchItemFailures=[ 865 dict(itemIdentifier=message_id) 866 for message_id in batch_item_failures 867 ] 868 ) 869 870 def _shutdown_handler(self, signum: int, frame: object) -> None: 871 signal(SIGTERM, SIG_IGN) 872 getLogger().warning("Received SIGTERM, stopping the loop") 873 self.__loop.stop() 874 875 def handle_event(self, event: LambdaEvent) -> BatchItemFailures: 876 return asyncio.run_coroutine_threadsafe( 877 self._handle_event(event), self.__loop 878 ).result() 879 880 async def start(self) -> None: 881 await super().start() 882 self.__queue_name_to_source = { 883 edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources 884 } 885 self.__started.set()
267class Node(BaseNode): 268 """ 269 Base class for all implemented asyncio Nodes. 270 271 Nodes of this class must be instantiated outside of 272 the asyncio event loop. 273 """ 274 275 def __init__( 276 self, 277 *, 278 appsync_endpoint: str = None, 279 bulk_data_acceleration: bool = False, 280 client_id: str = None, 281 name: str = None, 282 password: str = None, 283 tenant: str = None, 284 timeout: float = None, 285 user_pool_id: str = None, 286 username: str = None, 287 ) -> None: 288 super().__init__( 289 appsync_endpoint=appsync_endpoint, 290 bulk_data_acceleration=bulk_data_acceleration, 291 client_id=client_id, 292 name=name, 293 password=password, 294 tenant=tenant, 295 timeout=timeout, 296 user_pool_id=user_pool_id, 297 username=username, 298 ) 299 self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict() 300 self.__bulk_data_storage_queue: _BulkDataStorageQueue = None 301 self.__gql_client = GqlClient( 302 fetch_schema_from_transport=True, 303 transport=AIOHTTPTransport( 304 auth=AppSyncCognitoAuthentication(self.__cognito), 305 url=appsync_endpoint or environ["APPSYNC_ENDPOINT"], 306 ), 307 ) 308 self.__lock: asyncio.Lock = None 309 self.__target_message_queues: dict[str, _TargetMessageQueue] = dict() 310 311 @property 312 def _gql_client(self) -> GqlClient: 313 return self.__gql_client 314 315 @property 316 def _lock(self) -> asyncio.Lock: 317 return self.__lock 318 319 def audit_message( 320 self, 321 /, 322 message: Message, 323 *, 324 extra_attributes: dict[str, Any] = None, 325 source: str = None, 326 ) -> None: 327 """ 328 Audits the provided message. If extra_attibutes is 329 supplied, they will be added to the message's audit 330 dict. If source is provided, it will be recorded in 331 the audit. 332 """ 333 if self.stopped: 334 raise RuntimeError(f"{self.name} is stopped") 335 if not self.audit: 336 return 337 extra_attributes = extra_attributes or dict() 338 message_type = message.message_type 339 record = dict( 340 datetime=datetime.now(timezone.utc).isoformat(), 341 previousTrackingIds=message.previous_tracking_ids, 342 sourceNode=source, 343 trackingId=message.tracking_id, 344 ) 345 if attributes := ( 346 message_type.auditor(message=message.body) | extra_attributes 347 ): 348 record["attributes"] = attributes 349 try: 350 self.__audit_records_queues[message_type.name].put_nowait(record) 351 except KeyError: 352 raise ValueError(f"Unrecognized message type {message_type.name}") 353 354 def audit_messages( 355 self, 356 /, 357 messages: list[Message], 358 *, 359 extra_attributes: list[dict[str, Any]] = None, 360 source: str = None, 361 ) -> None: 362 """ 363 Audits the provided messages. If extra_attibutes is 364 supplied they will be added to the respective message's audit 365 dict and they must have the same count as messages. 366 If source is provided, it will be recorded in the audit. 367 """ 368 if extra_attributes and len(extra_attributes) != len(messages): 369 raise ValueError( 370 "messages and extra_attributes must have the same number of items" 371 ) 372 for message, attributes in zip(messages, extra_attributes): 373 self.audit_message(message, extra_attributes=attributes, source=source) 374 375 async def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str: 376 """ 377 Posts data as bulk data and returns a GET URL for data retrieval. 378 Normally this returned URL will be used as a "ticket" in messages 379 that require bulk data. 380 """ 381 return await (await self.__bulk_data_storage_queue.get()).handle_bulk_data(data) 382 383 async def handle_received_message(self, *, message: Message, source: str) -> None: 384 """ 385 Callback called when a message is received. Subclasses that receive messages 386 should override this method. 387 """ 388 pass 389 390 async def join(self) -> None: 391 """ 392 Joins the calling thread with this Node. Will block until all 393 join conditions are satified. 394 """ 395 await asyncio.gather( 396 *[ 397 target_message_queue.join() 398 for target_message_queue in self.__target_message_queues.values() 399 ], 400 *[ 401 audit_records_queue.join() 402 for audit_records_queue in self.__audit_records_queues.values() 403 ], 404 ) 405 406 def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None: 407 """ 408 Send the message to the specified targets. If no targets are specified 409 the message will be sent to all targets. 410 """ 411 self.send_messages([message], targets=targets) 412 413 def send_messages( 414 self, /, messages: list[Message], *, targets: set[Edge] = None 415 ) -> None: 416 """ 417 Send the messages to the specified targets. If no targets are specified 418 the messages will be sent to all targets. 419 """ 420 if self.stopped: 421 raise RuntimeError(f"{self.name} is stopped") 422 if messages: 423 for target in targets or self.targets: 424 if target_message_queue := self.__target_message_queues.get( 425 target.name 426 ): 427 for message in messages: 428 target_message_queue.put_nowait(message) 429 else: 430 getLogger().warning(f"Target {target.name} does not exist") 431 432 async def start(self) -> None: 433 """ 434 Starts this Node. Must be called prior to any other usage. 435 """ 436 getLogger().info(f"Starting Node {self.name}") 437 self.__lock = asyncio.Lock() 438 self.__bulk_data_storage_queue = _BulkDataStorageQueue(self) 439 async with self.__lock: 440 async with self._gql_client as session: 441 data: dict[str, Union[str, dict]] = ( 442 await session.execute( 443 _GET_NODE_GQL, 444 variable_values=dict(name=self.name, tenant=self.tenant), 445 ) 446 )["GetNode"] 447 self._audit = data["tenant"].get("audit") or False 448 self.config = ( 449 json.loads(data["tenant"].get("config") or "{}") 450 | json.loads((data.get("app") or dict()).get("config") or "{}") 451 | json.loads(data.get("config") or "{}") 452 ) 453 self._stopped = data.get("stopped") 454 if receive_message_type := data.get("receiveMessageType"): 455 self._receive_message_type = MessageType( 456 auditor=dynamic_function_loader.load(receive_message_type["auditor"]), 457 name=receive_message_type["name"], 458 ) 459 if not self.stopped and self.audit: 460 self.__audit_records_queues[receive_message_type["name"]] = ( 461 _AuditRecordQueue(self.receive_message_type, self) 462 ) 463 if send_message_type := data.get("sendMessageType"): 464 self._send_message_type = MessageType( 465 auditor=dynamic_function_loader.load(send_message_type["auditor"]), 466 name=send_message_type["name"], 467 ) 468 if not self.stopped and self.audit: 469 self.__audit_records_queues[send_message_type["name"]] = ( 470 _AuditRecordQueue(self.send_message_type, self) 471 ) 472 if self.node_type == "AppChangeReceiverNode": 473 if edge := data.get("receiveEdge"): 474 self._sources = {Edge(name=edge["source"]["name"], queue=edge["queue"])} 475 else: 476 self._sources = set() 477 else: 478 self._sources = { 479 Edge(name=edge["source"]["name"], queue=edge["queue"]) 480 for edge in (data.get("receiveEdges") or list()) 481 } 482 self._targets = { 483 Edge(name=edge["target"]["name"], queue=edge["queue"]) 484 for edge in (data.get("sendEdges") or list()) 485 } 486 if not self.stopped: 487 self.__target_message_queues = { 488 edge.name: _TargetMessageQueue(self, edge) for edge in self._targets 489 }
Base class for all implemented asyncio Nodes.
Nodes of this class must be instantiated outside of the asyncio event loop.
275 def __init__( 276 self, 277 *, 278 appsync_endpoint: str = None, 279 bulk_data_acceleration: bool = False, 280 client_id: str = None, 281 name: str = None, 282 password: str = None, 283 tenant: str = None, 284 timeout: float = None, 285 user_pool_id: str = None, 286 username: str = None, 287 ) -> None: 288 super().__init__( 289 appsync_endpoint=appsync_endpoint, 290 bulk_data_acceleration=bulk_data_acceleration, 291 client_id=client_id, 292 name=name, 293 password=password, 294 tenant=tenant, 295 timeout=timeout, 296 user_pool_id=user_pool_id, 297 username=username, 298 ) 299 self.__audit_records_queues: dict[str, _AuditRecordQueue] = dict() 300 self.__bulk_data_storage_queue: _BulkDataStorageQueue = None 301 self.__gql_client = GqlClient( 302 fetch_schema_from_transport=True, 303 transport=AIOHTTPTransport( 304 auth=AppSyncCognitoAuthentication(self.__cognito), 305 url=appsync_endpoint or environ["APPSYNC_ENDPOINT"], 306 ), 307 ) 308 self.__lock: asyncio.Lock = None 309 self.__target_message_queues: dict[str, _TargetMessageQueue] = dict()
319 def audit_message( 320 self, 321 /, 322 message: Message, 323 *, 324 extra_attributes: dict[str, Any] = None, 325 source: str = None, 326 ) -> None: 327 """ 328 Audits the provided message. If extra_attibutes is 329 supplied, they will be added to the message's audit 330 dict. If source is provided, it will be recorded in 331 the audit. 332 """ 333 if self.stopped: 334 raise RuntimeError(f"{self.name} is stopped") 335 if not self.audit: 336 return 337 extra_attributes = extra_attributes or dict() 338 message_type = message.message_type 339 record = dict( 340 datetime=datetime.now(timezone.utc).isoformat(), 341 previousTrackingIds=message.previous_tracking_ids, 342 sourceNode=source, 343 trackingId=message.tracking_id, 344 ) 345 if attributes := ( 346 message_type.auditor(message=message.body) | extra_attributes 347 ): 348 record["attributes"] = attributes 349 try: 350 self.__audit_records_queues[message_type.name].put_nowait(record) 351 except KeyError: 352 raise ValueError(f"Unrecognized message type {message_type.name}")
Audits the provided message. If extra_attibutes is supplied, they will be added to the message's audit dict. If source is provided, it will be recorded in the audit.
354 def audit_messages( 355 self, 356 /, 357 messages: list[Message], 358 *, 359 extra_attributes: list[dict[str, Any]] = None, 360 source: str = None, 361 ) -> None: 362 """ 363 Audits the provided messages. If extra_attibutes is 364 supplied they will be added to the respective message's audit 365 dict and they must have the same count as messages. 366 If source is provided, it will be recorded in the audit. 367 """ 368 if extra_attributes and len(extra_attributes) != len(messages): 369 raise ValueError( 370 "messages and extra_attributes must have the same number of items" 371 ) 372 for message, attributes in zip(messages, extra_attributes): 373 self.audit_message(message, extra_attributes=attributes, source=source)
Audits the provided messages. If extra_attibutes is supplied they will be added to the respective message's audit dict and they must have the same count as messages. If source is provided, it will be recorded in the audit.
375 async def handle_bulk_data(self, data: Union[bytearray, bytes]) -> str: 376 """ 377 Posts data as bulk data and returns a GET URL for data retrieval. 378 Normally this returned URL will be used as a "ticket" in messages 379 that require bulk data. 380 """ 381 return await (await self.__bulk_data_storage_queue.get()).handle_bulk_data(data)
Posts data as bulk data and returns a GET URL for data retrieval. Normally this returned URL will be used as a "ticket" in messages that require bulk data.
383 async def handle_received_message(self, *, message: Message, source: str) -> None: 384 """ 385 Callback called when a message is received. Subclasses that receive messages 386 should override this method. 387 """ 388 pass
Callback called when a message is received. Subclasses that receive messages should override this method.
390 async def join(self) -> None: 391 """ 392 Joins the calling thread with this Node. Will block until all 393 join conditions are satified. 394 """ 395 await asyncio.gather( 396 *[ 397 target_message_queue.join() 398 for target_message_queue in self.__target_message_queues.values() 399 ], 400 *[ 401 audit_records_queue.join() 402 for audit_records_queue in self.__audit_records_queues.values() 403 ], 404 )
Joins the calling thread with this Node. Will block until all join conditions are satified.
406 def send_message(self, /, message: Message, *, targets: set[Edge] = None) -> None: 407 """ 408 Send the message to the specified targets. If no targets are specified 409 the message will be sent to all targets. 410 """ 411 self.send_messages([message], targets=targets)
Send the message to the specified targets. If no targets are specified the message will be sent to all targets.
413 def send_messages( 414 self, /, messages: list[Message], *, targets: set[Edge] = None 415 ) -> None: 416 """ 417 Send the messages to the specified targets. If no targets are specified 418 the messages will be sent to all targets. 419 """ 420 if self.stopped: 421 raise RuntimeError(f"{self.name} is stopped") 422 if messages: 423 for target in targets or self.targets: 424 if target_message_queue := self.__target_message_queues.get( 425 target.name 426 ): 427 for message in messages: 428 target_message_queue.put_nowait(message) 429 else: 430 getLogger().warning(f"Target {target.name} does not exist")
Send the messages to the specified targets. If no targets are specified the messages will be sent to all targets.
432 async def start(self) -> None: 433 """ 434 Starts this Node. Must be called prior to any other usage. 435 """ 436 getLogger().info(f"Starting Node {self.name}") 437 self.__lock = asyncio.Lock() 438 self.__bulk_data_storage_queue = _BulkDataStorageQueue(self) 439 async with self.__lock: 440 async with self._gql_client as session: 441 data: dict[str, Union[str, dict]] = ( 442 await session.execute( 443 _GET_NODE_GQL, 444 variable_values=dict(name=self.name, tenant=self.tenant), 445 ) 446 )["GetNode"] 447 self._audit = data["tenant"].get("audit") or False 448 self.config = ( 449 json.loads(data["tenant"].get("config") or "{}") 450 | json.loads((data.get("app") or dict()).get("config") or "{}") 451 | json.loads(data.get("config") or "{}") 452 ) 453 self._stopped = data.get("stopped") 454 if receive_message_type := data.get("receiveMessageType"): 455 self._receive_message_type = MessageType( 456 auditor=dynamic_function_loader.load(receive_message_type["auditor"]), 457 name=receive_message_type["name"], 458 ) 459 if not self.stopped and self.audit: 460 self.__audit_records_queues[receive_message_type["name"]] = ( 461 _AuditRecordQueue(self.receive_message_type, self) 462 ) 463 if send_message_type := data.get("sendMessageType"): 464 self._send_message_type = MessageType( 465 auditor=dynamic_function_loader.load(send_message_type["auditor"]), 466 name=send_message_type["name"], 467 ) 468 if not self.stopped and self.audit: 469 self.__audit_records_queues[send_message_type["name"]] = ( 470 _AuditRecordQueue(self.send_message_type, self) 471 ) 472 if self.node_type == "AppChangeReceiverNode": 473 if edge := data.get("receiveEdge"): 474 self._sources = {Edge(name=edge["source"]["name"], queue=edge["queue"])} 475 else: 476 self._sources = set() 477 else: 478 self._sources = { 479 Edge(name=edge["source"]["name"], queue=edge["queue"]) 480 for edge in (data.get("receiveEdges") or list()) 481 } 482 self._targets = { 483 Edge(name=edge["target"]["name"], queue=edge["queue"]) 484 for edge in (data.get("sendEdges") or list()) 485 } 486 if not self.stopped: 487 self.__target_message_queues = { 488 edge.name: _TargetMessageQueue(self, edge) for edge in self._targets 489 }
Starts this Node. Must be called prior to any other usage.
649class AppNode(Node): 650 def __init__( 651 self, 652 *, 653 appsync_endpoint: str = None, 654 bulk_data_acceleration: bool = False, 655 client_id: str = None, 656 concurrent_processing: bool = False, 657 name: str = None, 658 password: str = None, 659 tenant: str = None, 660 timeout: float = None, 661 user_pool_id: str = None, 662 username: str = None, 663 ) -> None: 664 super().__init__( 665 appsync_endpoint=appsync_endpoint, 666 bulk_data_acceleration=bulk_data_acceleration, 667 client_id=client_id, 668 name=name, 669 password=password, 670 tenant=tenant, 671 timeout=timeout, 672 user_pool_id=user_pool_id, 673 username=username, 674 ) 675 self.__concurrent_processing = concurrent_processing 676 self.__source_message_receivers: list[_SourceMessageReceiver] = list() 677 678 @property 679 def _concurrent_processing(self) -> bool: 680 return self.__concurrent_processing 681 682 async def join(self) -> None: 683 """ 684 Joins the calling thread with this Node. Will block until all 685 join conditions are satified. 686 """ 687 await asyncio.gather( 688 *[ 689 source_message_receiver.join() 690 for source_message_receiver in self.__source_message_receivers 691 ] 692 ) 693 await super().join() 694 695 async def start(self) -> None: 696 """ 697 Starts this Node. Must be called prior to any other usage. 698 """ 699 await super().start() 700 if not self.stopped: 701 self.__source_message_receivers = [ 702 _SourceMessageReceiver(edge, self) for edge in self._sources 703 ] 704 705 async def start_and_run_forever(self) -> None: 706 """Will start this Node and run until the containing Task is cancelled""" 707 await self.start() 708 await self.join()
Base class for all implemented asyncio Nodes.
Nodes of this class must be instantiated outside of the asyncio event loop.
650 def __init__( 651 self, 652 *, 653 appsync_endpoint: str = None, 654 bulk_data_acceleration: bool = False, 655 client_id: str = None, 656 concurrent_processing: bool = False, 657 name: str = None, 658 password: str = None, 659 tenant: str = None, 660 timeout: float = None, 661 user_pool_id: str = None, 662 username: str = None, 663 ) -> None: 664 super().__init__( 665 appsync_endpoint=appsync_endpoint, 666 bulk_data_acceleration=bulk_data_acceleration, 667 client_id=client_id, 668 name=name, 669 password=password, 670 tenant=tenant, 671 timeout=timeout, 672 user_pool_id=user_pool_id, 673 username=username, 674 ) 675 self.__concurrent_processing = concurrent_processing 676 self.__source_message_receivers: list[_SourceMessageReceiver] = list()
682 async def join(self) -> None: 683 """ 684 Joins the calling thread with this Node. Will block until all 685 join conditions are satified. 686 """ 687 await asyncio.gather( 688 *[ 689 source_message_receiver.join() 690 for source_message_receiver in self.__source_message_receivers 691 ] 692 ) 693 await super().join()
Joins the calling thread with this Node. Will block until all join conditions are satified.
695 async def start(self) -> None: 696 """ 697 Starts this Node. Must be called prior to any other usage. 698 """ 699 await super().start() 700 if not self.stopped: 701 self.__source_message_receivers = [ 702 _SourceMessageReceiver(edge, self) for edge in self._sources 703 ]
Starts this Node. Must be called prior to any other usage.
705 async def start_and_run_forever(self) -> None: 706 """Will start this Node and run until the containing Task is cancelled""" 707 await self.start() 708 await self.join()
Will start this Node and run until the containing Task is cancelled
Inherited Members
711class LambdaNode(Node): 712 def __init__( 713 self, 714 *, 715 appsync_endpoint: str = None, 716 bulk_data_acceleration: bool = False, 717 client_id: str = None, 718 concurrent_processing: bool = False, 719 name: str = None, 720 password: str = None, 721 report_batch_item_failures: bool = False, 722 tenant: str = None, 723 timeout: float = None, 724 user_pool_id: str = None, 725 username: str = None, 726 ) -> None: 727 super().__init__( 728 appsync_endpoint=appsync_endpoint, 729 bulk_data_acceleration=bulk_data_acceleration, 730 client_id=client_id, 731 name=name, 732 password=password, 733 tenant=tenant, 734 timeout=timeout or 0.01, 735 user_pool_id=user_pool_id, 736 username=username, 737 ) 738 self.__concurrent_processing = concurrent_processing 739 self.__loop = self._create_event_loop() 740 self.__queue_name_to_source: dict[str, str] = None 741 self.__report_batch_item_failures = report_batch_item_failures 742 743 # Set up the asyncio loop 744 signal(SIGTERM, self._shutdown_handler) 745 746 self.__started = threading.Event() 747 self.__loop.create_task(self.start()) 748 749 # Run the event loop in a seperate thread, or else we will block the main 750 # Lambda execution! 751 threading.Thread(name="event_loop", target=self.__run_event_loop).start() 752 753 # Wait until the started event is set before returning control to 754 # Lambda 755 self.__started.wait() 756 757 def __run_event_loop(self) -> None: 758 getLogger().info("Starting event loop") 759 asyncio.set_event_loop(self.__loop) 760 761 pending_exception_to_raise: Exception = None 762 763 def exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None: 764 nonlocal pending_exception_to_raise 765 pending_exception_to_raise = context.get("exception") 766 getLogger().error( 767 "Unhandled exception; stopping loop: %r", 768 context.get("message"), 769 exc_info=pending_exception_to_raise, 770 ) 771 loop.stop() 772 773 self.__loop.set_exception_handler(exception_handler) 774 executor = ThreadPoolExecutor() 775 self.__loop.set_default_executor(executor) 776 777 self.__loop.run_forever() 778 getLogger().info("Entering shutdown phase") 779 getLogger().info("Cancelling pending tasks") 780 if tasks := asyncio.all_tasks(self.__loop): 781 for task in tasks: 782 getLogger().debug(f"Cancelling task: {task}") 783 task.cancel() 784 getLogger().info("Running pending tasks till complete") 785 self.__loop.run_until_complete( 786 asyncio.gather(*tasks, return_exceptions=True) 787 ) 788 getLogger().info("Waiting for executor shutdown") 789 executor.shutdown(wait=True) 790 getLogger().info("Shutting down async generators") 791 self.__loop.run_until_complete(self.__loop.shutdown_asyncgens()) 792 getLogger().info("Closing the loop.") 793 self.__loop.close() 794 getLogger().info("Loop is closed") 795 if pending_exception_to_raise: 796 getLogger().info("Reraising unhandled exception") 797 raise pending_exception_to_raise 798 799 def _create_event_loop(self) -> asyncio.AbstractEventLoop: 800 return asyncio.new_event_loop() 801 802 def _get_source(self, queue_arn: str) -> str: 803 return self.__queue_name_to_source[queue_arn.split(":")[-1:][0]] 804 805 async def _handle_event(self, event: LambdaEvent) -> BatchItemFailures: 806 """ 807 Handles the AWS Lambda event passed into the containing 808 AWS Lambda function during invocation. 809 810 This is intended to be the only called method in your 811 containing AWS Lambda function. 812 """ 813 records: LambdaSqsRecords = None 814 if not (records := event.get("Records")): 815 getLogger().warning(f"No Records found in event {event}") 816 return 817 818 source = self._get_source(records[0]["eventSourceARN"]) 819 getLogger().info(f"Received {len(records)} messages from {source}") 820 batch_item_failures: list[str] = ( 821 [record["messageId"] for record in records] 822 if self.__report_batch_item_failures 823 else None 824 ) 825 826 async def handle_received_message(message: Message, message_id: str) -> None: 827 try: 828 await self.handle_received_message(message=message, source=source) 829 except asyncio.CancelledError: 830 raise 831 except Exception: 832 if not self.__report_batch_item_failures: 833 raise 834 getLogger().exception(f"Error handling recevied message for {source}") 835 else: 836 if self.__report_batch_item_failures: 837 batch_item_failures.remove(message_id) 838 839 message_handlers = [ 840 handle_received_message( 841 Message( 842 body=record["body"], 843 group_id=record["attributes"]["MessageGroupId"], 844 message_type=self.receive_message_type, 845 previous_tracking_ids=record["messageAttributes"] 846 .get("prevTrackingIds", {}) 847 .get("stringValue"), 848 tracking_id=record["messageAttributes"] 849 .get("trackingId", {}) 850 .get("stringValue"), 851 ), 852 record["messageId"], 853 ) 854 for record in records 855 ] 856 857 if self.__concurrent_processing: 858 await asyncio.gather(*message_handlers) 859 else: 860 for message_handler in message_handlers: 861 await message_handler 862 await self.join() 863 if self.__report_batch_item_failures and batch_item_failures: 864 return dict( 865 batchItemFailures=[ 866 dict(itemIdentifier=message_id) 867 for message_id in batch_item_failures 868 ] 869 ) 870 871 def _shutdown_handler(self, signum: int, frame: object) -> None: 872 signal(SIGTERM, SIG_IGN) 873 getLogger().warning("Received SIGTERM, stopping the loop") 874 self.__loop.stop() 875 876 def handle_event(self, event: LambdaEvent) -> BatchItemFailures: 877 return asyncio.run_coroutine_threadsafe( 878 self._handle_event(event), self.__loop 879 ).result() 880 881 async def start(self) -> None: 882 await super().start() 883 self.__queue_name_to_source = { 884 edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources 885 } 886 self.__started.set()
Base class for all implemented asyncio Nodes.
Nodes of this class must be instantiated outside of the asyncio event loop.
712 def __init__( 713 self, 714 *, 715 appsync_endpoint: str = None, 716 bulk_data_acceleration: bool = False, 717 client_id: str = None, 718 concurrent_processing: bool = False, 719 name: str = None, 720 password: str = None, 721 report_batch_item_failures: bool = False, 722 tenant: str = None, 723 timeout: float = None, 724 user_pool_id: str = None, 725 username: str = None, 726 ) -> None: 727 super().__init__( 728 appsync_endpoint=appsync_endpoint, 729 bulk_data_acceleration=bulk_data_acceleration, 730 client_id=client_id, 731 name=name, 732 password=password, 733 tenant=tenant, 734 timeout=timeout or 0.01, 735 user_pool_id=user_pool_id, 736 username=username, 737 ) 738 self.__concurrent_processing = concurrent_processing 739 self.__loop = self._create_event_loop() 740 self.__queue_name_to_source: dict[str, str] = None 741 self.__report_batch_item_failures = report_batch_item_failures 742 743 # Set up the asyncio loop 744 signal(SIGTERM, self._shutdown_handler) 745 746 self.__started = threading.Event() 747 self.__loop.create_task(self.start()) 748 749 # Run the event loop in a seperate thread, or else we will block the main 750 # Lambda execution! 751 threading.Thread(name="event_loop", target=self.__run_event_loop).start() 752 753 # Wait until the started event is set before returning control to 754 # Lambda 755 self.__started.wait()
881 async def start(self) -> None: 882 await super().start() 883 self.__queue_name_to_source = { 884 edge.queue.split("/")[-1:][0]: edge.name for edge in self._sources 885 } 886 self.__started.set()
Starts this Node. Must be called prior to any other usage.