import * as tf from '@tensorflow/tfjs';

export default class MultiRobotPPO {
  constructor(env, params, sharedBuffer) {
    this.env = env;
    this.params = params;
    this.isInitialized = false;
    this.sharedBuffer = sharedBuffer;
    this.initializationAttempts = 0;
    this.maxInitializationAttempts = 3;
    this.rewards = [];
    this.actions = [];
    this.actor = null;
    this.critic = null;
    this.averageRewards = 0;
    this.performanceScore = 0;
    this.replayMemory = []; // Initialize replay memory
    this.memorySize = 1000; // Set a maximum size for the replay memory
    this.epsilon = 0.1; // Exploration rate
  }

  // Add the remember function
  remember(state, action, reward, nextState, done) {
    if (this.replayMemory.length > this.memorySize) {
      const removed = this.replayMemory.shift();
      // Dispose of the tensors that are no longer needed
      removed.state.dispose();
      removed.nextState.dispose();
    }
    this.replayMemory.push({
      state,
      action,
      reward,
      nextState,
      done
    });
  }

  async initialize() {
    if (this.isInitialized) return;

    try {
      await tf.ready();
      await this.initializeBackend();
      console.log('TensorFlow.js initialized with backend:', tf.getBackend());

      const inputShape = this.params.inputShape || 33;
      const outputShape = this.env.actionSpace.n;

      this.actor = await this.createNetwork(inputShape, outputShape, 'actor');
      this.critic = await this.createNetwork(inputShape, 1, 'critic');
      console.log('Actor model created:', this.actor);
      console.log('Critic model created:', this.critic);

      await this.testGetAction();

      this.optimizer = tf.train.adam(this.params.learningRate);

      this.isInitialized = true;
      console.log('MultiRobotPPO initialized for robot:', this.env.robotId);
    } catch (error) {
      console.error('Failed to initialize MultiRobotPPO:', error);
      throw error;
    }
  }

  async testGetAction() {
    const inputShape = 33; // Assuming the input shape is 33
    const testState = Array.from({
      length: inputShape
    }, () => Math.random() * 2 - 1);
    const testStateTensor = tf.tensor2d([testState], [1, inputShape]);

    console.log('Test state:', testState);
    console.log('Test state tensor:', testStateTensor.toString());
    console.log('Test state tensor data:', testStateTensor.dataSync());

    try {
      const actionProbs = this.actor.predict(testStateTensor);
      console.log('Test action probabilities:', actionProbs.arraySync());
    } catch (error) {
      console.error('Error during test action prediction:', error);
    }

    testStateTensor.dispose();
    //await this.saveModels();
    console.warn('Test action prediction completed and models saved');
  }

  async saveModels() {
    const saveResultActor = await this.actor.save(`indexeddb://actor-model`);
    console.log(`actor model saved:`, saveResultActor);

    const saveResultCritic = await this.critic.save(`indexeddb://critic-model`);
    console.log(`critic model saved:`, saveResultCritic);
  }

  async initializeBackend() {
    const backendPriority = ['webgpu', 'webgl', 'cpu'];

    for (const backend of backendPriority) {
      try {
        await tf.setBackend(backend);
        await tf.ready();
        console.log(`${backend.toUpperCase()} backend initialized successfully`);
        return;
      } catch (error) {
        console.warn(`Failed to initialize ${backend.toUpperCase()} backend:`, error);
      }
    }

    throw new Error('Failed to initialize any supported backend');
  }

  async createNetwork(inputShape, outputShape, name) {
    const inputDim = Array.isArray(inputShape) ? inputShape[0] : inputShape;

    const model = tf.sequential({
      name
    });

    model.add(tf.layers.dense({
      units: this.params.netArch.pi[0],
      activation: 'relu',
      kernelInitializer: 'glorotNormal',
      batchInputShape: [null, inputDim] // This allows for variable batch size
    }));

    for (let i = 1; i < this.params.netArch.pi.length; i++) {
      model.add(tf.layers.dense({
        units: this.params.netArch.pi[i],
        activation: 'relu',
        kernelInitializer: 'glorotNormal'
      }));
    }

    model.add(tf.layers.dense({
      units: outputShape,
      activation: name === 'actor' ? 'softmax' : 'linear',
      kernelInitializer: 'glorotNormal'
    }));

    return model;
  }

