/**
 * @fileoverview This library provides a 3D drawing utility on a NxNxN grid in a
 * 1x1x1 space.
 */

import {NormalizedLandmark, NormalizedLandmarkList} from 'google3/third_party/mediapipe/web/solutions/utils/transform_utils/landmark';
import * as THREE from 'three';

import * as util from './util';

/**
 * A connection between two landmarks
 */
type Connection = number[];

/**
 * A list of connections between landmarks
 */
type ConnectionList = Connection[];

/**
 * An interface for specifying colors for lists (e.g. landmarks and connections)
 */
type ColorMap<T> = Array<{color: ColorName | undefined; list: T[]}>;

/**
 * Name for a color
 */
type ColorName = Exclude<string, ''>;

/**
 * An interface for containing number labels and data about them.
 */
interface NumberLabel {
  element: HTMLSpanElement;
  position: THREE.Vector3;
  value: number;
}

/**
 * Configuration for the landmark grid.
 */
export interface LandmarkGridConfig extends util.ViewerWidgetConfig {
  axesColor?: number;
  axesWidth?: number;
  /**
   * The "centered" attribute describes whether the grid should use the center
   * of the bounding box of the landmarks as the origin.
   */
  centered?: boolean;
  connectionColor?: number;
  connectionWidth?: number;
  definedColors?: Array<{name: string; value: number;}>;
  /**
   * The "fitToGrid" attribute describes whether the grid should dynamically
   * resize based on the landmarks given.
   */
  fitToGrid?: boolean;
  labelPrefix?: string;
  labelSuffix?: string;
  landmarkColor?: number;
  landmarkSize?: number;
  margin?: number;
  minVisibility?: number;
  nonvisibleLandmarkColor?: number;
  numCellsPerAxis?: number;
  /**
   * The "range" attribue describes the default numerical boundaries of the
   * grid. The grid ranges from [-range, range] on every axis.
   */
  range?: number;
  showHidden?: boolean;
}


const DEFAULT_CONFIG: util.BaseRequired<LandmarkGridConfig> = {
  axesColor: 0xffffff,
  axesWidth: 2,
  centered: false,
  connectionColor: 0x00ffff,
  connectionWidth: 3,
  definedColors: [],
  fitToGrid: false,
  labelPrefix: '',
  labelSuffix: '',
  landmarkSize: 3,
  landmarkColor: 0xaaaaaa,
  margin: 0,
  minVisibility: .65,
  nonvisibleLandmarkColor: 0xff7777,
  numCellsPerAxis: 3,
  range: 1,
  rotationSpeed: .05,
  showHidden: true,
};

const HIDDEN_MATERIAL = new THREE.Material();
HIDDEN_MATERIAL.visible = false;

/**
 * This class makes a canvas instance where points can be drawn in a NxNxN grid
 * in a 1x1x1 space.
 */
export class LandmarkGrid extends util.ViewerWidget {
  private readonly size: number = 100;
  private readonly labels!:
      {x: NumberLabel[]; y: NumberLabel[]; z: NumberLabel[];};
  private readonly landmarkGroup: THREE.Group;
  private readonly connectionGroup: THREE.Group;
  private readonly origin: THREE.Vector3;
  protected override config!: Required<LandmarkGridConfig>;
  private axesMaterial!: THREE.Material;
  private connectionMaterial!: THREE.Material;
  private definedColors!: {[key: string]: THREE.Material;};
  private gridMaterial!: THREE.Material;
  private isVisible!: (e: NormalizedLandmark) => boolean;
  private landmarkGeometry!: THREE.BufferGeometry;
  private landmarkMaterial!: THREE.Material;
  private nonvisibleMaterial!: THREE.Material;
  private landmarks: NormalizedLandmarkList = [];
  private sizeWhenFitted!: number;

  constructor(parent: HTMLElement, config: LandmarkGridConfig = {}) {
    super(parent, {...DEFAULT_CONFIG, ...config});
    this.setConfig();

    this.drawAxes();
    this.labels = this.createAxesLabels();
    this.landmarkGroup = new THREE.Group();
    this.scene.add(this.landmarkGroup);
    this.connectionGroup = new THREE.Group();
    this.scene.add(this.connectionGroup);
    this.origin = new THREE.Vector3();

    this.requestFrame();
  }

