231 lines
6.8 KiB
Python
231 lines
6.8 KiB
Python
"""Discord adapter for interacting with FastAPI backend worker."""
|
|
|
|
import asyncio
|
|
from os import getenv
|
|
|
|
import aiohttp
|
|
import interactions
|
|
|
|
from ..env import load_project_dotenv
|
|
|
|
|
|
load_project_dotenv()
|
|
|
|
|
|
BACKEND_HOST = getenv("YOUDIS_BACKEND_HOST", "127.0.0.1")
|
|
BACKEND_PORT = int(getenv("YOUDIS_BACKEND_PORT", "8000"))
|
|
BACKEND_URL = f"http://{BACKEND_HOST}:{BACKEND_PORT}".rstrip("/")
|
|
POLL_INTERVAL_SECONDS = float(getenv("YOUDIS_POLL_INTERVAL_SECONDS", "2"))
|
|
DEFAULT_SCOPE = int(getenv("DISCORD_BOT_SCOPE", "2147491904"))
|
|
|
|
|
|
bot = interactions.Client(
|
|
intents=interactions.Intents.DEFAULT,
|
|
default_scope=DEFAULT_SCOPE,
|
|
)
|
|
http_session: aiohttp.ClientSession | None = None
|
|
poll_tasks: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
def backend_url(path: str) -> str:
|
|
return f"{BACKEND_URL}{path}"
|
|
|
|
|
|
async def get_session() -> aiohttp.ClientSession:
|
|
global http_session
|
|
if http_session is None or http_session.closed:
|
|
http_session = aiohttp.ClientSession()
|
|
return http_session
|
|
|
|
|
|
async def request_json(method: str, path: str, **kwargs):
|
|
session = await get_session()
|
|
async with session.request(method, backend_url(path), **kwargs) as response:
|
|
data = await response.json()
|
|
return response.status, data
|
|
|
|
|
|
def format_status_message(job: dict) -> str:
|
|
state = job.get("state")
|
|
phase = job.get("phase")
|
|
disposition = job.get("disposition")
|
|
message = job.get("message")
|
|
result_path = job.get("result_path")
|
|
|
|
parts = [f"state={state}"]
|
|
if phase:
|
|
parts.append(f"phase={phase}")
|
|
if disposition:
|
|
parts.append(f"disposition={disposition}")
|
|
if message:
|
|
parts.append(message)
|
|
if result_path:
|
|
parts.append(f"path={result_path}")
|
|
return " | ".join(parts)
|
|
|
|
|
|
async def dm(ctx: interactions.SlashContext, message: str) -> None:
|
|
await ctx.author.send(message)
|
|
|
|
|
|
async def respond(ctx: interactions.SlashContext, message: str) -> None:
|
|
if ctx.channel is not None:
|
|
await ctx.channel.send(message)
|
|
return
|
|
await dm(ctx, message)
|
|
|
|
|
|
async def poll_job_updates(ctx: interactions.SlashContext, job_id: str) -> None:
|
|
last_sent = None
|
|
try:
|
|
while True:
|
|
status_code, payload = await request_json("GET", "/jobs/current")
|
|
if status_code != 200:
|
|
await dm(ctx, f"backend status check failed: HTTP {status_code}")
|
|
return
|
|
|
|
job = payload.get("job")
|
|
if not job:
|
|
await dm(ctx, f"job {job_id} is no longer visible from the backend")
|
|
return
|
|
|
|
if job.get("job_id") != job_id:
|
|
await dm(ctx, f"job {job_id} is no longer the current backend job")
|
|
return
|
|
|
|
summary = format_status_message(job)
|
|
if summary != last_sent:
|
|
await dm(ctx, summary)
|
|
last_sent = summary
|
|
|
|
if job.get("state") in {"completed", "failed", "cancelled"}:
|
|
return
|
|
|
|
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except aiohttp.ClientError as exc:
|
|
await dm(ctx, f"backend poll failed: {exc}")
|
|
finally:
|
|
poll_tasks.pop(job_id, None)
|
|
|
|
|
|
def ensure_poll_task(ctx: interactions.SlashContext, job_id: str) -> None:
|
|
existing = poll_tasks.get(job_id)
|
|
if existing and not existing.done():
|
|
return
|
|
poll_tasks[job_id] = asyncio.create_task(poll_job_updates(ctx, job_id))
|
|
|
|
@bot.listen()
|
|
async def on_startup():
|
|
await get_session()
|
|
print(f"discord adapter configured for backend {BACKEND_URL}")
|
|
|
|
@bot.listen()
|
|
async def on_shutdown():
|
|
global http_session
|
|
for task in list(poll_tasks.values()):
|
|
task.cancel()
|
|
poll_tasks.clear()
|
|
if http_session is not None and not http_session.closed:
|
|
await http_session.close()
|
|
http_session = None
|
|
|
|
|
|
@interactions.slash_command(name="youtube", description="submit a youtube download to the backend")
|
|
@interactions.slash_option(
|
|
name="url",
|
|
opt_type=interactions.OptionType.STRING,
|
|
required=True,
|
|
description="url target",
|
|
)
|
|
async def youtube(ctx: interactions.SlashContext, url: str):
|
|
payload = {
|
|
"url": url,
|
|
"requester_id": str(ctx.author.id),
|
|
"requester_name": ctx.author.username,
|
|
"origin": f"discord:{ctx.guild_id or 'dm'}:{ctx.channel_id}",
|
|
}
|
|
|
|
try:
|
|
status_code, job = await request_json("POST", "/jobs", json=payload)
|
|
except aiohttp.ClientError as exc:
|
|
await dm(ctx, f"backend request failed: {exc}")
|
|
return
|
|
|
|
if status_code != 200:
|
|
await dm(ctx, f"backend request failed: HTTP {status_code}")
|
|
return
|
|
|
|
state = job.get("state")
|
|
job_id = job.get("job_id", "unknown")
|
|
if state == "busy":
|
|
await respond(ctx, "Backend is busy with another job. Details via DM.")
|
|
await dm(ctx, f"busy: {job.get('message')}")
|
|
return
|
|
|
|
if state != "accepted":
|
|
await respond(ctx, "Backend rejected the request. Details via DM.")
|
|
await dm(ctx, format_status_message(job))
|
|
return
|
|
|
|
await respond(ctx, f"Submitted <{url}> to the backend. Status updates via DM.")
|
|
await dm(ctx, f"accepted job {job_id} for <{url}>")
|
|
ensure_poll_task(ctx, job_id)
|
|
|
|
|
|
@interactions.slash_command(name="interrupt", description="cancel the current backend job")
|
|
@interactions.check(interactions.is_owner())
|
|
async def interrupt(ctx: interactions.SlashContext):
|
|
try:
|
|
status_code, payload = await request_json("POST", "/jobs/current/cancel")
|
|
except aiohttp.ClientError as exc:
|
|
await dm(ctx, f"backend cancel failed: {exc}")
|
|
return
|
|
|
|
if status_code == 404:
|
|
await dm(ctx, "no active backend job to interrupt")
|
|
return
|
|
|
|
if status_code != 200:
|
|
await dm(ctx, f"backend cancel failed: HTTP {status_code}")
|
|
return
|
|
|
|
await dm(ctx, format_status_message(payload))
|
|
|
|
|
|
@interactions.slash_command(name="status", description="show the current backend job status")
|
|
async def status(ctx: interactions.SlashContext):
|
|
try:
|
|
status_code, payload = await request_json("GET", "/jobs/current")
|
|
except aiohttp.ClientError as exc:
|
|
await dm(ctx, f"backend status failed: {exc}")
|
|
return
|
|
|
|
if status_code != 200:
|
|
await dm(ctx, f"backend status failed: HTTP {status_code}")
|
|
return
|
|
|
|
job = payload.get("job")
|
|
if not job:
|
|
await dm(ctx, "backend has no active or recent job")
|
|
return
|
|
|
|
active = payload.get("active")
|
|
prefix = "active" if active else "last"
|
|
await dm(ctx, f"{prefix} job: {format_status_message(job)}")
|
|
|
|
|
|
def main() -> None:
|
|
api_token = getenv("DISCORD_BOT_TOKEN")
|
|
if not api_token:
|
|
raise ValueError("API token not set. Retrieve from your Discord bot.")
|
|
bot.add_command(youtube)
|
|
bot.add_command(status)
|
|
bot.add_command(interrupt)
|
|
bot.start(api_token)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|