Skip to content

Commit f919a02

Browse files
author
Alexandre Marquet
committed
Refactoring of log_bcjr implementation, to allow easy implementation of max_log_bcjr.
1 parent 6a099b4 commit f919a02

File tree

5 files changed

+356
-178
lines changed

5 files changed

+356
-178
lines changed

PyTurbo.pyx

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,29 @@ cdef extern from "viterbi.h":
3131
int get_S()
3232
int get_O()
3333

34-
cdef extern from "log_bcjr.cc":
34+
cdef extern from "log_bcjr_base.cc":
3535
pass
3636

37-
cdef extern from "log_bcjr.h":
38-
cppclass log_bcjr:
39-
log_bcjr(int, int, int, vector[int], vector[int]) except +
40-
@staticmethod
41-
float max_star(const float*, size_t)
37+
cdef extern from "log_bcjr_base.h":
38+
cppclass log_bcjr_base:
39+
log_bcjr_base(int, int, int, vector[int], vector[int]) except +
4240
void log_bcjr_algorithm(vector[float], vector[float], vector[float], vector[float])
4341
int get_I()
4442
int get_S()
4543
int get_O()
4644

45+
cdef extern from "log_bcjr.h":
46+
cppclass log_bcjr(log_bcjr_base):
47+
log_bcjr(int, int, int, vector[int], vector[int]) except +
48+
@staticmethod
49+
float max_star(const float*, size_t)
50+
51+
cdef extern from "max_log_bcjr.h":
52+
cppclass max_log_bcjr(log_bcjr_base):
53+
max_log_bcjr(int, int, int, vector[int], vector[int]) except +
54+
@staticmethod
55+
float max(const float*, size_t)
56+
4757
import numpy
4858

4959
cdef class PyViterbi:
@@ -92,3 +102,29 @@ cdef class PyLogBCJR:
92102
self.cpp_log_bcjr.log_bcjr_algorithm(A0, BK, _in, _out)
93103

94104
return numpy.asarray(_out, dtype=numpy.float32)
105+
106+
cdef class PyMaxLogBCJR:
107+
cdef int I, S, O
108+
cdef max_log_bcjr* cpp_max_log_bcjr
109+
110+
def __cinit__(self, int I, int S, int O, vector[int] NS, vector[int] OS):
111+
self.cpp_max_log_bcjr= new max_log_bcjr(I, S, O, NS, OS)
112+
self.I = self.cpp_max_log_bcjr.get_I()
113+
self.S = self.cpp_max_log_bcjr.get_S()
114+
self.O = self.cpp_max_log_bcjr.get_O()
115+
116+
def __dealloc__(self):
117+
del self.cpp_max_log_bcjr
118+
119+
@staticmethod
120+
def max(float[::1] vec):
121+
cdef size_t n_ele = vec.shape[0]
122+
123+
return max_log_bcjr.max(&vec[0], n_ele)
124+
125+
def log_bcjr_algorithm(self, vector[float] &A0, vector[float] &BK, vector[float] &_in):
126+
cdef vector[float] _out
127+
128+
self.cpp_max_log_bcjr.log_bcjr_algorithm(A0, BK, _in, _out)
129+
130+
return numpy.asarray(_out, dtype=numpy.float32)

log_bcjr.h

Lines changed: 18 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -22,46 +22,14 @@
2222
#ifndef INCLUDED_TURBO_LOG_BCJR_H
2323
#define INCLUDED_TURBO_LOG_BCJR_H
2424

25-
#include <algorithm>
26-
#include <limits>
27-
#include <vector>
28-
#include <stdexcept>
29-
#include <cmath>
30-
#include <cfloat>
31-
//#include <iostream>
25+
#include "log_bcjr_base.h"
3226

