#!/usr/bin/python3 -su
# -*- coding: utf-8 -*-
# vim: set ts=4 sw=4 sts=4 et :
# pylint: disable=broad-exception-caught
"""
replace-ips - Search and replace IP addresses in specified files.

All Whonix configuration files provided are searched for the last known
occurrence of an IP address that was used and replaced with the current IP
address provided by Qubes.

Initially, the known IP addresses are 10.152.152.10 and 10.152.152.11 as
defaults in Whonix configuration files. They are also checked each time this
module is run in case the configuration files were modified due to a system
update. The default IPv6 addresses fd19:c33d:88bc::10 and fd19:c33d:88bc::11
are also checked.

Qubes feature request: optional static IP addresses
https://github.com/QubesOS/qubes-issues/issues/1477

Copyright (C) 2014 - 2015 Jason Mehring <nrgaway@gmail.com>
License: GPL-2+
Authors: Jason Mehring

  This program is free software; you can redistribute it and/or
  modify it under the terms of the GNU General Public License
  as published by the Free Software Foundation; either version 2
  of the License, or (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""

import os
import re
import subprocess
import ipaddress
from typing import Pattern

WHONIX_IP_GATEWAY: str = "/var/cache/qubes-whonix/whonix-ip-gateway"
WHONIX_IP6_GATEWAY: str = "/var/cache/qubes-whonix/whonix-ip6-gateway"
WHONIX_IP_LOCAL: str = "/var/cache/qubes-whonix/whonix-ip-local"
WHONIX_IP6_LOCAL: str = "/var/cache/qubes-whonix/whonix-ip6-local"

# This is a list of all Whonix files that contain IP addresses that will
# be searched and replaced with the currently assigned IP address
FILES: list[str] = [
    "/usr/share/whonix-gw-network-conf/network_internal_ip.txt",
    "/etc/resolv.conf",
    "/etc/resolv.conf.whonix",
    "/etc/resolv.conf.anondist",
    "/etc/rinetd.conf.anondist",
    "/etc/tor/torrc",
    "/usr/local/etc/torrc.d/50_user.conf",
    "/usr/share/anon-gw-anonymizer-config/torrc.examples",
    "/home/user/.torchat/torchat.ini",
    "/home/user/.xchat2/xchat.conf",
    "/home/user/.config/hexchat/hexchat.conf",
    "/usr/lib/leaktest-workstation/simple_ping.py",
    "/usr/share/anon-apps-config/kioslaverc",
    "/usr/share/anon-torchat/.torchat/torchat.ini",
    "/usr/share/tor/tor-service-defaults-torrc.anondist.base",
    "/usr/share/tor/tor-service-defaults-torrc.anondist",
]

default_ip_gateway: str = "10.152.152.10"
default_ip6_gateway: str = "fd19:c33d:88bc::10"
default_ip_local: str = "10.152.152.11"
default_ip6_local: str = "fd19:c33d:88bc::11"


def whonix_mode() -> str:
    """Determine Whonix mode.

    Can be either 'gateway', 'workstation', 'template', or 'unknown'.
    """
    mode: str = "unknown"
    if os.path.exists("/run/qubes/this-is-templatevm"):
        mode = "template"
    elif os.path.exists("/usr/share/anon-gw-base-files/gateway"):
        mode = "gateway"
    elif os.path.exists("/usr/share/anon-ws-base-files/workstation"):
        mode = "workstation"

    return mode


def ip_line_fixup(line: str, last_ip: str, current_ip: str) -> str:
    """Does low-level IP address replacement via regex matches.
    """
    comment_regex: Pattern[str] = re.compile(r"\s*#")
    if line == "":
        return line
    if comment_regex.match(line):
        return line

    if last_ip in line:
        line = re.sub(
            rf"(?m){re.escape(last_ip)}(?=\D|$)",
            current_ip,
            line,
        )

    if ":" in last_ip:
        line = re.sub(
            rf"(?m){re.escape(last_ip.rsplit(":", 1)[0])}[:]0"
            r"(?=\D|$)",
            current_ip.rsplit(":", 1)[0] + ":0",
            line,
        )
    else:
        line = re.sub(
            rf"(?m){re.escape(last_ip.rsplit(".", 1)[0])}[.]0"
            r"(?=\D|$)",
            current_ip.rsplit(".", 1)[0] + ".0",
            line,
        )

    return line

# pylint: disable=too-many-locals,too-many-branches
def replace_ip(
    ips: list[str],
    current_ip: str,
    files: list[str],
    ip_file: str
) -> bool:
    """Searches and replaces IP addresses in the provided files.

    ips:
        List of IP addresses to replace.

    current_ip:
        IP replacement address.

    files:
        List of files to search. The list must contain full pathnames.

    ip_file:
        Full path to the filename used to store the last known value of the IP
        address. The `current_ip` is stored in this file and used the next
        time this module is executed.
    """
    replaced: bool = False

    for ip_item in ips:
        try:
            ipaddress.ip_address(ip_item)
            # print(f"INFO: ip_item variable valid: {ip_item}")
        except ValueError:
            print(f"ERROR: ip_item variable invalid: {ip_item}")
            return replaced

    try:
        ipaddress.ip_address(current_ip)
        # print(f"INFO: current_ip variable valid: {current_ip}")
    except ValueError:
        print(f"ERROR: current_ip variable invalid: {current_ip}")
        return replaced

    protocol: str
    if ":" in current_ip:
        protocol = "IPv6"
    else:
        protocol = "IPv4"

    for filename in files:
        if os.path.exists(filename):
            try:
                with open(
                    filename, "r", encoding="utf-8", errors="surrogateescape"
                ) as infile:
                    text: str = infile.read()
            except IOError:
                print(
                    f"ERROR: {protocol}: file existing but failed to open for "
                    f"reading: {filename}"
                )
                continue

            line_list: list[str] = text.split("\n")
            for idx, line in enumerate(line_list):
                for last_ip in ips:
                    line = ip_line_fixup(line, last_ip, current_ip)
                    line_list[idx] = line

            replaced_text: str = "\n".join(line_list)

            if text != replaced_text:
                try:
                    with open(filename, "w", encoding="utf-8") as outfile:
                        outfile.write(replaced_text)
                    replaced = True
                    print(
                        f"INFO: {protocol}: filename from filelist updated: "
                        f"{filename}"
                    )
                except IOError:
                    print(
                        f"ERROR: {protocol}: filename from filelist existing "
                        f"but failed to open filename for writing: {filename}"
                    )
                    continue
                    ## Probably better to fail open and try more files.
                    # return False
            else:
                print(
                    f"INFO: {protocol}: filename from filelist unchanged: "
                    f"{filename}"
                )
    if replaced:
        try:
            with open(ip_file, "w", encoding="utf-8") as outfile:
                outfile.write(current_ip)
            replaced = True
            print(f"INFO: {protocol}: ip_file updated: {ip_file}")
        except IOError:
            print(f"ERROR: {protocol}: writing to ip_file failed: {ip_file}")
            return False

    return replaced


def get_ip_address(filename: str, default: str = "") -> str:
    """Retrieve an IP address from a file."""
    if not os.path.exists(filename):
        print(
            f"INFO: IP filename does not exist (returning default): {filename}"
        )
        return default

    try:
        with open(
            filename, "r", encoding="utf-8", errors="surrogateescape"
        ) as infile:
            ip_txt: str = infile.read().strip()
    except (OSError, IOError):
        print(f"INFO: IP filename exists but opening failed: {filename}")
        return default

    try:
        ipaddress.ip_address(ip_txt)
        print(f"INFO: IP filename reading succeeded and valid IP: {filename}")
        return ip_txt
    except ValueError:
        print(
            f"INFO: IP filename reading succeeded but not a valid IP "
            f"(returning default): {filename}"
        )
        return default


def maybe_reload_tor() -> None:
    """Reload Tor's configuration files if Tor is currently active and not
    disabled.
    """
    try:
        if subprocess.check_output(["systemctl", "is-active", "tor@default"]):
            print("INFO: executing: systemctl restart tor@default")
            # Restarting instead of reloading due to upstream Tor bug
            # https://trac.torproject.org/projects/tor/ticket/16161
            subprocess.call(["systemctl", "restart", "tor@default"])
    except subprocess.CalledProcessError:
        print(
            "INFO: Systemd unit tor@default is not running, therefore not "
            "restarting."
        )


# pylint: disable=too-many-branches,too-many-statements
def main() -> None:
    """Main function.
    """
    ## IP HARDCODED, but this does not matter for Non-Qubes-Whonix. This script
    ## is currently only used in Qubes-Whonix.
    last_ip_gateway: str = get_ip_address(WHONIX_IP_GATEWAY, default_ip_gateway)
    last_ip6_gateway: str = get_ip_address(
        WHONIX_IP6_GATEWAY, default_ip6_gateway
    )
    last_ip_local: str = get_ip_address(WHONIX_IP_LOCAL, default_ip_local)
    last_ip6_local: str = get_ip_address(WHONIX_IP6_LOCAL, default_ip6_local)
    assert last_ip_gateway != ""
    assert last_ip6_gateway != ""
    assert last_ip_local != ""
    assert last_ip6_local != ""

    print(f"INFO: last_ip_gateway    : {last_ip_gateway}")
    print(f"INFO: last_ip6_gateway   : {last_ip6_gateway}")
    print(f"INFO: last_ip_local      : {last_ip_local}")
    print(f"INFO: last_ip6_local     : {last_ip6_local}")

    current_ip_local: str | None = None
    current_ip_gateway: str | None = None
    current_ip6_local: str | None = None
    current_ip6_gateway: str | None = None
    chg4: bool = False
    chg6: bool = False

    if not os.path.isdir("/var/cache/qubes-whonix"):
        try:
            os.makedirs("/var/cache/qubes-whonix")
        except Exception:
            print("ERROR: could not create folder '/var/cache/qubes-whonix'.")

    if whonix_mode() == "gateway":
        try:
            current_ip_gateway = (
                subprocess.check_output(
                    ["qubesdb-read", "/qubes-netvm-gateway"]
                )
                .decode()
                .rstrip()
            )
        except (OSError, subprocess.CalledProcessError):
            print("WARNING: 'qubesdb-read /qubes-netvm-gateway' failed!")
        print(f"INFO: current_ip_gateway : {current_ip_gateway}")

        try:
            current_ip6_gateway = (
                subprocess.check_output(
                    ["qubesdb-read", "/qubes-netvm-gateway6"]
                )
                .decode()
                .rstrip()
            )
        except (OSError, subprocess.CalledProcessError):
            print("WARNING: 'qubesdb-read /qubes-netvm-gateway6' failed!")
        print(f"INFO: current_ip6_gateway: {current_ip6_gateway}")

        ips_to_replace: list[str]
        if current_ip_gateway is not None:
            ips_to_replace = [
                last_ip_gateway,
                default_ip_gateway,
                default_ip_local,
            ]
            chg4 = replace_ip(
                ips_to_replace, current_ip_gateway, FILES, WHONIX_IP_GATEWAY
            )

        if current_ip6_gateway is not None:
            ips_to_replace = [
                last_ip6_gateway,
                default_ip6_gateway,
                default_ip6_local,
            ]
            chg6 = replace_ip(
                ips_to_replace,
                current_ip6_gateway,
                FILES,
                WHONIX_IP6_GATEWAY,
            )

        if chg4 or chg6:
            maybe_reload_tor()

    if whonix_mode() == "workstation":
        try:
            current_ip_local = (
                subprocess.check_output(["qubesdb-read", "/qubes-ip"])
                .decode()
                .rstrip()
            )
            current_ip_gateway = (
                subprocess.check_output(["qubesdb-read", "/qubes-gateway"])
                .decode()
                .rstrip()
            )
        except (OSError, subprocess.CalledProcessError):
            print(
                "WARNING: 'qubesdb-read /qubes-ip' or 'qubesdb-read "
                "/qubes-gateway' failed!"
            )
        print(f"INFO: current_ip_local   : {current_ip_local}")
        print(f"INFO: current_ip_gateway : {current_ip_gateway}")

        try:
            current_ip6_local = (
                subprocess.check_output(["qubesdb-read", "/qubes-ip6"])
                .decode()
                .rstrip()
            )
            current_ip6_gateway = (
                subprocess.check_output(["qubesdb-read", "/qubes-gateway6"])
                .decode()
                .rstrip()
            )
        except (OSError, subprocess.CalledProcessError):
            print(
                "WARNING: 'qubesdb-read /qubes-ip6' or 'qubesdb-read "
                "/qubes-gateway6' failed!"
            )
        print(f"INFO: current_ip6_local  : {current_ip6_local}")
        print(f"INFO: current_ip6_gateway: {current_ip6_gateway}")

        if None not in (current_ip_local, current_ip_gateway):
            assert current_ip_local is not None
            assert current_ip_gateway is not None
            ips_to_replace = [last_ip_local, default_ip_local]
            replace_ip(
                ips_to_replace, current_ip_local, FILES, WHONIX_IP_LOCAL
            )
            ips_to_replace = [last_ip_gateway, default_ip_gateway]
            replace_ip(
                ips_to_replace, current_ip_gateway, FILES, WHONIX_IP_GATEWAY
            )

        if None not in (current_ip6_local, current_ip6_gateway):
            assert current_ip6_local is not None
            assert current_ip6_gateway is not None
            ips_to_replace = [last_ip6_local, default_ip6_local]
            replace_ip(
                ips_to_replace, current_ip6_local, FILES, WHONIX_IP6_LOCAL
            )
            ips_to_replace = [last_ip6_gateway, default_ip6_gateway]
            replace_ip(
                ips_to_replace,
                current_ip6_gateway,
                FILES,
                WHONIX_IP6_GATEWAY,
            )


if __name__ == "__main__":
    print("/usr/lib/qubes-whonix/replace-ips INFO: START")
    main()
    print("/usr/lib/qubes-whonix/replace-ips INFO: END")
