import * as tf from "@tensorflow/tfjs";

/**
 * ConvergenceSpeedTracker class to track the convergence speed of the reward history.
 * @example
 * const tracker = new ConvergenceSpeedTracker(100, 0.01);
 * tracker.addReward(0.5);
 * tracker.addReward(0.6);
 * console.log(tracker.getConvergenceSpeed()); // "Not converged yet"
 * tracker.addReward(0.7);
 */
export default class ConvergenceSpeedTracker {
    /**
     * ConvergenceSpeedTracker class to track the convergence speed of the reward history.
     * @param {number} windowSize 
     * @param {number} convergenceThreshold 
     */
    constructor(windowSize = 100, convergenceThreshold = 0.01) {
        this.windowSize = windowSize;
        this.convergenceThreshold = convergenceThreshold;
        this.rewardHistory = [];
        this.convergenceEpoch = null;
    }

    /**
     * The addReward method adds a new reward to the reward history.
     * @param {number} reward 
     */
    addReward(reward) {
        this.rewardHistory.push(reward);
        if (this.rewardHistory.length > this.windowSize) {
            this.rewardHistory.shift();
        }
        this.checkConvergence();
    }

    /**
     * The checkConvergence method checks if the reward history has converged.
     */
    checkConvergence() {
        if (this.rewardHistory.length < this.windowSize) return;

        const rewardTensor = tf.tensor1d(this.rewardHistory);
        const mean = rewardTensor.mean();
        const stdDev = tf.moments(rewardTensor).variance.sqrt();

        const coefficientOfVariation = stdDev.div(mean);

        if (coefficientOfVariation.dataSync()[0] < this.convergenceThreshold && this.convergenceEpoch === null) {
            this.convergenceEpoch = this.rewardHistory.length;
        }

        rewardTensor.dispose();
        mean.dispose();
        stdDev.dispose();
        coefficientOfVariation.dispose();
    }

    /**
     * The getConvergenceSpeed method returns the convergence epoch if the reward history has converged, otherwise it returns "Not converged yet".
     * @returns {number|string}
     */
    getConvergenceSpeed() {
        return this.convergenceEpoch !== null ? this.convergenceEpoch : "Not converged yet";
    }
}