email-amazon/unified-worker/email-worker/unified_worker.py

202 lines
6.9 KiB
Python

#!/usr/bin/env python3
"""
Unified Worker - coordinates all domain pollers
"""
import sys
import time
import threading
from typing import List, Dict
from logger import log
from config import config, load_domains
from aws import S3Handler, SQSHandler, SESHandler, DynamoDBHandler
from smtp import SMTPPool
from smtp.delivery import EmailDelivery
from worker import MessageProcessor
from domain_poller import DomainPoller
from metrics.prometheus import MetricsCollector
class UnifiedWorker:
"""Main worker coordinating all domain pollers"""
def __init__(self):
self.stop_event = threading.Event()
self.domains: List[str] = []
self.queue_urls: Dict[str, str] = {}
self.poller_threads: List[threading.Thread] = []
# Shared stats across all pollers
self.domain_stats: Dict[str, int] = {} # domain -> processed count
self.stats_lock = threading.Lock()
# AWS handlers
self.s3 = S3Handler()
self.sqs = SQSHandler()
self.ses = SESHandler()
self.dynamodb = DynamoDBHandler()
# SMTP pool
self.smtp_pool = SMTPPool(config.smtp_host, config.smtp_port, config.smtp_pool_size)
# Email delivery
self.delivery = EmailDelivery(self.smtp_pool)
# Metrics
self.metrics: MetricsCollector = None
# Message processor
self.processor = MessageProcessor(
self.s3,
self.sqs,
self.ses,
self.dynamodb,
self.delivery,
None # Metrics will be set later
)
def setup(self):
"""Initialize worker"""
self.domains = load_domains()
if not self.domains:
log("❌ No domains configured!", 'ERROR')
sys.exit(1)
# Get queue URLs
for domain in self.domains:
url = self.sqs.get_queue_url(domain)
if url:
self.queue_urls[domain] = url
log(f"{domain} -> queue found")
else:
log(f"{domain} -> Queue not found!", 'WARNING')
if not self.queue_urls:
log("❌ No valid queues found!", 'ERROR')
sys.exit(1)
# Initialize SMTP pool
self.smtp_pool.initialize()
log(f"Initialized with {len(self.queue_urls)} domains")
def start(self):
"""Start all domain pollers"""
# Initialize stats for all domains
for domain in self.queue_urls.keys():
self.domain_stats[domain] = 0
# Create poller for each domain
for domain, queue_url in self.queue_urls.items():
poller = DomainPoller(
domain=domain,
queue_url=queue_url,
message_processor=self.processor,
sqs=self.sqs,
metrics=self.metrics,
stop_event=self.stop_event,
stats_dict=self.domain_stats,
stats_lock=self.stats_lock
)
thread = threading.Thread(
target=poller.poll,
name=f"poller-{domain}",
daemon=True
)
thread.start()
self.poller_threads.append(thread)
log(f"Started {len(self.poller_threads)} domain pollers")
# Periodic status log (every 5 minutes)
last_status_log = time.time()
status_interval = 300 # 5 minutes
try:
while not self.stop_event.is_set():
self.stop_event.wait(timeout=10)
# Log status summary every 5 minutes
if time.time() - last_status_log > status_interval:
self._log_status_table()
last_status_log = time.time()
except KeyboardInterrupt:
pass
def _log_status_table(self):
"""Log a compact status table"""
active_threads = sum(1 for t in self.poller_threads if t.is_alive())
with self.stats_lock:
total_processed = sum(self.domain_stats.values())
# Build compact stats: only show domains with activity or top domains
stats_parts = []
for domain in sorted(self.queue_urls.keys()):
count = self.domain_stats.get(domain, 0)
if count > 0: # Only show active domains
# Shorten domain for display
short_domain = domain.split('.')[0][:12]
stats_parts.append(f"{short_domain}:{count}")
if stats_parts:
stats_line = " | ".join(stats_parts)
else:
stats_line = "no activity"
log(
f"📊 Status: {active_threads}/{len(self.poller_threads)} active, "
f"total:{total_processed} | {stats_line}"
)
def stop(self):
"""Stop gracefully"""
log("⚠ Stopping worker...")
self.stop_event.set()
# Wait for poller threads (max 10 seconds each)
for thread in self.poller_threads:
thread.join(timeout=10)
if thread.is_alive():
log(f"Warning: {thread.name} did not stop gracefully", 'WARNING')
self.smtp_pool.close_all()
log("👋 Worker stopped")
def set_metrics(self, metrics: MetricsCollector):
"""Set metrics collector"""
self.metrics = metrics
self.processor.metrics = metrics
def print_startup_banner(self):
"""Print startup information"""
log(f"\n{'='*70}")
log(f"🚀 UNIFIED EMAIL WORKER")
log(f"{'='*70}")
log(f" Domains: {len(self.queue_urls)}")
log(f" DynamoDB: {'Connected' if self.dynamodb.available else 'Not Available'}")
if config.lmtp_enabled:
log(f" Delivery: LMTP -> {config.lmtp_host}:{config.lmtp_port} (bypasses transport_maps)")
else:
log(f" Delivery: SMTP -> {config.smtp_host}:{config.smtp_port}")
log(f" Poll Interval: {config.poll_interval}s")
log(f" Visibility: {config.visibility_timeout}s")
log(f"")
log(f" Features:")
log(f" ✓ Bounce Detection & Header Rewriting")
log(f" {'' if self.dynamodb.available else ''} Auto-Reply / Out-of-Office")
log(f" {'' if self.dynamodb.available else ''} Email Forwarding")
log(f" {'' if self.dynamodb.available else ''} Blocked Senders (Wildcard)")
log(f" {'' if self.metrics else ''} Prometheus Metrics")
log(f" {'' if config.lmtp_enabled else ''} LMTP Direct Delivery")
log(f"")
log(f" Active Domains:")
for domain in sorted(self.queue_urls.keys()):
log(f"{domain}")
log(f"{'='*70}\n")