import { Quaternion } from './Quaternion'
import { Vec3 } from './Vec3'
import { Vec4 } from './Vec4'


const det3x3 = (
  a00: number, a10: number, a20: number,
  a01: number, a11: number, a21: number,
  a02: number, a12: number, a22: number,
): number =>
  a00 * (a11 * a22 - a21 * a12) +
  a10 * (a21 * a02 - a01 * a22) +
  a20 * (a01 * a12 - a11 * a02)


export class Mat4 {

  public static Null(): Readonly<Mat4> {
    return NULL
  }

  public static Identity(): Readonly<Mat4> {
    return IDENTITY
  }

  public static Translation(pos: Vec3): Readonly<Mat4> {
    return new Mat4(
      1, 0, 0, 0,
      0, 1, 0, 0,
      0, 0, 1, 0,
      pos.x, pos.y, pos.z, 1,
    )
  }

  public static Rotation(axis: Vec3, angle: number): Readonly<Mat4> {
    const normalizedAxis = axis.normalize()
    const a = angle / 180 * Math.PI
    const c = Math.cos(a)
    const s = Math.sin(a)
    const onec = 1 - c
    const u = normalizedAxis.x
    const v = normalizedAxis.y
    const w = normalizedAxis.z

    return new Mat4(
      u * u + (1 - u * u) * c, u * v * onec - w * s, u * w * onec + v * s, 0,
      u * v * onec + w * s, v * v + (1 - v * v) * c, v * w * onec - u * s, 0,
      u * w * onec - v * s, v * w * onec + u * s, w * w + (1 - w * w) * c, 0,
      0, 0, 0, 1,
    )
  }

  public static Reflection(pos: Vec3, normal: Vec3): Readonly<Mat4> {
    const translateToOrigin = Mat4.Translation(pos)
    const translateBack = Mat4.Translation(Vec3.Null().sub(pos))
    const n = normal.normalize()
    const reflection = new Mat4(
      1 - 2 * n.x * n.x, -2 * n.x * n.y, -2 * n.x * n.z, 0,
      -2 * n.x * n.y, 1 - 2 * n.y * n.y, -2 * n.y * n.z, 0,
      -2 * n.x * n.z, -2 * n.y * n.z, 1 - 2 * n.z * n.z, 0,
      0, 0, 0, 1,
    )

    return translateToOrigin.mul(reflection).mul(translateBack)
  }

  public static Scale(s: number, t: number, u: number): Readonly<Mat4> {
    return new Mat4(
      s, 0, 0, 0,
      0, t, 0, 0,
      0, 0, u, 0,
      0, 0, 0, 1,
    )
  }

  public static Ortho(left: number, right: number, top: number, bottom: number, near: number, far: number): Readonly<Mat4> {
    return new Mat4(
      2 / (right - left),              0,                               0,                   0,
      0,                               2 / (top - bottom),              0,                   0,
      0,                               0,                               2 / (far - near),    0,
      (right + left) / (right - left), (top + bottom) / (top - bottom), near / (far - near), 1,
    )
  }

  public static Frustum(left: number, right: number, top: number, bottom: number, near: number, far: number): Readonly<Mat4> {
    return new Mat4(
      2 * near / (right - left),       0,                               0,                             0,
      0,                               2 * near / (top - bottom),       0,                             0,
      (right + left) / (right - left), (top + bottom) / (top - bottom), far / (far - near),           -1,
      0,                               0,                               -(far * near) / (far - near),  0,
    )
  }

  public static Perspective(fovy: number, ratio: number, near: number, far: number): Readonly<Mat4> {
    const tanHalfFovy = Math.tan(fovy / 720.0 * Math.PI)

    return new Mat4(
      1 / (ratio * tanHalfFovy), 0,               0,                            0,
      0,                         1 / tanHalfFovy, 0,                            0,
      0,                         0,               far / (near - far),          -1,
      0,                         0,               -(far * near) / (far - near), 0,
    )
  }

  public static LookAt(from: Vec3, to: Vec3, up: Vec3): Readonly<Mat4> {
    const f = to.sub(from).normalize()
    const s = f.cross(up).normalize()
    const u = s.cross(f).normalize()

    return new Mat4(
      s.x, u.x, -f.x, 0,
      s.y, u.y, -f.y, 0,
      s.z, u.z, -f.z, 0,
      -s.dot(from), -u.dot(from), f.dot(from), 1,
    )
  }

  public readonly m: [
    number, number, number, number,
    number, number, number, number,
    number, number, number, number,
    number, number, number, number
  ]

  constructor(
    a00: number = 1, a01: number = 0, a02: number = 0, a03: number = 0,
    a10: number = 0, a11: number = 1, a12: number = 0, a13: number = 0,
    a20: number = 0, a21: number = 0, a22: number = 1, a23: number = 0,
    a30: number = 0, a31: number = 0, a32: number = 0, a33: number = 1,
  ) {
    this.m = [
      a00, a01, a02, a03,
      a10, a11, a12, a13,
      a20, a21, a22, a23,
      a30, a31, a32, a33,
    ]
  }

