#!/usr/bin/env python

# Copyright 2022 Adrian Granados <adrian@adriangranados.com>
#
# Wireshark extcap interface for remote Wi-Fi captures.
#
# This extcap interface is basically a wrapper for sshdump so that
# we can specify additional capture options, such as the channel
# we want to capture on, and to simplify the configuration of the
# extcap interface so that the user does not have to deal with
# remote capture commands, etc.
#
# Based on Wireshark's extcap_example.py.

from __future__ import print_function

import os
import sys
import argparse
import subprocess

ERROR_USAGE          = 0
ERROR_ARG            = 1
ERROR_INTERFACE      = 2
ERROR_FIFO           = 3

freqs_wifi = [2412, 2417, 2422, 2427, 2432, 2437, 2442, 2447, 2452, 2457, 2462, 2467, 2472, 2484,
                5180, 5200, 5220, 5240, 5260, 5280, 5300, 5320, 5500, 5520, 5540, 5560, 5580,
                5600, 5620, 5640, 5660, 5680, 5700, 5720, 5745, 5765, 5785, 5805, 5825,
                5975, 6055, 6135, 6215, 6295, 6375, 6455, 6535, 6615, 6695, 6775, 6855, 6935, 7015, 7095]

wifidump_script = """
function iface_down {
    sudo ip link set $1 down > /dev/null 2>&1
}

function iface_up {
    sudo ip link set $1 up > /dev/null 2>&1
}

function iface_monitor {
    sudo iw dev $1 set monitor control otherbss > /dev/null 2>&1 ||
    sudo iw dev $1 set type monitor control otherbss > /dev/null 2>&1
}

function iface_scan {
    iface_down $1 &&
    sudo iw dev $1 set type managed > /dev/null 2>&1 &&
    iface_up $1 &&
    sudo iw dev $1 scan > /dev/null 2>&1
}

function iface_config {
    if [ $2 -eq $4 ]; then
        sudo iw dev $1 set freq $2 $3 2>&1
    else
        sudo iw dev $1 set freq $2 $3 $4 2>&1
    fi
}

function iface_start {
    sudo tcpdump -i $1 -U -w -
}

function capture_generic {
    if ! { iwconfig $1 | grep Monitor > /dev/null 2>&1; }; then
        iface_down    $1 &&
        iface_monitor $1 &&
        iface_up      $1
    else
        iface_monitor $1
    fi

    iface_config  $1 $2 $3 $4 &&
    iface_start   $1
}

function capture_iwlwifi {
    INDEX=`sudo iw dev $1 info | grep wiphy | grep -Eo '[0-9]+'`
    sudo iw phy phy${INDEX} channels | grep $2 | grep -i disabled > /dev/null 2>&1 &&
    iface_scan $1
    MON=${1}mon
    sudo iw $1 interface add $MON type monitor flags none > /dev/null 2>&1
    iface_up $MON &&
    iface_down $1 &&
    iface_config $MON $2 $3 $4 &&
    iface_start $MON
}

function capture {
    driver=`/usr/sbin/ethtool -i $1 | grep driver | awk '{ print $2 }'`
    if [ $driver = "iwlwifi" ]; then
        capture_iwlwifi $1 $2 $3 $4
    else
        capture_generic $1 $2 $3 $4
    fi
}

capture INTERFACE CHANNEL_FREQUENCY CHANNEL_WIDTH CENTER_FREQUENCY
"""

zbdump_script = """
sudo /usr/bin/killall /usr/local/bin/whsniff > /dev/null
sleep 3
sudo /usr/local/bin/whsniff -c CHANNEL
"""

