#!/usr/bin/python3 -su

## Copyright (C) 2025 - 2025 ENCRYPTED SUPPORT LLC <adrelanos@whonix.org>
## See the file COPYING for copying conditions.

# pylint: disable=broad-exception-caught

"""
find_wl_compositor.py - determines the XDG_RUNTIME_DIR and WAYLAND_DISPLAY
environment variable values needed to connect to the compositor on the active
VT.

We use the following procedure to find the correct compositor:

- Determine if XDG_RUNTIME_DIR and WAYLAND_DISPLAY are already set. If
  they are, we have everything we need already set up, skip the rest of
  the autodetect process. Otherwise, continue.
- Determine the active VT.
- Find all applications that are "running on" the TTY corresponding to
  that VT (any application with stdout or stderr connected to a
  particular /dev/tty device are considered to be "running on" that VT).
- Determine which one of those applications is a Wayland compositor by
  matching against their process name (comm). We tolerate multiple
  matches so that things still work even if someone happens to have an
  application named 'sway', 'labwc', or similar that isn't a Wayland
  compositor.
- Extract the XDG_RUNTIME_DIR variable from each matched process to
  determine where it most likely has put its Wayland socket.
- Look in all found XDG_RUNTIME_DIR directories for any UNIX sockets
  named according to the pattern 'wayland-*'.
- Attempt to connect to each of these sockets, starting from the earliest
  and ending at the latest, quarying the PID of the process listening on
  each socket as we go.
- Once a socket is found that matches a PID found earlier, we have found
  the Wayland socket for the current VT. Export XDG_RUNTIME_DIR and
  WAYLAND_DISPLAY to point to that socket.

This should work, as long as no one makes an application named after a
popular Wayland compositor that also opens a Wayland socket in
XDG_RUNTIME_DIR. Unfortunately, there doesn't appear to be any good way
to determine the PID of the process that put the VT into graphics mode,
so it seems like making an educated guess like this is the best that can
be done.

This method should be pretty hard to fool by accident - because we sort
the list of Wayland sockets we find and commit to using the earliest one
that works, even if the user is running a nested labwc instance, we will
almost certainly end up connecting to the "master" instance (which will
have an earlier wayland-* socket open than the nested instance). There
are probably ways of fooling this mechanism maliciously, but if an attacker
can do that, they probably already have enough access to the machine to do
much worse things.
"""

import os
import sys
import subprocess
import re
import socket
import struct
import logging
import time
from pathlib import Path
from typing import Pattern, Match, NoReturn

