#!/usr/bin/python

# --
#                 - Mellanox Confidential and Proprietary -
#
# Copyright (C) Jan 2013, Mellanox Technologies Ltd.  ALL RIGHTS RESERVED.
#
# Except as specifically permitted herein, no portion of the information,
# including but not limited to object code and source code, may be reproduced,
# modified, distributed, republished or otherwise exploited in any form or by
# any means for any purpose without the prior written permission of Mellanox
# Technologies Ltd. Use of software subject to the terms and conditions
# detailed in the file "LICENSE.txt".
# --

# Python Imports #####
import os
import json
import logging
import jsonschema

# Local Imoprts #####
import neohost_exceptions
from path_provider import MPathProvider
from plugin_manager import MPluginPathFinder
from common_meta import MEnumMeta
from neohost_common import NeoHostCommon

####################
# initialize logger#
####################
logger = logging.getLogger("neohost." + __name__)


class ValidationMode(object):
    __metaclass__ = MEnumMeta
    Request = 0
    Params = 1
    Response = 2
    EXCEPTION_MAP = {
        Request: neohost_exceptions.MInvalidRequest,
        Response: neohost_exceptions.MInvalidResponse,
        Params: neohost_exceptions.MInvalidParams,
    }


class SchemaPath(object):
    __metaclass__ = MEnumMeta
    SCHEMAS_DIR = "schemas"
    BASIC_ASYNC_RESULT_SCHEMA = "BasicAsyncResult.schema.json"
    BASIC_REQUEST_SCHEMA = "BasicRequest.schema.json"
    BASIC_ERROR_SCHEMA = "BasicError.schema.json"

    @classmethod
    def getSchemaPath(cls, plugin_dir, schema_name):
        return os.path.normpath(
            os.path.join(plugin_dir, cls.SCHEMAS_DIR, schema_name))

    @classmethod
    def getMethodInputSchema(cls, method_name):
        return "%s.IN.schema.json" % method_name

    @classmethod
    def getMethodInputSchemaForWindows(cls, method_name):
        return "%s.IN.Windows.schema.json" % method_name

    @classmethod
    def getMethodOutputSchema(cls, method_name):
        return "%s.OUT.schema.json" % method_name

    @classmethod
    def getSpecialMethodOutputSchemaForWindows(cls, method_name):
        return "%s.OUT.Windows.schema.json" % method_name