"""
This code has been taken from http://stackoverflow.com/questions/5943249/python-argparse-and-controlling-overriding-the-exit-status-code - originally developed by Rob Cowie http://stackoverflow.com/users/46690/rob-cowie
"""
class ArgumentParser(argparse.ArgumentParser):
    def _get_action_from_name(self, name):
        """Given a name, get the Action instance registered with this parser.
        If only it were made available in the ArgumentError object. It is
        passed as it's first arg...
        """
        container = self._actions
        if name is None:
            return None
        for action in container:
            if '/'.join(action.option_strings) == name:
                return action
            elif action.metavar == name:
                return action
            elif action.dest == name:
                return action

    def error(self, message):
        exc = sys.exc_info()[1]
        if exc:
            exc.argument = self._get_action_from_name(exc.argument_name)
            raise exc
        super(ArgumentParser, self).error(message)

### EXTCAP FUNCTIONALITY

def extcap_config(interface, option):

    def freq_to_channel(freq_mhz):
        if freq_mhz == 2484:
            return 14
        elif freq_mhz >= 2412 and freq_mhz <= 2484:
            return int(((freq_mhz - 2412) / 5) + 1)
        elif freq_mhz >= 5160 and freq_mhz <= 5885:
            return int(((freq_mhz - 5180) / 5) + 36)
        elif freq_mhz >= 5955 and freq_mhz <= 7115:
            return int(((freq_mhz - 5955) / 5) + 1)
        return None

    def freq_to_band(freq_mhz):
        if freq_mhz >= 2412 and freq_mhz <= 2484:
            return "2.4 GHz"
        elif freq_mhz >= 5160 and freq_mhz <= 5885:
            return "5 GHz"
        elif freq_mhz >= 5955 and freq_mhz <= 7115:
            return "6 GHz (PSC)"
        return None

    if (interface == "wifidump"):
        # Args
        # Capture Tab
        print("arg {number=0}{call=--remote-interface}{display=Remote Wi-Fi interface name}{type=string}{toolip=The name of the Wi-Fi interface to run the capture on}{required=true}{default=wlan0}{group=Capture}")
        print("arg {number=1}{call=--remote-channel-or-frequency}{display=Remote Wi-Fi channel}{type=selector}{tooltip=The channel to capture on}{group=Capture}")
        print("arg {number=2}{call=--remote-channel-width}{display=Remote Wi-Fi channel width}{type=selector}{tooltip=The width of the channel to capture on}{group=Capture}")

        # Channel values
        for freq in freqs_wifi:
            print("value {arg=1}{value=" + str(freq) + "}{display=" + freq_to_band(freq) + ", Channel " + str(freq_to_channel(freq)) + "}")

        # Channel width values
        print("value {arg=2}{value=20}{display=20 MHz}")
        print("value {arg=2}{value=40}{display=40 MHz}")
        print("value {arg=2}{value=80}{display=80 MHz}")
        print("value {arg=2}{value=160}{display=160 MHz}")

    elif (interface == "zbdump"):
        print("arg {number=0}{call=--remote-channel-or-frequency}{display=Remote Zigbee channel}{type=selector}{tooltip=The channel to capture on}{group=Capture}")

        for ch in range(11, 27):
            print("value {arg=0}{value=" + str(ch) + "}{display=Channel " + str(ch) + "}")

    # Server Tab
    print("arg {number=10}{call=--remote-host}{display=Remote SSH server address}{type=string}{tooltip=The remote SSH host. It can be both an IP address or a hostname}{required=true}{group=Server}")
    print("arg {number=11}{call=--remote-port}{display=Remote SSH server port}{type=unsigned}{tooltip=The remote SSH host port (1-65535)}{range=1,65535}{default=22}{group=Server}")

    # Authentication Tab
    print("arg {number=20}{call=--remote-username}{display=Remote SSH server username}{type=string}{tooltip=The remote SSH username. If not provided, the current user will be used}{group=Authentication}")
    print("arg {number=21}{call=--remote-password}{display=Remote SSH server password}{type=password}{tooltip=The SSH password, used when other methods (SSH agent or key files) are unavailable.}{group=Authentication}")
    print("arg {number=22}{call=--sshkey}{display=Path to SSH private key}{type=fileselect}{tooltip=The path on the local filesystem of the private ssh key}{group=Authentication}")

