diff --git a/server/services/news_service.py b/server/services/news_service.py index 54cb216..cb24c6e 100644 --- a/server/services/news_service.py +++ b/server/services/news_service.py @@ -7,12 +7,19 @@ from typing import Any, Dict, List, Optional from server.db import get_pool +# Sources whose articles should be categorised as "tech" when they have no +# specific category or are filed under the generic "allgemein" bucket. +TECH_SOURCES = {"Golem", "Heise", "Computerbase"} + def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]: """Convert an asyncpg Record to a plain dictionary with JSON-safe values.""" d: Dict[str, Any] = dict(row) if "published_at" in d and d["published_at"] is not None: d["published_at"] = d["published_at"].isoformat() + # Override category for known tech sources + if d.get("source") in TECH_SOURCES and d.get("category") in (None, "allgemein"): + d["category"] = "tech" return d @@ -35,9 +42,23 @@ async def get_news( ) if category is not None: - base_query += f" AND category = ${param_idx}" - params.append(category) - param_idx += 1 + if category == "tech": + # "tech" is a virtual category — match tech sources with allgemein/NULL + src_placeholders = ", ".join(f"${param_idx + i}" for i in range(len(TECH_SOURCES))) + base_query += f" AND source IN ({src_placeholders}) AND (category IS NULL OR category = 'allgemein')" + params.extend(sorted(TECH_SOURCES)) + param_idx += len(TECH_SOURCES) + else: + # Exclude tech sources from "allgemein" so they don't appear twice + if category == "allgemein": + src_placeholders = ", ".join(f"${param_idx + i}" for i in range(len(TECH_SOURCES))) + base_query += f" AND category = 'allgemein' AND source NOT IN ({src_placeholders})" + params.extend(sorted(TECH_SOURCES)) + param_idx += len(TECH_SOURCES) + else: + base_query += f" AND category = ${param_idx}" + params.append(category) + param_idx += 1 base_query += f" ORDER BY published_at DESC LIMIT ${param_idx} OFFSET ${param_idx + 1}" params.append(limit) @@ -66,8 +87,17 @@ async def get_news_count( ) if category is not None: - query += f" AND category = ${param_idx}" - params.append(category) + if category == "tech": + src_placeholders = ", ".join(f"${param_idx + i}" for i in range(len(TECH_SOURCES))) + query += f" AND source IN ({src_placeholders}) AND (category IS NULL OR category = 'allgemein')" + params.extend(sorted(TECH_SOURCES)) + elif category == "allgemein": + src_placeholders = ", ".join(f"${param_idx + i}" for i in range(len(TECH_SOURCES))) + query += f" AND category = 'allgemein' AND source NOT IN ({src_placeholders})" + params.extend(sorted(TECH_SOURCES)) + else: + query += f" AND category = ${param_idx}" + params.append(category) async with pool.acquire() as conn: row = await conn.fetchrow(query, *params) @@ -90,4 +120,9 @@ async def get_categories(max_age_hours: int = 48) -> List[str]: async with pool.acquire() as conn: rows = await conn.fetch(query) - return [row["category"] for row in rows] + cats = [row["category"] for row in rows] + # Inject "tech" if any tech source has articles + if "allgemein" in cats and "tech" not in cats: + cats.append("tech") + cats.sort() + return cats