class MSchemaValidator(object):
    """Schema validation class, validates user request and response against
    built-in/plug-in-specific schemas."""

    def __init__(self):
        self._running_dir = os.path.dirname(os.path.realpath(__file__))
        self._core_dir = os.path.join(
            MPathProvider().get_root_dir(), MPathProvider.CORE_DIR)
        self._path_finder = MPluginPathFinder()

    def validate_basic_request(self, request):
        """Validates the basic layout of a request string, returns the request
        as a dictionary.
        Raises an exception in case of failure."""
        schema_path = SchemaPath.getSchemaPath(
            self._core_dir, SchemaPath.BASIC_REQUEST_SCHEMA)
        return self.__validate_comm(
            request, schema_path, ValidationMode.Request)

    def validate_request_params(self, request_dict):
        """Validates the Request params of a request dictionary as returned by
        validate_basic_request using plugin specific schemas.
        Returns the request dictionary as it.
        Raises an exception in case of failure."""
        method_name = request_dict["method"]
        module_name = request_dict["module"]
        # get plugin path
        try:
            plugin_path = self._path_finder.get_plugin_path(module_name)
        except Exception as e:
            raise neohost_exceptions.MMethodNotFound(
                "Method (%s.%s) not found: %s" %
                (module_name, method_name, str(e)))
        # get schema
        schema_found = False
        if NeoHostCommon.isWindowsOs():
            schema_name = SchemaPath.getMethodInputSchemaForWindows(method_name)
            cmd_schema_path = SchemaPath.getSchemaPath(plugin_path, schema_name)
            schema_found = os.path.exists(cmd_schema_path)
        if not schema_found:
            schema_name = SchemaPath.getMethodInputSchema(method_name)
            cmd_schema_path = SchemaPath.getSchemaPath(plugin_path, schema_name)
            schema_found = os.path.exists(cmd_schema_path)
        logger.debug("params schema path: %s" % cmd_schema_path)
        if not schema_found:
            raise neohost_exceptions.MMethodNotFound(
                "Method (%s.%s) not found: Failed to locate %s" %
                (module_name, method_name, schema_name))
        return self.__validate_comm(
            request_dict["params"], cmd_schema_path,
            ValidationMode.Params)

    def validate_response(self, request_dict, response):
        """Validates the response params of a response string or dictionary
        using plugin specific schemas.
        Returns the response as a dictionary.
        Raises an exception in case of failure."""
        if type(response) is str:
            try:
                response = json.loads(response)
            except Exception as e:
                logger.error("Failed to parse response '%s' as json: %s",
                             response, str(e))
                raise neohost_exceptions.MInvalidResponse(
                    "Internal Error: Got unexpected response, "
                    "see log for details")
        method_name = request_dict["method"]
        if "result" in response:
            logger.debug(
                "validating response against plugin specific output schema")
            if request_dict["execMode"] & 2:
                schema_name = SchemaPath.BASIC_ASYNC_RESULT_SCHEMA
                plugin_dir = self._core_dir
            else:
                schema_name = SchemaPath.getMethodOutputSchema(method_name)
                plugin_dir = self._path_finder.get_plugin_path(
                    request_dict["module"])
            validatee = response["result"]
        elif "error" in response:
            logger.debug("validating response against basic error schema")
            schema_name = SchemaPath.BASIC_ERROR_SCHEMA
            plugin_dir = self._core_dir
            validatee = response["error"]
        else:
            logger.error("invalid format for response '%s'", response)
            raise neohost_exceptions.MInvalidResponse(
                "Internal Error: Unexpected response format, "
                "see log for more details")

        # if windows - check first if special schema exist
        schema_found = False
        if NeoHostCommon.isWindowsOs():
            special_schema_name = SchemaPath.getSpecialMethodOutputSchemaForWindows(method_name)
            schema_path = SchemaPath.getSchemaPath(plugin_dir, special_schema_name)
            schema_found = os.path.exists(schema_path)
            if schema_found:
                schema_name = special_schema_name
        if not schema_found:
            schema_path = SchemaPath.getSchemaPath(plugin_dir, schema_name)

        if not os.path.exists(schema_path):
            logger.error("Failed to locate response scheme: %s", schema_name)
            raise neohost_exceptions.MPathError(
                "Internal Error: Failed to validate response, "
                "see log for more details")
        self.__validate_comm(
            validatee, schema_path, ValidationMode.Response)
        return response

    def __validate_comm(self, validatee, schema_path, validation_mode):
        """Private method used for validating a python dictionary against a
        JSON schema"""
        exception_class = ValidationMode.EXCEPTION_MAP[validation_mode]

        logger.debug("performing validation against: %s" % schema_path)
        schema = None
        with open(schema_path, "r") as fd:
            try:
                schema = json.load(fd)
            except Exception as e:
                logger.error("Failed to parse schema %s as json: %s" %
                    (schema_path, str(e)))
                raise neohost_exceptions.MJsonConversionError(
                    "Failed to parse schema %s as json: %s" %
                    (schema_path, str(e)))
        try:
            jsonschema.validate(validatee, schema)
        except (jsonschema.ValidationError, jsonschema.SchemaError) as e:
            schema_name = os.path.basename(schema_path)
            logger.error("Failed to validate against schema: %s. %s" %
                (schema_name, e.message))
            raise exception_class(
                "Failed to validate against schema: %s. %s" %
                (schema_name, e.message))
        return validatee
