import { Vector3, Mesh, Intersection, Object3D, Raycaster, Ray, Triangle, Matrix4, BufferGeometry, Float32BufferAttribute, LineBasicMaterial, LineSegments } from "three";
import { SceneContext } from "../scene/SceneContext";

export interface iRay {
    origin: Vector3;
    direction: Vector3;
}

export class OpticsMesh extends Mesh {
    public n: number;
}

export interface iIntersection extends Intersection {
    normal: Vector3;
}

export class RayTracer {

    //__________________________________________________________________________________________
    public static getIntersection(pRay: iRay, pObjects: Array<Object3D>,
        threshold: number = 1) {
        let o = pRay.origin;
        let d = pRay.direction;
        let rt = new Raycaster(o, d, 0, 10000);
        rt.params.Points.threshold = threshold;
        rt.params.Line.threshold = threshold;

        let aIntersections = rt.intersectObjects(pObjects, true);
        let aIntersection = aIntersections[0] as iIntersection;
        if (null == aIntersection) {
            return null;
        }

        aIntersection.normal = this._getIntersectionNormal(aIntersection);
        return aIntersection;
    }
    //__________________________________________________________________________________________
    public static startRayCast(pRay: Ray, pObjects: Array<OpticsMesh>,
        pAirToMat: boolean = true, normalFunc: (x: number, y: number) => Vector3) {

        let o = pRay.origin;
        let d = pRay.direction;


        let aIntersection = new Raycaster(o, d).intersectObjects(pObjects)[0];
        if (null != aIntersection) {
            let aObject = aIntersection.object as OpticsMesh;
            SceneContext.MAIN_SCENE.add(aObject);
            let p = aIntersection.point;

            let g = (aIntersection.object as Mesh).geometry.attributes['position'].array;
            let fa = aIntersection.face.a;
            let a = new Vector3(g[fa * 3], g[fa * 3 + 1], g[fa * 3 + 2]);
            let fb = aIntersection.face.b;
            let b = new Vector3(g[fb * 3], g[fb * 3 + 1], g[fb * 3 + 2]);
            let fc = aIntersection.face.c;
            let c = new Vector3(g[fc * 3], g[fc * 3 + 1], g[fa * 3 + 2]);



            let aTriangle = new Triangle(a, b, c);
            let aPoint = (aTriangle as any).barycoordFromPoint(p, new Vector3());

            let aObjects = pObjects.slice();
            aObjects.splice(pObjects.indexOf(aObject), 1);
            let aNp = aPoint.clone().sub(aIntersection.object.getWorldPosition(new Vector3()));
            let N = normalFunc(aNp.x, aNp.y);
            let aMatrix = new Matrix4().extractRotation(aObject.matrixWorld);
            N.applyMatrix4(aMatrix).normalize();

            let n1 = (true == pAirToMat) ? 1 : aObject.n;
            let n2 = (true == pAirToMat) ? aObject.n : 1;

            let aOrigin = aPoint;
            let aDirection = RayTracer.snellLaw(N, d, n1, n2);

            let aLS = this.getLine(o, aPoint);
            SceneContext.MAIN_SCENE.add(aLS);

            let aNext = RayTracer.startRayCast(new Ray(aOrigin, aDirection), aObjects, !pAirToMat, normalFunc);
            if (null == aNext) {
                let aEndPoint = new Vector3().copy(aOrigin).add(aDirection.multiplyScalar(1000));
                let aEndLine = this.getLine(aOrigin, aEndPoint);
                SceneContext.MAIN_SCENE.add(aEndLine);
            }

            return aIntersection;
        }

        return null;
    }
    //__________________________________________________________________________________________
    public static getLine(pStart: Vector3, pEnd: Vector3) {
        let aLineG = new BufferGeometry();
        let aVertices = new Array<number>();
        aVertices.push(...pStart.toArray());
        aVertices.push(...pEnd.toArray());
        aLineG.setAttribute('position', new Float32BufferAttribute(aVertices, 3));
        aLineG.computeVertexNormals();
        let aMat = new LineBasicMaterial({ color: 0x00FF00 });
        let aLS = new LineSegments(aLineG, aMat);

        return aLS;
    }
    //__________________________________________________________________________________________
    public static snellLaw(N_in: Vector3, V_in: Vector3, n_input: number,
        n_element: number) {

        let eta = n_input / n_element;
        if (eta > 1) {
            N_in.multiplyScalar(-1);
        }

        let N_dot_V = N_in.dot(V_in);
        let cosThetaIn = N_dot_V / (N_in.length() * V_in.length());


        let sinThetaIn = Math.sin(Math.acos(cosThetaIn));

        if ((sinThetaIn * eta) > 1) { //TIR
            let Vref = new Vector3().copy(V_in);
            Vref.sub(N_in.clone().multiplyScalar(2 * cosThetaIn));

            return Vref;
        } else {
            let sinPhi = sinThetaIn * eta;
            let XY = eta * N_dot_V + Math.cos(Math.asin(sinPhi));

            let Vout = N_in.clone().multiplyScalar(XY);
            Vout.sub(V_in.clone().multiplyScalar(eta));
            Vout.multiplyScalar(-1);
            return Vout.normalize();
        }
    }
    //__________________________________________________________________________________________
    public static snellLaw2(N_in, V_in, DeltaR_in, _W, n_input, n_element) {
        let DeltaR_out = DeltaR_in;
        let N_out = -N_in;
        let eta = n_input / n_element;
        let N_dot_V = N_in[0] * V_in[0] + N_in[1] * V_in[1] + N_in[2] * V_in[2];
        if (N_dot_V > 1) {
            N_dot_V = 1;
        }
        if (N_dot_V < -1) {
            N_dot_V = -1;
        }
        let sinTheta = Math.sin(Math.acos(N_dot_V));
        let sinPhi = sinTheta * eta;
        if (sinPhi > 1) {
            sinPhi = 1;
        }
        if (sinPhi < -1) {
            sinPhi = -1;
        }
        let XY = eta * N_dot_V + Math.cos(Math.asin(sinPhi));
        let Vout = [];
        Vout[0] = -(N_in[0] * XY - eta * V_in[0]);
        Vout[1] = -(N_in[1] * XY - eta * V_in[1]);
        Vout[2] = -(N_in[2] * XY - eta * V_in[2]);
        // copy this for seq mirror type
        let Vref = [];
        Vref[0] = V_in[0] - 2 * N_dot_V * N_in[0];
        Vref[1] = V_in[1] - 2 * N_dot_V * N_in[1];
        Vref[2] = V_in[2] - 2 * N_dot_V * N_in[2];
        let aRet: any = {};
        if ((sinTheta * eta) > 1) {
            aRet.normalOut = new Vector3(N_out[0], N_out[1], N_out[2]);
            aRet.laserDirectionOut = new Vector3(Vref[0], Vref[1], Vref[2]);
            aRet.deltaFromHitCenterOut = new Vector3(DeltaR_in[0], DeltaR_in[1], DeltaR_in[2]);
        } else {
            aRet.normalOut = new Vector3(N_out[0], N_out[1], N_out[2]);
            aRet.laserDirectionOut = new Vector3(Vout[0], Vout[1], Vout[2]);
            aRet.deltaFromHitCenterOut = new Vector3(DeltaR_out[0], DeltaR_out[1], DeltaR_out[2]);
        }
        return aRet;
    }
    //__________________________________________________________________________________________
    private static _getIntersectionNormal(pIntersection: Intersection) {
        let aObject = pIntersection.object as Mesh;

        let index = pIntersection.index;
        let array = aObject.geometry.attributes['normal'].array as Array<number>;

        let nx = array[index * 3];
        let ny = array[index * 3 + 1];
        let nz = array[index * 3 + 2];


        let aFn = new Vector3(nx, ny, nz);

        let aMatrix4 = new Matrix4().extractRotation(aObject.matrixWorld);
        aFn.applyMatrix4(aMatrix4);

        return aFn;
    }
    //__________________________________________________________________________________________
}
