Skip to content

Commit d462c57

Browse files
committed
Rewrite LDPC Bp decoder to hundle several block decoding at once
1 parent b0cf955 commit d462c57

File tree

3 files changed

+78
-76
lines changed

3 files changed

+78
-76
lines changed

commpy/channelcoding/ldpc.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
__all__ = ['build_matrix', 'get_ldpc_code_params', 'ldpc_bp_decode', 'write_ldpc_params',
99
'triang_ldpc_systematic_encode']
1010

11+
_llr_max = 500
1112

1213
def build_matrix(ldpc_code_params):
1314
"""
@@ -146,11 +147,12 @@ def sum_product_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs,
146147
offset = cnode_deg_list[cnode_idx]
147148
vnode_list = cnode_adj_list[start_idx:start_idx+offset]
148149
vnode_list_msgs_tanh = np.tanh(vnode_msgs[vnode_list*max_vnode_deg +
149-
cnode_vnode_map[start_idx:start_idx+offset]]/2.0)
150-
msg_prod = np.prod(vnode_list_msgs_tanh)
150+
cnode_vnode_map[start_idx:start_idx+offset]] / 2.0)
151+
msg_prod = vnode_list_msgs_tanh.prod(0)
151152

152153
# Compute messages on outgoing edges using the incoming message product
153-
cnode_msgs[start_idx:start_idx+offset]= 2.0*np.arctanh(msg_prod/vnode_list_msgs_tanh)
154+
np.clip(2 * np.arctanh(msg_prod / vnode_list_msgs_tanh),
155+
-_llr_max, _llr_max, cnode_msgs[start_idx:start_idx+offset])
154156

155157

156158
def min_sum_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs,
@@ -160,23 +162,21 @@ def min_sum_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs,
160162
offset = cnode_deg_list[cnode_idx]
161163
vnode_list = cnode_adj_list[start_idx:start_idx+offset]
162164
vnode_list_msgs = vnode_msgs[vnode_list*max_vnode_deg + cnode_vnode_map[start_idx:start_idx+offset]]
163-
vnode_list_msgs = np.ma.array(vnode_list_msgs, mask=False)
164165

165166
# Compute messages on outgoing edges using the incoming messages
166167
for i in range(start_idx, start_idx+offset):
167-
vnode_list_msgs.mask[i-start_idx] = True
168-
cnode_msgs[i] = np.prod(np.sign(vnode_list_msgs))*np.min(np.abs(vnode_list_msgs))
169-
vnode_list_msgs.mask[i-start_idx] = False
168+
vnode_list_msgs_excluded = vnode_list_msgs[np.arange(len(vnode_list_msgs)) != i - start_idx, :]
169+
cnode_msgs[i] = np.sign(vnode_list_msgs_excluded).prod(0) * np.abs(vnode_list_msgs_excluded).min(0)
170170

171171

172172
def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
173173
"""
174-
LDPC Decoder using Belief Propagation (BP).
174+
LDPC Decoder using Belief Propagation (BP). If several blocks are provided, they are all decoded at once.
175175
176176
Parameters
177177
----------
178-
llr_vec : 1D array of float
179-
Received codeword LLR values from the channel. They will be clipped in [-38, 38].
178+
llr_vec : 1D array of float with a length multiple of block length.
179+
Received codeword LLR values from the channel. They will be clipped in [-500, 500].
180180
181181
ldpc_code_params : dictionary that at least contains these parameters
182182
Parameters of the LDPC code as provided by `get_ldpc_code_params`:
@@ -203,15 +203,15 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
203203
204204
Returns
205205
-------
206-
dec_word : 1D array of 0's and 1's
206+
dec_word : 1D array or 2D array of 0's and 1's with one block per column.
207207
The codeword after decoding.
208208
209-
out_llrs : 1D array of float
209+
out_llrs : 1D array or 2D array of float with one block per column.
210210
LLR values corresponding to the decoded output.
211211
"""
212212

213-
# Clip LLRs into [-38, 38]
214-
llr_vec.clip(-38, 38, llr_vec)
213+
# Clip LLRs
214+
llr_vec.clip(-_llr_max, _llr_max, llr_vec)
215215

216216
n_cnodes = ldpc_code_params['n_cnodes']
217217
n_vnodes = ldpc_code_params['n_vnodes']
@@ -224,11 +224,13 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
224224
cnode_deg_list = ldpc_code_params['cnode_deg_list']
225225
vnode_deg_list = ldpc_code_params['vnode_deg_list']
226226

227-
dec_word = np.zeros(n_vnodes, int)
228-
out_llrs = np.zeros(n_vnodes, int)
227+
# Handling multi-block situations
228+
n_blocks = llr_vec.size // n_vnodes
229+
llr_vec = llr_vec.reshape(-1, n_blocks, order='F')
229230

230-
cnode_msgs = np.zeros(n_cnodes*max_cnode_deg)
231-
vnode_msgs = np.zeros(n_vnodes*max_vnode_deg)
231+
dec_word = np.empty_like(llr_vec, bool)
232+
out_llrs = np.empty_like(llr_vec)
233+
cnode_msgs = np.empty((n_cnodes * max_cnode_deg, n_blocks))
232234

233235
if decoder_algorithm == 'SPA':
234236
check_node_update = sum_product_update
@@ -238,15 +240,12 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
238240
raise NameError('Please input a valid decoder_algorithm string (meanning "SPA" or "MSA").')
239241

240242
# Initialize vnode messages with the LLR values received
241-
for vnode_idx in range(n_vnodes):
242-
start_idx = vnode_idx*max_vnode_deg
243-
offset = vnode_deg_list[vnode_idx]
244-
vnode_msgs[start_idx : start_idx+offset] = llr_vec[vnode_idx]
243+
vnode_msgs = llr_vec.repeat(max_vnode_deg, 0)
245244

246245
# Main loop of Belief Propagation (BP) decoding iterations
247246
for iter_cnt in range(n_iters):
248247

249-
continue_flag = 0
248+
continue_flag = False
250249

251250
# Check Node Update
252251
for cnode_idx in range(n_cnodes):
@@ -262,33 +261,31 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
262261
offset = vnode_deg_list[vnode_idx]
263262
cnode_list = vnode_adj_list[start_idx:start_idx+offset]
264263
cnode_list_msgs = cnode_msgs[cnode_list*max_cnode_deg + vnode_cnode_map[start_idx:start_idx+offset]]
265-
msg_sum = np.sum(cnode_list_msgs)
264+
msg_sum = cnode_list_msgs.sum(0)
266265

267-
# Compute messages on outgoing edges using the incoming message sum (LLRs are clipped in [-38, 38])
266+
# Compute messages on outgoing edges using the incoming message sum
268267
vnode_msgs[start_idx:start_idx+offset] = llr_vec[vnode_idx] + msg_sum - cnode_list_msgs
269268

270269
# Update output LLRs and decoded word
271270
out_llrs[vnode_idx] = llr_vec[vnode_idx] + msg_sum
272-
if out_llrs[vnode_idx] > 0:
273-
dec_word[vnode_idx] = 0
274-
else:
275-
dec_word[vnode_idx] = 1
271+
272+
np.signbit(out_llrs, out=dec_word)
276273

277274
# Compute early termination using parity check matrix
278275
for cnode_idx in range(n_cnodes):
279-
p_sum = 0
280-
for i in range(cnode_deg_list[cnode_idx]):
281-
p_sum ^= dec_word[cnode_adj_list[cnode_idx*max_cnode_deg + i]]
276+
start_idx = cnode_idx * max_cnode_deg
277+
offset = cnode_deg_list[cnode_idx]
278+
parity_sum = np.logical_xor.reduce(dec_word[cnode_adj_list[start_idx:start_idx + offset]])
282279

283-
if p_sum != 0:
284-
continue_flag = 1
280+
if parity_sum.any():
281+
continue_flag = True
285282
break
286283

287284
# Stop iterations
288-
if continue_flag == 0:
285+
if not continue_flag:
289286
break
290287

291-
return dec_word, out_llrs
288+
return dec_word.squeeze().astype(np.int8), out_llrs.squeeze()
292289

293290

294291
def write_ldpc_params(parity_check_matrix, file_path):

commpy/channelcoding/tests/test_ldpc.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tempfile import TemporaryDirectory
66

77
from nose.plugins.attrib import attr
8-
from numpy import array, sqrt, zeros, zeros_like, empty, int8
8+
from numpy import array, sqrt, zeros, zeros_like
99
from numpy.random import randn, choice
1010
from numpy.testing import assert_allclose, assert_equal, assert_raises
1111

@@ -29,39 +29,40 @@ def test_ldpc_bp_decode(self):
2929
ldpc_design_file = os.path.join(self.dir, '../designs/ldpc/gallager/96.33.964.txt')
3030
ldpc_code_params = get_ldpc_code_params(ldpc_design_file)
3131

32-
N = 96
33-
rate = 0.5
34-
Es = 1.0
35-
snr_list = array([2.0, 2.5])
36-
niters = 10000000
37-
tx_codeword = zeros(N, int)
38-
ldpcbp_iters = 100
32+
for n_blocks in (1, 2):
33+
N = 96 * n_blocks
34+
rate = 0.5
35+
Es = 1.0
36+
snr_list = array([2.0, 2.5])
37+
niters = 10000000
38+
tx_codeword = zeros(N, int)
39+
ldpcbp_iters = 100
3940

40-
fer_array_ref = array([200.0/1000, 200.0/2000])
41-
fer_array_test = zeros(len(snr_list))
41+
fer_array_ref = array([200.0/1000, 200.0/2000])
42+
fer_array_test = zeros(len(snr_list))
4243

43-
for idx, ebno in enumerate(snr_list):
44+
for idx, ebno in enumerate(snr_list):
4445

45-
noise_std = 1/sqrt((10**(ebno/10.0))*rate*2/Es)
46-
fer_cnt_bp = 0
46+
noise_std = 1/sqrt((10**(ebno/10.0))*rate*2/Es)
47+
fer_cnt_bp = 0
4748

48-
for iter_cnt in range(niters):
49+
for iter_cnt in range(niters):
4950

50-
awgn_array = noise_std * randn(N)
51-
rx_word = 1-(2*tx_codeword) + awgn_array
52-
rx_llrs = 2.0*rx_word/(noise_std**2)
51+
awgn_array = noise_std * randn(N)
52+
rx_word = 1-(2*tx_codeword) + awgn_array
53+
rx_llrs = 2.0*rx_word/(noise_std**2)
5354

54-
[dec_word, out_llrs] = ldpc_bp_decode(rx_llrs, ldpc_code_params, 'SPA', ldpcbp_iters)
55+
[dec_word, out_llrs] = ldpc_bp_decode(rx_llrs, ldpc_code_params, 'SPA', ldpcbp_iters)
5556

56-
num_bit_errors = hamming_dist(tx_codeword, dec_word)
57-
if num_bit_errors > 0:
58-
fer_cnt_bp += 1
57+
num_bit_errors = hamming_dist(tx_codeword, dec_word.reshape(-1))
58+
if num_bit_errors > 0:
59+
fer_cnt_bp += 1
5960

60-
if fer_cnt_bp >= 200:
61-
fer_array_test[idx] = float(fer_cnt_bp)/(iter_cnt+1)
62-
break
61+
if fer_cnt_bp >= 200:
62+
fer_array_test[idx] = float(fer_cnt_bp) / (iter_cnt + 1) / n_blocks
63+
break
6364

64-
assert_allclose(fer_array_test, fer_array_ref, rtol=.5, atol=0)
65+
assert_allclose(fer_array_test, fer_array_ref, rtol=.5, atol=0)
6566

6667
def test_write_ldpc_params(self):
6768
with TemporaryDirectory() as tmp_dir:
@@ -92,11 +93,13 @@ def test_triang_ldpc_systematic_encode(self):
9293
# Test decoding
9394
coded_bits[coded_bits == 1] = -1
9495
coded_bits[coded_bits == 0] = 1
95-
block_length = param['generator_matrix'].shape[1]
96-
nb_blocks = coded_bits.shape[1]
97-
decoded_bits = empty(block_length * nb_blocks, int8)
98-
for i in range(nb_blocks):
99-
decoded_bits[i * block_length:(i + 1) * block_length] = \
100-
ldpc_bp_decode(coded_bits[:, i], param, 'SPA', 10)[0][:block_length]
101-
assert_equal(decoded_bits[:len(message_bits)], message_bits,
102-
'Encoded and decoded messages do not match the initial bits without noise')
96+
MSA_decoded_bits = ldpc_bp_decode(coded_bits.reshape(-1, order='F').astype(float), param, 'MSA', 10)[0]
97+
SPA_decoded_bits = ldpc_bp_decode(coded_bits.reshape(-1, order='F').astype(float), param, 'SPA', 10)[0]
98+
99+
# Extract systematic part
100+
MSA_decoded_bits = MSA_decoded_bits[:720].reshape(-1, order='F')
101+
SPA_decoded_bits = SPA_decoded_bits[:720].reshape(-1, order='F')
102+
assert_equal(MSA_decoded_bits[:len(message_bits)], message_bits,
103+
'Encoded and decoded messages do not match the initial bits without noise (MS algorithm)')
104+
assert_equal(SPA_decoded_bits[:len(message_bits)], message_bits,
105+
'Encoded and decoded messages do not match the initial bits without noise (SP algorithm)')

commpy/links.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,13 @@ def __init__(self, modulate, channel, receive, num_bits_symbol, constellation, E
190190
self.decoder = decoder
191191

192192

193-
def idd_decoder(word_size, detector, decoder, n_it):
193+
def idd_decoder(detector, decoder, decision, n_it):
194194
"""
195195
Produce a decoder function that model the specified MIMO iterative detection and decoding (IDD) process.
196196
The returned function can be used as is to build a working LinkModel object.
197197
198198
Parameters
199199
----------
200-
word_size : positive integer
201-
Size of the words exchanged between the detector and the decoder.
202-
203200
detector : function with prototype detector(y, H, constellation, noise_var, a_priori) that return a LLRs array.
204201
y : 1D ndarray
205202
Received complex symbols (shape: num_receive_antennas x 1).
@@ -215,7 +212,12 @@ def idd_decoder(word_size, detector, decoder, n_it):
215212
a_priori : 1D ndarray of floats
216213
A priori as Log-Likelihood Ratios.
217214
218-
decoder : function with the same signature as detector.
215+
decoder : function with prototype(LLRs) that return a LLRs array.
216+
LLRs : 1D ndarray of floats
217+
A priori as Log-Likelihood Ratios.
218+
219+
decision : function wih prototype(LLRs) that return a binary 1D-array that model the decision to extract the
220+
information bits from the LLRs array.
219221
220222
n_it : positive integer
221223
Number or iteration during the IDD process.
@@ -238,7 +240,7 @@ def idd_decoder(word_size, detector, decoder, n_it):
238240
Number or bit send at each symbol vector.
239241
"""
240242
def decode(y, h, constellation, noise_var, a_priori, bits_per_send):
241-
a_priori_decoder = a_priori
243+
a_priori_decoder = a_priori.copy()
242244
nb_vect, nb_rx, nb_tx = h.shape
243245
for iteration in range(n_it):
244246
a_priori_detector = (decoder(a_priori_decoder) - a_priori_decoder)
@@ -247,6 +249,6 @@ def decode(y, h, constellation, noise_var, a_priori, bits_per_send):
247249
detector(y[i], h[i], constellation, noise_var,
248250
a_priori_detector[i * bits_per_send:(i + 1) * bits_per_send])
249251
a_priori_decoder -= a_priori_detector
250-
return np.signbit(a_priori_decoder + a_priori_detector)
252+
return decision(a_priori_decoder + a_priori_detector)
251253

252254
return decode

0 commit comments

Comments
 (0)