import { Effect, PostProcess, Camera } from '../loader/babylonjs-import';
import { ShortFrustumDepthRenderer } from '../loader/ShortFrustumDepthRenderer';
import { BaseLineShading, TEXTURE_SAMPLER } from './BaseLineShading';
import { NORMALS_FROM_DEPTH, VIEWSPACE_POSITION_FROM_LOGARITHMIC_DEPTH } from './TwinfinityPostProcess';

/**
 * LineShading is a post-process effect that adds an outline to objects in the viewer.
 * Unlike the other line shader, this one uses normals recreated from the depth buffer to be able to find lines where there are geoemtry with opposing normals but too similar depth values
 */

const SHADER_NAME_PREFIX = 'outlineFromDepthNormals';
const DEPTH_SAMPLER = 'depthSampler';
const INVERTED_PROJECTION_MATRIX = 'invertedProjectionMatrix';
const DEPTH_DIMENSIONS = 'depthDimensions';
const NEAR_FAR = 'nearFar';
const MIN_MAX_SEPERATION = 'minMaxSeperation';
const THICKNESS = 'thickness';
const STRENGTH = 'strength';
const EXPENSIVE_NORMALS = 'EXPENSIVE_NORMALS';

export class NormalsFromDepthLineShading extends BaseLineShading {
    private _shortFrustumDepthRenderer: ShortFrustumDepthRenderer;
    private _strength: number;

    constructor(lineThickness = 0.3, strength = 0.35) {
        super(lineThickness);
        this._strength = strength;
    }

    protected initialize(camera: Camera): PostProcess {
        const engine = camera.getEngine();

        if (this._postProcess === undefined) {
            Effect.ShadersStore[SHADER_NAME_PREFIX + 'FragmentShader'] = NormalsFromDepthLineShading._shaderSource;

            const uniforms = [
                MIN_MAX_SEPERATION,
                THICKNESS,
                STRENGTH,
                INVERTED_PROJECTION_MATRIX,
                DEPTH_DIMENSIONS,
                NEAR_FAR
            ];
            const samplers = [DEPTH_SAMPLER, TEXTURE_SAMPLER];
            const defines: string[] = [];
            defines.push(`#define ${EXPENSIVE_NORMALS}`);

            const newPostProcess = new PostProcess(
                'Normals and depth based outline post-process',
                SHADER_NAME_PREFIX,
                uniforms,
                samplers,
                1 /* ratio */,
                null,
                undefined,
                engine,
                false,
                defines.join('\n')
            );

            const shortFrustumProjectionMatrixState = camera.twinfinity.shortFrustum;

            newPostProcess.onApplyObservable.add((effect) => {
                const viewer = camera.getScene().twinfinity.viewer;
                if (viewer) {
                    this._shortFrustumDepthRenderer = viewer.enableDepthRenderer(true);
                    this.updateLineThicknessUniform(effect, this.lineThickness);
                    effect.setTexture(DEPTH_SAMPLER, this._shortFrustumDepthRenderer.depthRenderer.getDepthMap());

                    effect.setMatrix(
                        INVERTED_PROJECTION_MATRIX,
                        shortFrustumProjectionMatrixState.getInvertedProjectionMatrix()
                    );
                    const depthDimensionsX = 1.0 / engine.getRenderWidth();
                    const depthDimensionsY = 1.0 / engine.getRenderHeight();
                    effect.setFloat2(DEPTH_DIMENSIONS, depthDimensionsX, depthDimensionsY);
                    effect.setFloat2(
                        NEAR_FAR,
                        shortFrustumProjectionMatrixState.depthRenderCameraMinZ,
                        shortFrustumProjectionMatrixState.depthRenderCameraMaxZ
                    );

                    effect.setFloat2(MIN_MAX_SEPERATION, 1.0, 1.0); // Hardcoded to avoid artifacts

                    effect.setFloat(THICKNESS, this.lineThickness);
                    effect.setFloat(STRENGTH, Math.min(this._strength, 1.0));
                } else {
                    throw new Error('No viewer created?');
                }
            });

            return newPostProcess;
        } else return this._postProcess;
    }

    detach(camera: Camera): boolean {
        const detached = super.detach(camera);

        if (detached) {
            this._shortFrustumDepthRenderer.disable();
        }

        return detached;
    }