def extcap_version():
    print("extcap {version=1.0}")

def extcap_interfaces():
    print("extcap {version=1.0}")
    print("interface {value=wifidump}{display=Wi-Fi remote capture}")
    print("interface {value=zbdump}{display=Zigbee remote capture}")

def extcap_dlts(interface):
    if (interface == "wifidump"):
        print("dlt {number=147}{name=wifidump}{display=Remote capture dependent DLT}")
    elif (interface == "zbdump"):
        print("dlt {number=195}{name=zbdump}{display=IEEE 802.15.4 with FCS DLT}")

def extcap_capture(interface, remote_host, remote_port, remote_username, remote_password, sshkey, remote_interface, remote_channel_or_frequency, remote_channel_width, fifo, capture_filter):
    extcap_path = sys.argv[0]
    extcap_dir  = os.path.dirname(extcap_path)
    sshdump_path = os.path.join(extcap_dir, "sshdump")

    def center_freq(freq_mhz, ch_width_mhz):
        if ch_width_mhz == 20:
            return freq_mhz

        elif ch_width_mhz == 40:
            if 5180 <= freq_mhz <= 5720:
                for start_freq in range(5180, 5700 + ch_width_mhz, ch_width_mhz):
                    if start_freq <= freq_mhz <= (start_freq + 20):
                        return ((start_freq * 2) + 20) / 2

            elif 5745 <= freq_mhz <= 5765:
                return 5755

            elif 5785 <= freq_mhz <= 5805:
                return 5795

            elif 5955 <= freq_mhz <= 7095:
                for start_freq in range(5955, 7075 + ch_width_mhz, ch_width_mhz):
                    if start_freq <= freq_mhz <= (start_freq + 20):
                        return ((start_freq * 2) + 20) / 2

        elif ch_width_mhz == 80:
            if 5180 <= freq_mhz <= 5660:
                for start_freq in range(5180, 5660 + ch_width_mhz, ch_width_mhz):
                    if start_freq <= freq_mhz <= (start_freq + 60):
                        return ((start_freq * 2) + 60) / 2

            elif 5745 <= freq_mhz <= 5805:
                return 5775

            elif 5955 <= freq_mhz <= 7055:
                for start_freq in range(5955, 6995 + ch_width_mhz, ch_width_mhz):
                    if start_freq <= freq_mhz <= (start_freq + 60):
                        return ((start_freq * 2) + 60) / 2

        elif ch_width_mhz == 160:
            if 5180 <= freq_mhz <= 5640:
                for start_freq in range(5180, 5500 + ch_width_mhz, ch_width_mhz):
                    if start_freq <= freq_mhz <= (start_freq + 140):
                        return ((start_freq * 2) + 140) / 2

            elif 5955 <= freq_mhz <= 7055:
                for start_freq in range(5955, 6915 + ch_width_mhz, ch_width_mhz):
                    if start_freq <= freq_mhz <= (start_freq + 140):
                        return ((start_freq * 2) + 140) / 2
        return -1


    capture_cmd = ""

    if not remote_username:
        remote_username = ""

    if not remote_password:
        remote_password = ""

    if not sshkey:
        sshkey = ""

    if not capture_filter:
        capture_filter = ""

    # Prepare remote command for Wi-Fi capture
    if (interface == "wifidump"):

        if not remote_channel_width:
            remote_channel_width = "20"

        if not remote_interface:
            remote_interface = "wlan0"

        center_frequency = str(int(center_freq(int(remote_channel_or_frequency), int(remote_channel_width))))

        capture_cmd = wifidump_script
        for arg, value in (("INTERFACE", remote_interface), ("CHANNEL_FREQUENCY", remote_channel_or_frequency),
            ("CHANNEL_WIDTH", remote_channel_width), ("CENTER_FREQUENCY", center_frequency)):
            capture_cmd = capture_cmd.replace(arg, value)

    # Prepare remote command for Zigbee capture
    elif (interface == "zbdump"):
        capture_cmd = zbdump_script.replace("CHANNEL", remote_channel_or_frequency)

    if capture_cmd:
        sshdump_args = [sshdump_path,
                        "--extcap-interface", "sshdump",
                        "--extcap-capture-filter", capture_filter,
                        "--remote-capture-command", capture_cmd,
                        "--remote-host", remote_host,
                        "--remote-port", remote_port,
                        "--remote-username", remote_username,
                        "--remote-password", remote_password,
                        "--sshkey", sshkey,
                        "--fifo", fifo,
                        "--capture"]

        subprocess.call(sshdump_args)