3327
/*!
3428
* \brief <+description+>
3529
*
3630
*/
37-
class log_bcjr
31+
class log_bcjr : public log_bcjr_base
3832
{
39-
private:
40-
//! The number of possible input sequences (e.g. 2 for binary codes).
41-
int d_I;
42-
//! The number of states in the trellis.
43-
int d_S;
44-
//! The number of possible output sequences.
45-
int d_O;
46-
/* Gives the next state ns of a branch defined by its
47-
* initial state s and its input symbol i : NS[s*I+i]=ns.
48-
*/
49-
std::vector<int> d_NS;
50-
/* Gives the output symbol of of a branch defined by its
51-
* initial state s and its input symbol i : OS[s*I+i]=os.
52-
*/
53-
std::vector<int> d_OS;
54-
/* Defined such that d_PS[s] contains all the previous states having a
55-
* branch with state s.
56-
* Such a previous state may appear multiple time if there are multiple
57-
* transistions between two states.
58-
*/
59-
std::vector<std::vector<int> > d_PS;
60-
//! Defined such that d_PI[s] contains all the inputs yielding to state s.
61-
std::vector<std::vector<int> > d_PI;
62-
//! Generates PS, PI and T tables.
63-
void generate_PS_PI();
64-
6533
public:
6634
//! Default constructor.
6735
log_bcjr();
@@ -78,7 +46,7 @@ class log_bcjr
7846
*/
7947
log_bcjr(int I, int S, int O,
8048
const std::vector<int> &NS,
81-
const std::vector<int> &OS);
49+
const std::vector<int> &OS) : log_bcjr_base(I, S, O, NS, OS) {};
8250

8351
//! Computes max* of two value.
8452
/*!
@@ -95,6 +63,8 @@ class log_bcjr
9563
{
9664
return std::max(A, B) + log(1.0 + exp(-fabs(A - B)));
9765
}
66+
// Override log_bcjr_base method
67+
float _max_star(float A, float B) { return max_star(A, B); }
9868

9969
//! Recursively compute max* of a vector.
10070
/*!
@@ -104,112 +74,21 @@ class log_bcjr
10474
* \param vec Input data.
10575
* \param n_ele number of elements in the vector.
10676
*
107-
* \return: max* of vec. If axis is None, the result is a scalar value.
77+
* \return: max* of vec.
10878
* If axis is given, the result is an array of dimension vec.ndim - 1.
10979
*/
110-
static float max_star(const float *vec, size_t n_ele);
111-
112-
//! Compute forward log metrics.
113-
/*!
114-
* From A_k(s) the forward log metric for state s at time index k, and
115-
* G_k(s,i) the log metric of the branch identified by state s and
116-
* input symbol i at index k, this function computes:
117-
*
118-
* A_k(s) = max*_{ s' \in [0 ; d_S[, i \in \tau(s',s) } G_{k-1}(s', i) + A_{k-1}(s')
119-
*
120-
* where \tau(s,s') regroups every input symbols that belongs to every
121-
* transitions between s and s'.
122-
*
123-
* Note: in practice, here, we only have the metrics of every possible
124-
* output symbols: G_k(o) with o \in [0 ; d_O[. The correspondance is
125-
* done through d_OS: G_k(s,i) = G_k(d_OS[s*I+i]).
126-
*
127-
* \param G Const reference to the log metrics vector (size: d_O*K).
128-
* \param A0 Const reference to the initial forward state metrics
129-
* (size: d_S).
130-
* \param A Reference to the forward metrics vector (will have a size
131-
* of d_S*(K+1) at the end of function execution).
132-
* \param K Number of observations.
133-
*/
134-
virtual void compute_fw_metrics(const std::vector<float> &G,
135-
const std::vector<float> &A0, std::vector<float> &A, size_t K);
136-
137-
//! Compute backward log metrics.
138-
/*!
139-
* From B_k(s) the backward log metric for state s at time index k, and
140-
* G_k(s,i) the log metric of the branch identified by state s and
141-
* input symbol i at index k, this function computes:
142-
*
143-
* B_k(s) = max*_{ s' \in [0 ; d_S[, i \in \tau(s,s') } G_k(s, i) + B_{k+1}(s').
144-
*
145-
* where \tau(s,s') regroups every input symbols that belongs to every
146-
* transitions between s and s'.
147-
*
148-
* Note: in practice, here, we only have the metrics of every possible
149-
* output symbols: G_k(o) with o \in [0 ; d_O[. The correspondance is
150-
* done through d_OS: G_k(s,i) = G_k(d_OS[s*I+i]).
151-
*
152-
* \param G Const reference to the log metrics vector (size: d_O*K).
153-
* \param BK Const reference to the final backward state metrics
154-
* (size: d_S).
155-
* \param B Reference to the backward metrics vector (will have a size
156-
* of d_S*(K+1) at the end of function execution).
157-
* \param K Number of observations.
158-
*/
159-
virtual void compute_bw_metrics(const std::vector<float> &G,
160-
const std::vector<float> &BK, std::vector<float> &B, size_t K);
161-
162-
//! Compute branch log a-posteriori probabilities.
163-
/*!
164-
* From A_k(s) the forward log metric for state s at time index k,
165-
* B_k(s) the backward log metric for state s at time index k, and
166-
* G_k(s,s') the branch log metric between states s' and s at time
167-
* index k, this function computes:
168-
*
169-
* APP_k(s,i) = B_{k+1}(NS(s,i)) + G_k(s,i) + A_k(s),
170-
*
171-
* where s' = NS(s,i) is the next state for transition with initial
172-
* state s and input symbol i (NS[s*I+i]).
173-
* Which is equivalent to log a-posteriori probabilites, up to an
174-
* additive constant.
175-
*
176-
* \param A Const reference to the forward metrics vector (size: d_S*(K+1)).
177-
* \param B Const reference to the backward metrics vector (size: d_S*(K+1)).
178-
* \param G Const reference to the branch log metrics vector (size: d_O*K).
179-
* \param K Number of observations.
180-
* \param out Reference to a posteriori branch log probabilities (will
181-
* have a size of d_S*d_I*K at the end of function execution).
182-
*
183-
*/
184-
virtual void compute_app(const std::vector<float> &A,
185-
const std::vector<float> &B, const std::vector<float> &G,
186-
size_t K, std::vector<float> &out);
187-
188-
/*! Actually computes logarithm of a-posteriori probabilities for a
189-
* given observation sequence.
190-
*
191-
* \param A0 Log of initial state probabilities of the encoder (size: d_S).
192-
* \param BK Log of final state probabilities of the encoder (size: d_S).
193-
* \param in Log of input branch metrics for the algorithm (size: d_O*k).
194-
* \param out A quantity equivalent to log a-posteriori probabilites, up
195-
* to an additive constant (will have a size of d_S*d_I*K at the end of
196-
* function execution).
197-
*/
198-
void log_bcjr_algorithm(const std::vector<float> &A0,
199-
const std::vector<float> &BK,
200-
const std::vector<float> &in,
201-
std::vector<float> &out);
202-
203-
//! Getter for d_I.
204-
int get_I() { return d_I; }
205-
//! Getter for d_S.
206-
int get_S() { return d_S; }
207-
//! Getter for d_O.
208-
int get_O() { return d_O; }
209-
//! Getter for d_NS.
210-
std::vector<int>& get_NS() { return d_NS; }
211-
//! Getter for d_OS.
212-
std::vector<int>& get_OS() { return d_OS; }
80+
static float max_star(const float *vec, size_t n_ele)
81+
{
82+
float ret_val = -std::numeric_limits<float>::max();
83+
84+
for (float *vec_it = (float*)vec ; vec_it < (vec + n_ele) ; ++vec_it) {
85+
ret_val = max_star(ret_val, *vec_it);
86+
}
87+
88+
return ret_val;
89+
}
90+
// Override log_bcjr_base method
91+
float _max_star(const float *vec, size_t n_ele) { return max_star(vec, n_ele); }
21392
};
21493

