8
8
__all__ = ['build_matrix' , 'get_ldpc_code_params' , 'ldpc_bp_decode' , 'write_ldpc_params' ,
9
9
'triang_ldpc_systematic_encode' ]
10
10
11
+ _llr_max = 500
11
12
12
13
def build_matrix (ldpc_code_params ):
13
14
"""
@@ -146,11 +147,12 @@ def sum_product_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs,
146
147
offset = cnode_deg_list [cnode_idx ]
147
148
vnode_list = cnode_adj_list [start_idx :start_idx + offset ]
148
149
vnode_list_msgs_tanh = np .tanh (vnode_msgs [vnode_list * max_vnode_deg +
149
- cnode_vnode_map [start_idx :start_idx + offset ]]/ 2.0 )
150
- msg_prod = np .prod (vnode_list_msgs_tanh )
150
+ cnode_vnode_map [start_idx :start_idx + offset ]] / 2.0 )
151
+ msg_prod = vnode_list_msgs_tanh .prod (0 )
151
152
152
153
# Compute messages on outgoing edges using the incoming message product
153
- cnode_msgs [start_idx :start_idx + offset ]= 2.0 * np .arctanh (msg_prod / vnode_list_msgs_tanh )
154
+ np .clip (2 * np .arctanh (msg_prod / vnode_list_msgs_tanh ),
155
+ - _llr_max , _llr_max , cnode_msgs [start_idx :start_idx + offset ])
154
156
155
157
156
158
def min_sum_update (cnode_idx , cnode_adj_list , cnode_deg_list , cnode_msgs ,
@@ -160,23 +162,21 @@ def min_sum_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs,
160
162
offset = cnode_deg_list [cnode_idx ]
161
163
vnode_list = cnode_adj_list [start_idx :start_idx + offset ]
162
164
vnode_list_msgs = vnode_msgs [vnode_list * max_vnode_deg + cnode_vnode_map [start_idx :start_idx + offset ]]
163
- vnode_list_msgs = np .ma .array (vnode_list_msgs , mask = False )
164
165
165
166
# Compute messages on outgoing edges using the incoming messages
166
167
for i in range (start_idx , start_idx + offset ):
167
- vnode_list_msgs .mask [i - start_idx ] = True
168
- cnode_msgs [i ] = np .prod (np .sign (vnode_list_msgs ))* np .min (np .abs (vnode_list_msgs ))
169
- vnode_list_msgs .mask [i - start_idx ] = False
168
+ vnode_list_msgs_excluded = vnode_list_msgs [np .arange (len (vnode_list_msgs )) != i - start_idx , :]
169
+ cnode_msgs [i ] = np .sign (vnode_list_msgs_excluded ).prod (0 ) * np .abs (vnode_list_msgs_excluded ).min (0 )
170
170
171
171
172
172
def ldpc_bp_decode (llr_vec , ldpc_code_params , decoder_algorithm , n_iters ):
173
173
"""
174
- LDPC Decoder using Belief Propagation (BP).
174
+ LDPC Decoder using Belief Propagation (BP). If several blocks are provided, they are all decoded at once.
175
175
176
176
Parameters
177
177
----------
178
- llr_vec : 1D array of float
179
- Received codeword LLR values from the channel. They will be clipped in [-38, 38 ].
178
+ llr_vec : 1D array of float with a length multiple of block length.
179
+ Received codeword LLR values from the channel. They will be clipped in [-500, 500 ].
180
180
181
181
ldpc_code_params : dictionary that at least contains these parameters
182
182
Parameters of the LDPC code as provided by `get_ldpc_code_params`:
@@ -203,15 +203,15 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
203
203
204
204
Returns
205
205
-------
206
- dec_word : 1D array of 0's and 1's
206
+ dec_word : 1D array or 2D array of 0's and 1's with one block per column.
207
207
The codeword after decoding.
208
208
209
- out_llrs : 1D array of float
209
+ out_llrs : 1D array or 2D array of float with one block per column.
210
210
LLR values corresponding to the decoded output.
211
211
"""
212
212
213
- # Clip LLRs into [-38, 38]
214
- llr_vec .clip (- 38 , 38 , llr_vec )
213
+ # Clip LLRs
214
+ llr_vec .clip (- _llr_max , _llr_max , llr_vec )
215
215
216
216
n_cnodes = ldpc_code_params ['n_cnodes' ]
217
217
n_vnodes = ldpc_code_params ['n_vnodes' ]
@@ -224,11 +224,13 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
224
224
cnode_deg_list = ldpc_code_params ['cnode_deg_list' ]
225
225
vnode_deg_list = ldpc_code_params ['vnode_deg_list' ]
226
226
227
- dec_word = np .zeros (n_vnodes , int )
228
- out_llrs = np .zeros (n_vnodes , int )
227
+ # Handling multi-block situations
228
+ n_blocks = llr_vec .size // n_vnodes
229
+ llr_vec = llr_vec .reshape (- 1 , n_blocks , order = 'F' )
229
230
230
- cnode_msgs = np .zeros (n_cnodes * max_cnode_deg )
231
- vnode_msgs = np .zeros (n_vnodes * max_vnode_deg )
231
+ dec_word = np .empty_like (llr_vec , bool )
232
+ out_llrs = np .empty_like (llr_vec )
233
+ cnode_msgs = np .empty ((n_cnodes * max_cnode_deg , n_blocks ))
232
234
233
235
if decoder_algorithm == 'SPA' :
234
236
check_node_update = sum_product_update
@@ -238,15 +240,12 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
238
240
raise NameError ('Please input a valid decoder_algorithm string (meanning "SPA" or "MSA").' )
239
241
240
242
# Initialize vnode messages with the LLR values received
241
- for vnode_idx in range (n_vnodes ):
242
- start_idx = vnode_idx * max_vnode_deg
243
- offset = vnode_deg_list [vnode_idx ]
244
- vnode_msgs [start_idx : start_idx + offset ] = llr_vec [vnode_idx ]
243
+ vnode_msgs = llr_vec .repeat (max_vnode_deg , 0 )
245
244
246
245
# Main loop of Belief Propagation (BP) decoding iterations
247
246
for iter_cnt in range (n_iters ):
248
247
249
- continue_flag = 0
248
+ continue_flag = False
250
249
251
250
# Check Node Update
252
251
for cnode_idx in range (n_cnodes ):
@@ -262,33 +261,31 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
262
261
offset = vnode_deg_list [vnode_idx ]
263
262
cnode_list = vnode_adj_list [start_idx :start_idx + offset ]
264
263
cnode_list_msgs = cnode_msgs [cnode_list * max_cnode_deg + vnode_cnode_map [start_idx :start_idx + offset ]]
265
- msg_sum = np .sum (cnode_list_msgs )
264
+ msg_sum = cnode_list_msgs .sum (0 )
266
265
267
- # Compute messages on outgoing edges using the incoming message sum (LLRs are clipped in [-38, 38])
266
+ # Compute messages on outgoing edges using the incoming message sum
268
267
vnode_msgs [start_idx :start_idx + offset ] = llr_vec [vnode_idx ] + msg_sum - cnode_list_msgs
269
268
270
269
# Update output LLRs and decoded word
271
270
out_llrs [vnode_idx ] = llr_vec [vnode_idx ] + msg_sum
272
- if out_llrs [vnode_idx ] > 0 :
273
- dec_word [vnode_idx ] = 0
274
- else :
275
- dec_word [vnode_idx ] = 1
271
+
272
+ np .signbit (out_llrs , out = dec_word )
276
273
277
274
# Compute early termination using parity check matrix
278
275
for cnode_idx in range (n_cnodes ):
279
- p_sum = 0
280
- for i in range ( cnode_deg_list [cnode_idx ]):
281
- p_sum ^= dec_word [cnode_adj_list [cnode_idx * max_cnode_deg + i ]]
276
+ start_idx = cnode_idx * max_cnode_deg
277
+ offset = cnode_deg_list [cnode_idx ]
278
+ parity_sum = np . logical_xor . reduce ( dec_word [cnode_adj_list [start_idx : start_idx + offset ]])
282
279
283
- if p_sum != 0 :
284
- continue_flag = 1
280
+ if parity_sum . any () :
281
+ continue_flag = True
285
282
break
286
283
287
284
# Stop iterations
288
- if continue_flag == 0 :
285
+ if not continue_flag :
289
286
break
290
287
291
- return dec_word , out_llrs
288
+ return dec_word . squeeze (). astype ( np . int8 ) , out_llrs . squeeze ()
292
289
293
290
294
291
def write_ldpc_params (parity_check_matrix , file_path ):
0 commit comments