  private createAxesLabels() {
    const labels = {
      x: [] as NumberLabel[],
      y: [] as NumberLabel[],
      z: [] as NumberLabel[],
    };
    const cellsPerAxis = this.config.numCellsPerAxis;
    const range = this.config.range;

    const HALF_SIZE = this.size / 2;
    for (let i = 0; i < cellsPerAxis; i++) {
      // X labels
      // This for vector adds one to the count as it covers numCellsPerAxis-1
      // points on the x-axis. The point not covered is where the y-axis meets
      // the x-axis.
      const xValue = ((i + 1) / cellsPerAxis - .5) * range;
      labels.x.push({
        position: new THREE.Vector3(
            (i + 1) / cellsPerAxis * this.size - HALF_SIZE, -HALF_SIZE,
            HALF_SIZE),
        element: this.createLabel(xValue),
        value: xValue
      });
      // Z labels
      // This vector covers numCellsPerAxis-1 points on the z-axis. The point
      // not covered is where the z-axis meets the x-axis.
      const zValue = (i / cellsPerAxis - .5) * range;
      labels.z.push({
        position: new THREE.Vector3(
            HALF_SIZE, -HALF_SIZE, i / cellsPerAxis * this.size - HALF_SIZE),
        element: this.createLabel(zValue),
        value: zValue
      });
    }
    // Y labels
    // This for loop covers all points on the y-axis
    for (let i = 0; i <= cellsPerAxis; i++) {
      const yValue = (i / cellsPerAxis - .5) * range;
      labels.y.push({
        position: new THREE.Vector3(
            -HALF_SIZE, i / cellsPerAxis * this.size - HALF_SIZE, HALF_SIZE),
        element: this.createLabel(yValue),
        value: yValue,
      });
    }

    return labels;
  }

  private createLabel(value: number) {
    const el = document.createElement('span');
    el.classList.add('landmark-label-js');
    this.setLabel(el, value);
    this.container.appendChild(el);
    return el;
  }

  private setLabel(el: HTMLSpanElement, value: number) {
    el.textContent = this.config.labelPrefix + value.toPrecision(2).toString() +
        this.config.labelSuffix;
  }

  private drawAxes() {
    const axes = new THREE.Group();
    const HALF_SIZE = this.size / 2;

    const grid = this.makeGrid(this.size, this.config.numCellsPerAxis);
    const xGrid = grid;
    const yGrid = grid.clone();
    const zGrid = grid.clone();

    xGrid.translateX(-HALF_SIZE);
    xGrid.rotateY(Math.PI / 2);
    axes.add(xGrid);

    yGrid.translateY(-HALF_SIZE);
    yGrid.rotateX(Math.PI / 2);
    axes.add(yGrid);

    zGrid.translateZ(-HALF_SIZE);
    axes.add(zGrid);


    const border = new THREE.BufferGeometry().setFromPoints([
      new THREE.Vector3(-HALF_SIZE, HALF_SIZE, HALF_SIZE),
      new THREE.Vector3(-HALF_SIZE, -HALF_SIZE, HALF_SIZE),
      new THREE.Vector3(HALF_SIZE, -HALF_SIZE, HALF_SIZE),
      new THREE.Vector3(HALF_SIZE, -HALF_SIZE, -HALF_SIZE),
      new THREE.Vector3(HALF_SIZE, HALF_SIZE, -HALF_SIZE),
      new THREE.Vector3(-HALF_SIZE, HALF_SIZE, -HALF_SIZE),
      new THREE.Vector3(-HALF_SIZE, HALF_SIZE, HALF_SIZE)
    ]);
    axes.add(new THREE.Line(border, this.axesMaterial));


    this.scene.add(axes);
  }

  protected override render() {
    super.render();
    this.setLabels();
  }

  private colorLandmarks(landmarks?: number[], colorName?: string) {
    const color =
        colorName ? this.definedColors[colorName] : this.connectionMaterial;
    const meshList = this.landmarkGroup.children as THREE.Mesh[];

    if (landmarks) {
      for (const landmarkIndex of landmarks) {
        if (!this.isVisible(this.landmarks[landmarkIndex])) continue;
        meshList[landmarkIndex].material = color;
      }
    } else {
      for (let i = 0; i < this.landmarks.length; i++) {
        if (!this.isVisible(this.landmarks[i])) continue;
        meshList[i].material = color;
      }
    }
  }

