Add hungarian file
[trackerpp.git] / src / hungarian.cpp
1 #include <iostream>
2 #include <algorithm>
3 #include "hungarian.h"
4
5 using namespace std;
6 using namespace Eigen;
7
8 class Hungary;
9
10 int step_one(Hungary& state);
11 int step_two(Hungary& state);
12 int step_three(Hungary& state);
13 int step_four(Hungary& state);
14 int step_five(Hungary& state);
15 int step_six(Hungary& state);
16
17 class Hungary
18 {
19 public:
20     Hungary(const MatrixXi& cost){
21         this->cost = cost;
22         this->row_uncovered = VectorXi::Ones(cost.rows());
23         this->col_uncovered = VectorXi::Ones(cost.cols());
24         Z0_r = 0;
25         Z0_c = 0;
26         this->mark = MatrixXi::Zero(cost.rows(), cost.cols());
27     };
28     ~Hungary(){};
29
30     void print(){
31         cout <<"Cost:\n" << cost << "\nMark:\n" << mark << endl;
32         cout <<"Row uncovered :[" << row_uncovered.transpose() << "], col uncovered: [" << col_uncovered.transpose() << "]" << endl;
33     }
34
35 public:
36     void clearCovers(){
37         this->row_uncovered.setOnes();
38         this->col_uncovered.setOnes();
39     }
40     
41 public :
42     // The cost matrix, cols() should > rows(). If not, transpose it first.
43     MatrixXi cost;
44     // The mark matrix. the value of the element is 0 (None), 1 (means the starred), 2 (primed)
45     MatrixXi mark;
46     // the covered state of the rows (0: covered; 1: nocovered)
47     VectorXi row_uncovered;
48     // the covered state of the column (0: covered; 1: nocovered)
49     VectorXi col_uncovered;
50     // the position of Z0, used in step 5
51     int Z0_r;
52     int Z0_c;
53     int min_value;
54 };
55
56 int linear_sum_assignment(const MatrixXi& cost_matrix, VectorXi& row_ind, VectorXi& col_ind)
57 {
58     // The algorithm expects more columns than rows in the cost matrix.
59     MatrixXi correct_matrix = cost_matrix;
60     bool is_transposed = false;
61     if (cost_matrix.cols() < cost_matrix.rows()){
62         cout << "cols < rows, transpose." << endl;
63         correct_matrix = cost_matrix.transpose();
64         is_transposed = true;
65     }
66     Hungary state(correct_matrix);
67     cout << "Cost Matrix: \n" << correct_matrix << endl;;
68
69     bool done = false;
70     int next = 1;
71     int pre_step = 1;
72     while (!done){
73         pre_step = next;
74         switch(next)
75         {
76             case 1:
77                 next = step_one(state);
78                 break;
79             case 2:
80                 next = step_two(state);
81                 break;
82             case 3:
83                 next = step_three(state);
84                 break;
85             case 4:
86                 next = step_four(state);
87                 break;
88             case 5:
89                 next = step_five(state);
90                 break;
91             case 6:
92                 next = step_six(state);
93                 break;
94             case 7:
95                 done = true;
96                 break;
97         }
98         cout << "After step: " << pre_step << endl;
99         state.print();
100     }
101     MatrixXi mark = state.mark;
102     if (is_transposed){
103         mark = state.mark.transpose();
104     }
105     cout << "Done" << endl << mark << endl;
106     int sum = (mark.array() * cost_matrix.array()).sum();
107     cout << "Sum :" << sum << endl;
108     
109
110     // return the array of the positions
111     row_ind = VectorXi::Zero(mark.cols());
112     col_ind = VectorXi::Zero(mark.cols());
113     VectorXi::Index max_index;
114     int point = 0;
115     for (int i = 0; i < mark.rows(); i++){
116         if (mark.row(i).maxCoeff(&max_index)){
117             row_ind(point) = i;
118             col_ind(point) = max_index;
119             point++;
120         }
121     }
122     return sum;
123 }
124
125
126 // Step 1:
127 // For each row of the matrix, find the smallest element and subtract
128 // it from every element in its row. Go to Step 2.
129 int step_one(Hungary& state)
130 {
131     state.cost.colwise() -= state.cost.rowwise().minCoeff();
132     return 2;
133 }
134
135 // Step 2:
136 // Find a zero (Z) in the resulting matrix. If there is no starred zero in its row or column,
137 // star Z. Repeat for each elements in the matrix. Go to Step 3.
138 int step_two(Hungary& state)
139 {
140     for (int i = 0; i < state.cost.rows(); i++){
141         for (int j = 0; j < state.cost.cols(); j++)
142             if (state.row_uncovered(i) == 1 && state.col_uncovered(j) == 1 && state.cost(i, j) == 0){
143                 state.mark(i, j) = 1;
144                 state.col_uncovered(j) = 0;
145                 state.row_uncovered(i) = 0;
146             }
147     }
148     state.clearCovers();
149     return 3;
150 }
151
152 // Step 3:
153 // Cover each column containing a starred zero. If K columns are covered,
154 // the starred zeros describe a complete set of unique assignments. In this case,
155 // Go to DONE, otherwise, Go to Step 4.
156 int step_three(Hungary& state)
157 {
158     MatrixXi m1 = (state.mark.array() == 1).select(state.mark, 0);
159     state.col_uncovered = m1.colwise().any();
160     //for (int i = 0; i < state.col_uncovered.size(); i++)
161     //    if (m1.col(i).any())
162     //        state.col_uncovered(i) = 0;
163         //state.col_uncovered(i) = (state.col_uncovered(i) == 1) ? 0 : 1;
164     state.col_uncovered = (state.col_uncovered.array() == 1).select(0, VectorXi::Ones(state.col_uncovered.size()));
165     int next = 4;
166     if (m1.sum() == m1.rows())
167         next = 7;
168     return next;
169 }
170
171 // Step 4 :
172 // Find a noncovered zero and prime it. If there is no starred zero in the row
173 // constaining this primed zero, Go to Step 5. Otherwise, cover this row and uncover the column
174 // containing the starred zero. Continue in this manner until there are no uncovered zeros left.
175 // Save the smallest uncovered value and Go to Step 6.
176 //
177 //
178 // find_uncovered_zero 
179 //   - find zero or minimal value in the uncovered elements. 
180 //   Return 0 if zero found, or the minimal element.
181 int find_uncovered_zero(Hungary& state, int& row, int& col)
182 {
183     MatrixXi cover = state.row_uncovered * state.col_uncovered.transpose();
184     int min_value = 0;
185     for (int i = 0; i < cover.rows(); i++)
186         for (int j = 0; j < cover.cols(); j++){
187             if (cover(i, j) != 0){
188                 if (state.cost(i, j) == 0){
189                     row = i;
190                     col = j;
191                     return 0;
192                 } else {
193                     min_value = (min_value > 0) ? min(min_value, state.cost(i, j)) : state.cost(i, j);
194                 }
195             }
196         }
197     return min_value;
198 }
199
200 int step_four(Hungary& state)
201 {
202     while(true)
203     {
204         int rr = 0;
205         int cc = 0;
206         // find a nocovered zero
207         int min = find_uncovered_zero(state, rr, cc);
208         if(min == 0){
209             // prime it
210             state.mark(rr, cc) = 2;
211             // If no starred zero in its row
212             if ((state.mark.row(rr).array() == 1).any() == 0){
213                 state.Z0_r = rr;
214                 state.Z0_c = cc;
215                 return 5;
216             } else {
217                 // Otherwise, cover this row
218                 state.row_uncovered(rr) = 0;
219                 // uncover the column
220                 for (int j = 0; j < state.mark.cols(); j++){
221                     if (state.mark(rr, j) == 1){
222                         state.col_uncovered(j) = 1;
223                     }
224                 }
225             }
226         } else {
227             state.min_value = min;
228             return 6;
229         }
230     }
231 }
232
233 // Step 5:
234 //   Construct a seriee of alternating primed and starred zeros as follows,
235 //   Let Z0 represent the uncovered primed zero found in Step 4.
236 //   Let Z1 denote the starred zero in the column of Z0 (if any).
237 //   Let Z2 denote the primed zero in the row of Z1 (there will always be one).
238 //   Continue until the series terminates at a primed zero that has no starred
239 //   zero in its column. Unstar each starred zero of the series, star each primed 
240 //   zero of the series, erase all primes and uncover every line in the matrix. Return to Step 3
241 int step_five(Hungary& state)
242 {
243     MatrixXi path = MatrixXi::Zero(state.cost.rows() + state.cost.cols(), 2);
244     int count = 0;
245     path(count, 0) = state.Z0_r;
246     path(count, 1) = state.Z0_c;
247
248     MatrixXi m1 = (state.mark.array() == 1).select(state.mark, 0);
249     MatrixXi m2 = (state.mark.array() == 2).select(state.mark, 0);
250     MatrixXi::Index index;
251     while(true){
252         // Z1, find the starred zero in the column of Z0
253         int max = m1.col(path(count, 1)).maxCoeff(&index);
254         if (max != 1)
255             break;
256         else {
257             count++;
258             path(count, 0) = index;
259             path(count, 1) = path(count - 1, 1);
260         }
261
262         // Z2, find the primed zero in the row of Z1;
263         max = m2.row(path(count, 0)).maxCoeff(&index);
264         if (max != 2){
265             cout << "Error, should never reach here" << endl;
266         } else {
267             count++;
268             path(count, 0) = path(count - 1, 0);
269             path(count, 1) = index;
270         }
271     }
272
273     // terminates
274     for (int i = 0; i < count + 1; i++){
275         if (state.mark(path(i, 0), path(i, 1)) == 1)
276             state.mark(path(i, 0), path(i, 1)) = 0;
277         else
278             state.mark(path(i, 0), path(i, 1)) = 1;
279     }
280
281     state.clearCovers();
282
283     // erase all prime markings
284     state.mark = (state.mark.array() == 2).select(0, state.mark);
285     return 3;
286 }
287
288 // Step 6:
289 //    Add the value found in Step 4 to every element of each covered row, and substract it 
290 //    from every element of each uncovered column.
291 //    Return to Step 4 without altering any stars, primes, or covered lines.
292 //
293 int step_six(Hungary& state)
294 {
295     for (int i = 0; i < state.mark.rows(); i++)
296         for (int j = 0; j < state.mark.cols(); j++){
297             if (state.row_uncovered(i) == 0)
298                 state.cost(i, j) += state.min_value;
299             if (state.col_uncovered(j) == 1)
300                 state.cost(i, j) -= state.min_value;
301         }
302     return 4;
303 }
304
305 /*
306 int main()
307 {
308     Matrix3i C;
309 //    MatrixXi C2(4, 3);
310 //
311     C << 1, 2, 3,
312          2, 4, 2,
313          3, 6, 9;
314
315     Matrix3i M;
316     M << 0, 1, 2,
317          0, 0, 0,
318          1, 0, 0;
319
320 //    C2 << 4, 1, 3,
321 //         2, 4, 2,
322 //         3, 6, 9,
323 //         2, 6, 3;
324
325     Vector3i vv;
326     //Matrix3i m1 = (M.array() == 1).select(0, MatrixXi::Ones(M.cols(), M.rows()));
327     //cout << m1.colwise().sum().transpose() << endl;
328     //vv = vv.rowwise() 
329
330     VectorXi row_ind, col_ind;
331
332     //MatrixXi RR = MatrixXi::Random(10, 10);
333     linear_sum_assignment(C, row_ind, col_ind);
334     cout << "row: [" << row_ind.transpose() << "], col: [" << col_ind.transpose() << "]" << endl;
335 }
336 */