  sanitizeValue(value, name) {
    if (isNaN(value) || !isFinite(value)) {
      console.warn(`${name} is ${isNaN(value) ? 'NaN' : 'Infinite'}, setting to 0`);
      return 0;
    }
    return value;
  }

  sanitizeState(state) {
    return state.map(value => {
      if (typeof value !== 'number' || isNaN(value) || !isFinite(value)) {
        console.warn(`Invalid state value: ${value}, replacing with 0`);
        return 0;
      }
      return value;
    });
  }

  async getAction(robotId, state) {
    console.log('Getting action for ', robotId);
    console.log('isInitialized', this.isInitialized);

    if (!this.isInitialized) {
      if (this.initializationAttempts < this.maxInitializationAttempts) {
        console.log(`Attempting to initialize (Attempt ${this.initializationAttempts + 1})`);
        await this.initialize();
        this.initializationAttempts++;
      } else {
        throw new Error('Max initialization attempts reached');
      }
    }

    // Check if actor is initialized before predicting
    if (!this.actor) {
      console.error('Actor model is not initialized');
      throw new Error('Actor model is not initialized');
    }

    try {
      const stateTensor = tf.tensor2d([this.sanitizeState(state)], [1, state.length]);
      console.log('state', state);
      console.log('stateTensor', stateTensor);
      console.log(`Getting action for robot ${robotId}`);
      console.log('Input tensor shape:', stateTensor.shape);

      let action;
      const actionProbs = this.actor.predict(stateTensor);
      console.log(`Action probabilities for robot ${robotId}:`, actionProbs.arraySync());

      // Epsilon-greedy action selection
      if (Math.random() < this.epsilon) {
        // Explore: select a random action
        action = Math.floor(Math.random() * this.env.actionSpace.n);
        console.log(`Exploring: selected random action ${action}`);
      } else {
        // Exploit: select the action with the highest probability
        action = tf.argMax(actionProbs, 1).dataSync()[0];
        console.log(`Exploiting: selected action for robot ${robotId}:`, action);
      }

      // Dispose tensors to free up memory
      stateTensor.dispose();
      actionProbs.dispose();

      return action;
    } catch (error) {
      console.error(`Error in getAction for robot ${robotId}:`, error);
      console.error('Error stack:', error.stack);
      console.error('Current TensorFlow backend:', tf.getBackend());
      if (this.actor) {
        console.error('Actor model summary:');
        this.actor.summary();
      } else {
        console.error('Actor model is not initialized');
      }
      throw error;
    }
  }

  async update(state, action, reward, nextState, done) {
    const sanitizedReward = this.sanitizeValue(reward, 'reward');
    this.rewards.push(reward);
    this.actions.push(action);
    this.averageRewards = this.rewards.reduce((sum, reward) => sum + reward, 0) / this.env.episode;
    this.performanceScore += (reward * this.env.calculateStepPerformance(this.env.trainingAim) + 1) / 100 / this.env.episode;
  
    // Convert state and nextState to tensors if they are arrays
    const stateTensor = Array.isArray(state) ? tf.tensor2d([state], [1, state.length]) : state;
    const nextStateTensor = Array.isArray(nextState) ? tf.tensor2d([nextState], [1, nextState.length]) : nextState;
  
    // Add experience to the shared buffer
    this.sharedBuffer.add({
      robotId: this.env.robotId,
      state: stateTensor,
      action,
      reward: sanitizedReward,
      nextState: nextStateTensor,
      done
    });
  
    // Perform learning update if buffer has enough samples
    if (this.sharedBuffer.length >= this.params.batchSize) {
      await this.learn(this.params.batchSize);
      
      // Move tidy operation outside of async context
      const actionProbs = tf.tidy(() => {
        console.log('learn');
        const inputTensor = tf.tensor2d([state], [1, state.length]);
        console.log('Input tensor shape:', inputTensor.shape);
        return this.actor.predict(inputTensor);
      });
      
      console.log('actionProbs', actionProbs);
      actionProbs.dispose(); // Don't forget to dispose the tensor after use
    }
  
    // Store experience in replay memory
    this.remember(stateTensor, action, sanitizedReward, nextStateTensor, done);
  
    // Dispose tensors after they are no longer needed
    tf.dispose([stateTensor, nextStateTensor]);
  }
  
