#!/usr/bin/python
# --
#                 - Mellanox Confidential and Proprietary -
#
# Copyright (C) Jan 2013, Mellanox Technologies Ltd.  ALL RIGHTS RESERVED.
#
# Except as specifically permitted herein, no portion of the information,
# including but not limited to object code and source code, may be reproduced,
# modified, distributed, republished or otherwise exploited in any form or by
# any means for any purpose without the prior written permission of Mellanox
# Technologies Ltd. Use of software subject to the terms and conditions
# detailed in the file "LICENSE.txt".
# --

try:
    # Python Imports
    import sys
    import json
    import string
    import socket
    import logging
    import re
    # NeoHost IFC Imports
    import neohost_plugin_ifc as APIFC
    from cmd_exec import Command
    import json_entry_creator
    # Plugin Imports
    from mft_core_defs import MftCoreDefs
    import mft_core_exceptions
    from neohost_common import NeoHostCommon
    from common_help_funcs import CommonHelpFuncs


except Exception as e:
    print("-E- could not import : %s" % str(e))
    sys.exit(1)

logger = logging.getLogger("neohost." + __name__)


class GetInterfaceInfoCommand(APIFC.AbsNeoHostCommandIFC):
    IFC_ATTR_MAPPING = {
        "speed":
            ("speed", "Speed", "Interface operating speed."),
        "duplex":
            ("duplex", "Duplex", "Duplex capability."),
        "transceiver":
            ("transceiver", "Transceiver", "Transceiver type."),
        "auto-negotiation":
            ("autoNegotiation", "Auto Negotiation",
             "Link auto-negotiation state."),
        "link detected":
            ("linkDetected", "Link Detected",
             "Is a Link Detected on this interface"),
        "permanent address":
            ("hwAddress", "Hardware Address", "Permanent Hardware Address."),
    }

    GENERAL_ATTR_MAPPING = {
        "mtu":
            ("MTU", "Current MTU."),
        "driver":
            ("Driver", "Driver information for the driver currently binded "
             "to the interface."),
        "ifcInfo":
            ("Interface Information", "Various interface information."),
        "ipAddrs":
            ("IP Addresses", "The IP addresses of the interface"),
        "defaultGateway":
            ("Default Gateway", "The default gateway"),
        "subnetPrefix":
            ("Subnet Prefix", "The Subnet prefix and mask")
    }

    DRIVER_ATTR_MAPPING = {
        "driver":
            ("name", "Name",
             "Name of the driver currently binded to the interface.", True, 0),
        "version":
            ("version", "Version", "Driver version.", True, 1),
        "bus-info":
            ("busInfo", "Bus Information", "Bus Information.", False, 2)
    }

    VLANS_ATTR_MAPPING = {
        "vlans":
            ("Vlans", "vlan interfaces."),
        "id":
            ("ID", "Vlan ID."),
        "name":
            ("Name", "Interface name.")
    }

    BOND_ATTR_MAPPING = {
        "bondInfo":
            ("Bond Information", "bond information."),
        "name":
            ("Name", "Bond name."),
        "members":
            ("Members", "Bond members.")
    }

    def __init__(self):
        super(GetInterfaceInfoCommand, self).__init__(
            cmd_name="GetInterfaceInfo",
            cmd_desc="get various interface information",
            cmd_type=APIFC.EnumCmdType.Cmd_Type_Get,
            cmd_scope=APIFC.EnumCmdScope.Cmd_Scope_Interface,
            supp_exec_mask=APIFC.EnumCmdExecMode.Cmd_Exec_Mode_All,
            cmd_in_desc="ifcUid: interface ID, attrMask: attribute mask",
            cmd_out_desc="various interface information",
            extra_str="")

    def __get_cmd_output_dict(self, cmd):
        rc, out, _ = Command(cmd).execute()
        if rc != 0:
            CommonHelpFuncs.printToLogAndRaiseMftCoreException(
                                "Failed to get interface information from cmd: %s" % cmd, logger)
        result_dict = dict()
        for line in out.splitlines():
            line = line.strip().split(":", 1)
            if len(line) < 2:
                continue
            result_dict[line[0].lower().strip()] = line[1].strip()
        return result_dict

    def __get_mtu(self, ifc_name):
        cmd = "cat /sys/class/net/%s/mtu" % ifc_name
        rc, out, _ = Command(cmd).execute()
        mtu = -1
        if rc == 0:
            try:
                mtu = int(out)
            except (ValueError, TypeError):
                pass
        return mtu

    def __linux_exec(self, ifc_info_dict, req, entry_creator):
        driver_info_dict = dict()
        local_idx = 0
        if req["ifcUid"].startswith("net-"):
            req["ifcUid"] = req["ifcUid"][4:]
        cmd = "ethtool %s" % req["ifcUid"]
        ethtool_dict = self.__get_cmd_output_dict(cmd)

        for ethtool_attr, (attr_name, disp, desc) in \
                self.IFC_ATTR_MAPPING.iteritems():
            if ethtool_attr in ethtool_dict:
                ethtool_val = ethtool_dict[ethtool_attr].title()
            else:
                ethtool_val = "NA"
            ifc_info_dict[attr_name] = entry_creator.create_entry(
                disp, ethtool_val, desc, local_idx)
            local_idx += 1

        cmd = "ethtool -i %s" % req["ifcUid"]
        ethtool_dict = self.__get_cmd_output_dict(cmd)

        for ethtool_attr, (attr_name, disp, desc, is_driver, entry_index) in \
                self.DRIVER_ATTR_MAPPING.iteritems():
            if ethtool_attr in ethtool_dict:
                if is_driver:
                    _dict = driver_info_dict
                    index = entry_index
                else:
                    _dict = ifc_info_dict
                    index = local_idx
                    local_idx += 1
                val = ethtool_dict[ethtool_attr]
            else:
                val = "NA"
                local_idx += 1
            _dict[attr_name] = entry_creator.create_entry(
                disp, val, desc, index)

        attr_name = "driver"
        disp, desc = self.GENERAL_ATTR_MAPPING.get(attr_name)
        ifc_info_dict[attr_name] = entry_creator.create_entry(
            disp, driver_info_dict, desc, local_idx)
        local_idx += 1

        attr_name = "mtu"
        disp, desc = self.GENERAL_ATTR_MAPPING.get(attr_name)
        ifc_info_dict[attr_name] = entry_creator.create_entry(
            disp, self.__get_mtu(req["ifcUid"]), desc, local_idx)
        local_idx += 1

        cmd = "ethtool -P %s" % req["ifcUid"]
        ethtool_dict = self.__get_cmd_output_dict(cmd)
        ethtool_attr = "permanent address"
        (attr_name, disp, desc) = self.IFC_ATTR_MAPPING.get(ethtool_attr)
        attr_val = ethtool_dict.get(ethtool_attr, "NA")
        ifc_info_dict[attr_name] = entry_creator.create_entry(
            disp, attr_val, desc, local_idx)
        local_idx += 1

        # add ip addresses to ifc_info_dict
        self.__setIpAddrsToInfoDict(ifc_info_dict, entry_creator,
                                    req["ifcUid"], local_idx)
        local_idx += 1

        # add default gateway to ifc_info_dict
        self.__setDefGWToInfoDict(ifc_info_dict, entry_creator,
                                  local_idx)
        local_idx += 1

        # add vlans to ifc_info_dict
        self.__setVlansToInfoDict(ifc_info_dict, entry_creator,
                                  req["ifcUid"], local_idx)
        local_idx += 1

        # add bond to ifc_info_dict
        self.__setBondToInfoDict(
            ifc_info_dict, entry_creator, req["ifcUid"], local_idx)
        local_idx += 1

    def execute_command(self, json_request, exec_opt):
        logger.info("Executing GetInterfaceInfo command..")
        req = json.loads(json_request)
        ifc_info_dict = dict()
        entry_creator = json_entry_creator.InfoEntryCreator()
        if "attrMask" in req:
            entry_creator.set_attr_mask(req["attrMask"])
        
        if NeoHostCommon.isWindowsOs():
            self.__win_exec(ifc_info_dict, req, entry_creator)
        else:
            self.__linux_exec(ifc_info_dict, req, entry_creator)

        final_result = dict()
        attr_name = "ifcInfo"
        disp, desc = self.GENERAL_ATTR_MAPPING.get(attr_name)
        final_result[attr_name] = entry_creator.create_entry(
            disp, ifc_info_dict, desc, 0)
        logger.info("finished executing GetInterfaceInfo command.")
        return MftCoreDefs.MFT_CORE_STATUS_SUCCESS, json.dumps(final_result)

    def __setIpAddrsToInfoDict(self, ifc_info_dict, entry_creator, ifcUid,
                               parent_dict_idx):
        attr_name = "ipAddrs"
        disp, desc = self.GENERAL_ATTR_MAPPING.get(attr_name)
        ips_str = self.__getInfIpsAndMaskString(ifcUid)
        ifc_info_dict[attr_name] = entry_creator.create_entry(
            disp, ips_str, desc, parent_dict_idx)

    def __setDefGWToInfoDict(self, ifc_info_dict, entry_creator,
                             parent_dict_idx):
        attr_name = "defaultGateway"
        gw = self.__getDefaultGateway()
        disp, desc = self.GENERAL_ATTR_MAPPING.get(attr_name)
        ifc_info_dict[attr_name] = entry_creator.create_entry(
            disp, gw, desc, parent_dict_idx)

    def __setVlansToInfoDict(self, ifc_info_dict, entry_creator, ifcUid,
                             parent_dict_idx):
        vlans_arr = self.__getVlansInfo(ifcUid)
        ifcs_index = 0
        vlan_ifcs = {}

        # for each interface in vlan
        for ifc_name, (ip, vlan_id) in vlans_arr.items():
            vlan_ifc_index = 0
            ifc_info_inner = {}

            # vlan interface IPs
            attr_name = "id"
            disp, desc = self.VLANS_ATTR_MAPPING.get(attr_name)
            ifc_info_inner[attr_name] = entry_creator.create_entry(
                disp, vlan_id, desc, vlan_ifc_index)
            vlan_ifc_index += 1

            # vlan interface IPs
            attr_name = "ipAddrs"
            disp, desc = self.GENERAL_ATTR_MAPPING.get(attr_name)
            ifc_info_inner[attr_name] = entry_creator.create_entry(
                disp, ip, desc, vlan_ifc_index)
            vlan_ifc_index += 1

            # vlan interface name
            disp, desc = self.VLANS_ATTR_MAPPING.get("name")
            vlan_ifcs[ifc_name] = entry_creator.create_entry(
                ifc_name, ifc_info_inner, desc, ifcs_index)
            ifcs_index += 1

        if not vlan_ifcs:
            vlan_ifcs = "NA"

        attr_name = "vlans"
        disp, desc = self.VLANS_ATTR_MAPPING.get(attr_name)
        ifc_info_dict[attr_name] = entry_creator.create_entry(
            disp, vlan_ifcs, desc, parent_dict_idx)

    def __getInfIpsAndMask(self, ifc_name):
        ips_list = []
        cmd = "ip addr show %s" % ifc_name
        rc, out, _ = Command(cmd).execute()
        if rc:
            CommonHelpFuncs.printToLogAndRaiseMftCoreException(
                                        "Command not found - %s" % cmd, logger)

        flag = False
        for word in out.split():
            if word == "inet" or word == "inet6":
                flag = True
            elif flag:
                ips_list.append(word)
                flag = False
        return ips_list

    def __getInfIpsAndMaskString(self, ifc_name):
        ips_str = "NA"
        ips = self.__getInfIpsAndMask(ifc_name)

        ip_strs = []
        first_ip = True
        for ip in ips:
            ip_parts = ip.split("/")
            if first_ip:
                first_ip = False
                ip_strs.append("%s / %s" % (ip_parts[0], ip_parts[1]))
            else:
                ip_strs.append(",  %s / %s" % (ip_parts[0], ip_parts[1]))

        if ips:
            ips_str = '  '.join(ip_strs)
        return ips_str

    def __getDefaultGateway(self):
        gateway = "NA"
        cmd = "route -n"
        rc, out, _ = Command(cmd).execute()
        if rc:
            logger.error("Command not found - %s" % cmd)
            return gateway

        # get title line + destination and gateway cols num
        lines = out.split('\n')
        col = 0
        def_gw_col = 0
        dest_col = 0
        def_gw_found = False
        dest_found = False
        title_line = 0
        for line in lines:
            # go over cols
            for word in line.split():
                if word == "Gateway":
                    def_gw_found = True
                    def_gw_col = col
                if word == "Destination":
                    dest_found = True
                    dest_col = col
                col += 1
            if def_gw_found and dest_found:
                break
            if not def_gw_found or not dest_found:
                col = 0
                # depends on g"Gateway" and "Destination" on same line
                title_line += 1
            else:
                break

        # get next line
        line_index = 1 + title_line
        if def_gw_found and dest_found and len(lines) > line_index:
            # get default gateway line
            for line in lines[line_index:]:
                # split line to get destination
                splited_line = line.split()
                if len(splited_line) > dest_col:
                    dest = splited_line[dest_col]
                    if dest == "0.0.0.0":
                        break
                    line_index += 1
                else:
                    return gateway

            if line_index < len(lines):
                line = lines[line_index]
                gateway = line.split()[def_gw_col]

        return gateway

    def __getVlans(self, ifc_name):
        vlan_list = []
        cmd = "ls /sys/class/net"
        rc, out, _ = Command(cmd).execute()
        if rc:
            CommonHelpFuncs.printToLogAndRaiseMftCoreException(
                                        "Command not found - %s" % cmd, logger)
        interfaces = out.split()
        for ifc in interfaces:
            # if ifc_name in ifc :
            if ifc_name in ifc and ifc_name != ifc:
                vlan_list.append(ifc)
        return vlan_list

    def __getVlansInfo(self, ifc_name):
        ifcs = self.__getVlans(ifc_name)
        vlans_info = {}

        for ifc in ifcs:
            ips_str = self.__getInfIpsAndMaskString(ifc)
            vlan_id = self.__getValdID(ifc)
            vlans_info[ifc] = (ips_str, vlan_id)
        return vlans_info

    def __getValdID(self, name):
        ret = "NA"
        parsed_name = name.split(".")
        if len(parsed_name) > 1:
            ret = parsed_name[1]
        return ret

    def __setBondToInfoDict(self, ifc_info_dict, entry_creator, ifcUid,
                            parent_dict_idx):
        bond_name = self.__getBondName(ifcUid)
        bond_dict = {}
        bond_info_index = 0
        is_bond_exist = False

        if bond_name:
            is_bond_exist = True
        else:
            bond_name = "NA"

        # add bond name
        attr_name = "name"
        disp, desc = self.BOND_ATTR_MAPPING.get(attr_name)
        bond_dict[attr_name] = entry_creator.create_entry(
            disp, bond_name, desc, bond_info_index)
        bond_info_index += 1

        if is_bond_exist:
            # add bond IPs
            ips_str = self.__getInfIpsAndMaskString(bond_name)
            attr_name = "ipAddrs"
            disp, desc = self.GENERAL_ATTR_MAPPING.get(attr_name)
            bond_dict[attr_name] = entry_creator.create_entry(
                disp, ips_str, desc, bond_info_index)
            bond_info_index += 1

            # add bond members
            members_str = self.__getBondMembersString(bond_name)
            attr_name = "members"
            disp, desc = self.BOND_ATTR_MAPPING.get(attr_name)
            bond_dict[attr_name] = entry_creator.create_entry(
                disp, members_str, desc, bond_info_index)
            bond_info_index += 1

            # add Vlans
            self.__setVlansToInfoDict(
                bond_dict, entry_creator, bond_name, bond_info_index)
            bond_info_index += 1

        attr_name = "bondInfo"
        disp, desc = self.BOND_ATTR_MAPPING.get(attr_name)
        ifc_info_dict[attr_name] = entry_creator.create_entry(
            disp, bond_dict, desc, parent_dict_idx)

    def __getBondName(self, ifc_name):
        bond_name = ""
        cmd = "ip addr show %s | grep master" % ifc_name
        rc, out, _ = Command(cmd).execute()
        # if there is master
        if not rc:
            out_list = out.split()
            flag = False
            for word in out_list:
                if word == "master":
                    flag = True
                elif flag:
                    if self.__checkMasterIsBond(word):
                        bond_name = word
                    break

        return bond_name

    def __checkMasterIsBond(self, master):
        ret = False
        # check master is bond
        cmd = "cat /sys/class/net/bonding_masters"
        rc, out, _ = Command(cmd).execute()
        if not rc:
            bonds = out.split()
            for bond in bonds:
                if bond == master:
                    ret = True
                    break
        return ret

    def __getBondMembersString(self, bond_name):
        members_str = "NA"
        cmd = "cat /proc/net/bonding/%s | grep \"Slave Interface\"" % bond_name
        rc, out, _ = Command(cmd).execute()
        if rc:
            CommonHelpFuncs.printToLogAndRaiseMftCoreException(
                                        "Command not found - %s" % cmd, logger)
        members = []
        if out:
            for line in out.splitlines():
                line_parts = line.split(":")
                if len(line_parts) > 1:
                    members.append(line_parts[1].strip())

        if members:
            members_str = ', '.join(members)

        return members_str

