import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgpu';

let autoencoder;
let model_1;
let model_2;
let fentanyl_model;
let xylazine_model;
let benzodiazepine_model;
let hydromorphone_ml;

async function loadModels(backend) {
    try {
        await tf.setBackend(backend);
        autoencoder = await tf.loadLayersModel(`/tf/autoencoder/model.json`);
        model_1 = await tf.loadGraphModel(`/tf/ml/model_1/model.json`);
        model_2 = await tf.loadGraphModel(`/tf/ml/model_2/model.json`);
        fentanyl_model = await tf.loadLayersModel(`/tf/fentanyl_ml/model.json`);
        xylazine_model = await tf.loadLayersModel(`/tf/xylazine_ml/model.json`);
        benzodiazepine_model = await tf.loadLayersModel(`/tf/benzodiazepine_ml/model.json`);
        hydromorphone_ml = await tf.loadLayersModel(`/tf/hydromorphone_ml/model.json`);
        console.log(`Models loaded on ${backend} backend!`);
    } catch (error) {
        loadModels(tf.getBackend());
        console.log(`Error loading models on ${backend} backend:`, error)
    }
}

(async () => {
    await tf.ready();
    loadModels('webgpu');
})();

export const getReconstructionError = async (spectra) => {
    let errors = [];
    let tensors = spectra.map(spectrum => {
        spectrum = spectrum.slice(0, 1576);
        const q = Math.hypot(...spectrum);
        spectrum = spectrum.map(e => e / q);
        let t = tf.tensor(spectrum);
        return t.reshape([spectrum.length]);
    });

    let t = tf.stack(tensors);
    let reconstructedSpectrum = autoencoder.predict(t);
    const errorTensor = await tf.metrics.meanSquaredError(t, reconstructedSpectrum);
    const errorData = await errorTensor.data();

    for (let i = 0; i < errorData.length; i++) {
        errors.push(errorData[i]);
    }
    return errors;
}

export const getMLPredictionArray = async (spectra) => {
    let tensors = spectra.map(spectrum => {
        let activeRegionSpectrum = spectrum.slice(0, 1576);
        let t = tf.tensor(activeRegionSpectrum);
        return t.reshape([activeRegionSpectrum.length, 1]);
    });

    let t = tf.stack(tensors);
    let prediction_1 = model_1.predict(t);
    let prediction_2 = model_2.predict(t);
    const values_1 = await prediction_1.data();
    const values_2 = await prediction_2.data();

    let results = [];
    for (let i = 0; i < spectra.length; i++) {
        let start = i * values_1.length / spectra.length;
        let end = (i + 1) * values_1.length / spectra.length;
        let spectrumValues_1 = Array.from(values_1.slice(start, end));
        let spectrumValues_2 = Array.from(values_2.slice(start, end));
        let values = spectrumValues_1.map((e, j) => ((e + spectrumValues_2[j]) / 2));
        results.push([values, spectrumValues_1, spectrumValues_2]);
    }
    return results;
}

export const getFentanylPrediction = async (spectra) => {
    let tensors = spectra.map(spectrum => {
        let activeRegionSpectrum = spectrum.slice(0, 1576);
        let t = tf.tensor(activeRegionSpectrum);
        return t.reshape([activeRegionSpectrum.length, 1]);
    });

    let t = tf.stack(tensors);
    let prediction_1 = fentanyl_model.predict(t);
    const values_1 = await prediction_1.data();

    let results = [];
    for (let i = 0; i < spectra.length; i++) {
        let start = i * values_1.length / spectra.length;
        let end = (i + 1) * values_1.length / spectra.length;
        let spectrumValues_1 = Array.from(values_1.slice(start, end));
        results.push([spectrumValues_1]);
    }
    return results;
}

export const getXylazinePrediction = async (spectra) => {
    let tensors = spectra.map(spectrum => {
        let activeRegionSpectrum = spectrum.slice(0, 1576);
        let t = tf.tensor(activeRegionSpectrum);
        return t.reshape([activeRegionSpectrum.length, 1]);
    });

    let t = tf.stack(tensors);
    let prediction_1 = xylazine_model.predict(t);
    const values_1 = await prediction_1.data();

    let results = [];
    for (let i = 0; i < spectra.length; i++) {
        let start = i * values_1.length / spectra.length;
        let end = (i + 1) * values_1.length / spectra.length;
        let spectrumValues_1 = Array.from(values_1.slice(start, end));
        results.push([spectrumValues_1]);
    }
    return results;
}

export const getBenzodiazepinePrediction = async (spectra) => {
    let tensors = spectra.map(spectrum => {
        let activeRegionSpectrum = spectrum.slice(0, 1576);
        let t = tf.tensor(activeRegionSpectrum);
        return t.reshape([activeRegionSpectrum.length, 1]);
    });

    let t = tf.stack(tensors);
    let prediction_1 = benzodiazepine_model.predict(t);
    const values_1 = await prediction_1.data();

    let results = [];
    for (let i = 0; i < spectra.length; i++) {
        let start = i * values_1.length / spectra.length;
        let end = (i + 1) * values_1.length / spectra.length;
        let spectrumValues_1 = Array.from(values_1.slice(start, end));
        results.push([spectrumValues_1]);
    }
    return results;
}

export const getHydromorphonePrediction = async (spectra) => {
    let tensors = spectra.map(spectrum => {
        let activeRegionSpectrum = spectrum.slice(0, 1576);
        let t = tf.tensor(activeRegionSpectrum);
        return t.reshape([activeRegionSpectrum.length]);
    });

    let t = tf.stack(tensors);
    let prediction_1 = hydromorphone_ml.predict(t);
    const values_1 = await prediction_1.data();

    let results = [];
    for (let i = 0; i < spectra.length; i++) {
        let start = i * values_1.length / spectra.length;
        let end = (i + 1) * values_1.length / spectra.length;
        let spectrumValues_1 = Array.from(values_1.slice(start, end));
        results.push([spectrumValues_1]);
    }
    return results;
}
