import { Camera, DynamicTexture, Effect, PostProcess, Texture, Vector3 } from '../loader/babylonjs-import';
import { ShortFrustumDepthRenderer } from '../loader/ShortFrustumDepthRenderer';
import {
    DEPTH_DIMENSIONS,
    EXPENSIVE_NORMALS,
    INVERTED_PROJECTION_MATRIX,
    NEAR_FAR,
    NORMALS_FROM_DEPTH,
    TwinfinityPostProcess,
    VIEWSPACE_POSITION_FROM_LOGARITHMIC_DEPTH
} from './TwinfinityPostProcess';

const SHADER_NAME_PREFIX = 'custom_ssao';
const DEPTH_SAMPLER = 'depthSampler';
const NOISE_SAMPLER = 'noiseSampler';

const SAMPLE_KERNEL_SIZE = 'sampleKernelSize';
const SAMPLE_KERNEL = 'sampleKernel';
const STRENGTH = 'strength';
const NOISE_SCALE = 'noiseScale';
const UV_SCALE = 'scale';
const RADIUS = 'radius';
const PROJECTION_MATRIX = 'projectionMatrix';

/**
 * Calculates SSAO using the depth buffer.
 * Pretty much John Chapmans SSAO tutorial: http://john-chapman-graphics.blogspot.com/2013/01/ssao-tutorial.html but there is no normal buffer so we have to recreate the normals from the depth buffer instead
 */
export class SSAOCalculator extends TwinfinityPostProcess {
    private _shortFrustumDepthRenderer: ShortFrustumDepthRenderer;

    private static readonly _fragmentShaderSource = `
    #ifdef GL_ES
        precision highp float;
    #endif
    varying vec2 vUV;

    // Samplers
    uniform sampler2D ${DEPTH_SAMPLER};
    uniform sampler2D ${NOISE_SAMPLER};
    uniform vec2 ${NOISE_SCALE};
    uniform vec3 ${SAMPLE_KERNEL}[${SAMPLE_KERNEL_SIZE}];
    uniform mat4 ${PROJECTION_MATRIX};
    
    uniform float ${STRENGTH};
    uniform float ${RADIUS};
    uniform vec2 ${NEAR_FAR};
    
    float bias = -0.015;

    ${VIEWSPACE_POSITION_FROM_LOGARITHMIC_DEPTH}
    ${NORMALS_FROM_DEPTH}

    float ambientOcclusionValue(vec3 viewSpaceFragment, vec3 viewSpaceNormal, vec2 screenUV) {
        vec3 noise = texture2D(${NOISE_SAMPLER}, screenUV * ${NOISE_SCALE}).rgb;
        vec3 randomVec = noise * 2.0 - 1.0;
        float occlusion = 0.0;
    
        vec3 tangent = normalize(randomVec - viewSpaceNormal * dot(randomVec, viewSpaceNormal));
        vec3 biTangent = cross(viewSpaceNormal, tangent);
        mat3 tbn = mat3(tangent, biTangent, viewSpaceNormal);
    
        for (int i = 0; i < ${SAMPLE_KERNEL_SIZE}; i++) {
            vec3 samplePosition = tbn * sampleKernel[i];
            samplePosition = samplePosition * ${RADIUS} + viewSpaceFragment;
    
            vec4 offset = vec4(samplePosition, 1.0);
            offset = ${PROJECTION_MATRIX} * offset;
            offset.xy /= offset.w;
            offset.xy = offset.xy * 0.5 + vec2(0.5);
    
            float viewDepth = texture2D(${DEPTH_SAMPLER}, offset.xy).r;
            vec4 viewSpaceSample = vec4(viewSpacePositionFromDepth(viewDepth, screenUV.x, screenUV.y), 1.0);
            float sampleDepth = viewSpaceSample.z;
    
            float rangeCheck = smoothstep(0.0, 1.0, ${RADIUS} / abs(viewSpaceFragment.z - sampleDepth));
            occlusion += (sampleDepth <= samplePosition.z + bias ? 1.0 * rangeCheck : 0.0);
        }
    
        occlusion = min((occlusion / float(${SAMPLE_KERNEL_SIZE})) * ${STRENGTH}, 1.0);
    
        return occlusion;
    }
    
    void main(void)
    {
        vec2 sampleUV = vUV;
        float fragmentDepth = texture2D(${DEPTH_SAMPLER}, sampleUV).r;
        vec3 viewSpacePosition = viewSpacePositionFromDepth(fragmentDepth, sampleUV.x, sampleUV.y);
        #ifdef ${EXPENSIVE_NORMALS}
            vec3 viewSpaceNormal = expensiveViewSpaceNormalFromNeighbours(sampleUV, viewSpacePosition, fragmentDepth);
        #else
            vec3 viewSpaceNormal = viewSpaceNormalFromNeighbours(sampleUV, viewSpacePosition);
        #endif

        float distanceFade = 1.0 - smoothstep(${NEAR_FAR}.y * 0.1, ${NEAR_FAR}.y * 0.9, viewSpacePosition.z);
        float depthDifferenceScaling = 1.0 + pow((1.0 - (viewSpacePosition.z / ${NEAR_FAR}.y)), 16.0) * 5.0;
        float aoContribution = distanceFade * depthDifferenceScaling;

        float fragmentWritten = 1.0 - step(0.9995, fragmentDepth);
        float ambientOcclusion = 0.0;
        if (fragmentWritten == 1.0) {
            ambientOcclusion = ambientOcclusionValue(viewSpacePosition, viewSpaceNormal, sampleUV) * aoContribution;
        }
    
        gl_FragColor = vec4(vec3(ambientOcclusion), 1.0);
    }`;
    private _kernel: Vector3[] = [];
    private _noiseTexture: DynamicTexture;
    public strength: number;
    public radius: number;
    private _samples: number;
    private _renderScale: number;

