Skip to content

Commit

Permalink
Merge pull request #10 from torch-points3d/model-dataset
Browse files Browse the repository at this point in the history
working training
  • Loading branch information
clee-ai authored Jul 26, 2021
2 parents 88bcc70 + 5041c31 commit 49052ce
Show file tree
Hide file tree
Showing 32 changed files with 449 additions and 300 deletions.
2 changes: 0 additions & 2 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
defaults: # loads default configs
- dataset: ???
- optimizer: sgd
- scheduler: default
- model: ???
- training: default
- trainer: default
Expand Down
3 changes: 1 addition & 2 deletions conf/dataset/default.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# @package dataset
# cfg:
# torch data-loader specific arguments
cfg:
feature_dimension:
batch_size: ${training.batch_size}
num_workers: ${training.num_workers}
dataroot: data
Expand Down
5 changes: 4 additions & 1 deletion conf/dataset/segmentation/default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# @package dataset
defaults:
- /dataset/default
- /dataset/default

cfg:
num_classes:
2 changes: 2 additions & 0 deletions conf/dataset/segmentation/s3dis/s3dis1x1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ defaults:
- segmentation/default
_target_: torch_points3d.datasets.s3dis1x1.s3dis_data_module
cfg:
num_classes: 13
feature_dimension: 6 # todo: able to calculate this dynamically
fold: 5
8 changes: 5 additions & 3 deletions conf/model/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# @package model
defaults:
- /optimizer: sgd
- /scheduler:
# By default we turn off recursive instantiation, allowing the user to instantiate themselves at the appropriate times.
_recursive_: false

_target_: torch_points3d.models.base_model.PointCloudBaseModel
optimizer: ${optimizer}
scheduler: ${scheduler}
_target_: torch_points3d.tasks.base_model.PointCloudBaseModule
14 changes: 14 additions & 0 deletions conf/model/segmentation/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# @package model
defaults:
- /model/default

model:
_recursive_: false
_target_: torch_points3d.models.segmentation.base_model.SegmentationBaseModel
num_classes: ${dataset.cfg.num_classes}
criterion:
_target_: torch.nn.NLLLoss

backbone:
input_nc: ${dataset.cfg.feature_dimension}
architecture: unet
74 changes: 0 additions & 74 deletions conf/model/segmentation/sparseconv3d.yaml

This file was deleted.

10 changes: 10 additions & 0 deletions conf/model/segmentation/sparseconv3d/Res16UNet34.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# @package model
defaults:
- /model/segmentation/ResUNet32

model:
backbone:
down_conv:
N: [ 0, 2, 3, 4, 6 ]
up_conv:
N: [ 1, 1, 1, 1, 1 ]
41 changes: 41 additions & 0 deletions conf/model/segmentation/sparseconv3d/ResUNet32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# @package model
defaults:
- /model/segmentation/default

model:
backbone:
_target_: torch_points3d.applications.sparseconv3d.SparseConv3d
backend: torchsparse

config:
define_constants:
in_feat: 32
block: ResBlock # Can be any of the blocks in modules/MinkowskiEngine/api_modules.py
down_conv:
module_name: ResNetDown
block: block
N: [ 0, 1, 2, 2, 3 ]
down_conv_nn:
[
[ FEAT, in_feat ],
[ in_feat, in_feat ],
[ in_feat, 2*in_feat ],
[ 2*in_feat, 4*in_feat ],
[ 4*in_feat, 8*in_feat ],
]
kernel_size: 3
stride: [ 1, 2, 2, 2, 2 ]
up_conv:
block: block
module_name: ResNetUp
N: [ 1, 1, 1, 1, 0 ]
up_conv_nn:
[
[ 8*in_feat, 4*in_feat ],
[ 4*in_feat + 4*in_feat, 4*in_feat ],
[ 4*in_feat + 2*in_feat, 3*in_feat ],
[ 3*in_feat + in_feat, 3*in_feat ],
[ 3*in_feat + in_feat, 3*in_feat ],
]
kernel_size: 3
stride: [ 2, 2, 2, 2, 1 ]
Empty file removed conf/scheduler/default.yaml
Empty file.
2 changes: 1 addition & 1 deletion conf/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ gradient_clip_val: 0.0
process_position: 0
num_nodes: 1
num_processes: 1
gpus: null
gpus: 1
auto_select_gpus: False
tpu_cores: null
log_gpu_memory: null
Expand Down
10 changes: 10 additions & 0 deletions torch_points3d/applications/base_architectures/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch.nn as nn

