import React, { useEffect, useMemo, useState } from 'react';
import * as THREE from 'three';
import { LineMaterial } from 'three/examples/jsm/lines/LineMaterial';
import { Line2 } from 'three/examples/jsm/lines/Line2';
import {
  createLine2,
  getXYZ,
  updateLine2Position,
} from '@/routes/dashboard/projects/project/project-canvas.helpers';
import { ThreeEvent } from '@react-three/fiber';
import { useUnmount } from 'react-use';

const generateCylinderMesh = (
  pointX: THREE.Vector3,
  pointY: THREE.Vector3,
  width: number
) => {
  // edge from X to Y
  const direction = new THREE.Vector3().subVectors(pointY, pointX);
  const material = new THREE.MeshBasicMaterial({});
  // Make the geometry (of "direction" length)
  const geometry = new THREE.CylinderGeometry(
    width,
    width,
    direction.length(),
    6,
    4,
    true
  );
  // shift it so one end rests on the origin
  geometry.applyMatrix4(
    new THREE.Matrix4().makeTranslation(0, direction.length() / 2, 0)
  );
  // // rotate it the right way for lookAt to work
  geometry.applyMatrix4(
    new THREE.Matrix4().makeRotationX(THREE.MathUtils.degToRad(90))
  );
  // Make a mesh with the geometry
  const mesh = new THREE.Mesh(geometry, material);
  // Position it where we want
  mesh.position.copy(pointX);
  // And make it point to where we want
  mesh.lookAt(pointY);
  mesh.updateMatrixWorld(true);
  return mesh;
};

interface FatLineProps {
  startPoint: THREE.Vector3;
  endPoint: THREE.Vector3;
  lineMaterial: LineMaterial;
  hitWidth?: number;
  onPointerEnter?: (e: ThreeEvent<PointerEvent>, lineInstance: Line2) => void;
  onPointerLeave?: (e: ThreeEvent<PointerEvent>, lineInstance: Line2) => void;
  onPointerOver?: (e: ThreeEvent<PointerEvent>, lineInstance: Line2) => void;
  onPointerDown?: (e: ThreeEvent<PointerEvent>, lineInstance: Line2) => void;
  onPointerMove?: (e: ThreeEvent<PointerEvent>, lineInstance: Line2) => void;
}

const FatLine: React.FC<FatLineProps> = ({
  lineMaterial,
  startPoint,
  endPoint,
  hitWidth = 0.003,
  onPointerEnter,
  onPointerDown,
  onPointerOver,
  onPointerLeave,
  onPointerMove,
}) => {
  const [cylinderMesh, setCylinderMesh] = useState<THREE.Mesh>();
  const [line, setLine] = useState<Line2>();

  const line3 = useMemo(() => {
    return new THREE.Line3(startPoint, endPoint);
  }, [startPoint, endPoint]);

  const calculatePointOnLine = (point: THREE.Vector3) => {
    const result = new THREE.Vector3();
    line3.closestPointToPoint(point, true, result);
    return result;
  };

  useEffect(() => {
    if (!cylinderMesh) {
      setCylinderMesh(generateCylinderMesh(startPoint, endPoint, hitWidth));
    } else {
      cylinderMesh.geometry.dispose();
      setCylinderMesh(generateCylinderMesh(startPoint, endPoint, hitWidth));
    }
  }, [startPoint, endPoint, hitWidth]);

  useEffect(() => {
    const points = [getXYZ(startPoint), getXYZ(endPoint)].flat();
    if (!line) {
      setLine(createLine2(points, lineMaterial));
    } else {
      updateLine2Position(line, points);
    }
  }, [startPoint, endPoint, hitWidth]);

  useEffect(() => {
    if (line) {
      line.material = lineMaterial;
    }
  }, [lineMaterial]);

  useUnmount(() => {
    cylinderMesh?.geometry.dispose();
    line?.geometry.dispose();
  });

  if (!line || !cylinderMesh) return null;

  const handlePointerEnter = (e: ThreeEvent<PointerEvent>) =>
    onPointerEnter &&
    onPointerEnter({ ...e, pointOnLine: calculatePointOnLine(e.point) }, line);

  const handlePointerLeave = (e: ThreeEvent<PointerEvent>) =>
    onPointerLeave &&
    onPointerLeave({ ...e, pointOnLine: calculatePointOnLine(e.point) }, line);

  const handlePointerOver = (e: ThreeEvent<PointerEvent>) =>
    onPointerOver &&
    onPointerOver({ ...e, pointOnLine: calculatePointOnLine(e.point) }, line);

  const handlePointerDown = (e: ThreeEvent<PointerEvent>) =>
    onPointerDown &&
    onPointerDown({ ...e, pointOnLine: calculatePointOnLine(e.point) }, line);

  const handlePointerMove = (e: ThreeEvent<PointerEvent>) =>
    onPointerMove &&
    onPointerMove({ ...e, pointOnLine: calculatePointOnLine(e.point) }, line);

  return (
    <group>
      <primitive
        object={cylinderMesh}
        visible={false}
        onPointerEnter={handlePointerEnter}
        onPointerLeave={handlePointerLeave}
        onPointerOver={handlePointerOver}
        onPointerDown={handlePointerDown}
        onPointerMove={handlePointerMove}
      />
      <primitive object={line} />
    </group>
  );
};

export default FatLine;
