const Shader = require("@/game/Shader");

class ModelShader extends Shader {
    __getShaders() {
        return {
            vert:
            // language=Glsl
                `precision highp float;
                 attribute vec3 _position;
                 attribute vec4 _normal;
                 attribute vec4 _tangent;
                 attribute vec2 tex;
                 attribute vec4 weights;
                 attribute vec4 ids;

                 uniform mat4 projectView;
                 uniform mat4 bones[63];

                 varying vec2 uv;
                 varying vec3 normal;
                 varying vec3 tangent;
                 varying vec3 pos;
                 void main() {
                     uv = vec2(tex.x, 1. - tex.y);
                     vec3 norm = _normal.xyz;
                     vec3 tang = _tangent.xyz;

                     vec4 totalLocalPos = vec4(0.0);
                     vec3 totalNormal = vec3(0.0);
                     vec3 totalTangent = vec3(0.0);

                     for (int i = 0; i < 4; i++) {
                         mat4 JM = bones[int(ids[i])];
                         totalLocalPos += (JM * vec4(_position, 1.)) * weights[i];
                         totalNormal += (mat3(JM) * norm) * weights[i];
                         totalTangent += (mat3(JM) * tang) * weights[i];
                     }
                     normal = normalize(totalNormal.xyz);
                     tangent = normalize(totalTangent.xyz);
                     pos = totalLocalPos.xyz;
                     gl_Position = projectView * totalLocalPos;
                 }`,
            // language=Glsl
            frag:
                `precision highp float;
                 uniform sampler2D albedoMap;
                 uniform sampler2D normalMap;
                 uniform sampler2D physMap;
                 varying vec2 uv;
                 varying vec3 normal;
                 varying vec3 tangent;
                 varying vec3 pos;

                 const float PI = 3.1415926589;
                 float NDF(float n_dot_h, float roughness) {
                     float roughness_sqr = roughness * roughness;
                     float temp = n_dot_h * n_dot_h * (roughness_sqr - 1.) + 1.;
                     return roughness_sqr / (PI * temp * temp);
                 }

                 float G_Schlick(float n_dot_v, float k) {
                     return n_dot_v / (n_dot_v * (1. - k) + k);
                 }

                 float k_direct(float roughness) {
                     float roughness_plus_one = roughness + 1.;
                     float k = roughness_plus_one * roughness_plus_one * 0.125;
                     return k;
                 }

                 float k_IBL(float roughness) {
                     float k = roughness * roughness * 0.5;
                     return k;
                 }

                 float G_Smith(float n_dot_v, float n_dot_l, float k) {
                     return G_Schlick(n_dot_v, k) * G_Schlick(n_dot_l, k);
                 }

                 vec3 F_Schlick(float h_dot_v, vec3 f0) {
                     float temp = 1. - h_dot_v;
                     float temp2 = temp * temp;
                     float temp5 = temp2 * temp2 * temp;

                     return f0 + (1. - f0) * temp5;
                 }

                 vec3 F(float h_dot_v, vec3 albedo, float metallic) {
                     vec3 f0 = mix(vec3(0.04), albedo, metallic);
                     return F_Schlick(h_dot_v, f0);
                 }

                 vec3 f_lit(vec3 light_dir, vec3 light_radiance, vec3 normal, vec3 view_dir, vec3 albedo, float metallic, float roughness) {
                     vec3 halfway = normalize(view_dir + light_dir);
                     float n_dot_v = clamp(dot(normal, view_dir), 0., 1.);
                     float n_dot_l = clamp(dot(normal, light_dir), 0., 1.);
                     float h_dot_v = clamp(dot(halfway, view_dir), 0., 1.);
                     float n_dot_h = clamp(dot(normal, halfway), 0., 1.);

                     float ndf = NDF(n_dot_h, roughness);
                     float g = G_Smith(n_dot_v, n_dot_l, k_direct(roughness));
                     vec3 f = F(h_dot_v, albedo, metallic);
                     float normalization_factor = 4. * n_dot_v * n_dot_l;

                     vec3 kS = f;
                     vec3 kD = (vec3(1.) - kS) * (1. - metallic);

                     vec3 specular = ndf * g * f / max(normalization_factor, 0.001);
                     vec3 diffuse = kD * albedo / PI;

                     vec3 BRDF = diffuse + specular;
                     return BRDF * light_radiance * n_dot_l;
                 }
                 void main() {
                     vec3 albedo = texture2D(albedoMap, uv).rgb;
                     vec3 phys = texture2D(physMap, uv).rgb;
                     vec3 N = normalize(mat3(tangent, normalize(cross(normal, tangent)), normal) * (pow(texture2D(normalMap, uv).rgb, vec3(1. )) * 2.0 - 1.0));
                     vec3 V = normalize(-pos);
                     vec3 ld = vec3(3.2, 2.9, 2.8) * .9;
                     vec3 c;
                     c += f_lit(vec3(-.96, .28, 0.), ld, N, V, albedo, phys.b, phys.g);
                     c += f_lit(vec3(.96, .28, 0.), ld, N, V, albedo, phys.b, phys.g);
                     c += f_lit(vec3(0., .707, .707), ld, N, V, albedo, phys.b, phys.g);
                     gl_FragColor = vec4(c, 1.0);
                 }`
        }
    }

    bind(projectView, joints, time) {
        super.bind();
        this.gl.uniformMatrix4fv(this.uniforms.projectView, false, projectView);
        this.gl.uniformMatrix4fv(this.uniforms.bones, false, joints);
        this.gl.uniform1f(this.uniforms.time, time);
    }

    __initUniforms() {
        return {
            projectView: this.gl.getUniformLocation(this.program, 'projectView'),
            bones: this.gl.getUniformLocation(this.program, 'bones'),
            time: this.gl.getUniformLocation(this.program, 'time'),
            a: this.gl.getUniformLocation(this.program, 'albedoMap'),
            n: this.gl.getUniformLocation(this.program, 'normalMap'),
            p: this.gl.getUniformLocation(this.program, 'physMap'),
        }
    }

    __initAttributions() {
        return {
            position: this.gl.getAttribLocation(this.program, '_position'),
            normal: this.gl.getAttribLocation(this.program, '_normal'),
            tangent: this.gl.getAttribLocation(this.program, '_tangent'),
            tex: this.gl.getAttribLocation(this.program, 'tex'),
            weights: this.gl.getAttribLocation(this.program, 'weights'),
            ids: this.gl.getAttribLocation(this.program, 'ids')
        }
    }

    bindUniforms() {

    }
}

module.exports = ModelShader;