    private static readonly _shaderSource = `
    #ifdef GL_ES
        precision highp float;
    #endif

    varying vec2 vUV;

    uniform sampler2D textureSampler;
    uniform sampler2D depthSampler;
    uniform vec2 nearFar;
    uniform vec2 minMaxSeperation;
    uniform float thickness;
    uniform float strength;
    uniform vec3 lineColor;
    #define NORMAL_SCALE 2.5
    #define DEPTH_SCALE 0.75

    ${VIEWSPACE_POSITION_FROM_LOGARITHMIC_DEPTH}
    ${NORMALS_FROM_DEPTH}

    void compare(inout float depthOutline, inout float normalOutline, float viewSpaceZ, vec3 fragmentViewSpaceNormal, vec2 uv, vec2 offset){
        
        vec2 neighborUV = uv + offset;
        float neighborFragmentDepth = texture2D(depthSampler, neighborUV).r;
        vec3 neighborViewSpacePosition = viewSpacePositionFromDepth(neighborFragmentDepth, neighborUV.x, neighborUV.y);
       
        #ifdef ${EXPENSIVE_NORMALS}
        vec3 neighborViewSpaceNormal = expensiveViewSpaceNormalFromNeighbours(neighborUV, neighborViewSpacePosition, neighborFragmentDepth);
        #else
        vec3 neighborViewSpaceNormal = viewSpaceNormalFromNeighbours(neighborUV, neighborViewSpacePosition);
        #endif

        // new
        float depthDifference = min(abs(viewSpaceZ - neighborViewSpacePosition.z), 1.0);
        depthOutline = depthOutline + depthDifference;

        // Use  step function to avoid the normal contributing to the line on the fragments of edges    
        float normalDifference = min(pow(length(fragmentViewSpaceNormal - neighborViewSpaceNormal), 2.0), 1.0);
        normalOutline = normalOutline + normalDifference;
    }

    void main(void) 
    {
        float fragmentDepth = texture2D(depthSampler, vUV).r;
        float fragmentWritten = 1.0 - step(0.9995, fragmentDepth);
        vec4 baseColor = texture2D(textureSampler, vUV);
        if (fragmentWritten == 1.0) {
        
            vec3 viewSpacePosition = viewSpacePositionFromDepth(fragmentDepth, vUV.x, vUV.y);

            #ifdef ${EXPENSIVE_NORMALS}
                vec3 viewSpaceNormal = expensiveViewSpaceNormalFromNeighbours(vUV, viewSpacePosition, fragmentDepth);
            #else
                vec3 viewSpaceNormal = viewSpaceNormalFromNeighbours(vUV, viewSpacePosition);
            #endif

            float depthOutline = 0.0;
            float normalOutline = 0.0;

            float separation = mix(minMaxSeperation.x, minMaxSeperation.y, viewSpacePosition.z/nearFar.y);

            compare(depthOutline, normalOutline, viewSpacePosition.z, viewSpaceNormal, vUV, vec2(depthDimensions.x, 0.0) * separation);
            compare(depthOutline, normalOutline, viewSpacePosition.z, viewSpaceNormal, vUV, vec2(0, depthDimensions.y) * separation);
            compare(depthOutline, normalOutline, viewSpacePosition.z, viewSpaceNormal, vUV, vec2(0, -depthDimensions.y) * separation);
            compare(depthOutline, normalOutline, viewSpacePosition.z, viewSpaceNormal, vUV, vec2(-depthDimensions.x, 0) * separation);

            // Fade out the line strength the closer we get to the short distance view frustum
            float distanceFade = 1.0 - smoothstep(nearFar.y * 0.1, nearFar.y * 0.9, viewSpacePosition.z);
            float depthDifferenceScaling = 1.0 + pow((1.0 - (viewSpacePosition.z / nearFar.y)), 16.0) * 5.0;
            // There are 4 samples, each sample may get contribution from either depth different to neighbors, or normal difference to neighbor, hence 4 * 2 as denominator
            float denominator = (4.0 * 2.0);
            float lineContribution = strength * distanceFade * ((depthDifferenceScaling * depthOutline * DEPTH_SCALE) + normalOutline * NORMAL_SCALE) / denominator;

            gl_FragColor = mix(baseColor, vec4(0.0, 0.0, 0.0, 1.0), lineContribution);
        } else {
            gl_FragColor = baseColor;
        }
    }
    `;
}
