223 lines
6.9 KiB
Python
223 lines
6.9 KiB
Python
|
|
"""MQTT client service for real-time entity updates.
|
||
|
|
|
||
|
|
Connects to an MQTT broker, subscribes to configured topics, and stores
|
||
|
|
the latest messages per topic. The dashboard WebSocket pushes MQTT
|
||
|
|
state changes to connected clients.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Any, Callable, Dict, List, Optional
|
||
|
|
|
||
|
|
import aiomqtt
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class MqttMessage:
|
||
|
|
"""Single retained MQTT message."""
|
||
|
|
|
||
|
|
topic: str
|
||
|
|
payload: Any
|
||
|
|
timestamp: float = field(default_factory=time.time)
|
||
|
|
|
||
|
|
|
||
|
|
class MqttService:
|
||
|
|
"""Async MQTT client that maintains a live topic→value store."""
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._store: Dict[str, MqttMessage] = {}
|
||
|
|
self._client: Optional[aiomqtt.Client] = None
|
||
|
|
self._task: Optional[asyncio.Task] = None
|
||
|
|
self._listeners: List[Callable[[MqttMessage], Any]] = []
|
||
|
|
self._connected = False
|
||
|
|
self._config: Dict[str, Any] = {}
|
||
|
|
|
||
|
|
@property
|
||
|
|
def connected(self) -> bool:
|
||
|
|
return self._connected
|
||
|
|
|
||
|
|
@property
|
||
|
|
def store(self) -> Dict[str, MqttMessage]:
|
||
|
|
return self._store
|
||
|
|
|
||
|
|
def on_message(self, callback: Callable[[MqttMessage], Any]) -> None:
|
||
|
|
"""Register a callback fired on every incoming MQTT message."""
|
||
|
|
self._listeners.append(callback)
|
||
|
|
|
||
|
|
async def start(
|
||
|
|
self,
|
||
|
|
host: str,
|
||
|
|
port: int = 1883,
|
||
|
|
username: Optional[str] = None,
|
||
|
|
password: Optional[str] = None,
|
||
|
|
topics: Optional[List[str]] = None,
|
||
|
|
client_id: str = "daily-briefing",
|
||
|
|
) -> None:
|
||
|
|
"""Connect to the broker and start listening in the background."""
|
||
|
|
if not host:
|
||
|
|
logger.info("MQTT disabled — no MQTT_HOST configured")
|
||
|
|
return
|
||
|
|
|
||
|
|
self._config = dict(
|
||
|
|
host=host,
|
||
|
|
port=port,
|
||
|
|
username=username,
|
||
|
|
password=password,
|
||
|
|
topics=topics or ["#"],
|
||
|
|
client_id=client_id,
|
||
|
|
)
|
||
|
|
self._task = asyncio.create_task(self._run_loop())
|
||
|
|
logger.info("MQTT background task started (broker %s:%d)", host, port)
|
||
|
|
|
||
|
|
async def stop(self) -> None:
|
||
|
|
"""Disconnect and cancel the background task."""
|
||
|
|
if self._task:
|
||
|
|
self._task.cancel()
|
||
|
|
try:
|
||
|
|
await self._task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
self._task = None
|
||
|
|
self._connected = False
|
||
|
|
logger.info("MQTT service stopped")
|
||
|
|
|
||
|
|
async def publish(
|
||
|
|
self,
|
||
|
|
topic: str,
|
||
|
|
payload: Any,
|
||
|
|
retain: bool = False,
|
||
|
|
) -> None:
|
||
|
|
"""Publish a message to the broker (e.g. for controlling devices)."""
|
||
|
|
if not self._client or not self._connected:
|
||
|
|
raise RuntimeError("MQTT not connected")
|
||
|
|
msg = json.dumps(payload) if not isinstance(payload, (str, bytes)) else payload
|
||
|
|
await self._client.publish(topic, msg, retain=retain)
|
||
|
|
logger.debug("MQTT published → %s", topic)
|
||
|
|
|
||
|
|
def get_state(self) -> Dict[str, Any]:
|
||
|
|
"""Return all stored MQTT states for the dashboard API."""
|
||
|
|
result: Dict[str, Any] = {}
|
||
|
|
for topic, msg in self._store.items():
|
||
|
|
result[topic] = {
|
||
|
|
"value": msg.payload,
|
||
|
|
"timestamp": msg.timestamp,
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
|
||
|
|
def get_entities(self) -> List[Dict[str, Any]]:
|
||
|
|
"""Return a flat list of MQTT entities grouped for the frontend."""
|
||
|
|
entities: List[Dict[str, Any]] = []
|
||
|
|
for topic, msg in self._store.items():
|
||
|
|
entities.append({
|
||
|
|
"topic": topic,
|
||
|
|
"value": msg.payload,
|
||
|
|
"timestamp": msg.timestamp,
|
||
|
|
"name": _topic_to_name(topic),
|
||
|
|
"category": _topic_to_category(topic),
|
||
|
|
})
|
||
|
|
return entities
|
||
|
|
|
||
|
|
# -- internal ---
|
||
|
|
|
||
|
|
async def _run_loop(self) -> None:
|
||
|
|
"""Reconnecting event loop."""
|
||
|
|
cfg = self._config
|
||
|
|
backoff = 1
|
||
|
|
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
async with aiomqtt.Client(
|
||
|
|
hostname=cfg["host"],
|
||
|
|
port=cfg["port"],
|
||
|
|
username=cfg.get("username"),
|
||
|
|
password=cfg.get("password"),
|
||
|
|
identifier=cfg["client_id"],
|
||
|
|
) as client:
|
||
|
|
self._client = client
|
||
|
|
self._connected = True
|
||
|
|
backoff = 1
|
||
|
|
logger.info(
|
||
|
|
"MQTT connected to %s:%d — subscribing to %d topic(s)",
|
||
|
|
cfg["host"],
|
||
|
|
cfg["port"],
|
||
|
|
len(cfg["topics"]),
|
||
|
|
)
|
||
|
|
|
||
|
|
for t in cfg["topics"]:
|
||
|
|
await client.subscribe(t)
|
||
|
|
|
||
|
|
async for message in client.messages:
|
||
|
|
await self._handle_message(message)
|
||
|
|
|
||
|
|
except aiomqtt.MqttError as exc:
|
||
|
|
self._connected = False
|
||
|
|
self._client = None
|
||
|
|
logger.warning(
|
||
|
|
"MQTT connection lost (%s) — reconnecting in %ds",
|
||
|
|
exc,
|
||
|
|
backoff,
|
||
|
|
)
|
||
|
|
await asyncio.sleep(backoff)
|
||
|
|
backoff = min(backoff * 2, 60)
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
self._connected = False
|
||
|
|
self._client = None
|
||
|
|
break
|
||
|
|
|
||
|
|
except Exception:
|
||
|
|
self._connected = False
|
||
|
|
self._client = None
|
||
|
|
logger.exception("Unexpected MQTT error — reconnecting in %ds", backoff)
|
||
|
|
await asyncio.sleep(backoff)
|
||
|
|
backoff = min(backoff * 2, 60)
|
||
|
|
|
||
|
|
async def _handle_message(self, message: aiomqtt.Message) -> None:
|
||
|
|
topic = str(message.topic)
|
||
|
|
raw = message.payload
|
||
|
|
|
||
|
|
# Try JSON decode
|
||
|
|
if isinstance(raw, (bytes, bytearray)):
|
||
|
|
raw = raw.decode("utf-8", errors="replace")
|
||
|
|
try:
|
||
|
|
payload = json.loads(raw)
|
||
|
|
except (json.JSONDecodeError, TypeError):
|
||
|
|
payload = raw
|
||
|
|
|
||
|
|
msg = MqttMessage(topic=topic, payload=payload)
|
||
|
|
self._store[topic] = msg
|
||
|
|
|
||
|
|
# Notify listeners
|
||
|
|
for cb in self._listeners:
|
||
|
|
try:
|
||
|
|
result = cb(msg)
|
||
|
|
if asyncio.iscoroutine(result):
|
||
|
|
await result
|
||
|
|
except Exception:
|
||
|
|
logger.exception("MQTT listener error on topic %s", topic)
|
||
|
|
|
||
|
|
|
||
|
|
def _topic_to_name(topic: str) -> str:
|
||
|
|
"""Derive a human-readable name from the MQTT topic."""
|
||
|
|
parts = topic.rstrip("/").split("/")
|
||
|
|
return parts[-1].replace("_", " ").replace("-", " ").title() if parts else topic
|
||
|
|
|
||
|
|
|
||
|
|
def _topic_to_category(topic: str) -> str:
|
||
|
|
"""Derive a category from the first topic segment."""
|
||
|
|
parts = topic.strip("/").split("/")
|
||
|
|
if len(parts) >= 2:
|
||
|
|
return parts[0]
|
||
|
|
return "other"
|
||
|
|
|
||
|
|
|
||
|
|
# Singleton
|
||
|
|
mqtt_service = MqttService()
|