    public set samples(samples: number) {
        this._samples = samples;
        this.createKernel();

        if (this._postProcess) {
            this._postProcess.updateEffect(
                this.getDefines(),
                this.getUniforms(),
                this.getSamplers(),
                undefined,
                () => {},
                undefined
            );
        } else {
            throw new Error('Updating samples of uinitialized ssao calculator');
        }
    }

    public get samples(): number {
        return this._samples;
    }

    private getDefines(): string {
        const defines: string[] = [];
        defines.push(`#define ${SAMPLE_KERNEL_SIZE} ${this.samples}`);
        defines.push(`#define ${EXPENSIVE_NORMALS}`);
        return defines.join('\n');
    }

    private getUniforms(): string[] {
        const uniforms = [
            UV_SCALE,
            STRENGTH,
            RADIUS,
            NOISE_SCALE,
            DEPTH_DIMENSIONS,
            NEAR_FAR,
            PROJECTION_MATRIX,
            INVERTED_PROJECTION_MATRIX
        ];

        for (let i = 0; i < this._kernel.length; i++) {
            uniforms.push(SAMPLE_KERNEL + '[' + i + ']');
        }

        return uniforms;
    }

    private getSamplers(): string[] {
        return [DEPTH_SAMPLER, NOISE_SAMPLER];
    }