# This is the base class of the Backbone/API models, that provides useful functions to use within these models
class BaseModel(nn.Module):

# When creating new tensors (esp sparsetensors), we need to be able to send them to the correct device
# as ptl won't do it automatically
@property
def device(self):
return next(self.parameters()).device
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BatchNorm1d as BN,
Dropout,
)
from torch_points3d.applications.base_architectures.base_model import BaseModel
from omegaconf.listconfig import ListConfig
from omegaconf.dictconfig import DictConfig
import logging
Expand Down Expand Up @@ -48,7 +49,7 @@ def get_module(self, flow):
############################# UNWRAPPED UNET BASE ###################################


class UnwrappedUnetBasedModel(nn.Module):
class UnwrappedUnetBasedModel(BaseModel):
"""Create a Unet unwrapped generator"""

def _save_sampling_and_search(self, down_conv):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys

from torch_points3d.core.common_modules import Seq, Identity
import torch_points3d.modules.SparseConv3d.nn as snn
import torch_points3d.applications.modules.SparseConv3d.nn as snn


class ResBlock(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
stride=stride,
dilation=dilation,
bias=bias,
transpose=True,
transposed=True,
)


Expand All @@ -62,8 +62,8 @@ def __init__(self, inplace=True):
super().__init__(inplace=inplace)


def cat(*args, dim=1):
return TS.cat(args, dim)
def cat(*args):
return TS.cat(args)


def SparseTensor(feats, coordinates, batch, device=torch.device("cpu")):
Expand Down
10 changes: 5 additions & 5 deletions torch_points3d/applications/sparseconv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from torch_geometric.data import Batch

from torch_points3d.applications.modelfactory import ModelFactory
import torch_points3d.modules.SparseConv3d as sp3d
from torch_points3d.modules.SparseConv3d.modules import *
import torch_points3d.applications.modules.SparseConv3d as sp3d
from torch_points3d.applications.modules.SparseConv3d.modules import *

# from torch_points3d.core.base_conv.message_passing import *
# from torch_points3d.core.base_conv.partial_dense import *
from torch_points3d.models.base_architectures.unet import UnwrappedUnetBasedModel
from torch_points3d.applications.base_architectures.unet import UnwrappedUnetBasedModel
from torch_points3d.core.common_modules.base_modules import MLP

from .utils import extract_output_nc
Expand Down Expand Up @@ -135,7 +135,7 @@ def _set_input(self, data):
data:
a dictionary that contains the data itself and its metadata information.
"""
self.input = sp3d.nn.SparseTensor(data.x, data.coords, data.batch)
self.input = sp3d.nn.SparseTensor(data.x, data.coords, data.batch, self.device)
if data.pos is not None:
self.xyz = data.pos
else:
Expand Down Expand Up @@ -163,7 +163,7 @@ def forward(self, data, *args, **kwargs):
for i in range(len(self.down_modules)):
data = self.down_modules[i](data)

out = Batch(x=data.F, batch=data.C[:, 0].long().to(data.F.device))
out = Batch(x=data.F, batch=data.C[:, 0].long())
if not isinstance(self.inner_modules[0], Identity):
out = self.inner_modules[0](out)

Expand Down
Loading

0 comments on commit 49052ce

Please sign in to comment.