import joblib
import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import PolynomialFeatures, SplineTransformer
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.pipeline import make_pipeline
import json
import sys
import base64
from io import BytesIO
import traceback

if __name__ == "__main__":
    try:
        payload = json.loads(sys.argv[1])
        data = payload['points']
        
        # Get model parameters
        model_degree = payload.get('degrees', 6)
        model_type = payload.get('model_type', 'polynomial')
        create_job = payload.get('create_job', False)
        job_name = payload.get('job_name', 'test.joblib')
        job_path = payload.get('job_path', False)
        file_name = ''

        plot_title = ''
        
        # Print debug info
        print(f"Debug: model_type={model_type}, model_degree={model_degree}, data_points={len(data)}", file=sys.stderr)
        
        # Extract m2 and price values
        x = np.array([d['m2'] for d in data])
        y = np.array([d['price'] for d in data])
        
        # Check if we have enough data
        
        # Make sure x is properly shaped for sklearn
        x_shaped = x.reshape(-1, 1)
        
        # Generate points for plotting
        x_plot = np.linspace(x.min(), x.max(), 300).reshape(-1, 1)
        
        if model_type == 'polynomial':
            # Polynomial regression
            poly_features = PolynomialFeatures(degree=model_degree, include_bias=False)
            x_poly = poly_features.fit_transform(x_shaped)
            model = LinearRegression()
            model.fit(x_poly, y)
            
            # Coefficients and intercept
            coefficients = model.coef_
            intercept = model.intercept_
            
            # Save the coefficients and intercept
            model_parameters = {
                'model_type': model_type,
                'degree': model_degree,
                'coefficients': coefficients.tolist(),
                'intercept': float(intercept)
            }
            
            # Generate predictions
            x_poly_plot = poly_features.transform(x_plot)
            y_pred = model.predict(x_poly_plot)
            plot_title = 'Polynomial Regression Fit'
            
        elif model_type == 'spline':
            alpha = payload.get('alpha', 0.02)
            # Use 3 for degree (cubic spline) and model_degree for n_knots
            spline_degree = 2  # Fixed to cubic splines
            n_knots = model_degree  # Use the input degree as number of knots
            
            # Make sure we have enough data points for the requested knots
            min_knots = 3  # Minimum knots for stability
            max_knots = min(10, len(x) - 2)  # Maximum knots based on data size
            n_knots = max(min_knots, min(n_knots, max_knots))
            
            print(f"Debug: Using n_knots={n_knots}, spline_degree={spline_degree}", file=sys.stderr)
            
            # Create spline transformer
            spline = SplineTransformer(n_knots=n_knots, degree=spline_degree)
            
            # Create and fit pipeline
            model = make_pipeline(spline, Ridge(alpha=alpha))
            model.fit(x_shaped, y)

            if(create_job):
                joblib.dump(model, job_path+'/'+job_name, compress=3)
                file_name = job_name

            
            # Get feature names and coefficients
            feature_names = spline.get_feature_names_out(['x'])
            spline_coefficients = model[-1].coef_
            spline_intercept = model[-1].intercept_

            
            model_parameters = {
                'model_type': 'spline',
                'n_knots': n_knots,
                'degree': spline_degree,  # Always 3 for cubic splines
                'coefficients': dict(zip(feature_names, spline_coefficients.tolist())),
                'intercept': float(spline_intercept),
                'knots': spline.knots_[0].tolist() if hasattr(spline, 'knots_') else [],
                'alpha': alpha
            }

            # Access knots from the bsplines_ attribute
            if hasattr(spline, 'bsplines_') and len(spline.bsplines_) > 0:
                # For a single feature (which is your case), get the first BSpline's knots
                knots = spline.bsplines_[0].t
                # Remove the repeated knots that are added for the spline degree
                # (B-splines add 'degree' extra knots at each end)
                unique_knots = knots[spline_degree:-spline_degree]
                model_parameters['knots'] = unique_knots.tolist()
            else:
                # Fallback if bsplines_ is not available
                model_parameters['knots'] = []
                print(f"Debug: Cannot access knots, attributes available: {dir(spline)}", file=sys.stderr)
            
            # Generate predictions for plot
            y_pred = model.predict(x_plot)
            plot_title = 'Spline Regression Fit'
        else:
            raise ValueError(f"Invalid model_type: {model_type}")
        
        # Create plot
        plt.figure(figsize=(10, 6))
        plt.title(plot_title)
        plt.scatter(x, y, label='Faktisk data')
        plt.plot(x_plot, y_pred, color='red', label='Genererat pris')
        plt.xlabel('Size (m^2)')
        plt.ylabel('Pris')
        plt.legend(loc='upper left')
        plt.xlim(0, x.max() * 1.1)
        plt.ylim(0, y.max() * 1.1)
        plt.grid(True, alpha=0.3)
        
        # Save the plot to a BytesIO object and encode it
        buffer = BytesIO()
        plt.savefig(buffer, format='png', dpi=300)
        buffer.seek(0)
        image_png = buffer.getvalue()
        graph_image = base64.b64encode(image_png)
        graph_image = graph_image.decode('utf8')
        
        # Include the image in the response
        response = {
            'model_parameters': model_parameters,
            'poly': model_degree,  # Keep for backward compatibility
            'graph_image': graph_image,
            'file': file_name
        }
        
        # Convert response to JSON and print
        print(json.dumps(response))
        
        # Clear the buffer and close it
        buffer.close()
        
    except Exception as e:
        error_info = {
            'error': str(e),
            'traceback': traceback.format_exc()
        }
        print(json.dumps(error_info))
