From b810cb1e22eb5cf583308e698563d7fba79f8dac Mon Sep 17 00:00:00 2001 From: Dsplib Date: Fri, 16 Oct 2020 23:30:55 +0300 Subject: [PATCH] added pseudoinverion of matrix Changes to be committed: modified: dspl/src/matrix.c new file: examples/src/matrix_pinv_test.c modified: include/dspl.c modified: include/dspl.h --- dspl/src/matrix.c | 178 ++++++++++++++++++++++++++------ examples/src/matrix_pinv_test.c | 46 +++++++++ include/dspl.c | 2 + include/dspl.h | 7 ++ 4 files changed, 202 insertions(+), 31 deletions(-) create mode 100644 examples/src/matrix_pinv_test.c diff --git a/dspl/src/matrix.c b/dspl/src/matrix.c index 83b2d78..5f0e1dd 100644 --- a/dspl/src/matrix.c +++ b/dspl/src/matrix.c @@ -20,6 +20,7 @@ #include #include #include +#include #include "dspl.h" #include "dspl_internal.h" #include "blas.h" @@ -235,50 +236,116 @@ int DSPL_API matrix_mul(double* a, int na, int ma, #ifdef DOXYGEN_RUSSIAN #endif -int DSPL_API matrix_svd(double* a, int n, int m, - double* u, double* s, double* vt, int* info) +int DSPL_API matrix_pinv(double* a, int n, int m, double* tol, + double* inv, int* info) { - char jobz = 'A'; - double* work = NULL; - int* iwork = NULL; - int lwork, mn, mx, err; - int pi; + int err, mn, i, j; + double eps; + + + double* u = NULL; + double* vt = NULL; + double* v = NULL; + double* ut = NULL; + double* s = NULL; - if(!a || !u || !s || !vt) - return ERROR_PTR; - if(n < 1 || m < 1) - return ERROR_SIZE; - - if(n > m) + mn = (m > n) ? n : m; + + u = (double*) malloc(n*n*sizeof(double)); + if(!u) { - mn = m; - mx = n; + err = ERROR_MALLOC; + goto exit_label; } + + vt = (double*) malloc(m*m*sizeof(double)); + if(!vt) + { + err = ERROR_MALLOC; + goto exit_label; + } + + s = (double*) malloc(mn*sizeof(double)); + if(!s) + { + err = ERROR_MALLOC; + goto exit_label; + } + + err = matrix_svd(a, n, m, u, s, vt, info); + if(err != RES_OK) + goto exit_label; + + + + if(tol) + eps = *tol; else - { - mn = n; - mx = m; + { + double smax; + double mx = (n > m) ? (double)n : (double)m; + err = minmax(s, mn, NULL, &smax); + eps = DBL_EPSILON * mx * smax; } - err = RES_OK; + for(i = 0; i < mn; i++) + if(s[i] > eps) + s[i] = 1.0 / s[i]; + else + s[i] = 0.0; + + v = (double*) malloc(m*m*sizeof(double)); + if(!v) + { + err = ERROR_MALLOC; + goto exit_label; + } + err = matrix_transpose(vt, m, m, v); + if(err != RES_OK) + goto exit_label; - lwork = 4 * mn * mn + 6 * mn + mx; - work = (double*) malloc(lwork*sizeof(double)); - iwork = (int*) malloc(8*mn*sizeof(int)); - dgesdd_(&jobz, &n, &m, a, &n, s, u, &n, vt, &m, work, &lwork, iwork, &pi); + ut = (double*) malloc(n*n*sizeof(double)); + if(!ut) + { + err = ERROR_MALLOC; + goto exit_label; + } + err = matrix_transpose(u, n, n, ut); + if(err != RES_OK) + goto exit_label; + + for(i = 0; i < mn; i++) + for(j = 0; j < m; j++) + v[j + i*m] *= s[i]; + + if(mn < m) + memset(v+ mn*m, 0, (m-mn)*sizeof(double)); + + err = matrix_mul(v, m, n, ut, n, n, inv); + +exit_label: + if(u) + free(u); + if(vt) + free(vt); + if(s) + free(s); + if(v) + free(v); + if(ut) + free(ut); - if(info) - *info = pi; - if(pi) - err = ERROR_LAPACK; - if(work) - free(work); - if(iwork) - free(iwork); return err; } + + + + + + + #ifdef DOXYGEN_ENGLISH #endif @@ -357,6 +424,55 @@ int DSPL_API matrix_print_cmplx(complex_t* a, int n, int m, +#ifdef DOXYGEN_ENGLISH + +#endif +#ifdef DOXYGEN_RUSSIAN + +#endif +int DSPL_API matrix_svd(double* a, int n, int m, + double* u, double* s, double* vt, int* info) +{ + char jobz = 'A'; + double* work = NULL; + int* iwork = NULL; + int lwork, mn, mx, err; + int pi; + + if(!a || !u || !s || !vt) + return ERROR_PTR; + if(n < 1 || m < 1) + return ERROR_SIZE; + + if(n > m) + { + mn = m; + mx = n; + } + else + { + mn = n; + mx = m; + } + + err = RES_OK; + + lwork = 4 * mn * mn + 6 * mn + mx; + work = (double*) malloc(lwork*sizeof(double)); + iwork = (int*) malloc(8*mn*sizeof(int)); + dgesdd_(&jobz, &n, &m, a, &n, s, u, &n, vt, &m, work, &lwork, iwork, &pi); + + if(info) + *info = pi; + + if(pi) + err = ERROR_LAPACK; + if(work) + free(work); + if(iwork) + free(iwork); + return err; +} #ifdef DOXYGEN_ENGLISH diff --git a/examples/src/matrix_pinv_test.c b/examples/src/matrix_pinv_test.c new file mode 100644 index 0000000..956d17d --- /dev/null +++ b/examples/src/matrix_pinv_test.c @@ -0,0 +1,46 @@ +#include +#include +#include +#include "dspl.h" + +#define N 2 +#define M 6 + + + +int main(int argc, char* argv[]) +{ + void* handle; /* DSPL handle */ + handle = dspl_load(); /* Load DSPL function */ + + /* Matrix + A = [ 1 2 2; + 2 4 5; + 2 0 1; + 2 -1 0;]; + in array a by columns + */ + double a[N*M] = { 1, 2, 2, 5, 2, 4, 4, -1, 0, 0, 3, -2}; + double inv[M*N]; /* left orthogonal matrix U */ + + + int err, info, mn; + + /* print input matrix */ + matrix_print(a, N, M, "A", "%8.2f"); + + /* SVD decomposition A = U*S*V^T */ + /*-----------------------------------------------------*/ + err = matrix_pinv(a, N, M, NULL, inv, &info); + if(err != RES_OK) + printf("err = %.8x info = %d\n", err, info); + + /* Print SVD decomposition */ + matrix_print(inv, M, N, "inv(A)", "%8.8f"); + + + + + dspl_free(handle); /* free dspl handle */ + return 0; +} diff --git a/include/dspl.c b/include/dspl.c index b7a818e..7b050e7 100644 --- a/include/dspl.c +++ b/include/dspl.c @@ -137,6 +137,7 @@ p_matrix_eig_cmplx matrix_eig_cmplx ; p_matrix_eye matrix_eye ; p_matrix_eye_cmplx matrix_eye_cmplx ; p_matrix_mul matrix_mul ; +p_matrix_pinv matrix_pinv ; p_matrix_print matrix_print ; p_matrix_print_cmplx matrix_print_cmplx ; p_matrix_svd matrix_svd ; @@ -350,6 +351,7 @@ void* dspl_load() LOAD_FUNC(matrix_eye); LOAD_FUNC(matrix_eye_cmplx); LOAD_FUNC(matrix_mul); + LOAD_FUNC(matrix_pinv); LOAD_FUNC(matrix_print); LOAD_FUNC(matrix_print_cmplx); LOAD_FUNC(matrix_svd); diff --git a/include/dspl.h b/include/dspl.h index 63b8489..298262b 100644 --- a/include/dspl.h +++ b/include/dspl.h @@ -1216,6 +1216,13 @@ DECLARE_FUNC(int, matrix_mul, double* a COMMA int mb COMMA double* c); /*----------------------------------------------------------------------------*/ +DECLARE_FUNC(int, matrix_pinv, double* a + COMMA int n + COMMA int m + COMMA double* tol + COMMA double* inv + COMMA int* info); +/*----------------------------------------------------------------------------*/ DECLARE_FUNC(int, matrix_print, double* a COMMA int n COMMA int m