21594
#endif /* INCLUDED_TURBO_LOG_BCJR_H */

log_bcjr.cc renamed to log_bcjr_base.cc

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
* Boston, MA 02110-1301, USA.
1919
*/
2020

21-
#include "log_bcjr.h"
21+
#include "log_bcjr_base.h"
2222

23-
log_bcjr::log_bcjr(int I, int S, int O,
23+
log_bcjr_base::log_bcjr_base(int I, int S, int O,
2424
const std::vector<int> &NS,
2525
const std::vector<int> &OS)
2626
: d_I(I), d_S(S), d_O(O)
@@ -39,7 +39,7 @@ log_bcjr::log_bcjr(int I, int S, int O,
3939
}
4040

4141
void
42-
log_bcjr::generate_PS_PI()
42+
log_bcjr_base::generate_PS_PI()
4343
{
4444
d_PS.resize(d_S);
4545
d_PI.resize(d_S);
@@ -61,20 +61,8 @@ log_bcjr::generate_PS_PI()
6161
}
6262
}
6363

64-
float
65-
log_bcjr::max_star(const float *vec, size_t n_ele)
66-
{
67-
float ret_val = -std::numeric_limits<float>::max();
68-
69-
for (float *vec_it = (float*)vec ; vec_it < (vec + n_ele) ; ++vec_it) {
70-
ret_val = max_star(ret_val, *vec_it);
71-
}
72-
73-
return ret_val;
74-
}
75-
7664
void
77-
log_bcjr::compute_fw_metrics(const std::vector<float> &G,
65+
log_bcjr_base::compute_fw_metrics(const std::vector<float> &G,
7866
const std::vector<float> &A0, std::vector<float> &A, size_t K)
7967
{
8068
A.resize(d_S*(K+1), -std::numeric_limits<float>::max());
@@ -99,7 +87,7 @@ log_bcjr::compute_fw_metrics(const std::vector<float> &G,
9987

10088
//Loop
10189
for(size_t i=0 ; i<(d_PS[s]).size() ; ++i) {
102-
*A_curr = max_star(*A_curr,
90+
*A_curr = _max_star(*A_curr,
10391
A_prev[*PS_it] + G_k[d_OS[(*PS_it)*d_I + (*PI_it)]]);
10492

10593
//Update PS/PI iterators
@@ -115,14 +103,14 @@ log_bcjr::compute_fw_metrics(const std::vector<float> &G,
115103
A_prev += d_S;
116104

117105
//Metrics normalization
118-
norm_A = max_star(&(*(A_prev)), d_S);
106+
norm_A = _max_star(&(*(A_prev)), d_S);
119107
std::transform(A_prev, A_curr, A_prev,
120108
std::bind2nd(std::minus<float>(), norm_A));
121109
}
122110
}
123111

124112
void
125-
log_bcjr::compute_bw_metrics(const std::vector<float> &G,
113+
log_bcjr_base::compute_bw_metrics(const std::vector<float> &G,
126114
const std::vector<float> &BK, std::vector<float> &B, size_t K)
127115
{
128116
B.resize(d_S*(K+1), -std::numeric_limits<float>::max());
@@ -146,7 +134,7 @@ log_bcjr::compute_bw_metrics(const std::vector<float> &G,
146134
for(int s=0 ; s < d_S ; ++s) {
147135
//Loop
148136
for(size_t i=0 ; i < d_I ; ++i) {
149-
*B_curr = max_star(*B_curr,
137+
*B_curr = _max_star(*B_curr,
150138
B_next[(d_S-1)-*NS_it] + G_k[(d_O-1)-*OS_it]);
151139

152140
//Update PS/PI iterators
@@ -162,27 +150,19 @@ log_bcjr::compute_bw_metrics(const std::vector<float> &G,
162150
B_next += d_S;
163151

164152
//Metrics normalization
165-
norm_B = max_star(&(*B_curr)+1, d_S);
153+
norm_B = _max_star(&(*B_curr)+1, d_S);
166154
std::transform(B_next, B_curr, B_next,
167155
std::bind2nd(std::minus<float>(), norm_B));
168156
}
169157
}
170158

171159
void
172-
log_bcjr::compute_app(const std::vector<float> &A, const std::vector<float> &B,
160+
log_bcjr_base::compute_app(const std::vector<float> &A, const std::vector<float> &B,
173161
const std::vector<float> &G, size_t K, std::vector<float> &out)
174162
{
175163
std::vector<float>::const_iterator A_it = A.begin();
176164
std::vector<float>::const_iterator B_it = B.begin() + d_S;
177165

178-
//std::vector<float>::const_iterator it1 = A.begin();
179-
//std::vector<float>::const_iterator it2 = B.begin();
180-
//while(it1 != A.end()) {
181-
// std::cout << "A " << *it1 << "\tB " << *it2 << std::endl;
182-
// ++it1; ++it2;
183-
//}
184-
//std::cout << std::endl;
185-
186166
out.reserve(d_S*d_I*K);
187167

188168
for(std::vector<float>::const_iterator G_k = G.begin() ;
@@ -199,13 +179,11 @@ log_bcjr::compute_app(const std::vector<float> &A, const std::vector<float> &B,
199179

200180
//Update backward iterator
201181
B_it += d_S;
202-
203-
//std::cout << "---------------------------------------------------" << std::endl;
204182
}
205183
}
206184

207185
void
208-
log_bcjr::log_bcjr_algorithm(const std::vector<float> &A0,
186+
log_bcjr_base::log_bcjr_algorithm(const std::vector<float> &A0,
209187
const std::vector<float> &BK, const std::vector<float> &in,
210188
std::vector<float> &out)
211189
{
@@ -221,3 +199,4 @@ log_bcjr::log_bcjr_algorithm(const std::vector<float> &A0,
221199
//Compute branch APP
222200
compute_app(A, B, in, K, out);
223201
}
202+

0 commit comments

Comments
 (0)