####################################################################################
###################################  WINDOWS  ######################################
####################################################################################

    def commandExecute(self, cmd, failMsg=None, raiseExcept=True):
        rc, out, err = Command(cmd).execute()
        if failMsg is None:
            failMsg = "Command %s has failed." % cmd
        if out == "":
            out = err
        if raiseExcept and rc:
            CommonHelpFuncs.printToLogAndRaiseMftCoreException(
                                    "%s. %s" % (failMsg, out), logger)
        return out
    
    def powerShellCommandExec(self, cmd, logger, failMsg=None, raiseExcept=True):
        pshell_cmd = 'powershell -command "%s"' % cmd
        out = self.commandExecute(pshell_cmd, failMsg, raiseExcept)
        logger.info("Runnig command in powershell - %s" % cmd)
        return out
    
    def __win_get_mac(self, ifc_name):
        netAdapters = self.powerShellCommandExec("Get-NetAdapter ", logger)
        macPtrn = re.compile("%s\s+M.*\s(?P<macAddr>(\w+-){5}\w+)" % ifc_name)
        match = macPtrn.search(netAdapters)
        if match == None:
            CommonHelpFuncs.printToLogAndRaiseMftCoreException(
                                    "Failed to get Interface MAC", logger)
        return match.group("macAddr").strip()
        
    
    def __win_get_ifc_name(self, bdf_dict):
        netAdaptersHwInfo = self.powerShellCommandExec("Get-NetAdapterHardwareInfo ", logger)
        ifcPtrn = re.compile ("\s+(?P<ifcName>[\w\s]+\d*)\s*\d+\s*%s\s*%s\s*%s.*" % (bdf_dict["Bus"], bdf_dict["Device"], bdf_dict["Function"]))
        match = ifcPtrn.search(netAdaptersHwInfo)
        if match == None:
            CommonHelpFuncs.printToLogAndRaiseMftCoreException(
                                    "Failed to get Interface Hardware Info", logger)
        return match.group("ifcName").strip()
    
    def __win_get_mlx5_dict(self):
        mlx5_dict = dict()
        rc, out, _ = Command("mlx5cmd -stat -json").execute()
        if rc:
            CommonHelpFuncs.printToLogAndRaiseMftCoreException(
                        "Failed to get Interface Info from the driver", logger)
        mlx5Json = json.loads(out)
        for nicName, nicDict in mlx5Json.iteritems():
            for adapter, adapterInfo in nicDict["Adapters"].iteritems():
                ifc_dict = dict()
                bdfDict = adapterInfo["physical_location"]
                ifc_dict["bus-info"] = "0000:%02x:%02x.%x" % (
                    int(bdfDict["Bus"]), int(bdfDict["Device"]), int(bdfDict["Function"]))
                ifc_dict["driver"] = "MLNX_WinOF2"
                ifc_dict["version"] = nicDict["Info"].get("driver_ver", "NA")
                ifc_dict["mtu"] = adapterInfo.get("active_mtu", "NA")
                ifc_dict["speed"] = adapterInfo.get("link_speed", "NA")
                ifc_dict["link detected"] = adapterInfo.get("state", "NA")
                mlx5_dict[self.__win_get_ifc_name(bdfDict)] = ifc_dict
        return mlx5_dict
        
    def __win_exec(self, ifc_info_dict, req, entry_creator):
        driver_info_dict = dict()
        adapterName = req["ifcUid"]
        ipInfoDict =  self.__get_cmd_output_dict("netsh interface ip show addresses \"%s\"" % req["ifcUid"])

        ifcsDict = self.__win_get_mlx5_dict()
        
        ifcsDict[adapterName]["permanent address"] = self.__win_get_mac(adapterName)
        ifcsDict[adapterName]["ipAddrs"] = ipInfoDict.get("ip address", "NA")
        ifcsDict[adapterName]["subnetPrefix"] = ipInfoDict.get("subnet prefix", "NA")
        local_idx = 0
        for attr, val in ifcsDict[adapterName].iteritems():
            if attr in self.IFC_ATTR_MAPPING:
                (attr_name, disp, desc) = self.IFC_ATTR_MAPPING[attr]
                ifc_info_dict[attr_name] = entry_creator.create_entry(disp, val, desc, local_idx)
                local_idx += 1
            elif attr in self.DRIVER_ATTR_MAPPING:
                (attr_name, disp, desc, is_driver, entry_index) = self.DRIVER_ATTR_MAPPING[attr]
                if is_driver:
                    _dict = driver_info_dict
                    index = entry_index
                else:
                    _dict = ifc_info_dict
                    index = local_idx
                    local_idx += 1
                _dict[attr_name] = entry_creator.create_entry(
                    disp, val, desc, index)
                local_idx += 1
            
        
        attr_name = "driver"
        disp, desc = self.GENERAL_ATTR_MAPPING.get(attr_name)
        ifc_info_dict[attr_name] = entry_creator.create_entry(
            disp, driver_info_dict, desc, local_idx)
        local_idx += 1
        
        for attrs in ["mtu", "ipAddrs", "subnetPrefix"]:
            disp, desc = self.GENERAL_ATTR_MAPPING.get(attrs)
            ifc_info_dict[attrs] = entry_creator.create_entry(
                disp, ifcsDict[adapterName][attrs], desc, local_idx)
            local_idx += 1
         