    protected initialize(camera: Camera): PostProcess {
        const engine = camera.getEngine();

        const scene = camera.getScene();

        const noiseSize = 64;
        this._noiseTexture = new DynamicTexture(
            'SSAOnoiseTexture',
            noiseSize,
            scene,
            false,
            Texture.NEAREST_SAMPLINGMODE
        );
        this._noiseTexture.wrapR = 1;
        this._noiseTexture.wrapU = 1;
        this._noiseTexture.wrapV = 1;

        const noiseTextureContext = this._noiseTexture.getContext();

        for (let i = 0; i < noiseSize; i++) {
            for (let j = 0; j < noiseSize; j++) {
                noiseTextureContext.fillStyle = `rgba(
                    ${Math.floor(Math.random() * 255.0)},
                    ${Math.floor(Math.random() * 255.0)},
                    0)`;
                noiseTextureContext.fillRect(i, j, 1, 1);
            }
        }
        this._noiseTexture.update();

        const renderWidth = engine.getRenderWidth();
        const renderHeight = engine.getRenderHeight();

        if (this._postProcess === undefined) {
            Effect.ShadersStore[SHADER_NAME_PREFIX + 'FragmentShader'] = SSAOCalculator._fragmentShaderSource;
            const newPostProcess = new PostProcess(
                'SSAO ',
                SHADER_NAME_PREFIX,
                this.getUniforms(),
                this.getSamplers(),
                {
                    // Do not include scale here, because that would mean that the postprocess before this one (this blit) would be resized, because Babylonjs has really weird rules for this
                    width: renderWidth,
                    height: renderHeight
                },
                null,
                undefined,
                engine,
                false,
                this.getDefines()
            );

            const viewer = camera.getScene().twinfinity.viewer;
            if (viewer) {
                const nonLinearDepth = true;
                this._shortFrustumDepthRenderer = viewer.enableDepthRenderer(nonLinearDepth);
                const shortFrustumProjectionMatrixState =
                    this._shortFrustumDepthRenderer.shortFrustumProjectionMatrixState;

                newPostProcess.onApplyObservable.add((effect) => {
                    effect.setTexture(NOISE_SAMPLER, this._noiseTexture);
                    effect.setTexture(DEPTH_SAMPLER, this._shortFrustumDepthRenderer.depthRenderer.getDepthMap());
                    effect.setFloat(STRENGTH, this.strength);
                    effect.setFloat(RADIUS, this.radius);
                    const depthDimensionsX = 1.0 / (renderWidth * this._renderScale);
                    const depthDimensionsY = 1.0 / (renderHeight * this._renderScale);
                    effect.setFloat2(DEPTH_DIMENSIONS, depthDimensionsX, depthDimensionsY);
                    effect.setFloat2(
                        NOISE_SCALE,
                        (renderWidth * this._renderScale) / noiseSize,
                        (renderHeight * this._renderScale) / noiseSize
                    );
                    effect.setFloat2(UV_SCALE, 1.0, 1.0);
                    effect.setFloat2(
                        NEAR_FAR,
                        shortFrustumProjectionMatrixState.depthRenderCameraMinZ,
                        shortFrustumProjectionMatrixState.depthRenderCameraMaxZ
                    );

                    effect.setMatrix(
                        INVERTED_PROJECTION_MATRIX,
                        shortFrustumProjectionMatrixState.getInvertedProjectionMatrix()
                    );
                    effect.setMatrix(PROJECTION_MATRIX, shortFrustumProjectionMatrixState.getProjectionMatrix());

                    for (let i = 0; i < this._kernel.length; i++) {
                        const kernelSample = this._kernel[i];
                        const sampleKernelUniformName = `${SAMPLE_KERNEL}[${i}]`;
                        effect.setFloat3(sampleKernelUniformName, kernelSample.x, kernelSample.y, kernelSample.z);
                    }
                });
            } else {
                throw new Error('No viewer created?');
            }

            this._postProcess = newPostProcess;

            return newPostProcess;
        } else return this._postProcess;
    }

    private createKernel(): void {
        const lerp = function (a: number, b: number, n: number): number {
            return (1 - n) * a + n * b;
        };

        this._kernel = [];

        for (let i = 0; i < this.samples; i++) {
            const kernelSample = new Vector3(Math.random() * 2.0 - 1.0, Math.random() * 2.0 - 1.0, Math.random());
            kernelSample.normalize();

            let scale = (i % this.samples) / this.samples;
            scale = lerp(0.1, 1.0, scale * scale);

            kernelSample.scaleInPlace(scale);

            this._kernel.push(kernelSample);
        }
    }

    constructor(samplesCount: number, strength: number, radius: number, renderScale: number) {
        super();
        this._samples = samplesCount;
        this.strength = strength;
        this.radius = radius;
        this._renderScale = renderScale;
        this.createKernel();
    }

    detach(camera: Camera): boolean {
        const detached = super.detach(camera);

        if (detached) {
            this._shortFrustumDepthRenderer.disable();
        }

        return detached;
    }
}
