#!/usr/bin/env python2
import json
import argparse
import re
import hostlist
import math
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.base import MIMEBase
from email import encoders
import ntpath
import sys
import os.path
import tarfile


# each entry in file_list contains of the full name to the json/csv file without the extension
def send_email(subject, result_msg, addresses, file_list):
    sender_email = "clusterkit@mellanox.com"
    msg = MIMEMultipart()
    msg['Subject'] = subject
    msg['From'] = sender_email
    msg['To'] = ''
    for file_name in file_list:
        try:
            fp = open(file_name, "rb")
            attachment = MIMEBase(_maintype="text", _subtype="csv")
            attachment.set_payload(fp.read())
            fp.close()
            encoders.encode_base64(attachment)
            attachment.add_header("Content-Disposition", "attachment", filename=ntpath.basename(file_name))
            msg.attach(attachment)
        except:
            pass

    msg.attach(MIMEText(result_msg))
    try:
        mailserver = smtplib.SMTP('localhost')
        print("Sending results...")
        print("recipients: {}".format(addresses))
        ret_val = mailserver.sendmail(sender_email, addresses, msg.as_string())
        print("Sending finished successfully!")
        mailserver.close()
    except:
        ret_val = None
        print("Couldn't deliver the message")
        exit(0)

    failed_to_send = "Couldn't deliver the message to the following emails:\n"
    if ret_val:
        for addr in ret_val:
            failed_to_send += addr + " "
        print(failed_to_send)


# create a tar file with tar_name of contains the json files in path directory
def create_tar(path, tar_name):
    with tarfile.open(tar_name, "w:gz") as tar_handle:
            for cur_file in os.listdir(path):
                if ".json" in cur_file:
                    full_name = os.path.join(path, cur_file)
                    file_name = ntpath.basename(full_name)
                    tar_handle.add(full_name, arcname=file_name)

# get string of emails, comma delimitered, return a list of the valid emails
def create_email_list(emails_str):
    if not emails_str:
        return list()
    valid_emails = emails_str.split(',')
    for addr in valid_emails:
        if not re.match(r"[^@]+@[^@]+\.[^@]+", addr):
            valid_emails.remove(addr)
            print("Invalid email address: {}".format(addr))
    return valid_emails


