# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

import torch
import nvdiffrast.torch as dr
import pytorch3d.ops

from . import util
from . import mesh

######################################################################################
# Computes the image gradient, useful for kd/ks smoothness losses
######################################################################################
def image_grad(buf, std=0.01):
    t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"), 
                          torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"),
                          indexing='ij')
    tc   = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...]
    tap  = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp')
    return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:]

######################################################################################
# Computes the avergage edge length of a mesh. 
# Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients
######################################################################################
def avg_edge_length(v_pos, t_pos_idx):
    e_pos_idx = mesh.compute_edges(t_pos_idx)
    edge_len  = util.length(v_pos[e_pos_idx[:, 0]] - v_pos[e_pos_idx[:, 1]])
    return torch.mean(edge_len)

######################################################################################
# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
# https://mgarland.org/class/geom04/material/smoothing.pdf
######################################################################################
def laplace_regularizer_const(v_pos, t_pos_idx):
    term = torch.zeros_like(v_pos)
    norm = torch.zeros_like(v_pos[..., 0:1])

    v0 = v_pos[t_pos_idx[:, 0], :]
    v1 = v_pos[t_pos_idx[:, 1], :]
    v2 = v_pos[t_pos_idx[:, 2], :]

    term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
    term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
    term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

    two = torch.ones_like(v0) * 2.0
    norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
    norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
    norm.scatter_add_(0, t_pos_idx[:, 2:3], two)

    term = term / torch.clamp(norm, min=1.0)

    return torch.mean(term**2)


def scale_dependent_relative_laplace_regularizer_const(v_pos, v_pos_abs, t_pos_idx):
    term = torch.zeros_like(v_pos)
    norm = torch.zeros_like(v_pos[..., 0:1])

    v0 = v_pos[t_pos_idx[:, 0], :]
    v1 = v_pos[t_pos_idx[:, 1], :]
    v2 = v_pos[t_pos_idx[:, 2], :]

    v0_abs = v_pos_abs[t_pos_idx[:, 0], :]
    v1_abs = v_pos_abs[t_pos_idx[:, 1], :]
    v2_abs = v_pos_abs[t_pos_idx[:, 2], :]

    eps = 1e-8
    deformable_dist = False
    if deformable_dist:
        raise NotImplementedError
    else:
        ## The original distance; does not account for the 
        v01_dist = ((v0_abs - v1_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()
        v12_dist = ((v1_abs - v2_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()
        v20_dist = ((v2_abs - v0_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()

    term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) / v01_dist + (v2 - v0) / v20_dist)
    term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) / v01_dist + (v2 - v1) / v12_dist)
    term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) / v20_dist + (v1 - v2) / v12_dist)

    return torch.mean(term**2)


def scale_dependent_laplace_regularizer_const(v_pos, t_pos_idx):
    term = torch.zeros_like(v_pos)
    norm = torch.zeros_like(v_pos[..., 0:1])

    v0 = v_pos[t_pos_idx[:, 0], :]
    v1 = v_pos[t_pos_idx[:, 1], :]
    v2 = v_pos[t_pos_idx[:, 2], :]

    eps = 1e-8
    v01_dist = ((v0 - v1).pow(2).sum(-1, keepdim=True) + eps).sqrt()
    v12_dist = ((v1 - v2).pow(2).sum(-1, keepdim=True) + eps).sqrt()
    v20_dist = ((v2 - v0).pow(2).sum(-1, keepdim=True) + eps).sqrt()
    
    stopgd = True
    if stopgd:
        v01_dist = v01_dist.detach()
        v12_dist = v12_dist.detach()
        v20_dist = v20_dist.detach()

    term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) / v01_dist + (v2 - v0) / v20_dist)
    term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) / v01_dist + (v2 - v1) / v12_dist)
    term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) / v20_dist + (v1 - v2) / v12_dist)

    return torch.mean(term**2)


def mesh_repulsion(v_pos, t_pos_idx):
    term = torch.zeros_like(v_pos)

    v0 = v_pos[t_pos_idx[:, 0], :]
    v1 = v_pos[t_pos_idx[:, 1], :]
    v2 = v_pos[t_pos_idx[:, 2], :]


    eps = 1e-8
    v01_dist = ((v0 - v1).pow(2).sum(-1, keepdim=True) + eps).sqrt()
    v12_dist = ((v1 - v2).pow(2).sum(-1, keepdim=True) + eps).sqrt()
    v20_dist = ((v2 - v0).pow(2).sum(-1, keepdim=True) + eps).sqrt()

    term.scatter_add_(0, t_pos_idx[:, 0:1], v01_dist)
    term.scatter_add_(0, t_pos_idx[:, 1:2], v12_dist)
    term.scatter_add_(0, t_pos_idx[:, 2:3], v20_dist)

    return term**2

def laplace_regularizer_const_adaptive(v_pos, t_pos_idx):
    term = torch.zeros_like(v_pos)
    norm = torch.zeros_like(v_pos[..., 0:1])

    v0 = v_pos[t_pos_idx[:, 0], :]
    v1 = v_pos[t_pos_idx[:, 1], :]
    v2 = v_pos[t_pos_idx[:, 2], :]

    term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
    term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
    term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

    two = torch.ones_like(v0) * 2.0
    norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
    norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
    norm.scatter_add_(0, t_pos_idx[:, 2:3], two)

    term = term / torch.clamp(norm, min=1.0)
    
    v_pos = v_pos.unsqueeze(0) * 64
    with torch.no_grad():
        scale = (pytorch3d.ops.knn.knn_points(v_pos, v_pos, K=2).dists[0, :, -1].detach()).sqrt().pow(1.5) ## K=2 because dist(self, self)=0
    dist = term.pow(2).mean(-1) ### since the vanilla one uses mean

    return torch.mean(dist * scale)

# def laplace_regularizer_const_sec_order(v_pos, t_pos_idx):
#     term = torch.zeros_like(v_pos)
#     norm = torch.zeros_like(v_pos[..., 0:1])

#     v0 = v_pos[t_pos_idx[:, 0], :]
#     v1 = v_pos[t_pos_idx[:, 1], :]
#     v2 = v_pos[t_pos_idx[:, 2], :]

#     term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
#     term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
#     term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

#     two = torch.ones_like(v0) * 2.0
#     norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
#     norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
#     norm.scatter_add_(0, t_pos_idx[:, 2:3], two)

#     term = term / torch.clamp(norm, min=1.0)

#     return torch.mean(term**2)

######################################################################################
# Smooth vertex normals
######################################################################################
def normal_consistency(v_pos, t_pos_idx):
    # Compute face normals
    v0 = v_pos[t_pos_idx[:, 0], :]
    v1 = v_pos[t_pos_idx[:, 1], :]
    v2 = v_pos[t_pos_idx[:, 2], :]

    face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))

    tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx)

    # Fetch normals for both faces sharind an edge
    n0 = face_normals[tris_per_edge[:, 0], :]
    n1 = face_normals[tris_per_edge[:, 1], :]

    # Compute error metric based on normal difference
    term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0)
    term = (1.0 - term) * 0.5

    return torch.mean(torch.abs(term))