【pytorch】Animatable 3D Gaussian+源码解读(一)_animatable gaussians-程序员宅基地

技术标签: 3d  3DGS  pytorch  windows  

概述

创新点:

  1. 多人场景 无遮挡处理
  2. 以3DGS进行表达

方法:
在这里插入图片描述

环境配置

基本和3DGS的配置差不多…

pip install torch==1.13.1+cu117 torchvision --index-url https://download.pytorch.org/whl/cu117
pip install hydra-core==1.3.2
pip install pytorch-lightning==2.1.2
pip install imageio
pip install ./submodules/diff-gaussian-rasterization
pip install ./submodules/simple-knn
pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch

子模块记得 git clone --recursive 多次踩坑…
tinycudann建议先git到本地…

数据准备

数据集下载

|---data
|   |---Gala
|   |---PeopleSnapshot
|   |---smpl

下载以后挨个解压,整理成格式

单个数据集结构:
在这里插入图片描述

0-6代表不同相机序号——同一视频的不同视角

每个文件夹均有对应的300帧+300mask。

在这里插入图片描述

bg定义了7个相机视角下的背景,目前还不知道有什么用,后续看看…

在这里插入图片描述

camera文件夹下是7个文本文件,描述相机参数。具体的表示在【pytorch】Animatable 3D Gaussian+源码解读(二)中分析。

。

pose文件夹下是300个文本文件,分别描述300帧中300个pose。

此外,Gala数据集中还有一个名为model的文件夹需要注意:

在这里插入图片描述
应该是定义了人体标准模板,包括网格、tpose、蒙皮权重等…

model_path (str) : The path to the folder that holds the vertices, tpose matrix, binding weights and indexes.

数据集具体是怎么利用的我们再结合代码来看看…

代码解读

先跟着debug走一遍流程… 然后再以面向对象的思路把握全局

Define Gaussian

首先定义场景表达元素:高斯球

train.py:

model = NeRFModel(opt)

nerf_model.py:

class NeRFModel(pl.LightningModule):
    def __init__(self, opt):
        super(NeRFModel, self).__init__()
        self.save_hyperparameters() # 储存init中输入的所有超参
        self.model = hydra.utils.instantiate(opt.deformer)

此处,opt.deformer==Gala模型:

Since the public dataset [1] contains few pose and shadow changes, we create a new dataset named GalaBasketball in order to show the
performance of our method under complex motion and dynamic
shadows
.

正文开始:

    """
    Attributes:
        parents (list[J]) : Indicate the parent joint for each joint, -1 for root joint.
        bone_count (int) : The count of joints including root joint, i.e. J.
        joints (torch.Tensor[J-1, 3]) : The translations of each joint relative to the parent joint, except for root joint.
        tpose_w2l_mats (torch.Tensor[J, 3]) : The skeleton to local transform matrixes for each joint.
    """

初始化函数:

        """
        Init joints and offset matrixes from files.

        Args:
            model_path (str) : The path to the folder that holds the vertices, tpose matrix, binding weights and indexes.
            num_players (int) : Number of players.  # 多人场景
        """

由此模型超参数为:

在这里插入图片描述

首先是一些人体基本操作——读取相关数据:

        model_file = os.path.join(model_path, "mesh.txt")
        vertices, normals, uvs, bone_weights, bone_indices = read_skinned_mesh_data(
            model_file) 

        tpose_file = os.path.join(model_path, "tpose.txt")
        tpose_positions, tpose_rotations, tpose_scales = read_bone_joints(
            tpose_file)

        tpose_mat_file = os.path.join(model_path, "tpose_mat.txt")
        tpose_w2l_mats = read_bone_joint_mats(tpose_mat_file)

        joint_parent_file = os.path.join(model_path, "jointParent.txt")
        self.joint_parent_idx = read_bone_parent_indices(joint_parent_file)

        self.bone_count = tpose_positions.shape[0]
        self.vertex_count = vertices.shape[0]

        print("mesh loaded:")
        print("total vertices: " + str(vertices.shape[0]))
        print("num of joints: " + str(self.bone_count))

read_skinned_mesh_data(“mesh.txt”)函数读取顶点、蒙皮权重、UV坐标;
read_bone_joints(“tpose.txt”)函数读取关节数据;
read_bone_joint_mats(“tpose_mat.txt”)读取world-to-local转化矩阵;
read_bone_parent_indices(“jointParent.txt”)读取关节父子关系。

