import numpy as np

class TransitionMatrix:
    
    def __init__(self, risk_categories):
        self.risk_categories = risk_categories
        self.nb_risk_categories = len(risk_categories)
    
    def input_matrix_data(self, matrix_data):
        output_states = [cat for cat in self.risk_categories]
        output_states.append('D')
        matrix = matrix_data.pivot(index='credit_status', columns='credit_status_next', values='probability_scenario')
        matrix = matrix.reindex(index=self.risk_categories, columns=output_states, fill_value=np.nan)
        self.matrix = np.nan_to_num(matrix.to_numpy(), nan=0)
    
    def forecast_next_positions(self, positions):
        position_vector = np.zeros(len(self.risk_categories))
        for i, risk_category in enumerate(self.risk_categories):
            position = positions[positions['credit_status']==risk_category]['EAD']
            if len(position) > 0:
                position_vector[i] = position.iloc[0]
        next_positions = position_vector.dot(self.matrix)
        return next_positions