logging.basicConfig(
    format="%(funcName)s: %(levelname)s: %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)

known_compositor_list: list[str] = [
    "labwc",
    "sway",
    "cage",
    "cagebreak",
    "dwl",
    "kiwimi",
    "river",
    "waybox",
    "wayfire",
    "woodland",
]

kloak_env_save_str: str = "/run/kloak_wl_compositor_data"
kloak_wayland_flag_str: str = (
    "/run/desktop_config_dist_wayland_session_started"
)


def wlsortmorph(obj: tuple[str, str]) -> str:
    """
    Sort helper function for numerically sorting Wayland socket locations.
    """

    ## obj is a tuple like ("/run/user/1000", "wayland-1")
    ## In order to allow a numeric-like sort on the list of paths and sockets,
    ## we convert the UID in obj[0] and the Wayland compositor index in obj[1]
    ## to left-padded integers. The UID is left-padded to 10 digits (hopefully
    ## there won't be UIDs higher than 9999999999), while the compositor index
    ## is left-padded to four digits (no one will be running 10,000 Wayland
    ## compositors at once, right?).
    extract_uid: Pattern[str] = re.compile(r"^(/[^/]+/[^/]+/)(\d+)$")
    match_uid: Match[str] | None = extract_uid.match(obj[0])
    if match_uid is None:
        return "z"
    uid_match_list: list[str] = list(match_uid.groups())
    uid_match_list[1] = uid_match_list[1].zfill(10)

    extract_wl: Pattern[str] = re.compile(r"^(wayland-)(\d+)$")
    match_wl: Match[str] | None = extract_wl.match(obj[1])
    if match_wl is None:
        return "z"
    wl_match_list: list[str] = list(match_wl.groups())
    wl_match_list[1] = wl_match_list[1].zfill(4)

    return f"{"".join(uid_match_list)}/{"".join(wl_match_list)}"


def query_sock_pid(sock_path: str) -> str | None:
    """
    Given a path to a UNIX socket, returns the PID of the program listening on
    it.
    """

    try:
        with socket.socket(
            socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_CLOEXEC
        ) as sock:
            sock.settimeout(0.5)
            sock.connect(sock_path)
            struct_fmt: str = "=iII"
            rslt_buf: bytes = sock.getsockopt(
                socket.SOL_SOCKET,
                socket.SO_PEERCRED,
                struct.calcsize(struct_fmt),
            )
            pid: int
            pid, _, _ = struct.unpack(struct_fmt, rslt_buf)
            return str(pid)
    except Exception:
        ## Logging not desirable here, socket may have been closed or deleted.
        return None


# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def main() -> NoReturn:
    """
    Main function.
    """

    while not Path(kloak_wayland_flag_str).is_file():
        time.sleep(1)

    try:
        Path(kloak_env_save_str).unlink(missing_ok=True)
    except Exception:
        logger.error("Could not force-unlink '%s'!", kloak_env_save_str)
        sys.exit(1)

    try:
        active_vt: int = int(
            subprocess.run(
                "/usr/bin/fgconsole",
                capture_output=True,
                text=True,
                check=True,
                encoding="utf-8",
            ).stdout.strip()
        )
    except Exception as e:
        logger.error("Unable to determine active VT!", exc_info=e)
        sys.exit(1)

    if active_vt == 0:
        logger.error("Active VT appears to be 0, cannot continue!")
        sys.exit(1)

    process_on_vt_list: list[tuple[str, str]] = []

    try:
        for proc_entry in Path("/proc").iterdir():
            if not proc_entry.is_dir():
                continue
            if not proc_entry.name.isdigit():
                continue

            proc_stat_path: Path = Path(f"/proc/{proc_entry.name}/stat")
            proc_comm_path = Path(f"/proc/{proc_entry.name}/comm")

            proc_stat_cont: str | None = None
            try:
                proc_stat_cont = proc_stat_path.read_text(encoding="ascii")
                proc_control_tty_int: int = int(
                    re.sub(r"(?s)\(.*\)", "()", proc_stat_cont).split(" ")[6]
                )
                proc_control_tty_minor: int = proc_control_tty_int & 0xFF
                proc_control_tty_major: int = (
                    proc_control_tty_int & 0xFF00
                ) >> 8
            except Exception:
                ## No logging here, this will happen if we aren't able to read
                ## the path in question, which could happen for any for any
                ## number of reasons (for instance, the app we're inspecting
                ## may have just terminated).
                continue

            if (
                proc_control_tty_major != 4
                or proc_control_tty_minor != active_vt
            ):
                continue

            try:
                ## [:-1] needed to remove trailing newline
                comm_str: str = proc_comm_path.read_text(encoding="ascii")
                if comm_str.endswith("\n"):
                    comm_str = comm_str[:-1]
            except Exception:
                ## Logging not desirable here, process name may not be valid
                ## ASCII.
                continue
            process_on_vt_list.append((proc_entry.name, comm_str))
    except Exception as e:
        logger.error(
            "Unable to scan /proc for Wayland compositors!", exc_info=e
        )
        sys.exit(1)

    if len(process_on_vt_list) == 0:
        logger.error("No processes found on active VT!")
        sys.exit(1)

    wl_pid_str_list: list[str] = [
        pid_str
        for pid_str, comm_str in process_on_vt_list
        if comm_str in known_compositor_list
    ]

    if len(wl_pid_str_list) == 0:
        logger.error("No Wayland compositors found on active VT!")
        sys.exit(1)

    xdg_runtime_dir_set: set[str] = set()
    for pid_str in wl_pid_str_list:
        try:
            current_process_environ: list[bytes] = (
                Path(f"/proc/{pid_str}/environ").read_bytes().split(b"\0")
            )
        except Exception:
            ## Logging not desirable here, process may have terminated.
            continue
        for current_process_env_var in current_process_environ:
            if not current_process_env_var.startswith(b"XDG_RUNTIME_DIR="):
                continue
            current_process_env_val: bytes = current_process_env_var.split(
                b"=", maxsplit=1
            )[1]
            try:
                current_process_env_str: str = current_process_env_val.decode(
                    encoding="utf-8"
                )
            except Exception:
                ## Logging not desirable here, process may have invalid text
                ## in its environment file.
                continue
            if current_process_env_str != "":
                xdg_runtime_dir_set.add(current_process_env_str)

    wayland_socket_list: list[tuple[str, str]] = []
    for xdg_runtime_dir in xdg_runtime_dir_set:
        xdg_runtime_dir_path: Path = Path(xdg_runtime_dir)
        try:
            for xdg_runtime_dir_entry in xdg_runtime_dir_path.iterdir():
                try:
                    if not xdg_runtime_dir_entry.is_socket():
                        continue
                except Exception:
                    ## Logging not desirable here, this can error out if we
                    ## attempt to determine the file type of a FUSE
                    ## mountpoint.
                    continue
                if not xdg_runtime_dir_entry.name.startswith("wayland-"):
                    continue
                wayland_socket_list.append(
                    (str(xdg_runtime_dir), xdg_runtime_dir_entry.name)
                )
        except Exception:
            ## Logging not desirable here, it's possible the user we're trying
            ## to read the XDG_RUNTIME_DIR of just logged out and the
            ## directory has been deleted out from under us.
            continue

    wayland_socket_list.sort(key=wlsortmorph)
    xdg_runtime_dir_var: str | None = None
    wayland_display_var: str | None = None
    for wayland_socket_item in wayland_socket_list:
        socket_path: str = str(Path(*wayland_socket_item))
        socket_pid: str | None = query_sock_pid(socket_path)
        found_compositor: bool = False
        if socket_pid is None:
            continue

        for wl_pid_str in wl_pid_str_list:
            if wl_pid_str == socket_pid:
                xdg_runtime_dir_var = wayland_socket_item[0]
                wayland_display_var = wayland_socket_item[1]
                found_compositor = True
                break

        if found_compositor:
            break

    if xdg_runtime_dir_var is None or wayland_display_var is None:
        logger.error("No Wayland compositors found on active VT!")
        sys.exit(1)

    try:
        with os.fdopen(
            os.open(
                kloak_env_save_str,
                os.O_RDWR | os.O_CREAT | os.O_TRUNC | os.O_CLOEXEC,
                mode=0o600,
            ),
            "r+",
            encoding="utf-8",
        ) as output_file:
            print(f"XDG_RUNTIME_DIR={xdg_runtime_dir_var}", file=output_file)
            print(f"WAYLAND_DISPLAY={wayland_display_var}", file=output_file)
    except Exception as e:
        logger.error(
            "Cannot write Wayland compositor information to '%s'!",
            output_file,
            exc_info=e,
        )
        sys.exit(1)

    sys.exit(0)


if __name__ == "__main__":
    main()
