Skip to content

Commit

Permalink
Fix nn functional conv2d bug (#7892)
Browse files Browse the repository at this point in the history
* fix reduce_sum scalar check bug

* fix bug

* fix bug

* revert

* revert

* auto format by CI

* fix commnet

* auto format by CI

* fix clang check error

Co-authored-by: oneflow-ci-bot <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 2, 2022
1 parent aa749df commit 709f56a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
18 changes: 9 additions & 9 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -704,23 +704,23 @@

- name: "conv1d"
signature:
"Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List[1] stride,
Int32List[1] padding, Int32List[1] dilation, Int32 groups=1,
String channel_pos) => Conv1d"
"Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List[1] stride=1,
Int32List[1] padding=0, Int32List[1] dilation=1, Int32 groups=1,
String channel_pos=\"channels_first\") => Conv1d"
bind_python: True

- name: "conv2d"
signature:
"Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List[2] stride,
Int32List[2] padding, Int32List[2] dilation, Int32 groups=1,
String channel_pos) => Conv2d"
"Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List[2] stride=1,
Int32List[2] padding=0, Int32List[2] dilation=1, Int32 groups=1,
String channel_pos=\"channels_first\") => Conv2d"
bind_python: True

- name: "conv3d"
signature:
"Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List[3] stride,
Int32List[3] padding, Int32List[3] dilation, Int32 groups=1,
String channel_pos) => Conv3d"
"Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List[3] stride=1,
Int32List[3] padding=0, Int32List[3] dilation=1, Int32 groups=1,
String channel_pos=\"channels_first\") => Conv3d"
bind_python: True

- name: "fake_quantization"
Expand Down
8 changes: 8 additions & 0 deletions python/oneflow/test/modules/test_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,14 @@ def test_conv1d(test_case):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])

@autotest(n=3)
def test_nn_functional_conv1d(test_case):
device = random_device()
img = torch.ones((1, 3, 224), requires_grad=True).to(device)
kernel = torch.ones((3, 1, 3), requires_grad=True).to(device)
y = torch.nn.functional.conv1d(img, kernel, groups=3)
return y

@autotest()
def test_conv1d_with_random_data(test_case):
channels = random(1, 6)
Expand Down
8 changes: 8 additions & 0 deletions python/oneflow/test/modules/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,14 @@ def test_conv2d_default_init(test_case):
)
)

@autotest(n=3)
def test_nn_functional_conv2d(test_case):
device = random_device()
img = torch.ones((1, 3, 224, 224), requires_grad=True).to(device)
kernel = torch.ones((3, 1, 3, 3), requires_grad=True).to(device)
y = torch.nn.functional.conv2d(img, kernel, groups=3)
return y

def test_conv2d(test_case):
arg_dict = OrderedDict()
arg_dict["device"] = ["cuda", "cpu"]
Expand Down
8 changes: 8 additions & 0 deletions python/oneflow/test/modules/test_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@

@flow.unittest.skip_unless_1n1d()
class TestConv3DModule(flow.unittest.TestCase):
@autotest(n=3)
def test_nn_functional_conv3d(test_case):
device = random_device()
img = torch.ones((1, 3, 224, 224, 224), requires_grad=True).to(device)
kernel = torch.ones((6, 3, 3, 3, 3), requires_grad=True).to(device)
y = torch.nn.functional.conv3d(img, kernel)
return y

@autotest(n=10)
def test_conv3d_with_random_data(test_case):
channels = random(1, 6)
Expand Down

0 comments on commit 709f56a

Please sign in to comment.