  public transpose(): Readonly<Mat4> {
    return new Mat4(
      this.m[0], this.m[4], this.m[8], this.m[12],
      this.m[1], this.m[5], this.m[9], this.m[13],
      this.m[2], this.m[6], this.m[10], this.m[14],
      this.m[3], this.m[7], this.m[11], this.m[15],
    )
  }

  public inverse(): Readonly<Mat4> {
    const d00 = det3x3(this.m[5], this.m[6], this.m[7], this.m[9], this.m[10], this.m[11], this.m[13], this.m[14], this.m[15])
    const d01 = det3x3(this.m[4], this.m[6], this.m[7], this.m[8], this.m[10], this.m[11], this.m[12], this.m[14], this.m[15])
    const d02 = det3x3(this.m[4], this.m[5], this.m[7], this.m[8], this.m[9], this.m[11], this.m[12], this.m[13], this.m[15])
    const d03 = det3x3(this.m[4], this.m[5], this.m[6], this.m[8], this.m[9], this.m[10], this.m[12], this.m[13], this.m[14])

    const d10 = det3x3(this.m[1], this.m[2], this.m[3], this.m[9], this.m[10], this.m[11], this.m[13], this.m[14], this.m[15])
    const d11 = det3x3(this.m[0], this.m[2], this.m[3], this.m[8], this.m[10], this.m[11], this.m[12], this.m[14], this.m[15])
    const d12 = det3x3(this.m[0], this.m[1], this.m[3], this.m[8], this.m[9], this.m[11], this.m[12], this.m[13], this.m[15])
    const d13 = det3x3(this.m[0], this.m[1], this.m[2], this.m[8], this.m[9], this.m[10], this.m[12], this.m[13], this.m[14])

    const d20 = det3x3(this.m[1], this.m[2], this.m[3], this.m[5], this.m[6], this.m[7], this.m[13], this.m[14], this.m[15])
    const d21 = det3x3(this.m[0], this.m[2], this.m[3], this.m[4], this.m[6], this.m[7], this.m[12], this.m[14], this.m[15])
    const d22 = det3x3(this.m[0], this.m[1], this.m[3], this.m[4], this.m[5], this.m[7], this.m[12], this.m[13], this.m[15])
    const d23 = det3x3(this.m[0], this.m[1], this.m[2], this.m[4], this.m[5], this.m[6], this.m[12], this.m[13], this.m[14])

    const d30 = det3x3(this.m[1], this.m[2], this.m[3], this.m[5], this.m[6], this.m[7], this.m[9], this.m[10], this.m[11])
    const d31 = det3x3(this.m[0], this.m[2], this.m[3], this.m[4], this.m[6], this.m[7], this.m[8], this.m[10], this.m[11])
    const d32 = det3x3(this.m[0], this.m[1], this.m[3], this.m[4], this.m[5], this.m[7], this.m[8], this.m[9], this.m[11])
    const d33 = det3x3(this.m[0], this.m[1], this.m[2], this.m[4], this.m[5], this.m[6], this.m[8], this.m[9], this.m[10])

    const d = 1 / (this.m[0] * d00 - this.m[1] * d01 + this.m[2] * d02 - this.m[3] * d03)

    return new Mat4(
      +d00 * d, -d10 * d, +d20 * d, -d30 * d,
      -d01 * d, +d11 * d, -d21 * d, +d31 * d,
      +d02 * d, -d12 * d, +d22 * d, -d32 * d,
      -d03 * d, +d13 * d, -d23 * d, +d33 * d,
    )
  }

  public mul(other: Mat4): Readonly<Mat4> {
    return new Mat4(
      this.m[0] * other.m[0] + this.m[4] * other.m[1] + this.m[8] * other.m[2] + this.m[12] * other.m[3],
      this.m[1] * other.m[0] + this.m[5] * other.m[1] + this.m[9] * other.m[2] + this.m[13] * other.m[3],
      this.m[2] * other.m[0] + this.m[6] * other.m[1] + this.m[10] * other.m[2] + this.m[14] * other.m[3],
      this.m[3] * other.m[0] + this.m[7] * other.m[1] + this.m[11] * other.m[2] + this.m[15] * other.m[3],

      this.m[0] * other.m[4] + this.m[4] * other.m[5] + this.m[8] * other.m[6] + this.m[12] * other.m[7],
      this.m[1] * other.m[4] + this.m[5] * other.m[5] + this.m[9] * other.m[6] + this.m[13] * other.m[7],
      this.m[2] * other.m[4] + this.m[6] * other.m[5] + this.m[10] * other.m[6] + this.m[14] * other.m[7],
      this.m[3] * other.m[4] + this.m[7] * other.m[5] + this.m[11] * other.m[6] + this.m[15] * other.m[7],

      this.m[0] * other.m[8] + this.m[4] * other.m[9] + this.m[8] * other.m[10] + this.m[12] * other.m[11],
      this.m[1] * other.m[8] + this.m[5] * other.m[9] + this.m[9] * other.m[10] + this.m[13] * other.m[11],
      this.m[2] * other.m[8] + this.m[6] * other.m[9] + this.m[10] * other.m[10] + this.m[14] * other.m[11],
      this.m[3] * other.m[8] + this.m[7] * other.m[9] + this.m[11] * other.m[10] + this.m[15] * other.m[11],

      this.m[0] * other.m[12] + this.m[4] * other.m[13] + this.m[8] * other.m[14] + this.m[12] * other.m[15],
      this.m[1] * other.m[12] + this.m[5] * other.m[13] + this.m[9] * other.m[14] + this.m[13] * other.m[15],
      this.m[2] * other.m[12] + this.m[6] * other.m[13] + this.m[10] * other.m[14] + this.m[14] * other.m[15],
      this.m[3] * other.m[12] + this.m[7] * other.m[13] + this.m[11] * other.m[14] + this.m[15] * other.m[15],
    )
  }