多人扩维模板复制:

        self.register_buffer('v_template', vertices[None, ...].repeat(
            [self.num_players, 1, 1]))
        uvs = uvs * 2. - 1.
        self.register_buffer('uvs', uvs[None, ...].repeat(
            [self.num_players, 1, 1])) 
        bone_weights = torch.Tensor(
            np.load(os.path.join(model_path, "weights.npy")))[None, ...].repeat([self.num_players, 1, 1])
        self.register_buffer("bone_weights", bone_weights)

1.register_buffer:定义一组参数,该组参数的特别之处在于:模型训练时不会更新(即调用 optimizer.step() 后该组参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。
2.[None,…]:多一维

		
		# 关节位置申请加入训练
        self.J = nn.Parameter(
            tpose_positions[None, ...].repeat([self.num_players, 1, 1]))
        self.tpose_w2l_mats = tpose_w2l_mats

        # 顶点归一化
        minmax = [self.v_template[0].min(
            dim=0).values * 1.05,  self.v_template[0].max(dim=0).values * 1.05]
        self.register_buffer('normalized_vertices',
                             (self.v_template - minmax[0]) / (minmax[1] - minmax[0]))

        # distCUDA2 from simple_knn 计算点云中的每个点到与其最近的K个点的平均距离的平方
        dist2 = torch.clamp_min(
            distCUDA2(vertices.float().cuda()), 0.0000001)[..., None].repeat([1, 3])

然后开始处理要训练的高斯:

定义顶点偏移:

using unconstrained per-vertex displacement can easily cause the optimization process to diverge in dynamic scenes.Therefore, we also model a parameter field for vertex displacement. F.

	# x0 →  δx
    if use_point_displacement:
        self.displacements = nn.Parameter(
            torch.zeros_like(self.v_template))
    else:
        # 使用encoder
        self.displacementEncoder = DisplacementEncoder(
            encoder=encoder_type, num_players=num_players)

多种编码方式:uv encoder、hash encoder…

Since our animatable 3D Gaussian representation is initialized by a standard human body model, the centers of 3D Gaussians are uniformly distributed near the human surface. We only need to sample at fixed positions near the surface of the human body in the parameter fields. This allows for significant compression of the hash table for the hash encoding [36]. Thus, we choose the hash encoding to model our parameter field to reduce the time and storage consumption.

class DisplacementEncoder(nn.Module):
    def __init__(self, encoder="uv", num_players=1):
        super().__init__()
        self.num_players = num_players
        if encoder == "uv":
            self.input_channels = 2
            self.encoder = UVEncoder(
                3, num_players=num_players)
        elif encoder == "hash":
            self.input_channels = 3
            self.encoder = HashEncoder(
                3, num_players=num_players)
        elif encoder == "triplane":
            self.input_channels = 3
            self.encoder = TriPlaneEncoder(
                3, num_players=num_players)
        else:
            raise Exception("encoder does not exist!")

这里先选择hash编码,使用tcnn

class HashEncoder(nn.Module):
    def __init__(self, num_channels, num_players=1):
        super().__init__()
        self.networks = []
        self.num_players = num_players
        for i in range(num_players):
            self.networks.append(tcnn.NetworkWithInputEncoding(
                n_input_dims=3,
                n_output_dims=num_channels,
                encoding_config={
                    "otype": "HashGrid",
                    "n_levels": 16,
                    "n_features_per_level": 4,
                    "log2_hashmap_size": 17,
                    "base_resolution": 4,
                    "per_level_scale": 1.5,
                },
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "None",
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                }
            ))
        self.networks = nn.ModuleList(self.networks)

定义颜色、透明度、缩放、旋转:

rendering based on 3D Gaussian rasterization can only backpropagate the gradient to a finite number of Gaussians in a single iteration, which leads to a slow or even divergent optimization process for dynamic scenes. To address this issue, we suggest sampling spherical harmonic coefficients SH for each vertex from a continuous parameter field, which is able to affect all neighboring Gaussians in a single optimization.

