import bpy

import lnx.material.shader as shader
import lnx.material.mat_state as mat_state
import lnx.material.mat_utils as mat_utils
import lnx.utils

if lnx.is_reload(__name__):
    shader = lnx.reload_module(shader)
    lnx.utils = lnx.reload_module(lnx.utils)
else:
    lnx.enable_reload(__name__)


def write(vert: shader.Shader, frag: shader.Shader):
    wrd = bpy.data.worlds['Lnx']
    rpdat = lnx.utils.get_rp()
    blend = mat_state.material.lnx_blending
    parse_opacity = blend or mat_utils.is_transluc(mat_state.material)
    is_mobile = rpdat.lnx_material_model == 'Mobile'
    is_shadows = '_ShadowMap' in wrd.world_defs
    is_shadows_atlas = '_ShadowMapAtlas' in wrd.world_defs
    is_single_atlas = '_SingleAtlas' in wrd.world_defs

    frag.add_include_front('std/clusters.glsl')
    frag.add_uniform('vec2 cameraProj', link='_cameraPlaneProj')
    frag.add_uniform('vec2 cameraPlane', link='_cameraPlane')
    frag.add_uniform('vec4 lightsArray[maxLights * 3]', link='_lightsArray')
    frag.add_uniform('sampler2D clustersData', link='_clustersData')
    if is_shadows:
        frag.add_uniform('bool receiveShadow')
        frag.add_uniform('vec2 lightProj', link='_lightPlaneProj', included=True)
        if is_shadows_atlas:
            if not is_single_atlas:
                frag.add_uniform('sampler2DShadow shadowMapAtlasPoint', included=True)
                frag.add_uniform('sampler2D shadowMapAtlasPointTransparent', included=True)
            else:
                frag.add_uniform('sampler2DShadow shadowMapAtlas', top=True)
                frag.add_uniform('sampler2D shadowMapAtlasTransparent', top=True)
            frag.add_uniform('vec4 pointLightDataArray[maxLightsCluster]', link='_pointLightsAtlasArray', included=True)
        else:
            frag.add_uniform('samplerCubeShadow shadowMapPoint[4]', included=True)
            frag.add_uniform('samplerCube shadowMapPointTransparent[4]', included=True)

    if not '_VoxelAOvar' in wrd.world_defs and not '_VoxelGI' in wrd.world_defs or ((parse_opacity or '_VoxelShadow' in wrd.world_defs) and ('_VoxelAOvar' in wrd.world_defs or '_VoxelGI' in wrd.world_defs)):
        vert.add_out('vec4 wvpposition')
        vert.write('wvpposition = gl_Position;')
    # wvpposition.z / wvpposition.w
    frag.write('float viewz = linearize(gl_FragCoord.z, cameraProj);')
    frag.write('int clusterI = getClusterI((wvpposition.xy / wvpposition.w) * 0.5 + 0.5, viewz, cameraPlane);')
    frag.write('int numLights = int(texelFetch(clustersData, ivec2(clusterI, 0), 0).r * 255);')

    frag.write('#ifdef HLSL')
    frag.write('viewz += texture(clustersData, vec2(0.0)).r * 1e-9;') # TODO: krafix bug, needs to generate sampler
    frag.write('#endif')

    if '_Spot' in wrd.world_defs:
        frag.add_uniform('vec4 lightsArraySpot[maxLights * 2]', link='_lightsArraySpot')
        frag.write('int numSpots = int(texelFetch(clustersData, ivec2(clusterI, 1 + maxLightsCluster), 0).r * 255);')
        frag.write('int numPoints = numLights - numSpots;')
        if is_shadows:
            if is_shadows_atlas:
                if not is_single_atlas:
                    frag.add_uniform('sampler2DShadow shadowMapAtlasSpot', included=True)
                    frag.add_uniform('sampler2D shadowMapAtlasSpotTransparent', included=True)
                else:
                    frag.add_uniform('sampler2DShadow shadowMapAtlas', top=True)
                    frag.add_uniform('sampler2D shadowMapAtlasTransparent', top=True)
            else:
                frag.add_uniform('sampler2DShadow shadowMapSpot[4]', included=True)
                frag.add_uniform('sampler2D shadowMapSpotTransparent[4]', included=True)
            frag.add_uniform('mat4 LWVPSpotArray[maxLightsCluster]', link='_biasLightWorldViewProjectionMatrixSpotArray', included=True)

    frag.write('for (int i = 0; i < min(numLights, maxLightsCluster); i++) {')
    frag.write('int li = int(texelFetch(clustersData, ivec2(clusterI, i + 1), 0).r * 255);')
    frag.write('direct += sampleLight(')
    frag.write('    wposition,')
    frag.write('    n,')
    frag.write('    vVec,')
    frag.write('    dotNV,')
    frag.write('    lightsArray[li * 3].xyz,') # lp
    frag.write('    lightsArray[li * 3 + 1].xyz,') # lightCol
    frag.write('    albedo,')
    frag.write('    roughness,')
    frag.write('    specular,')
    frag.write('    f0')

    if is_shadows:
        if parse_opacity:
            frag.write('\t, li, lightsArray[li * 3 + 2].x, lightsArray[li * 3 + 2].z != 0.0, opacity != 1.0') # bias
        else:
            frag.write('\t, li, lightsArray[li * 3 + 2].x, lightsArray[li * 3 + 2].z != 0.0, false') # bias
    if '_Spot' in wrd.world_defs:
        frag.write('\t, lightsArray[li * 3 + 2].y != 0.0')
        frag.write('\t, lightsArray[li * 3 + 2].y') # spot size (cutoff)
        frag.write('\t, lightsArraySpot[li * 2].w') # spot blend (exponent)
        frag.write('\t, lightsArraySpot[li * 2].xyz') # spotDir
        frag.write('\t, vec2(lightsArray[li * 3].w, lightsArray[li * 3 + 1].w)') # scale
        frag.write('\t, lightsArraySpot[li * 2 + 1].xyz') # right
    if '_VoxelShadow' in wrd.world_defs:
            frag.write(', voxels, voxelsSDF, clipmaps')
    if '_MicroShadowing' in wrd.world_defs and not is_mobile:
        frag.write('\t, occlusion')
    if '_SSRS' in wrd.world_defs:
        frag.add_uniform('mat4 invVP', '_inverseViewProjectionMatrix')
        frag.add_uniform('vec3 eye', '_cameraPosition')
        frag.write(', gl_FragCoord.z, inVP, eye')
    frag.write(');')

    frag.write('}') # for numLights