Skip to content

Commit 4ad91ca

Browse files
author
Shrestha Malik
authored
Shrestha/upgrade r17 to master (#189)
Upgrade to master: Added bfloat tests Fixed version to correctly reflect v0.17.0-rc3
1 parent 4bce385 commit 4ad91ca

19 files changed

+1211
-144
lines changed

ngraph_bridge/version.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
// candidate such as v0.7.0-rc0
3333
// The code in master will always have the last released version number
3434
// with a suffix of '-master'
35-
#define NG_TF_VERSION_SUFFIX "-rc2"
35+
#define NG_TF_VERSION_SUFFIX "-rc3"
3636

3737
#define VERSION_STR_HELPER(x) #x
3838
#define VERSION_STR(x) VERSION_STR_HELPER(x)

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ if (NGRAPH_PLAIDML_ENABLE)
109109
endif()
110110

111111
add_subdirectory(python)
112+
add_subdirectory(python/bfloat16)
112113
add_subdirectory(model_level_tests)
113114

114115
if (DEFINED NGRAPH_TF_INSTALL_PREFIX)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2019 Nervana Systems Inc.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
cmake_minimum_required(VERSION 3.4)
15+
16+
file(GLOB files RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py")
17+
foreach(file ${files})
18+
execute_process(
19+
COMMAND ${CMAKE_COMMAND} -E create_symlink
20+
${CMAKE_CURRENT_SOURCE_DIR}/${file}
21+
${CMAKE_CURRENT_BINARY_DIR}/${file}
22+
)
23+
endforeach()
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# ==============================================================================
2+
# Copyright 2019 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""nGraph TensorFlow bridge Conv2d operation test
17+
18+
"""
19+
import pytest
20+
21+
import tensorflow as tf
22+
import numpy as np
23+
import os
24+
25+
import ngraph_bridge
26+
27+
#Test Ngraph Op Convolution, TF Op:conv2d
28+
# Implemented based on NNP's unit test TEST(test_assign_layout, convolution_special_case)
29+
30+
np.random.seed(5)
31+
32+
# Colvolution Op is placed on NNP and conerted to
33+
# bfloat16 only for the special case below, otherwise it falls
34+
# back to CPU for compute
35+
# Check to assure:
36+
# The input rank is 4-D
37+
# The stride is less than the filter size
38+
# The Window and Data dilation is {1,1}
39+
# Filter shape is allowed
40+
# If any fail, then we should place Op on CPU for compute
41+
42+
#Inputs
43+
N = 1
44+
C = 1
45+
H = 3
46+
W = 5
47+
48+
filter_size = np.random.rand(1, 1, 1, 2)
49+
input_size_nhwc = [N, H, W, C]
50+
input_size_nchw = [N, C, H, W]
51+
input_nhwc = tf.placeholder(tf.float32, shape=input_size_nhwc, name='x')
52+
input_nchw = tf.placeholder(tf.float32, shape=input_size_nchw, name='x')
53+
54+
n_np = np.random.rand(*input_size_nchw).astype('f')
55+
#Tensorflow supports only NHWC, change input shapes from NCHW to NHWC
56+
t_np = np.transpose(n_np, (0, 2, 3, 1))
57+
58+
59+
#TF graph
60+
def tf_model():
61+
stride_nhwc = [1, 2, 2, 1]
62+
x = tf.cast(input_nhwc, dtype=tf.bfloat16)
63+
filter_cast = tf.cast(filter_size, dtype=tf.bfloat16)
64+
m = tf.nn.conv2d(
65+
x, filter_cast, stride_nhwc, "SAME", data_format="NHWC", name="m")
66+
m = tf.cast(m, dtype=tf.float32)
67+
return m, input_nhwc
68+
69+
70+
#Ngraph graph
71+
def ng_model():
72+
stride_nchw = [1, 1, 2, 2]
73+
m = tf.nn.conv2d(
74+
input_nchw,
75+
filter_size,
76+
stride_nchw,
77+
"SAME",
78+
data_format="NCHW",
79+
name="m")
80+
return m, input_nchw
81+
82+
83+
config = tf.ConfigProto(
84+
allow_soft_placement=True,
85+
log_device_placement=False,
86+
inter_op_parallelism_threads=1)
87+
88+
89+
def test_conv2d():
90+
#Test 1: tf_model TF-native
91+
with tf.Session(config=config) as sess_tf:
92+
ngraph_bridge.disable()
93+
tf_out, input_data = tf_model()
94+
feed_dict = {input_data: t_np}
95+
tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict)
96+
97+
#Test 2: model2 with ngraph, NNP backend
98+
with tf.Session(config=config) as sess_ng:
99+
ngraph_bridge.enable()
100+
ngraph_bridge.update_config(config)
101+
os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1'
102+
ng_out, input_data = ng_model()
103+
feed_dict = {input_data: n_np}
104+
ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict)
105+
106+
assert np.allclose(
107+
np.transpose(tf_outval, (0, 3, 1, 2)), ng_outval, rtol=0, atol=1e-02)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# ==============================================================================
2+
# Copyright 2019 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""nGraph TensorFlow bridge Conv2d operation test
17+
18+
"""
19+
import pytest
20+
21+
import tensorflow as tf
22+
import numpy as np
23+
import os
24+
from tensorflow.python.ops import nn_ops
25+
import ngraph_bridge
26+
27+
#Tests Ngraph Op: ConvolutionBackpropFilters with data format NCHW
28+
#TF Op: conv2d_backprop_filter
29+
30+
np.random.seed(5)
31+
#Inputs
32+
N = 1
33+
H = 7
34+
W = 6
35+
C = 2
36+
37+
I = C
38+
O = 2
39+
filt_width = 3
40+
filt_height = 3
41+
42+
input_sizes_nchw = [N, C, H, W]
43+
input_sizes_nhwc = [N, H, W, C]
44+
filter_size_hwio = [filt_height, filt_width, I, O]
45+
out_backprop_valid = [1, 2, 3, 2]
46+
out_backprop_same = [1, 2, 4, 3]
47+
out_backprop_in_sizes = {"VALID": out_backprop_valid, "SAME": out_backprop_same}
48+
stride_nhwc = [1, 2, 2, 1]
49+
stride_nchw = [1, 1, 2, 2]
50+
51+
52+
#TF graph
53+
def tf_model(padding):
54+
t1 = tf.placeholder(dtype=tf.float32, shape=input_sizes_nhwc, name='t1')
55+
t2 = tf.constant(filter_size_hwio, dtype=tf.int32, name='t2')
56+
t3 = tf.placeholder(
57+
dtype=tf.float32, shape=out_backprop_in_sizes[padding], name='t3')
58+
59+
#reshaping the out_backprop to NHWC since TF does not support NCHW
60+
t3 = tf.transpose(t3, [0, 2, 3, 1])
61+
62+
#Cast dtype to bfloat16 for TF because NNP casts ng_model inputs
63+
t1 = tf.cast(t1, dtype=tf.bfloat16)
64+
t3 = tf.cast(t3, dtype=tf.bfloat16)
65+
66+
filt = nn_ops.conv2d_backprop_filter(
67+
t1, t2, t3, stride_nhwc, padding=padding, data_format='NHWC')
68+
69+
#Cast dtype back to float32 similar to NNP
70+
filt = tf.cast(filt, dtype=tf.float32)
71+
return filt, t1, t3
72+
73+
74+
#Ngraph Graph
75+
def ng_model(padding):
76+
t1 = tf.placeholder(dtype=tf.float32, shape=input_sizes_nchw, name='t1')
77+
t2 = tf.constant(filter_size_hwio, dtype=tf.int32, name='t2')
78+
t3 = tf.placeholder(
79+
dtype=tf.float32, shape=out_backprop_in_sizes[padding], name='t3')
80+
81+
filt = nn_ops.conv2d_backprop_filter(
82+
t1, t2, t3, stride_nchw, padding=padding, data_format='NCHW')
83+
return filt, t1, t3
84+
85+
86+
config = tf.ConfigProto(
87+
allow_soft_placement=True,
88+
log_device_placement=False,
89+
inter_op_parallelism_threads=1)
90+
91+
92+
@pytest.mark.parametrize("padding", ("VALID", "SAME"))
93+
def test_conv2dbackpropfilter_nchw(padding):
94+
n_np_inp = np.random.rand(*input_sizes_nchw).astype('f')
95+
n_np_out = np.random.rand(*out_backprop_in_sizes[padding]).astype('f')
96+
97+
#Reshape to NHWC for TF
98+
t_np_inp = np.transpose(n_np_inp, (0, 2, 3, 1))
99+
t_np_out = np.transpose(n_np_out, (0, 2, 3, 1))
100+
101+
with tf.Session(config=config) as sess_tf:
102+
ngraph_bridge.disable()
103+
tf_out, input_data, out_backprop = tf_model(padding)
104+
feed_dict = {input_data: t_np_inp, out_backprop: t_np_out}
105+
tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict)
106+
107+
#Test 2: model2 with ngraph, NNP backend
108+
with tf.Session(config=config) as sess_ng:
109+
ngraph_bridge.enable()
110+
ngraph_bridge.update_config(config)
111+
os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1'
112+
ng_out, input_data, out_backprop = ng_model(padding)
113+
feed_dict = {input_data: n_np_inp, out_backprop: n_np_out}
114+
ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict)
115+
116+
assert np.allclose(tf_outval, ng_outval, rtol=0, atol=1e-02)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# ==============================================================================
2+
# Copyright 2019 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""nGraph TensorFlow bridge Conv2d operation test
17+
18+
"""
19+
import pytest
20+
21+
import tensorflow as tf
22+
import numpy as np
23+
import os
24+
from tensorflow.python.ops import nn_ops
25+
import ngraph_bridge
26+
27+
#Tests Ngraph Op: ConvolutionBackpropFilters with data format NHWC
28+
#TF Op: conv2d_backprop_filter
29+
30+
np.random.seed(5)
31+
#Inputs
32+
N = 1
33+
H = 7
34+
W = 6
35+
C = 2
36+
37+
I = C
38+
O = 2
39+
filt_width = 3
40+
filt_height = 3
41+
42+
input_sizes_nhwc = [N, H, W, C]
43+
filter_size_hwio = [filt_height, filt_width, I, O]
44+
out_backprop_valid = [1, 3, 2, 2]
45+
out_backprop_same = [1, 4, 3, 2]
46+
out_backprop_in_sizes = {"VALID": out_backprop_valid, "SAME": out_backprop_same}
47+
stride = [1, 2, 2, 1]
48+
49+
50+
#TF graph
51+
def tf_model(padding):
52+
t1 = tf.placeholder(dtype=tf.float32, shape=input_sizes_nhwc, name='t1')
53+
t2 = tf.constant(filter_size_hwio, dtype=tf.int32, name='t1')
54+
t3 = tf.placeholder(
55+
dtype=tf.float32, shape=out_backprop_in_sizes[padding], name='t3')
56+
57+
#Cast dtype to bfloat16 for TF because NNP casts ng_model inputs
58+
t1 = tf.cast(t1, dtype=tf.bfloat16)
59+
t3 = tf.cast(t3, dtype=tf.bfloat16)
60+
61+
filt = nn_ops.conv2d_backprop_filter(
62+
t1, t2, t3, stride, padding=padding, data_format='NHWC')
63+
64+
#Cast dtype back to float32 similar to NNP
65+
filt = tf.cast(filt, dtype=tf.float32)
66+
return filt, t1, t3
67+
68+
69+
#Ngraph Graph
70+
def ng_model(padding):
71+
t1 = tf.placeholder(dtype=tf.float32, shape=input_sizes_nhwc, name='t1')
72+
t2 = tf.constant(filter_size_hwio, dtype=tf.int32, name='t1')
73+
t3 = tf.placeholder(
74+
dtype=tf.float32, shape=out_backprop_in_sizes[padding], name='t3')
75+
76+
filt = nn_ops.conv2d_backprop_filter(
77+
t1, t2, t3, stride, padding=padding, data_format='NHWC')
78+
return filt, t1, t3
79+
80+
81+
config = tf.ConfigProto(
82+
allow_soft_placement=True,
83+
log_device_placement=False,
84+
inter_op_parallelism_threads=1)
85+
86+
87+
@pytest.mark.parametrize("padding", ("VALID", "SAME"))
88+
def test_conv2dbackpropfilter_nhwc(padding):
89+
np_inp = np.random.rand(*input_sizes_nhwc).astype('f')
90+
np_out = np.random.rand(*out_backprop_in_sizes[padding]).astype('f')
91+
92+
with tf.Session(config=config) as sess_tf:
93+
ngraph_bridge.disable()
94+
tf_out, input_data, out_backprop = tf_model(padding)
95+
feed_dict = {input_data: np_inp, out_backprop: np_out}
96+
tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict)
97+
98+
#Test 2: model2 with ngraph, NNP backend
99+
with tf.Session(config=config) as sess_ng:
100+
ngraph_bridge.enable()
101+
ngraph_bridge.update_config(config)
102+
os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1'
103+
ng_out, input_data, out_backprop = ng_model(padding)
104+
feed_dict = {input_data: np_inp, out_backprop: np_out}
105+
ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict)
106+
107+
assert np.allclose(tf_outval, ng_outval, rtol=0, atol=1e-02)

0 commit comments

Comments
 (0)