Optionally, we provide UV-encoded spherical harmonic coefficients, allowing fast processing of custom human models with UV coordinate mappings. UV encoding potentially achieves higher reconstruction quality compared to hash encoding.

        n = self.v_template.shape[1] * num_players # 总顶点数
        # x0 → SH
        if use_point_color:
            self.shs_dc = nn.Parameter(torch.zeros(
                [n, 1, 3]))
            self.shs_rest = nn.Parameter(torch.zeros(
                [n, (max_sh_degree + 1) ** 2 - 1, 3]))
        else:
        	# 使用encoder
            self.shEncoder = SHEncoder(max_sh_degree=max_sh_degree,
                                       encoder=encoder_type, num_players=num_players)
        self.opacity = nn.Parameter(inverse_sigmoid(
            0.2 * torch.ones((n, 1), dtype=torch.float)))
        self.scales = nn.Parameter(
            torch.log(torch.sqrt(dist2)).repeat([num_players, 1]))
        rotations = torch.zeros([n, 4])
        rotations[:, 0] = 1
        self.rotations = nn.Parameter(rotations)

遮挡处理:x0, γ(t) → ao

We propose a time-dependent ambient occlusion module to address the issue of dynamic shadows in specific scenes.

        if enable_ambient_occlusion:
            self.aoEncoder = AOEncoder(
                encoder=encoder_type, max_freq=max_freq, num_players=num_players)
        self.register_buffer("aos", torch.ones_like(self.opacity))

在这里插入图片描述

we also employ hash encoding for the ambient occlusion ao, since shadows should be continuously modeled in space

class AOEncoder(nn.Module):
    def __init__(self, encoder="uv", num_players=1, max_freq=4):
        super().__init__()
        self.num_players = num_players
        self.max_freq = max_freq
        if encoder == "uv":
            self.input_channels = 2
            self.encoder = UVTimeEncoder(
                1, num_players=num_players, time_dim=max_freq*2 + 1)
        elif encoder == "hash":
            self.input_channels = 3
            self.encoder = HashTimeEncoder(
                1, num_players=num_players, time_dim=max_freq*2 + 1)
        else:
            raise Exception("encoder does not exist!")
class HashTimeEncoder(nn.Module):
    def __init__(self, num_channels, time_dim=9, num_players=1):
        super().__init__()
        self.networks = []
        self.time_nets = []
        self.num_players = num_players
        for i in range(num_players):
            self.networks.append(tcnn.Encoding(
                n_input_dims=3,
                encoding_config={
                    "otype": "HashGrid",
                    "n_levels": 16,
                    "n_features_per_level": 4,
                    "log2_hashmap_size": 19,
                    "base_resolution": 4,
                    "per_level_scale": 1.5,
                },
            ))
            self.time_nets.append(tcnn.Network(
                n_input_dims=self.networks[i].n_output_dims + time_dim,
                n_output_dims=num_channels,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "None",
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                }
            ))
        self.networks = nn.ModuleList(self.networks)
        self.time_nets = nn.ModuleList(self.time_nets)

至此,高斯球定义完成。
在这里插入图片描述
在这里插入图片描述

【pytorch】Animatable 3D Gaussian+源码解读(二)将进一步介绍数据集的处理细节。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/XYQjiayou/article/details/136271978

智能推荐

oracle 12c 集群安装后的检查_12c查看crs状态-程序员宅基地

文章浏览阅读1.6k次。安装配置gi、安装数据库软件、dbca建库见下:http://blog.csdn.net/kadwf123/article/details/784299611、检查集群节点及状态:[root@rac2 ~]# olsnodes -srac1 Activerac2 Activerac3 Activerac4 Active[root@rac2 ~]_12c查看crs状态

解决jupyter notebook无法找到虚拟环境的问题_jupyter没有pytorch环境-程序员宅基地

文章浏览阅读1.3w次,点赞45次,收藏99次。我个人用的是anaconda3的一个python集成环境,自带jupyter notebook,但在我打开jupyter notebook界面后,却找不到对应的虚拟环境,原来是jupyter notebook只是通用于下载anaconda时自带的环境,其他环境要想使用必须手动下载一些库:1.首先进入到自己创建的虚拟环境(pytorch是虚拟环境的名字)activate pytorch2.在该环境下下载这个库conda install ipykernelconda install nb__jupyter没有pytorch环境

国内安装scoop的保姆教程_scoop-cn-程序员宅基地

文章浏览阅读5.2k次,点赞19次,收藏28次。选择scoop纯属意外,也是无奈,因为电脑用户被锁了管理员权限,所有exe安装程序都无法安装,只可以用绿色软件,最后被我发现scoop,省去了到处下载XXX绿色版的烦恼,当然scoop里需要管理员权限的软件也跟我无缘了(譬如everything)。推荐添加dorado这个bucket镜像,里面很多中文软件,但是部分国外的软件下载地址在github,可能无法下载。以上两个是官方bucket的国内镜像,所有软件建议优先从这里下载。上面可以看到很多bucket以及软件数。如果官网登陆不了可以试一下以下方式。_scoop-cn

