diff --git a/journal/journal.py b/journal/journal.py index fdb3a3e..2a78fe3 100644 --- a/journal/journal.py +++ b/journal/journal.py @@ -1,5 +1,5 @@ +import re from datetime import datetime, timezone -from typing import Optional from maubot import MessageEvent, Plugin from maubot.handlers import command @@ -25,6 +25,22 @@ async def upgrade_v1(conn: Connection) -> None: await conn.execute("CREATE INDEX entries_ts ON entries (ts DESC)") +# Match `!journal` followed by any whitespace (space, tab, OR newline) +# and capture everything after. Maubot's @command.new parser only treats +# *space* as the command/args delimiter, so `!journal\n` gets +# parsed as a command name of "journal\n" and matches nothing, +# silently dropping multi-line entries. A passive regex matcher with +# DOTALL bypasses the parser quirk and catches every form. +_JOURNAL_RE = re.compile(r"^!journal(?:[ \t\r\n]+(.*))?$", re.DOTALL) + +_USAGE = ( + "Usage:\n" + "- `!journal ` — record an entry (multi-line OK)\n" + "- `!journal show [@user]` — last 10 entries, optionally filtered\n" + "- `!journal today` — all entries from today (UTC)" +) + + def _fmt(rows) -> str: if not rows: return "No entries." @@ -40,49 +56,55 @@ class JournalBot(Plugin): def get_db_upgrade_table(cls) -> UpgradeTable: return upgrade_table - @command.new( - "journal", - help="Farm journal — record what you did today", - require_subcommand=False, - arg_fallthrough=False, - ) - @command.argument("text", pass_raw=True, required=False) - async def journal(self, evt: MessageEvent, text: str = "") -> None: - if not text: - await evt.reply( - "Usage:\n" - "- `!journal ` — record an entry\n" - "- `!journal show [@user]` — last 10 entries (optionally filtered by user)\n" - "- `!journal today` — all entries from today" - ) + @command.passive(regex=_JOURNAL_RE) + async def journal(self, evt: MessageEvent, match) -> None: + rest = (match[1] or "").strip() + + if not rest: + await evt.reply(_USAGE) return + # Subcommand detection on the first whitespace-delimited token of + # the first line — only catches `show`/`today` if they appear + # alone on the first line (with optional arg). Anything else + # (including pasted multi-line lists) is recorded as-is. + first_line, _, _ = rest.partition("\n") + first_token, _, after = first_line.partition(" ") + + if first_token == "show": + user = after.strip() or None + if user: + rows = await self.database.fetch( + "SELECT user, ts, text FROM entries" + " WHERE user = $1 ORDER BY ts DESC LIMIT 10", + user, + ) + else: + rows = await self.database.fetch( + "SELECT user, ts, text FROM entries ORDER BY ts DESC LIMIT 10", + ) + await evt.reply(_fmt(rows)) + return + + if first_token == "today": + midnight = datetime.now(timezone.utc).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + cutoff_ms = int(midnight.timestamp() * 1000) + rows = await self.database.fetch( + "SELECT user, ts, text FROM entries" + " WHERE ts >= $1 ORDER BY ts ASC", + cutoff_ms, + ) + await evt.reply(_fmt(rows)) + return + + # Default: record the full rest (multi-line preserved) await self.database.execute( "INSERT INTO entries (user, room, ts, text) VALUES ($1, $2, $3, $4)", - evt.sender, evt.room_id, evt.timestamp, text, + evt.sender, + evt.room_id, + evt.timestamp, + rest, ) await evt.reply(f"📓 Logged for {evt.sender}.") - - @journal.subcommand("show", help="Show recent entries, optionally filtered by user") - @command.argument("user", required=False) - async def show(self, evt: MessageEvent, user: Optional[str] = None) -> None: - if user: - rows = await self.database.fetch( - "SELECT user, ts, text FROM entries WHERE user = $1 ORDER BY ts DESC LIMIT 10", - user, - ) - else: - rows = await self.database.fetch( - "SELECT user, ts, text FROM entries ORDER BY ts DESC LIMIT 10", - ) - await evt.reply(_fmt(rows)) - - @journal.subcommand("today", help="All entries from today (UTC) across users") - async def today(self, evt: MessageEvent) -> None: - midnight = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) - cutoff_ms = int(midnight.timestamp() * 1000) - rows = await self.database.fetch( - "SELECT user, ts, text FROM entries WHERE ts >= $1 ORDER BY ts ASC", - cutoff_ms, - ) - await evt.reply(_fmt(rows)) diff --git a/journal/maubot.yaml b/journal/maubot.yaml index 9352d83..74e2ae3 100644 --- a/journal/maubot.yaml +++ b/journal/maubot.yaml @@ -1,6 +1,6 @@ maubot: 0.1.0 id: dev.aiolabs.journal -version: 0.1.3 +version: 0.2.0 license: AGPL-3.0-or-later modules: - journal