-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
pbf_decoder.py
152 lines (135 loc) · 8.39 KB
/
pbf_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from numpy.typing import ArrayLike, NDArray
import numpy as np
from typing import Optional, Union
from ldpc.utils import IncorrectLength
from ldpc.decoder.common import InfoBitsNotSpecified, bsc_llr
from enum import Enum, auto
__all__ = ["PbfDecoder", "PbfVariant"]
class PbfVariant(Enum):
"""Enumerate variants of the probabilistic bit flipping algorithm"""
PPBF = auto() # based on an article by K. Le et al. named "A Probabilistic Parallel Bit-Flipping Decoder for Low-Density Parity-Check Codes"
class PbfDecoder:
"""Probabilistic bit flipping algorithm for LDPC decoding codes
The PBF algorithm is a probabilistic bit flipping algorithm for LDPC codes.
Variant is chosen by the decoder_variant parameter in the constructor.
See PbfVariant enum for available variants.
- PPBF - Implementation based on an article by K. Le et al. "A Probabilistic Parallel Bit-Flipping Decoder for
Low-Density Parity-Check Codes"
"""
def __init__(self, h: ArrayLike, max_iter: int, decoder_variant: PbfVariant, info_idx: Optional[NDArray[np.bool_]] = None,
**kwargs) -> None:
"""
:param h:the parity check code matrix of the code
:param decoder_variant: the variant of the decoder to use
:param max_iter: The maximal number of iterations for belief propagation algorithm
:param info_idx: a boolean array representing the indices of information bits in the code
"""
self.info_idx = info_idx
self.h: NDArray[np.int_] = h
self.m, self.n = h.shape
self.k = self.n - self.m if info_idx is None else np.sum(info_idx)
self.max_iter = max_iter
self.decoder_variant = decoder_variant
self.use_priors = kwargs.get("use_priors", 0)
if self.decoder_variant in {PbfVariant.PPBF}:
# for each check node i, which vnodes j are connected to it, referred to as the "N_i" set in by Ryan
self.check2vnode = {i: [j for j in range(self.n) if self.h[i, j] == 1] for i in range(self.m)}
# for each vnode j, which cnodes i are connected to it, referred to as the "M_j" set in by Ryan
self.vnode2check = {j: [i for i in range(self.m) if self.h[i, j] == 1] for j in range(self.n)}
self.vnode_degree = np.sum(h,axis=0)
self.cnode_degree = np.sum(h,axis=1)
# if probabilities vector was supplied use it, otherwise None
self.p_vector: NDArray[np.float_] = kwargs.get("p_vector", None)
if self.p_vector is not None:
self._verify_p_vector(self.p_vector)
def decode(self, channel_word: NDArray[np.int_], to_flip: int, prior: Optional[NDArray[np.int_]] = None,
p_vector: Optional[NDArray[np.float_]] = None) \
-> tuple[NDArray[np.int_], NDArray[np.float_], bool, int, NDArray[np.int_], NDArray[np.int_]]:
"""
decode a sequence received from the channel
:param p_vector: a vector of probabilities for flipping a bit, per energy level. If None, use the default
:param channel_word: a sequence of channel hard values
:param prior: an array of hard priors. -1 for no prior, 0 for 0 prior, 1 for 1 prior
:return: return a tuple (estimated_bits, llr, decode_success, no_iterations)
where:
- estimated_bits is a 1-d np array of hard bit estimates
- llr is a 1-d np array of soft bit estimates
- decode_success is a boolean flag stating of the estimated_bits form a valid code word
- no_iterations is the number of iterations of belief propagation before exiting the loop
- syndrome
- a measure of validity of each vnode, lower is better
"""
if len(channel_word) != self.n:
raise IncorrectLength("incorrect block size")
if prior is None or self.use_priors == 0:
prior = -1*np.ones(self.n, dtype=np.int_)
elif len(prior) != self.n:
raise IncorrectLength("incorrect prior size")
if p_vector is None:
if self.p_vector is None:
raise ValueError("p_vector must be supplied")
else:
p_vector = self.p_vector
else:
self._verify_p_vector(p_vector)
# initialize the vnodes to the channel word
vnode_values = channel_word.copy()
energy = np.zeros(self.n, dtype=np.int_)
for iteration in range(self.max_iter):
syndrome = self.h @ vnode_values % 2
if not syndrome.any(): # no errors detected, exit
break
# for each vnode how many equations are failed
# vnode_reliability = np.array((syndrome @ self.h)*max(self.vnode_degree) / self.vnode_degree).astype(np.int_)
vnode_reliability = syndrome @ self.h
flipped_vnodes = vnode_values ^ channel_word
prior_mask = np.array(prior != -1, dtype=np.int_)
prior_reliability = np.bitwise_xor(prior, vnode_values) * prior_mask
# for each vnode, the energy is the sum equations it failed to satisfy (vnode_reliability) plus 1 if it was flipped
# w.r.t the channel word (flipped_vnodes) plus 1 if it was flipped w.r.t the prior (prior_reliability)
energy = flipped_vnodes + vnode_reliability + prior_reliability
# The energy dictates the probability of flipping the current value in the next iteration. The probability is
# calculated by the p_vector, which is a vector of probabilities for flipping a bit, per energy level
# Draw a random bit according ta bernoulli distribution with the calculated probability per vnode energy.
# If the drawn bit is 1, flip the current value of the vnode
max_dv = max(self.vnode_degree)
bit_flip_p = np.array([p_vector[energy[i]]*max_dv / self.vnode_degree[i] for i in range(self.n)])
# bit_flip_p = np.array([p_vector[energy[i]] for i in range(self.n)])
# expected number of flipped bits is sum of bit_flip_p
if np.sum(bit_flip_p) > to_flip:
bit_flip_p = bit_flip_p * to_flip / np.sum(bit_flip_p)
flipped = np.random.binomial(1, bit_flip_p)
if np.sum(flipped) == 0:
# if no bit was flipped, flip the one with the highest energy
flip_bit = np.argwhere(energy == np.amax(energy)).flatten()
if len(flip_bit) > 1: # if there are several bits with the same reliability, choose one at random
flip_bit = np.random.choice(flip_bit)
flipped[flip_bit] = 1
vnode_values = np.bitwise_xor(vnode_values, flipped)
# we output also soft information (LLR) for each bit which can be used for Turbo decoding
# the LLR is calculated as the log of the ratio of probabilities of the bit being 1 or 0
# Since hte absolute value of the LLR indicates the confidence in the bit value, we use p_vector[energy] of each vnode
# as a crossover probability of a hypothetical BSC channel with the same LLR
p_regular = p_vector.copy()
p_regular[p_regular < p_regular[1]] = p_regular[1] # add regularization to avoid too much confidence in any bit
p_regular[-1] = p_regular[-2] # add regularization to avoid too little confidence in any bit (it actually explodes!!!)
channels = [bsc_llr(p) for p in p_regular]
llr = np.array([channels[energy[i]](vnode_values[i]) for i in range(self.n)])
return vnode_values, llr, not syndrome.any(), iteration, syndrome, energy
def info_bits(self, estimate: NDArray[np.int_]) -> NDArray[np.int_]:
"""extract information bearing bits from decoded estimate, assuming info bits indices were specified"""
if self.info_idx is not None:
return estimate[self.info_idx]
else:
raise InfoBitsNotSpecified("decoder cannot tell info bits")
def _verify_p_vector(self, p_vector: NDArray[np.float_]) -> None:
"""
verify that the p_vector is valid
:param p_vector: p_vector to verify
:return: None
"""
if len(p_vector) != max(self.vnode_degree) + 2 + self.use_priors:
raise IncorrectLength(f"incorrect length of p_vector, must be of length "
f"{max(self.vnode_degree) + 2 + self.use_priors}")
if np.any(p_vector < 0) or np.any(p_vector > 1):
raise ValueError("p_vector must be between 0 and 1")