# Lint as: python3
"""
Copyright (c) 2020 Project CHIP Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import ipaddress
import json
import logging
import os
import re
from typing import Union, List

import requests
import sys
import time
import traceback
from enum import IntEnum
from urllib.parse import urljoin


class TestResult(IntEnum):
    OK = 0
    TEST_FAILURE = 1
    SYSTEM_FAILURE = 2


'''
CHIPVirtualHome is a base class for single home tests
child classes should implement:
- setup()
- test_routine()
- tear_down()
'''


class CHIPVirtualHome:
    def __init__(self, cirque_url, device_config):
        self.home_id = None
        self.logger = None
        self.cirque_url = cirque_url
        self.device_config = device_config
        self.device_ids = []
        self.devices = []
        self.non_ap_devices = []
        self.thread_devices = []
        self.ap_devices = []

    # The entrance of the whole test
    def run_test(self, save_logs=True):
        test_ret = TestResult.OK
        try:
            self.setup()
            self.test_routine()
        except AssertionError:
            # AssertionError is thrown in self.assertXxx function
            test_ret = TestResult.TEST_FAILURE
            traceback.print_exc(file=sys.stderr)
        except Exception:
            # Other errors indicate a failure in system.
            test_ret = TestResult.SYSTEM_FAILURE
            traceback.print_exc(file=sys.stderr)
        if save_logs:
            try:
                self.save_device_logs()
            except:
                test_ret = TestResult.SYSTEM_FAILURE
                traceback.print_exc(file=sys.stderr)
        try:
            self.destroy_home()
        except:
            test_ret = TestResult.SYSTEM_FAILURE
            traceback.print_exc(file=sys.stderr)
        return test_ret

    def query_api(self, end_point, args=[], binary=False):
        ret = requests.get(self._build_request_url(end_point, args))
        if binary:
            return ret.content
        return ret.json()

    def execute_device_cmd(self, device_id, cmd, stream=False):
        self.logger.info(
            "device: {} exec: {}".format(self.get_device_pretty_id(device_id), cmd))
        ret = requests.get(self._build_request_url('device_cmd', [self.home_id, device_id, cmd]),
                           params={'stream': stream},
                           stream=stream)
        if stream:
            return ret

        ret_struct = ret.json()
        command_ret_code = ret_struct.get('return_code', None)
        if command_ret_code == None:
            # could be 0
            self.logger.error("cannot get command return code")
            raise Exception("cannot get command return code")
        self.logger.info(
            "command return code: {}".format(
                ret_struct.get('return_code', 'Unknown'))
        )
        command_output = ret_struct.get('output', None)
        if command_output == None:
            # could be empty string
            self.logger.error("cannot get command output")
            raise Exception("cannot get command output")
        self.logger.info(
            "command output: \n{}".format(ret_struct.get('output', ''))
        )
        return ret_struct

    def sequenceMatch(self, string, patterns):
        last_find = 0
        for s in patterns:
            self.logger.info('Finding string: "{}"'.format(s))
            this_find = string.find(s, last_find)
            if this_find < 0:
                self.logger.info('Not found')
                return False
            self.logger.info("Found at index={}".format(this_find))
            last_find = this_find + len(s)
        return True

    def reset_thread_devices(self, devices: Union[List[str], str]):
        """
        Reset device's thread settings and verify state.
        """
        if isinstance(devices, str):
            devices = [devices]
        for device_id in devices:
            # Wait for otbr-agent and CHIP server start
            self.assertTrue(self.wait_for_device_output(
                device_id, "Border router agent started.", 10))
            self.assertTrue(self.wait_for_device_output(
                device_id, "CHIP:SVR: Server Listening...", 15))
            # Clear default Thread network commissioning data
            self.logger.info("Resetting thread network on {}".format(
                self.get_device_pretty_id(device_id)))
            self.execute_device_cmd(device_id, 'ot-ctl factoryreset')
            self.check_device_thread_state(
                device_id=device_id, expected_role="disabled", timeout=10)

    def check_device_thread_state(self, device_id, expected_role, timeout):
        if isinstance(expected_role, str):
            expected_role = [expected_role]
        self.logger.info(
            f"Waiting for expected role. {self.get_device_pretty_id(device_id)}: {expected_role}")
        start = time.time()
        while time.time() < (start + timeout):
            reply = self.execute_device_cmd(device_id, 'ot-ctl state')
            if reply['output'].split()[0] in expected_role:
                return
            time.sleep(0.5)

        self.logger.error(
            f"Device {self.get_device_pretty_id(device_id)} does not reach expected role")
        raise AssertionError

    def form_thread_network(self, device_id: str, expected_role: Union[str, List[str]], timeout: int = 15,
                            dataset: str = ""):
        """
        Start Thread Network with provided dataset. If dataset is not provided then default will be set.
        Function that will be also verifying if device start in expected role.
        """
        if not dataset:
            dataset = "0e080000000000010000" + \
                      "000300000c" + \
                      "35060004001fffe0" + \
                      "0208fedcba9876543210" + \
                      "0708fd00000000001234" + \
                      "0510ffeeddccbbaa99887766554433221100" + \
                      "030e54657374696e674e6574776f726b" + \
                      "0102d252" + \
                      "041081cb3b2efa781cc778397497ff520fa50c0302a0ff"

        ot_init_commands = [
            "ot-ctl thread stop",
            "ot-ctl ifconfig down",
            f"ot-ctl dataset set active {dataset}",
            "ot-ctl ifconfig up",
            "ot-ctl thread start",
            "ot-ctl dataset active",
        ]
        self.logger.info(
            f"Setting Thread dataset for {self.get_device_pretty_id(device_id)}: {dataset}")
        for cmd in ot_init_commands:
            self.execute_device_cmd(device_id, cmd)
        self.check_device_thread_state(
            device_id=device_id, expected_role=expected_role, timeout=timeout)

    def connect_to_thread_network(self):
        '''
        The dataset in this function is used to replace the default dataset generated by openthread.
        When the test writer is calling this function to setup a thread network, it means they just
        want a working IPv6 network or a working thread network and don't care about the detail of
        this network.
        '''
        self.logger.info("Running commands to form default Thread network")
        for device in self.thread_devices:
            self.wait_for_device_output(
                device['id'], "Border router agent started.", 5)

        otInitCommands = [
            "ot-ctl thread stop",
            "ot-ctl ifconfig down",
            "ot-ctl dataset set active 0e080000000000010000000300000d35060004001fffe00208dead00beef00cafe0708fd01234567890abc051000112233445566778899aabbccddeeff030a4f70656e546872656164010212340410ad463152f9622c7297ec6c6c543a63e70c0302a0ff",
            "ot-ctl ifconfig up",
            "ot-ctl thread start",
            "ot-ctl dataset active",  # Emit
        ]
        for device in self.thread_devices:
            # Set default openthread provisioning
            for cmd in otInitCommands:
                self.execute_device_cmd(device['id'], cmd)
        self.logger.info("Waiting for Thread network to be formed...")
        threadNetworkFormed = False
        for i in range(30):
            roles = list()
            for device in self.thread_devices:
                # We can only check the status of ot-agent by query its state.
                reply = self.execute_device_cmd(device['id'], 'ot-ctl state')
                roles.append(reply['output'].split()[0])
            threadNetworkFormed = (roles.count('leader') == 1) and (roles.count(
                'leader') + roles.count('router') + roles.count('child') == len(self.thread_devices))
            if threadNetworkFormed:
                break
            time.sleep(1)
        self.assertTrue(threadNetworkFormed)
        self.logger.info("Thread network formed")

    def enable_wifi_on_device(self):
        ssid, psk = self.query_api('wifi_ssid_psk', [self.home_id])

        self.logger.info("wifi ap ssid: {}, psk: {}".format(ssid, psk))

        for device in self.non_ap_devices:
            self.logger.info(
                "device: {} connecting to desired ssid: {}".format(
                    self.get_device_pretty_id(device['id']), ssid))
            self.write_psk_to_wpa_supplicant_config(device['id'], ssid, psk)
            self.kill_existing_wpa_supplicant(device['id'])
            self.start_wpa_supplicant(device['id'])
        time.sleep(5)

    def get_device_thread_ip(self, device_id):
        ret = self.execute_device_cmd(device_id, 'ot-ctl ipaddr')
        ipaddr_list = ret["output"].splitlines()
        for ipstr in ipaddr_list:
            try:
                self.logger.info(
                    "device: {} thread ip: {}".format(self.get_device_pretty_id(device_id), ipstr))
                ipaddr = ipaddress.ip_address(ipstr)
                if ipaddr.is_link_local:
                    continue
                if not ipaddr.is_private:
                    continue
                if re.match("fd[0-9a-f]{2}:[0-9a-f]{4}:[0-9a-f]{4}:[0-9a-f]{4}:0000:00ff:fe00:[0-9a-f]{4}", ipaddr.exploded) != None:
                    continue
                self.logger.info("Get Mesh-Local EID: {}".format(ipstr))
                return str(ipaddr)
            except ValueError:
                # Since we are using ot-ctl, which is a command line interface and it will append 'Done' to end of output
                pass
        return None

    def get_device_log(self, device_id):
        return self.query_api('device_log', [self.home_id, device_id], binary=True)

    def wait_for_device_output(self, device_id, pattern, timeout=1):
        due = time.time() + timeout
        while True:
            if self.sequenceMatch(self.get_device_log(device_id).decode(), [pattern, ]):
                return True
            if time.time() < due:
                time.sleep(1)
            else:
                break
        return False

    def assertTrue(self, exp, note=None):
        '''
        assert{True|False}
        assert(Not)Equal
        python unittest style functions that raise exceptions when condition not met
        '''
        if not exp == True:
            if note:
                self.logger.error(note)
            raise AssertionError

    def assertFalse(self, exp, note=None):
        if not exp == False:
            if note:
                self.logger.error(note)
            raise AssertionError

    def assertEqual(self, val1, val2, note=None):
        if not val1 == val2:
            if note:
                self.logger.error(note)
            raise AssertionError

    def assertNotEqual(self, val1, val2, note=None):
        if val1 == val2:
            if note:
                self.logger.error(note)
            raise AssertionError

    def _build_request_url(self, end_point, args=[]):
        if len(args) == 0:
            return urljoin(self.cirque_url, end_point)
        return urljoin(self.cirque_url, "{}/{}".format(end_point, '/'.join([str(argv) for argv in args])))

    def destroy_home(self):
        self.logger.info("destroying home: {}".format(self.home_id))
        self.query_api('destroy_home', [self.home_id])

    def initialize_home(self):
        home_id = requests.post(
            self._build_request_url('create_home'), json=self.device_config).json()

        self.logger.info("home id: {} created!".format(home_id))

        self.assertTrue(home_id in
                        list(self.query_api('get_homes')),
                        "created home_id did not match id from get_homes!!")

        self.home_id = home_id

        device_types = set()
        created_devices = self.query_api('home_devices', [home_id])

        self.logger.info("home id: {} devices: {}".format(
            home_id, json.dumps(created_devices, indent=4, sort_keys=True)))

        for device in created_devices.values():
            device_types.add(device['type'])

        wanted_device_types = set()
        for device in self.device_config.values():
            wanted_device_types.add(device['type'])

        self.assertEqual(device_types, wanted_device_types,
                         "created device does not match to device config!!")

        self.device_config = created_devices

        self.device_ids = [device_id for device_id in self.device_config]
        self.non_ap_devices = [device for device in self.device_config.values()
                               if device['type'] != 'wifi_ap']
        self.thread_devices = [device for device in self.device_config.values()
                               if device['capability'].get('Thread', None) is not None]
        self.ap_devices = [device for device in self.device_config.values()
                           if device['type'] == 'wifi_ap']

    def save_device_logs(self):
        timestamp = int(time.time())
        log_dir = os.environ.get("DEVICE_LOG_DIR", None)
        if log_dir != None and not os.path.exists(log_dir):
            os.makedirs("logs")

        for device in self.non_ap_devices:
            ret_log = self.get_device_log(device['id'])
            # Use this format for easier sort
            f_name = '{}-{}-{}.log'.format(device['type'],
                                           timestamp, device['id'][:8])
            self.logger.debug("device log name: \n{}".format(f_name))
            with open(os.path.join(log_dir, f_name), 'wb') as fp:
                fp.write(ret_log)

    def start_wpa_supplicant(self, device_id):
        self.logger.info("device: {}: starting wpa_supplicant on device"
                         .format(self.get_device_pretty_id(device_id)))

        start_wpa_supplicant_command = "".join(
            ["wpa_supplicant -B -i wlan0 ",
             "-c /etc/wpa_supplicant/wpa_supplicant.conf ",
             "-f /var/log/wpa_supplicant.log -t -dd"])

        return self.execute_device_cmd(device_id, start_wpa_supplicant_command)

    def write_psk_to_wpa_supplicant_config(self, device_id, ssid, psk):
        self.logger.info("device: {}: writing ssid, psk to wpa_supplicant config"
                         .format(self.get_device_pretty_id(device_id)))

        write_psk_command = "".join(
            ["sh -c 'wpa_passphrase {} {} >> ".format(ssid, psk),
             "/etc/wpa_supplicant/wpa_supplicant.conf'"])

        return self.execute_device_cmd(device_id, write_psk_command)

    def kill_existing_wpa_supplicant(self, device_id):
        self.logger.info("device: {}: kill existing wpa_supplicant"
                         .format(self.get_device_pretty_id(device_id)))

        kill_wpa_supplicant_command = 'killall wpa_supplicant'

        return self.execute_device_cmd(device_id, kill_wpa_supplicant_command)

    def get_device_pretty_name(self, device_id):
        device_obj = self.device_config.get(device_id, None)
        if device_obj != None:
            return device_obj['type']
        return "<unknown>"

    def get_device_pretty_id(self, device_id):
        return "{}({}...)".format(self.get_device_pretty_name(device_id), device_id[:8])
