import pandas as pd

from .technical_analysis import TechnicalAnalysis

# Direction constants
BEARISH_CROSS, BULLISH_CROSS = (0, 1)


class Rsi(TechnicalAnalysis):
    """
    Implements a trading strategy based on the Relative Strength Index (RSI). The RSI is a momentum indicator
    that measures the magnitude of recent price changes to evaluate overbought or oversold conditions in the
    price of a stock or other asset.

    This strategy acts when the RSI crosses above the overbought threshold (indicating a sell signal) or below
    the oversold threshold (indicating a buy signal).

    :param api: The API interface for interacting with the trading system.
    :type api: :class:`traderion.client.TraderionClient`
    :param options: Strategy-specific options including the RSI calculation period and threshold values.
    :type options: dict
    :raises Exception: If required keys (`period`, `overbought_threshold`, `oversold_threshold`) are missing in `options`.
    """

    def __init__(self, api, options):
        """
        Initializes the RSI strategy with the specified API interface and options, including the RSI calculation period
        and overbought/oversold threshold values.
        """
        super().__init__(api, options)

        if not {'period', 'overbought_threshold', 'oversold_threshold'} <= set(options):
            raise Exception(
                f'Rsi options should contain the keys: period, overbought_threshold, oversold_threshold')

        self.period = options['period']
        self.overbought_threshold = options['overbought_threshold']
        self.oversold_threshold = options['oversold_threshold']

        if not (isinstance(self.period, int) and self.period > 0):
            raise Exception('Period must be a positive integer')

        if not all(isinstance(x, int) and 0 < x < 100 for x in [self.overbought_threshold, self.oversold_threshold]):
            raise Exception(
                'Overbought and oversold thresholds must be integers between 0 and 100 exclusively')

        self.current_amount = 0
        self.rsi = None
        self.last_cross = None
        self.compute_rsi()

    def on_price_curve_change(self, price_curve):
        """
        Responds to changes in the price curve by recalculating the RSI.

        :param price_curve: The updated price curve data.
        :type price_curve: list
        """
        super().on_price_curve_change(price_curve)
        self.compute_rsi()

    def compute_rsi(self):
        """
        Computes the Relative Strength Index (RSI) based on the latest price curve. This method updates the
        `rsi` attribute with the latest RSI value.
        """
        df = pd.Series(self.price_curve)
        diff = df.diff()
        diff = diff[1:]
        up, down = diff.clip(lower=0), diff.clip(upper=0)
        # Calculate the EWMA
        roll_up = up.ewm(span=self.period).mean()
        roll_down = down.abs().ewm(span=self.period).mean()

        # Calculate the RSI based on EWMA
        rs = (roll_up / roll_down).tolist()[-1]
        self.rsi = 100.0 - (100.0 / (1.0 + rs))

    def run(self):
        """
        Executes the strategy's logic. Determines if there's a new RSI signal and updates positions accordingly.
        """
        cross_direction = self.get_new_cross_direction()

        if cross_direction is not None:
            self.last_cross = cross_direction
            self.load(cross_direction)

    def get_new_cross_direction(self):
        """
        Determines the direction of any new signal based on the current RSI value and predefined thresholds.

        :return: The signal direction (BEARISH_CROSS for overbought, BULLISH_CROSS for oversold), or None if no new signal has occurred.
        :rtype: int or None
        """
        cross_direction = None
        if self.last_cross in [BEARISH_CROSS, None] and self.rsi < 30:
            cross_direction = BULLISH_CROSS
        elif self.last_cross in [BULLISH_CROSS, None] and self.rsi > 70:
            cross_direction = BEARISH_CROSS
        return cross_direction

    def load(self, direction):
        """
        Adjusts the trading position based on the signal direction, aiming to reach the target position amount.

        :param direction: The signal direction, either BEARISH_CROSS or BULLISH_CROSS.
        :type direction: int
        """
        sgn = (-1, 1)[direction == BULLISH_CROSS]
        remaining_amount = self.target_amount - sgn * self.current_amount
        if remaining_amount < 0:
            return
        else:
            while remaining_amount > 0 and direction == self.last_cross:
                ticket = self.api.max_ticket if remaining_amount >= self.api.max_ticket else remaining_amount
                amount = (self.api.hit_price(direction, ticket,
                          self.depth[direction][0]['price']))['amount']
                self.current_amount += sgn * amount
                remaining_amount -= amount

    def log_info(self):
        """
        Generates and returns a log message describing the current state of the strategy, including the current
        position amount, the RSI value, and the last signal direction.

        :return: A log message summarizing the strategy's current state.
        :rtype: str
        """
        last_cross = "Bearish" if self.last_cross == BEARISH_CROSS \
            else "Bullish" if self.last_cross == BULLISH_CROSS else None

        return f'RSI     -- Amount: {self.current_amount}, RSI: {self.rsi}, Last signal: {last_cross}'