###

def usage():
    print("Usage: %s <--extcap-interfaces | --extcap-dlts | --extcap-interface | --extcap-config | --capture | --extcap-capture-filter | --fifo>" % sys.argv[0])

if __name__ == '__main__':

    option = ""

    parser = ArgumentParser(
        prog="wifidump",
        description="WLAN remote capture"
    )

    #Extcap Arguments
    parser.add_argument("--capture", help="Run the capture", action="store_true")
    parser.add_argument("--extcap-interfaces", help="List the extcap interfaces", action="store_true")
    parser.add_argument("--extcap-interface", help="Specify the extcap interface")
    parser.add_argument("--extcap-dlts", help="List the DLTs", action="store_true")
    parser.add_argument("--extcap-config", help="List the additional configuration for an extcap interface",action="store_true")
    parser.add_argument("--extcap-capture-filter", help="The capture filter")
    parser.add_argument("--fifo", help="Dump data to file or fifo")
    parser.add_argument("--extcap-version", help="Print tool version", nargs='?', default="")

    # Interface Arguments
    parser.add_argument("--remote-host", help="The remote SSH host")
    parser.add_argument("--remote-port", help="The remote SSH port", default="22")
    parser.add_argument("--remote-username", help="The remote SSH host username")
    parser.add_argument("--remote-password", help="The remote SSH port password (if not specified, ssh-agent and ssh-key are used)")
    parser.add_argument("--sshkey", help="The path of the ssh key for passwordless authentication")
    parser.add_argument("--remote-interface", help="The name of remote interface to be used for capturing (e.g. wlan0)")
    parser.add_argument("--remote-channel-or-frequency", help="The remote channel or frequency (MHz) to capture")
    parser.add_argument("--remote-channel-width", help="The width of the remote channel (20, 40, 80, 160) in MHz")

    try:
        args, unknown = parser.parse_known_args()
    except argparse.ArgumentError as exc:
        print("%s: %s" % (exc.argument.dest, exc.message), file=sys.stderr)
        sys.exit(ERROR_ARG)

    if ( args.extcap_version and not args.extcap_interfaces ):
        extcap_version()
        sys.exit(0)

    if ( args.extcap_interfaces == False and args.extcap_interface == None ):
        parser.exit("An interface must be provided or the selection must be displayed")
    if ( args.extcap_interfaces == True or args.extcap_interface == None ):
        extcap_interfaces()
        sys.exit(0)

    if args.extcap_config:
        extcap_config(args.extcap_interface, option)
    elif args.extcap_dlts:
        extcap_dlts(args.extcap_interface)
    elif args.capture:
        if (args.extcap_interface == "wifidump" or args.extcap_interface == "zbdump"):
            extcap_capture(args.extcap_interface, args.remote_host, args.remote_port, args.remote_username, args.remote_password, args.sshkey,
                           args.remote_interface, args.remote_channel_or_frequency, args.remote_channel_width, args.fifo, args.extcap_capture_filter)
        else:
            sys.exit(ERROR_INTERFACE)
    else:
        usage()
        sys.exit(ERROR_USAGE)