  updateLandmarks(
      landmarks: NormalizedLandmarkList,
      colorConnections: ConnectionList|ColorMap<Connection> = [],
      colorLandmarks?: ColorMap<number>) {
    this.connectionGroup.clear();
    this.clearResources();

    this.landmarks = landmarks.map(util.copyLandmark);

    // Convert connections to ColorList if not already
    let connections: ColorMap<Connection> = [];
    if (colorConnections.length > 0 &&
        !colorConnections[0].hasOwnProperty('color')) {
      connections =
          [{color: undefined, list: colorConnections as ConnectionList}];
    } else {
      connections = colorConnections as ColorMap<Connection>;
    }

    const visibleLandmarks = this.landmarks.filter((e) => this.isVisible(e));
    const centeredLandmarks =
        visibleLandmarks.length === 0 ? this.landmarks : visibleLandmarks;
    if (this.config.centered) {
      this.centerLandmarks(centeredLandmarks);
    }

    // Fit to grid if necessary
    let scalingFactor = 1;
    if (this.config.fitToGrid) {
      const rawScalingFactor = this.getFitToGridFactor(centeredLandmarks);
      const RESCALE = .5;
      const range = this.config.range;
      // Finds the deviation from the default range ((1 / rawScalingFactor - 1)
      // * (range / 2)), and then it divides it by the step size of RESCALE. We
      // go the next step with Math.ceil. This calculation allows for discrete
      // steps.
      const numRescaleSteps =
          Math.ceil((1 / rawScalingFactor - 1) * (range / 2) / RESCALE);
      // Scaling factor takes the number of these steps and converts it back
      // into a factor that the landmark can be multiplied by.
      scalingFactor = 1 / (numRescaleSteps * RESCALE / (range / 2) + 1);
      for (const landmark of this.landmarks) {
        landmark.x *= scalingFactor;
        landmark.y *= scalingFactor;
        landmark.z *= scalingFactor;
      }
    }

    for (const label of this.labels.x) {
      this.setLabel(
          label.element, (label.value - this.origin.x) / scalingFactor);
    }
    for (const label of this.labels.y) {
      this.setLabel(
          label.element, (label.value - this.origin.y) / scalingFactor);
    }
    for (const label of this.labels.z) {
      this.setLabel(
          label.element, (label.value - this.origin.z) / scalingFactor);
    }

    const landmarkVectors: THREE.Vector3[] =
        this.landmarks.map(e => this.landmarkToVector(e));

    // Connections
    for (const connection of connections) {
      this.drawConnections(landmarkVectors, connection.list, connection.color);
    }

    // Shrink/Grow landmarks to fit
    const meshLength = this.landmarkGroup.children.length,
          landmarkLength = this.landmarks.length;
    if (meshLength < landmarkLength) {
      for (let i = meshLength; i < landmarkLength; i++) {
        this.landmarkGroup.add(new THREE.Mesh(this.landmarkGeometry));
      }
    } else if (meshLength > landmarkLength) {
      for (let i = landmarkLength; i < meshLength; i++) {
        this.landmarkGroup.remove(this.landmarkGroup.children[i]);
      }
    }

    // Landmarks
    for (let i = 0; i < this.landmarks.length; i++) {
      const visible = this.isVisible(this.landmarks[i]);
      let nonvisibleMaterial = this.nonvisibleMaterial;
      if (!this.config.showHidden && !visible) {
        nonvisibleMaterial = HIDDEN_MATERIAL;
      }

      const sphere = this.landmarkGroup.children[i] as THREE.Mesh;
      sphere.material = visible ? this.landmarkMaterial : nonvisibleMaterial;
      sphere.position.copy(landmarkVectors[i]);
    }

    // Color special landmarks
    if (colorLandmarks) {
      for (const colorDef of colorLandmarks) {
        this.colorLandmarks(colorDef.list, colorDef.color);
      }
    }

    this.requestFrame();
  }

  private drawConnections(
      landmarks: THREE.Vector3[], connections: ConnectionList,
      colorName?: ColorName) {
    const color =
        colorName ? this.definedColors[colorName] : this.connectionMaterial;

    const lines = [];
    for (const connection of connections) {
      if (!this.config.showHidden &&
          (!this.isVisible(this.landmarks[connection[0]]) ||
           !this.isVisible(this.landmarks[connection[1]]))) {
        continue;
      }

      lines.push(landmarks[connection[0]]);
      lines.push(landmarks[connection[1]]);
    }
    const geometry = new THREE.BufferGeometry().setFromPoints(lines);
    this.disposeQueue.push(geometry);
    const wireframe = new THREE.LineSegments(geometry, color);
    this.removeQueue.push(wireframe);
    this.connectionGroup.add(wireframe);
  }

  private landmarkToVector(point: NormalizedLandmark): THREE.Vector3 {
    return util.landmarkToVector(point).multiplyScalar(
        this.size / this.config.range);
  }