  public transform3d(vec: Vec3): Readonly<Vec3> {
    return new Vec3(
      this.m[0] * vec.x + this.m[4] * vec.y + this.m[8] * vec.z + this.m[12],
      this.m[1] * vec.x + this.m[5] * vec.y + this.m[9] * vec.z + this.m[13],
      this.m[2] * vec.x + this.m[6] * vec.y + this.m[10] * vec.z + this.m[14],
    )
  }

  public transformBatch3d(vectors: Vec3[]): Readonly<Vec3[]> {
    return vectors.map((vec) => new Vec3(
      this.m[0] * vec.x + this.m[4] * vec.y + this.m[8] * vec.z + this.m[12],
      this.m[1] * vec.x + this.m[5] * vec.y + this.m[9] * vec.z + this.m[13],
      this.m[2] * vec.x + this.m[6] * vec.y + this.m[10] * vec.z + this.m[14],
    ))
  }

  public transform4d(vec: Vec4): Readonly<Vec4> {
    return new Vec4(
      this.m[0] * vec.x + this.m[4] * vec.y + this.m[8] * vec.z + this.m[12] * vec.w,
      this.m[1] * vec.x + this.m[5] * vec.y + this.m[9] * vec.z + this.m[13] * vec.w,
      this.m[2] * vec.x + this.m[6] * vec.y + this.m[10] * vec.z + this.m[14] * vec.w,
      this.m[3] * vec.x + this.m[7] * vec.y + this.m[11] * vec.z + this.m[15] * vec.w,
    )
  }

  public transformBatch4d(vectors: Vec4[]): Readonly<Vec4[]> {
    return vectors.map((vec) => new Vec4(
      this.m[0] * vec.x + this.m[4] * vec.y + this.m[8] * vec.z + this.m[12] * vec.w,
      this.m[1] * vec.x + this.m[5] * vec.y + this.m[9] * vec.z + this.m[13] * vec.w,
      this.m[2] * vec.x + this.m[6] * vec.y + this.m[10] * vec.z + this.m[14] * vec.w,
      this.m[3] * vec.x + this.m[7] * vec.y + this.m[11] * vec.z + this.m[15] * vec.w,
    ))
  }

  public toQuaternion(): Readonly<Quaternion> {
    const trace = this.m[0] + this.m[5] + this.m[10]
    let qw
    let qx
    let qy
    let qz

    if (trace > 0) {
      const s = 0.5 / Math.sqrt(trace + 1.0)
      qw = 0.25 / s
      qx = (this.m[9] - this.m[6]) * s
      qy = (this.m[2] - this.m[8]) * s
      qz = (this.m[4] - this.m[1]) * s
    } else {
      if (this.m[0] > this.m[5] && this.m[0] > this.m[10]) {
        const s = 2.0 * Math.sqrt(1.0 + this.m[0] - this.m[5] - this.m[10])
        qw = (this.m[9] - this.m[6]) / s
        qx = 0.25 * s
        qy = (this.m[1] + this.m[4]) / s
        qz = (this.m[2] + this.m[8]) / s
      } else if (this.m[5] > this.m[10]) {
        const s = 2.0 * Math.sqrt(1.0 + this.m[5] - this.m[0] - this.m[10])
        qw = (this.m[2] - this.m[8]) / s
        qx = (this.m[1] + this.m[4]) / s
        qy = 0.25 * s
        qz = (this.m[6] + this.m[9] ) / s
      } else {
        const s = 2.0 * Math.sqrt(1.0 + this.m[10] - this.m[0] - this.m[5])
        qw = (this.m[4] - this.m[1] ) / s
        qx = (this.m[2] + this.m[8] ) / s
        qy = (this.m[6] + this.m[9] ) / s
        qz = 0.25 * s
      }
    }

    return new Quaternion(qx, qy, qz, qw)
  }

  public getXAxis(): Vec3 {
    return new Vec3(this.m[0], this.m[4], this.m[8])
  }

  public getYAxis(): Vec3 {
    return new Vec3(this.m[1], this.m[5], this.m[9])
  }

  public getZAxis(): Vec3 {
    return new Vec3(this.m[2], this.m[6], this.m[10])
  }
}

const IDENTITY = Object.freeze(new Mat4())
const NULL = Object.freeze(new Mat4(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0))