class BWAnalyzer:
    def __init__(self, file_name, email, admin_emails, rate, percentages):
        self.email = create_email_list(email)
        self.admin_emails = create_email_list(admin_emails)

        if file_name.endswith(".json"):
            file_name = file_name[:-5]
        with open(file_name + ".json") as f:
            self.bandwidth = json.load(f)
        try:
            noise_file = file_name.replace("bandwidth", "noise")
            with open(noise_file + ".json") as f:
                self.noise_json = json.load(f)
                latency_file = file_name.replace("bandwidth", "latency")
            with open(latency_file + ".json") as f:
                self.latency_json = json.load(f)
        except:
            pass

        if percentages:
            self.percentages = percentages.split(',')
        else:
            self.percentages = [85, 50]
        # normal - self.percentages[0], extreme - self.percentages[1]
        self.file_name = file_name
        # the nodes_num is +1 because we get half a matrix so the last line is omitted
        self.nodes_num = len(self.bandwidth["Links"]) + 1
        self.expected_rate = rate
        self.bad_bandwidth = 0
        self.avg_bandwidth = 0
        self.extreme_bandwidth = 0
        self.PROPER_BANDWIDTH_PERC = 0.97  # default percentage of the maximum bandwidth that is considered a good result
        self.proper_bandwidth = sys.maxint
        self.proper_nodes = list()
        self.nodes_list = sorted(list(self.bandwidth["Nodes"]), key=lambda e: (len(e), e))
        self.positive_deviation_sum = list()
        self.statistics_dict = dict()
        self.final_message_dict = dict()
        self.RESULT = "Result"
        self.NODE = "Node"
        self.SHOW_PERC = 0.98
        self.SEND_ADMIN = 0.93
        # keys are known clusters name in lowercase
        self.KNOWN_CLUSTERS_BW = {"hercules": {"node_prefix": "clx-hercules", "bw":22000},
                                  "everest": {"node_prefix": "clx-everest", "bw": 22000} ,
                                  "ppc": {"node_prefix": "clx-ppc", "bw": 10000},
                                  "jazz": {"node_prefix": "jazz", "bw": 11000}
                                  }
        # max/min deviation from all nodes
        self.max_deviation = -sys.maxint-1
        self.min_deviation = sys.maxint
        # email templates
        self.base_bandwidth_template = ""
        self.general_test_summary = ""
        self.bad_nodes_template = ""
        self.extreme_bad_nodes_template = ""
        self.all_nodes_under_baseline_template = ""

    # get the right value from the half matrix
    def getCell(self, test_json, i, j):
        if i == j:
            return 0
        if i > j:
            return test_json["Links"][j][i - j]
        return test_json["Links"][i][j - i]

    def get_keys_by_vals(self, dictionary, search_vals):
        # gets a list of 2 values as search_vals and returns a formatted string
        # of their corresponding keys
        found = 0
        key_list = list()
        for key, val in dictionary.items():
            if val in search_vals:
                key_list.append(key)
                found += 1
                if found == 2:
                    return "{}, {}".format(key_list[1], key_list[0])

    def statistic_dict_to_msg(self, statistic_dict, test_name, unit):
        min_of_avg_perc = ""
        min_val = statistic_dict["min"][self.RESULT]
        max_val = statistic_dict["max"][self.RESULT]
        avg_val = statistic_dict["avg"]
        min_nodes = statistic_dict["min"][self.NODE]
        max_nodes = statistic_dict["max"][self.NODE]
        if test_name == "bandwidth":
            if min_val < avg_val * self.SHOW_PERC:
                perc_diff =  round((avg_val - min_val) / float(avg_val) * 100)
                min_of_avg_perc = "\t({}% below the avg)".format(perc_diff)

        msg = "Minimum {0}: {1}{6}\t{2}{7}\nMaximum {0}: {3}{6}\t{4}\n" \
              "Average {0}: {5}{6}\n".format(test_name, min_val, min_nodes, max_val, max_nodes,
                                             avg_val, unit, min_of_avg_perc)
        return msg

    def noiseStatistics(self):
        noise_json = self.noise_json
        max_dict = {self.NODE: noise_json["Nodes"][0][self.NODE], self.RESULT: noise_json["Nodes"][0][self.RESULT]}
        min_dict = {self.NODE: noise_json["Nodes"][0][self.NODE], self.RESULT: noise_json["Nodes"][0][self.RESULT]}
        sum_res = 0
        cores_num = len(noise_json["Nodes"])
        for node in noise_json["Nodes"]:
            sum_res += node[self.RESULT]
            if node[self.RESULT] > max_dict[self.RESULT]:
                max_dict[self.RESULT] = node[self.RESULT]
                max_dict[self.NODE] = node[self.NODE]
            if node[self.RESULT] < min_dict[self.RESULT]:
                min_dict[self.RESULT] = node[self.RESULT]
                min_dict[self.NODE] = node[self.NODE]

        avg = round(sum_res / cores_num, 3)
        ret_val = {"max": max_dict, "min": min_dict, "avg": avg}
        return ret_val

    # get the result of a pairwise test in json format
    # returns a dict of min, max, avg values paired with the relevant nodes
    def pairwiseStatistics(self, test_json):
        node_names = test_json["Nodes"]
        init_val = self.getCell(test_json, 0, 1)
        init_nodes_str = self.get_keys_by_vals(node_names, [0, 1])
        max_dict = {self.NODE: init_nodes_str, self.RESULT: init_val}
        min_dict = {self.NODE: init_nodes_str, self.RESULT: init_val}
        sum_res = 0

        nodes_num = len(test_json["Links"]) + 1
        pairs_num = nodes_num * (nodes_num - 1)
        for i in range(0, nodes_num):
            for j in range(0, nodes_num):
                if i != j:
                    val = self.getCell(test_json, i, j)
                    sum_res += val
                    if val > max_dict[self.RESULT]:
                        max_dict[self.RESULT] = val
                        max_dict[self.NODE] = self.get_keys_by_vals(node_names, [i, j])
                    if val < min_dict[self.RESULT]:
                        min_dict[self.RESULT] = val
                        min_dict[self.NODE] = self.get_keys_by_vals(node_names, [i, j])

        avg = round(sum_res / pairs_num, 3)
        if test_json["Testname"] == "bandwidth":
            # round all the bandwidth results
            max_dict[self.RESULT] = int(max_dict[self.RESULT])
            min_dict[self.RESULT] = int(min_dict[self.RESULT])
            avg = int(avg)
        ret_val = {"max": max_dict, "min": min_dict, "avg": avg}
        return ret_val

    def get_statistics(self):
        self.statistics_dict["bandwidth"] = self.pairwiseStatistics(self.bandwidth)
        try:
            self.statistics_dict["latency"] = self.pairwiseStatistics(self.latency_json)
            self.statistics_dict["noise"] = self.noiseStatistics()
        except:
            pass

    def init_final_message(self):
        self.final_message_dict['$Cluster'] = self.bandwidth['Cluster']
        self.final_message_dict['$HCA'] = self.bandwidth['HCA_Tag']
        self.final_message_dict['$examined_nodes'] = hostlist.collect_hostlist(list(self.bandwidth['Nodes'].keys()))
        self.final_message_dict['$examined_bad_nodes'] = list()
        self.final_message_dict['$examined_extreme_bad_nodes'] = list()
        self.final_message_dict['all_nodes_under_baseline'] = False

    def find_suspicious_nodes_by_bad_bandwidth(self):
        # iterating over all pairs of nodes and get:
        # bad_bandwidth nodes - nodes with bad bandwidth with all nodes
        # proper_nodes - nodes with proper width with all nodes
        # calculation of the deviation sum and min/max between all nodes
        max_dev = -sys.maxint-1
        min_dev = sys.maxint
        for i in range(0, self.nodes_num):
            proper_nodes_count = 0
            lowestbw = sys.maxint
            count_of_bad_bandwidth = 0
            self.positive_deviation_sum.append({'dev_sum': 0, 'lowestbw' : 0})
            for j in range(0, self.nodes_num):
                if i != j:
                    val = round(self.getCell(self.bandwidth, i, j))
                    if val < self.bad_bandwidth:
                        count_of_bad_bandwidth += 1
                    elif val > self.proper_bandwidth:
                        proper_nodes_count += 1

                    lowestbw = min(lowestbw, val)
                    self.positive_deviation_sum[i]['dev_sum'] += (val - self.avg_bandwidth)
            self.positive_deviation_sum[i] = {'dev_sum': int(self.positive_deviation_sum[i]['dev_sum']), 'lowestbw': lowestbw}
            # print "{}:".format(i), self.positive_deviation_sum[i]
            max_dev = max(max_dev, self.positive_deviation_sum[i]['dev_sum'])
            min_dev = min(min_dev, self.positive_deviation_sum[i]['dev_sum'])
            if count_of_bad_bandwidth >= self.nodes_num - 1:
                self.final_message_dict['$examined_bad_nodes'].append({'node': self.nodes_list[i], 'lowestbw': lowestbw})
            elif proper_nodes_count >= self.nodes_num - 1:
                self.proper_nodes.append(self.nodes_list[i])

        self.max_deviation = max_dev
        self.min_deviation = min_dev

    # TODO: remove criteria / count extreme results only with nodes that are not in bad_nodes_list
    def find_suspicious_nodes_by_extreme_bandwidth(self):
        # iterating over all pairs of nodes and get:
        # extreme bad bandwidth nodes - nodes that had extreme bad bandwidth with at least one node
        # and did not count in the bad bandwidth node list
        bad_nodes_list = [x['node'] for x in self.final_message_dict['$examined_bad_nodes']]
        for i in range(0, self.nodes_num):
            lowestbw = sys.maxint
            count_of_extreme_bad_bandwidth = 0
            if self.nodes_list[i] not in bad_nodes_list:
                for j in range(0, self.nodes_num):
                    if i != j and self.nodes_list[j] not in bad_nodes_list:
                        val = round(self.getCell(self.bandwidth, i, j))
                        if val < self.extreme_bandwidth:
                            count_of_extreme_bad_bandwidth += 1
                        lowestbw = min(lowestbw, val)

                if count_of_extreme_bad_bandwidth > 0:
                    cur_percentage = round(count_of_extreme_bad_bandwidth / float(self.nodes_num - 1) * 100)
                    self.final_message_dict['$examined_extreme_bad_nodes'].append(
                        {'node': self.nodes_list[i], 'percentages': cur_percentage,
                         'times': count_of_extreme_bad_bandwidth, 'total': (self.nodes_num - 1), 'lowestbw': lowestbw})

    def find_suspicious_nodes_by_deviation(self):
        # find suspicious nodes by their positive deviation sum
        # excluding the ones that are in the proper nodes list, and the ones that are already
        # suspicious - bad_nodes and extreme_bad_nodes
        max_dev = self.max_deviation
        bad_nodes_list = [x['node'] for x in self.final_message_dict['$examined_bad_nodes']]
        extreme_bad_nodes_list = [x['node'] for x in self.final_message_dict['$examined_extreme_bad_nodes']]
        suspicious_nodes = bad_nodes_list + extreme_bad_nodes_list
        for k in range(0, len(self.positive_deviation_sum)):
            if self.nodes_list[k] in self.proper_nodes:
                continue
            if self.positive_deviation_sum[k]['dev_sum'] < -1.5 * max_dev:
                if self.nodes_list[k] not in suspicious_nodes:
                    self.final_message_dict['$examined_bad_nodes'].append({'node': self.nodes_list[k], 'lowestbw': self.positive_deviation_sum[k]['lowestbw']})

    def resolve_tokens(self, tok, template):
        for t in tok:
            # print "t:",t,"\ntemplate:", template, "\nreplace:", tok[t]
            template = template.replace(t, str(tok[t]))
        return template

    def get_known_baseline_by_cluster_name(self, cluster_name):
        result = sys.maxint
        try:
            result = self.KNOWN_CLUSTERS_BW[cluster_name]["bw"]
        except:
            pass
        return result

    def get_known_baseline_by_node_prefix(self):
        result = sys.maxint
        nodes_str = hostlist.collect_hostlist(list(self.bandwidth['Nodes'].keys()))
        for cluster_name, cluster_dict in self.KNOWN_CLUSTERS_BW.items():
            try:
                node_prefix = cluster_dict["node_prefix"]
                if nodes_str.startswith(node_prefix):
                    result = self.KNOWN_CLUSTERS_BW[cluster_name]["bw"]
                    return result
            except:
                pass
        return result

    def get_known_baseline(self):
        # if cluster is not known, method will return maxint
        cluster_name = self.bandwidth['Cluster'].lower()

        if cluster_name in self.KNOWN_CLUSTERS_BW:
            result = self.get_known_baseline_by_cluster_name(cluster_name)
        else:
            result = self.get_known_baseline_by_node_prefix()
        return result

    def is_cluster_known_by_node_prefix(self, input_nodes_str):
        # iterate self.KNOWN_CLUSTERS_BW to find cluster by node name
        # return True if input_nodes_str has name of a known cluster
        for cluster_name, cluster_dict in self.KNOWN_CLUSTERS_BW.items():
            try:
                node_prefix = cluster_dict["node_prefix"]
                if input_nodes_str.startswith(node_prefix):
                    return True
            except:
                pass
        return False

    def is_cluster_known(self):
        # cluster is known if defined in self.KNOWN_CLUSTERS_BW
        result = False
        cluster_name = self.bandwidth['Cluster'].lower()
        if cluster_name in self.KNOWN_CLUSTERS_BW:
            result = True
        elif cluster_name == "unknown":
            nodes_str = hostlist.collect_hostlist(list(self.bandwidth['Nodes'].keys()))
            result = self.is_cluster_known_by_node_prefix(nodes_str)
        return result

    def check_known_nodes_under_baseline(self):
        if self.is_cluster_known():
            cluster_baseline = self.get_known_baseline()
            max_bandwidth = self.statistics_dict["bandwidth"]["max"][self.RESULT]
            if cluster_baseline > max_bandwidth:
                self.final_message_dict['all_nodes_under_baseline'] = True

    def create_templates(self):

        self.bad_nodes_template = """Nodes exhibited poor performance with all other nodes:
minimal bandwidth of $min_bandwidth MB/s
$bad_nodes_list"""

        self.extreme_bad_nodes_template = """Nodes exhibited poor performance with $bottom_perc%-$top_perc% of the nodes:
minimal bandwidth of $min_bandwidth MB/s
$extreme_bad_nodes_list

"""
        # initial message + bandwidth
        self.base_bandwidth_template = """Examined nodes: $examined_nodes

===============================
Bandwidth

$bandwidth_statistics
$bad_node_msg
$extreme_bad_node_msg"""

        self.general_test_summary = """
===============================
$test_name

$statistics"""

        self.all_nodes_under_baseline_template = """Bandwidth of all nodes is below the known baseline of cluster!
Known baseline bandwidth of cluster is ~$baseline MB/s

"""

    def prepare_email(self):
        # prepare the email message according to the statistics and data collected
        self.create_templates()
        msg_dict = self.final_message_dict
        bandwidth_stat = self.statistics_dict["bandwidth"]
        min_bw = bandwidth_stat["min"][self.RESULT]
        avg_bw = bandwidth_stat["avg"]
        if min_bw < avg_bw * self.SEND_ADMIN:
            self.email = list(set(self.email + self.admin_emails)) # avoid duplication

        bad_node_msg = ""
        if msg_dict['all_nodes_under_baseline']:
            baseline = self.get_known_baseline()
            bad_node_msg = self.all_nodes_under_baseline_template.replace("$baseline",
                                                                          str(baseline))
        should_check = False
        attention_nodes = []
        if len(msg_dict['$examined_bad_nodes']) > 0:
            should_check = True
            bad_nodes_list = [x['node'] for x in msg_dict['$examined_bad_nodes']]
            attention_nodes += bad_nodes_list
            bad_nodes_list = hostlist.collect_hostlist(bad_nodes_list)
            bad_nodes_lowestbw = int(min([x['lowestbw'] for x in msg_dict['$examined_bad_nodes']]))
            bad_nodes_dict = {'$min_bandwidth': bad_nodes_lowestbw, '$bad_nodes_list': bad_nodes_list}
            bad_node_msg += self.resolve_tokens(bad_nodes_dict, self.bad_nodes_template)
        elif not msg_dict['all_nodes_under_baseline']:
            bad_node_msg = "All nodes bandwidth OK"

        extreme_bad_node_msg = ""
        if len(msg_dict['$examined_extreme_bad_nodes']) > 0:
            should_check = True
            extreme_bad_node_msg = "--------------------\n\n"
            step_perc = 100
            step_size = 10
            perc_dict = dict()
            while step_perc > 0:
                perc_dict[int(step_perc / step_size)] = list()
                step_perc = step_perc - step_size
            for entry in msg_dict['$examined_extreme_bad_nodes']:
                step = math.ceil(entry['percentages'] / step_size)
                perc_dict[step].append(entry)
            key_list = perc_dict.keys()
            for step in reversed(key_list):
                lowbw_list = [x['lowestbw'] for x in perc_dict[step]]
                if lowbw_list:
                    ex_node_list = [x['node'] for x in perc_dict[step]]
                    ex_node_list_compressed = hostlist.collect_hostlist(ex_node_list)
                    attention_nodes += ex_node_list
                    extreme_bad_node_dict = {'$bottom_perc': step * step_size - 9, '$top_perc': step * step_size,
                                             '$min_bandwidth': int(min(lowbw_list)), '$extreme_bad_nodes_list': ex_node_list_compressed}
                    extreme_bad_node_msg += self.resolve_tokens(extreme_bad_node_dict, self.extreme_bad_nodes_template)

        bandwidth_statistics = self.statistic_dict_to_msg(self.statistics_dict["bandwidth"], "bandwidth", " MB/sec")
        noise_statistics = ""
        latency_statistics = ""
        try:
            noise_statistics = self.statistic_dict_to_msg(self.statistics_dict["noise"], "efficiency", "")
            latency_statistics = self.statistic_dict_to_msg(self.statistics_dict["latency"], "latency", " usec")
        except:
            pass

        bandwidth_dict = {'$examined_nodes': msg_dict['$examined_nodes'], '$bandwidth_statistics': bandwidth_statistics,
                          '$bad_node_msg': bad_node_msg, '$extreme_bad_node_msg': extreme_bad_node_msg}
        result_msg = self.resolve_tokens(bandwidth_dict, self.base_bandwidth_template)
        if latency_statistics != "":
            noise_dict = {'$test_name': 'Noise', '$statistics': noise_statistics}
            latency_dict = {'$test_name': 'Latency', '$statistics': latency_statistics}
            result_msg += self.resolve_tokens(noise_dict, self.general_test_summary)
            result_msg += self.resolve_tokens(latency_dict, self.general_test_summary)

        output_dir = os.path.dirname(self.file_name)
        if output_dir == "":
            output_dir = "./"
        time_stamp = ntpath.basename(self.file_name).replace("bandwidth_", "")
        # list of files to be attached to email
        files_to_send = list()
        for cur_file in os.listdir(output_dir):
            if ".csv" in cur_file:
                cur_file = os.path.join(output_dir, cur_file)
                if os.path.isfile(cur_file) and os.path.getsize(cur_file) > 0:
                    files_to_send.append(cur_file)
                elif "bandwidth" in cur_file:
                    result_msg += "\n***Failed to attach the csv file***"

        tar_file = os.path.join(output_dir, "{}_json_results_{}.tgz".format(msg_dict["$Cluster"],time_stamp))
        try:
            create_tar(output_dir, tar_file)
            files_to_send.append(tar_file)
        except:
            pass

        nodes_num = len(hostlist.expand_hostlist(msg_dict['$examined_nodes']))
        subject = 'clusterkit {0}, {2} nodes, {1}'.format(msg_dict['$Cluster'], msg_dict['$HCA'], nodes_num)
        if should_check:
            attention_nodes_compressed = hostlist.collect_hostlist(attention_nodes)
            subject = '{}, {}: {}'.format(subject, 'Attention required', attention_nodes_compressed)
        ret_val = subject, files_to_send, result_msg
        return ret_val

    def run_analysis(self):
        self.get_statistics()

        max_bandwidth = self.statistics_dict["bandwidth"]["max"][self.RESULT]
        self.avg_bandwidth = self.statistics_dict["bandwidth"]["avg"]
        self.bad_bandwidth = round(max_bandwidth * int(self.percentages[0]) / float(100))
        self.extreme_bandwidth = round(max_bandwidth * int(self.percentages[1]) / 100)

        self.init_final_message()

        self.proper_bandwidth = self.PROPER_BANDWIDTH_PERC * max_bandwidth

        self.find_suspicious_nodes_by_bad_bandwidth()

        self.find_suspicious_nodes_by_extreme_bandwidth()

        self.find_suspicious_nodes_by_deviation()

        self.check_known_nodes_under_baseline()

        subject, file_list, result_msg = self.prepare_email()
        print(result_msg)
        if self.email:
            send_email(subject, result_msg, self.email, file_list)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--emails",
                        help="send the results to a specified list of emails, comma delimited")
    parser.add_argument("-a", "--admin_emails",
                        help="send the results to a specified list of emails only in special cases")
    parser.add_argument("-r", "--rate",
                        help="send rate, will allow another examination of the bandwidth values", type=int)
    parser.add_argument("-pr", "--percentages",
                        help="specify the percentages out of the optimal bandwidth to determine what is considered low"
                             " performance\nand what is considered extremely low, comma delimited\ndefault is 85,50")
    requiredNamed = parser.add_argument_group('required arguments')
    requiredNamed.add_argument('-f', '--file_name', help='bandwidth json file name',
                               required=True)
    args = parser.parse_args()

    analyzer = BWAnalyzer(args.file_name, args.emails, args.admin_emails, args.rate, args.percentages)
    analyzer.run_analysis()
