import * as tf from '@tensorflow/tfjs';
import PPO from 'ppo-tfjs';

export default class MultiRobotPPO extends PPO {
  constructor(env, params) {
    if (env.actionSpace && env.actionSpace.class) {
      console.log('env.actionSpace.class', env.actionSpace.class);
    } else {
      console.error('env.actionSpace or env.actionSpace.class is undefined');
    }
    super(env, params);
    this.numRobots = params.numRobots;
  }

  async collectExperiences() {
    let experiences = [];
    let obs = this.env.reset();

    for (let step = 0; step < this.nSteps; step++) {
      const actions = this.policyModel.predict(tf.tensor(obs));
      const clippedActions = actions.argMax(-1).arraySync();
      const results = await this.env.step(clippedActions);

      results.forEach((result, index) => {
        const [newObs, reward, done] = result;
        experiences.push({
          obs: obs[index],
          action: clippedActions[index],
          reward,
          newObs,
          done
        });
        obs[index] = newObs;

        if (done) {
          obs[index] = this.env.reset()[index];
        }
      });
    }

    return experiences;
  }

  computeAdvantages(rewards, values, nextValues, dones) {
    const advantages = [];
    let advantage = tf.zerosLike(tf.tensor([rewards[0]]));

    for (let i = rewards.length - 1; i >= 0; i--) {
      const mask = 1 - dones[i];
      const delta = rewards[i] + this.gamma * nextValues[i] * mask - values[i];
      advantage = delta + this.gamma * this.lam * advantage * mask;
      advantages.unshift(advantage);
    }

    return advantages;
  }

  async learn() {
    for (let epoch = 0; epoch < this.nEpochs; epoch++) {
      const experiences = await this.collectExperiences();
      const obs = experiences.map(exp => exp.obs);
      const actions = experiences.map(exp => exp.action);
      const rewards = experiences.map(exp => exp.reward);
      const nextObs = experiences.map(exp => exp.newObs);
      const dones = experiences.map(exp => exp.done ? 1 : 0);

      const values = this.valueModel.predict(tf.tensor(obs)).arraySync();
      const nextValues = this.valueModel.predict(tf.tensor(nextObs)).arraySync();

      const advantages = this.computeAdvantages(rewards, values, nextValues, dones);

      const obsTensor = tf.tensor(obs);
      const actionsTensor = tf.tensor(actions, [actions.length, 1], 'int32');
      const advantagesTensor = tf.tensor(advantages);

      await this.policyModel.fit(
        obsTensor, actionsTensor, {
          sampleWeight: advantagesTensor,
          epochs: 1,
          verbose: this.verbose
        }
      );

      const targetValues = advantages.map((adv, idx) => adv + values[idx]);
      const targetValuesTensor = tf.tensor(targetValues);

      await this.valueModel.fit(
        obsTensor, targetValuesTensor, {
          epochs: 1,
          verbose: this.verbose
        }
      );
    }
  }
}