Element ui colorpicker在Vue中的使用_vue el-color-picker-程序员宅基地

文章浏览阅读4.5k次,点赞2次,收藏3次。首先要有一个color-picker组件 <el-color-picker v-model="headcolor"></el-color-picker>在data里面data() { return {headcolor: ’ #278add ’ //这里可以选择一个默认的颜色} }然后在你想要改变颜色的地方用v-bind绑定就好了,例如:这里的:sty..._vue el-color-picker

迅为iTOP-4412精英版之烧写内核移植后的镜像_exynos 4412 刷机-程序员宅基地

文章浏览阅读640次。基于芯片日益增长的问题,所以内核开发者们引入了新的方法,就是在内核中只保留函数,而数据则不包含,由用户(应用程序员)自己把数据按照规定的格式编写,并放在约定的地方,为了不占用过多的内存,还要求数据以根精简的方式编写。boot启动时,传参给内核,告诉内核设备树文件和kernel的位置,内核启动时根据地址去找到设备树文件,再利用专用的编译器去反编译dtb文件,将dtb还原成数据结构,以供驱动的函数去调用。firmware是三星的一个固件的设备信息,因为找不到固件,所以内核启动不成功。_exynos 4412 刷机

Linux系统配置jdk_linux配置jdk-程序员宅基地

文章浏览阅读2w次,点赞24次,收藏42次。Linux系统配置jdkLinux学习教程,Linux入门教程(超详细)_linux配置jdk

随便推点

matlab(4):特殊符号的输入_matlab微米怎么输入-程序员宅基地

文章浏览阅读3.3k次,点赞5次,收藏19次。xlabel('\delta');ylabel('AUC');具体符号的对照表参照下图:_matlab微米怎么输入

C语言程序设计-文件(打开与关闭、顺序、二进制读写)-程序员宅基地

文章浏览阅读119次。顺序读写指的是按照文件中数据的顺序进行读取或写入。对于文本文件,可以使用fgets、fputs、fscanf、fprintf等函数进行顺序读写。在C语言中,对文件的操作通常涉及文件的打开、读写以及关闭。文件的打开使用fopen函数,而关闭则使用fclose函数。在C语言中,可以使用fread和fwrite函数进行二进制读写。‍ Biaoge 于2024-03-09 23:51发布 阅读量:7 ️文章类型:【 C语言程序设计 】在C语言中,用于打开文件的函数是____,用于关闭文件的函数是____。

Touchdesigner自学笔记之三_touchdesigner怎么让一个模型跟着鼠标移动-程序员宅基地

文章浏览阅读3.4k次,点赞2次,收藏13次。跟随鼠标移动的粒子以grid(SOP)为partical(SOP)的资源模板,调整后连接【Geo组合+point spirit(MAT)】,在连接【feedback组合】适当调整。影响粒子动态的节点【metaball(SOP)+force(SOP)】添加mouse in(CHOP)鼠标位置到metaball的坐标,实现鼠标影响。..._touchdesigner怎么让一个模型跟着鼠标移动

【附源码】基于java的校园停车场管理系统的设计与实现61m0e9计算机毕设SSM_基于java技术的停车场管理系统实现与设计-程序员宅基地

文章浏览阅读178次。项目运行环境配置:Jdk1.8 + Tomcat7.0 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。项目技术:Springboot + mybatis + Maven +mysql5.7或8.0+html+css+js等等组成,B/S模式 + Maven管理等等。环境需要1.运行环境:最好是java jdk 1.8,我们在这个平台上运行的。其他版本理论上也可以。_基于java技术的停车场管理系统实现与设计

Android系统播放器MediaPlayer源码分析_android多媒体播放源码分析 时序图-程序员宅基地

文章浏览阅读3.5k次。前言对于MediaPlayer播放器的源码分析内容相对来说比较多,会从Java-&amp;amp;gt;Jni-&amp;amp;gt;C/C++慢慢分析,后面会慢慢更新。另外,博客只作为自己学习记录的一种方式,对于其他的不过多的评论。MediaPlayerDemopublic class MainActivity extends AppCompatActivity implements SurfaceHolder.Cal..._android多媒体播放源码分析 时序图

java 数据结构与算法 ——快速排序法-程序员宅基地

文章浏览阅读2.4k次,点赞41次,收藏13次。java 数据结构与算法 ——快速排序法_快速排序法