import bpy
import nodeitems_utils
from nodeitems_utils import NodeCategory

import lnx
import lnx.material.lnx_nodes.lnx_nodes as lnx_nodes
# Import all nodes so that they register. Do not remove this import
# even if it looks unused
from lnx.material.lnx_nodes import *

if lnx.is_reload(__name__):
    lnx_nodes = lnx.reload_module(lnx_nodes)
    lnx.material.lnx_nodes = lnx.reload_module(lnx.material.lnx_nodes)
    from lnx.material.lnx_nodes import *
else:
    lnx.enable_reload(__name__)

registered_nodes = []


class MaterialNodeCategory(NodeCategory):
    @classmethod
    def poll(cls, context):
        return context.space_data.tree_type == 'ShaderNodeTree'


def register_nodes():
    global registered_nodes

    # Re-register all nodes for now..
    if len(registered_nodes) > 0:
        unregister_nodes()

    for n in lnx_nodes.nodes:
        registered_nodes.append(n)
        bpy.utils.register_class(n)

    node_categories = []

    for category in sorted(lnx_nodes.category_items):
        sorted_items = sorted(lnx_nodes.category_items[category], key=lambda item: item.nodetype)
        node_categories.append(
            MaterialNodeCategory('LnxMaterial' + category + 'Nodes', category, items=sorted_items)
        )

    nodeitems_utils.register_node_categories('LnxMaterialNodes', node_categories)


def unregister_nodes():
    global registered_nodes
    for n in registered_nodes:
        bpy.utils.unregister_class(n)
    registered_nodes = []
    nodeitems_utils.unregister_node_categories('LnxMaterialNodes')


def register():
    register_nodes()


def unregister():
    unregister_nodes()