# --
#                 - Mellanox Confidential and Proprietary -
#
# Copyright (C) Jan 2013, Mellanox Technologies Ltd.  ALL RIGHTS RESERVED.
#
# Except as specifically permitted herein, no portion of the information,
# including but not limited to object code and source code, may be reproduced,
# modified, distributed, republished or otherwise exploited in any form or by
# any means for any purpose without the prior written permission of Mellanox
# Technologies Ltd. Use of software subject to the terms and conditions
# detailed in the file "LICENSE.txt".
# --


# @author: Simon Raviv
# @date: April 02, 2017

import time
import re

import mtcr

from request_factory import RequestFactory
from msg_mgr import MsgMgr, MsgConstants
from performance.entities.connectx_device import ConnectXDevice
from performance.common.counters_names import CountersNames
from performance.common.units_measure import UnitsMeasure
from performance.common.utilization_reference import UtilizationReference \
    as Reference
from performance.performance_exceptions import PerformanceException
from performance.services.analysis.analyzers.layer_3.cx_family_pci_inbound \
    import CXFamilyPCIInboundUsedBW, CXFamilyPCIInboundAvailableBW


RATE_FACTOR = 1000000
INTERVAL_SAMPLING_FIX = 268.435456


class ConnectXDeviceMCRA(ConnectXDevice):
    """ Base class for MCRA ConnectX devices.
    """

    def __init__(self, mstdev, family, device_id):
        super(ConnectXDeviceMCRA, self).__init__(mstdev, family)

        self.mcra = None
        self.mf = None
        self.device_error_message = "Invalid device ID: {device_id}"
        self.timestamp = int(time.time() * 1000)

        try:
            self.mf = mtcr.MstDevice(mstdev)
            self.name = mstdev
            self.mstdev = mstdev
            self.set_device_id(device_id)
        except:
            self.set_error('Invalid mstdev: ' + mstdev)
            raise PerformanceException('Invalid mstdev: ' + mstdev)
        self.set_clock()
        # The closest sampling interval to 1 second in HW is 2^28 cycles
        # which are 268,435,456 cycles or 0.894 of a second for a device
        # with 300MHz chip.
        self.one_sec_correction = INTERVAL_SAMPLING_FIX / self.dev_clock

    def set_clock(self):
        """ Set device clock frequency attribute.
        """
        sub_req_obj = RequestFactory.createRequest(
                      1,
                      "mftCore",
                      "GetDeviceUpTime",
                      {"devUid": self.mstdev})
        _, rc, response = MsgMgr().handle(MsgConstants.NEOHOST_REQUEST,
                                          sub_req_obj)
        if rc != 0:
            raise PerformanceException("Failed to get device frequency")

        frequency_str = response["coreFrequency"]["val"]
        frequency_str = re.sub('[ MHz]', '', frequency_str)
        self.dev_clock = float(frequency_str)

        if not self.dev_clock:
            self.dev_clock = self.mf.read4(0x1544)

    def set_device_id(self, device_id):
        """ Set device ID for sampling.
        """
        if device_id:
            self.device_id = device_id
        if not self.device_id:
            self.device_id = self.mf.read_field(
                addr=0xdf24, startBit=16, size=16)

    def counters_set_selector_values(self):
        """ Set device selector values for sampling.
        """
        for unit in self.mcra.units:
            unit.set_selectors(self.mf)

    def reset_counters(self):
        """ Reset device counters.
        """
        for unit in self.mcra.units:
            unit.reset_counters(self.mf)

    def read_counter_values(self):
        """ Read device counters.
        """
        for unit in self.mcra.units:
            unit.read_values(self.mf)

    def print_counter_values(self):
        """ Print device counters.
        """
        print self.get_data()

    def get_counter_value(self, name):
        """ Returns counter's value.
        """
        x = next(x for x in self.mcra.counters if x.name == name)
        return x.getValue()

    def counters_disable(self):
        """ Disables counters.
        """
        for unit in self.mcra.units:
            self.mf.write_field(val=0, addr=unit.en_addr,
                                startBit=unit.en_start_bit, size=1)

    def counters_enable(self):
        """ Enable counters.
        """
        for unit in self.mcra.units:
            self.mf.write_field(val=1, addr=unit.en_addr,
                                startBit=unit.en_start_bit, size=1)

    def get_data(self):
        """ Returns counter's record list.
        """
        counters = list()
        for counter in self.mcra.counters:
            counter.set_value(one_sec_correction=self.one_sec_correction)
            counter.timestamp = self.timestamp
            self.special_treatment(counter)
            counters.append(counter)

        for counter in self.mcra.regular_counters:
            counter.timestamp = self.timestamp
            self.special_treatment(counter)
            counters.append(counter)

        return counters

    def worker(self):
        """ Class worker.
        """
        time.sleep(0.1)

    def get_next_data(self):
        """ Returns next sampling device.
        """
        self.disable_all_performance_counters()
        self.counters_set_selector_values()
        self.reset_counters()
        self.counters_enable()
        self.enables_all_performance_counters_for_about_1_sec()
        while self.is_all_counters_running():
            self.worker()
        self.read_counter_values()
        self.sample_index += 1
        return self.get_data()

    def take_snapshot(self):
        """ Takes counters snapshot.
        """
        for unit in self.mcra.units:
            unit.take_snapshot(self.mf)

    def restore_from_snapshot(self):
        """ Restores counters snapshot.
        """
        for unit in self.mcra.units:
            unit.restore(self.mf)

    def get_reference_data(self, counters):
        """ Calculate reference value for the counters.
        """
        cycles = 0
        tx_packets = 0
        rx_packets = 0
        available_pcie_bandwith = 0
        available_rx_link_4_pcie_bandwith = 0
        available_rx_link_8_pcie_bandwith = 0
        available_rx_link_12_pcie_bandwith = 0
        available_tx_link_4_pcie_bandwith = 0
        available_tx_link_8_pcie_bandwith = 0
        available_tx_link_12_pcie_bandwith = 0
        rxc_cqe_zip_open_session = 0

        for counter in counters:
            value = counter.value
            if counter.name == CountersNames.DEVICE_CLOCKS:
                cycles = value * RATE_FACTOR
            # ConnectX-4/4LX counter for TX packets:
            elif counter.name == CountersNames.SXW_PACKET_SEND_SXW2SXP_GO_VID:
                tx_packets = value
            # ConnectX-5/5EX counter for TX packets:
            elif counter.name == CountersNames.SXP_BW_COUNT_PERF_COUNT_0_2:
                tx_packets = value
            # ConnectX-4/4LX counters for RX packets:
            elif counter.name in [CountersNames.RXT_CTRL_PERF_SLICE_LOAD_SLOW,
                                  CountersNames.RXT_CTRL_PERF_SLICE_LOAD_FAST]:
                rx_packets += value
            elif counter.name == CountersNames.RXB_LRO_FIFO_PERF_COUNT2:
                rx_packets -= value
            # ConnectX-5/5EX counter for RX packets:
            elif counter.name == CountersNames.RXB_BW_COUNT_PERF_COUNT_0_2:
                rx_packets = value
            elif counter.name == "PXDP_RX_128B_TOTAL":
                available_pcie_bandwith = value
            elif counter.name == CountersNames.RXC_CQE_ZIP_OPEN_SESSION:
                rxc_cqe_zip_open_session = value
            elif counter.name == "PXDP_RX_128B_DATA_LINK4":
                available_rx_link_4_pcie_bandwith = value
            elif counter.name == "PXDP_RX_128B_DATA_LINK8":
                available_rx_link_8_pcie_bandwith = value
            elif counter.name == "PXDP_RX_128B_DATA_LINK12":
                available_rx_link_12_pcie_bandwith = value
            elif counter.name == "PXDP_TX_128B_DATA_LINK4":
                available_tx_link_4_pcie_bandwith = value
            elif counter.name == "PXDP_TX_128B_DATA_LINK8":
                available_tx_link_8_pcie_bandwith = value
            elif counter.name == "PXDP_TX_128B_DATA_LINK12":
                available_tx_link_12_pcie_bandwith = value

        result = {
            Reference.CYCLES: cycles,
            Reference.RX_PACKETS: rx_packets,
            Reference.TX_PACKETS: tx_packets,
            Reference.TOTAL_PACKETS: tx_packets + rx_packets,
            Reference.AVAILABLE_PCIE_BANDWIDTH: available_pcie_bandwith,
            Reference.AVAILABLE_RX_LINK_4_PCIE_BANDWIDTH:
                available_rx_link_4_pcie_bandwith,
            Reference.AVAILABLE_RX_LINK_8_PCIE_BANDWIDTH:
                available_rx_link_8_pcie_bandwith,
            Reference.AVAILABLE_RX_LINK_12_PCIE_BANDWIDTH:
                available_rx_link_12_pcie_bandwith,
            Reference.AVAILABLE_TX_LINK_4_PCIE_BANDWIDTH:
                available_tx_link_4_pcie_bandwith,
            Reference.AVAILABLE_TX_LINK_8_PCIE_BANDWIDTH:
                available_tx_link_8_pcie_bandwith,
            Reference.AVAILABLE_TX_LINK_12_PCIE_BANDWIDTH:
                available_tx_link_12_pcie_bandwith,
            Reference.CQE_ZIPPING_SESSIONS: rxc_cqe_zip_open_session
            }
        return result

    def special_treatment(self, counter):
        """ Counter's special treatment for needed counters.
        """
        # Calculate frequency:
        if counter.name in \
            [CountersNames.DEVICE_CLOCKS_WRAPAROUND,
             CountersNames.DEVICE_CLOCKS]:
            counter.value = self.dev_clock
            counter.units = UnitsMeasure.MHZ

    def disable_all_performance_counters(self):
        """ Disables all performance counters.
        """
        self.mf.write_field(val=0, addr=0xe3080, startBit=0, size=1)
        self.mf.write_field(val=0, addr=0xe3084, startBit=0, size=5)

    def enables_all_performance_counters_for_about_1_sec(self):
        """ Enables performance counters.
        """
        # Enables all performance counters:
        self.mf.write_field(val=1, addr=0xe3080, startBit=0, size=1)
        # Enable counters for 2^28 clocks:
        # for 400Mhz = 268435456/400000000 = 0.67108864 seconds
        # for 330Mhz = 268435456/330000000 = 0.813440775 seconds
        self.mf.write_field(val=28, addr=0xe3084, startBit=0, size=5)

    def is_all_counters_running(self):
        """ Check if all counters are running.
        """
        return self.mf.read_field(addr=0xe3080, startBit=0, size=1)

    def read_global_clock(self):
        """ Returns device clock.
        """
        clock = self.mf.read4(0x96700)
        return clock

    def worker(self):

        self.mcra.set_latency_counters(self.mf)
        time.sleep(0.2)
        del self.mcra.regular_counters[:]
        self.mcra.get_latency_counters(self.mf, self.dev_clock)

    def _initialize_analyzers(self):
        """ Initialize analyzers objects.
        """
        self._analyzers += [
            CXFamilyPCIInboundUsedBW(),
            CXFamilyPCIInboundAvailableBW()
            ]