  async learn(batchSize) {
    if (!this.isInitialized) {
      await this.initialize();
    }

    const batch = this.sharedBuffer.sample(batchSize);

    try {
      console.log(`Learning from batch of size ${batch.length} for robot ${this.env.robotId}`);
      const states = tf.tensor2d(batch.map(e => e.state));
      const actions = tf.tensor1d(batch.map(e => e.action), 'int32');
      const rewards = tf.tensor1d(batch.map(e => e.reward));
      const nextStates = tf.tensor2d(batch.map(e => e.nextState));
      const dones = tf.tensor1d(batch.map(e => e.done ? 1 : 0));

      // You could potentially use robotId here if you want to implement robot-specific learning
      // const robotIds = batch.map(e => e.robotId);

      const oldActionProbs = this.actor.predict(states);
      const oldValues = this.critic.predict(states);

      const returns = tf.tidy(() => {
        const nextValues = this.critic.predict(nextStates);
        return rewards.add(tf.scalar(1).sub(dones).mul(this.params.gamma).mul(nextValues));
      });

      const advantages = returns.sub(oldValues);

      // Actor loss
      const actorLoss = tf.tidy(() => {
        const newActionProbs = this.actor.predict(states);
        // Add entropy regularization to encourage exploration
        const entropy = -tf.sum(newActionProbs.mul(tf.log(newActionProbs)), 1);
        const entropyLoss = -this.params.entropyCoef * tf.mean(entropy);
        const ratios = tf.sum(actions.mul(newActionProbs).div(oldActionProbs), 1);
        const surr1 = ratios.mul(advantages);
        const surr2 = tf.clipByValue(ratios, 1 - this.params.clipRatio, 1 + this.params.clipRatio).mul(advantages);
        // return tf.mean(tf.minimum(surr1, surr2)).mul(-1);
        return tf.mean(tf.minimum(surr1, surr2)).mul(-1).add(entropyLoss);
      });

      // Critic loss
      const criticLoss = tf.tidy(() => {
        const newValues = this.critic.predict(states);
        return tf.mean(tf.squaredDifference(returns, newValues));
      });

      // Perform gradient descent
      const totalLoss = actorLoss.add(criticLoss);
      const grads = tf.variableGrads(() => totalLoss);
      this.optimizer.applyGradients(grads.grads);

      // Dispose tensors
      tf.dispose([states, actions, rewards, nextStates, dones, oldActionProbs, oldValues, returns, advantages, actorLoss, criticLoss, totalLoss]);
      console.log(`Learning step completed for robot ${this.env.robotId}`);
    } catch (error) {
      console.error(`Error in learn method for robot ${this.env.robotId}:`, error);
      console.error('Error stack:', error.stack);
      console.error('Current TensorFlow backend:', tf.getBackend());
      console.error('Actor model summary:', this.actor ? this.actor.summary() : 'Actor model not initialized');
      console.error('Critic model summary:', this.critic ? this.critic.summary() : 'Critic model not initialized');
    }
  }

  static async combinedPrediction(multiRobotPPOs) {
    if (multiRobotPPOs.length === 0) {
      throw new Error("No MultiRobotPPO instances provided");
    }

    const env = multiRobotPPOs[0].env;
    const config = multiRobotPPOs[0].config;
    const sharedBuffer = multiRobotPPOs[0].sharedBuffer;

    const combinedMultiRobotPPO = new MultiRobotPPO(env, config, sharedBuffer);

    combinedMultiRobotPPO.actor = multiRobotPPOs[0].actor;
    combinedMultiRobotPPO.critic = multiRobotPPOs[0].critic;

    combinedMultiRobotPPO.predictWithCombinedModels = async function (state) {
      const actorPredictions = await Promise.all(
        multiRobotPPOs.map((ppo) => ppo.actor.predict(tf.tensor2d([state], [1, state.length])))
      );
      const criticPredictions = await Promise.all(
        multiRobotPPOs.map((ppo) => ppo.critic.predict(tf.tensor2d([state], [1, state.length])))
      );

      const combinedActorPrediction = actorPredictions.reduce((sum, prediction) =>
        sum.add(prediction)
      ).div(actorPredictions.length);

      const combinedCriticPrediction = criticPredictions.reduce((sum, prediction) =>
        sum.add(prediction)
      ).div(criticPredictions.length);

      return {
        actor: combinedActorPrediction,
        critic: combinedCriticPrediction,
      };
    };

    return combinedMultiRobotPPO;
  }
}