  private makeGrid(size: number, numSteps: number) {
    const grid = new THREE.Group();

    const plane = new THREE.PlaneGeometry(size, size);
    const edges = new THREE.EdgesGeometry(plane);
    const wireframe = new THREE.LineSegments(edges, this.gridMaterial);
    grid.add(wireframe);

    const stepPlaneSize = size / numSteps;
    const stepPlane = new THREE.PlaneGeometry(stepPlaneSize, stepPlaneSize);
    const stepEdges = new THREE.EdgesGeometry(stepPlane);
    const corner = -size / 2 + stepPlaneSize / 2;
    for (let i = 0; i < numSteps; i++) {
      for (let j = 0; j < numSteps; j++) {
        const stepFrame = new THREE.LineSegments(stepEdges, this.gridMaterial);
        stepFrame.translateX(corner + i * stepPlaneSize);
        stepFrame.translateY(corner + j * stepPlaneSize);
        grid.add(stepFrame);
      }
    }

    return grid;
  }

  private setConfig() {
    this.landmarkMaterial =
        new THREE.MeshBasicMaterial({color: this.config.landmarkColor});
    this.landmarkGeometry = new THREE.SphereGeometry(this.config.landmarkSize);
    this.nonvisibleMaterial = new THREE.MeshBasicMaterial(
        {color: this.config.nonvisibleLandmarkColor});
    this.axesMaterial = new THREE.LineBasicMaterial(
        {color: this.config.axesColor, linewidth: this.config.axesWidth});
    this.gridMaterial = new THREE.LineBasicMaterial({color: 0x999999});
    this.connectionMaterial = new THREE.LineBasicMaterial({
      color: this.config.connectionColor,
      linewidth: this.config.connectionWidth
    });
    this.isVisible = (e: NormalizedLandmark) => e.visibility === undefined ||
        (!!e.visibility && (e.visibility > this.config.minVisibility));
    this.definedColors = {};
    for (const color of this.config.definedColors) {
      this.definedColors[color.name] = new THREE.LineBasicMaterial(
          {color: color.value, linewidth: this.config.connectionWidth});
    }
    this.sizeWhenFitted = (1 - 2 * this.config.margin);
  }

  private getFitToGridFactor(landmarks?: NormalizedLandmarkList) {
    if (!landmarks) {
      landmarks = this.landmarks;
    }
    if (landmarks.length === 0) {
      return 1;
    }

    let factor = Infinity;
    for (let i = 0; i < landmarks.length; i++) {
      const maxNum = Math.max(
          Math.abs(landmarks[i].x), Math.abs(landmarks[i].y),
          Math.abs(landmarks[i].z));
      factor = Math.min(factor, (this.config.range / 2) / maxNum);
    }
    return factor * this.sizeWhenFitted;
  }

  private setLabels() {
    for (const pair of this.labels.x) {
      const position = this.getCanvasPosition(pair.position);
      pair.element.style.transform =
          `translate(${position.x}px, ${position.y}px)`;
    }
    for (const pair of this.labels.y) {
      const position = this.getCanvasPosition(pair.position);
      pair.element.style.transform =
          `translate(${position.x}px, ${position.y}px)`;
    }
    for (const pair of this.labels.z) {
      const position = this.getCanvasPosition(pair.position);
      pair.element.style.transform =
          `translate(${position.x}px, ${position.y}px)`;
    }
  }

  private centerLandmarks(landmarks: NormalizedLandmarkList) {
    if (landmarks.length === 0) {
      return;
    }

    let maxX = landmarks[0].x, minX = landmarks[0].x, maxY = landmarks[0].y,
        minY = landmarks[0].y, maxZ = landmarks[0].z, minZ = landmarks[0].z;
    for (let i = 1; i < landmarks.length; i++) {
      const landmark = landmarks[i];
      maxX = Math.max(maxX, landmark.x);
      maxY = Math.max(maxY, landmark.y);
      maxZ = Math.max(maxZ, landmark.z);
      minX = Math.min(minX, landmark.x);
      minY = Math.min(minY, landmark.y);
      minZ = Math.min(minZ, landmark.z);
    }
    const centerX = (maxX + minX) / 2;
    const centerY = (maxY + minY) / 2;
    const centerZ = (maxZ + minZ) / 2;
    for (let i = 0; i < this.landmarks.length; i++) {
      this.landmarks[i].x -= centerX;
      this.landmarks[i].y -= centerY;
      this.landmarks[i].z -= centerZ;
    }

    this.origin.set(centerX, centerY, centerZ);
  }
}

goog.exportSymbol('LandmarkGrid', LandmarkGrid);
