added pseudoinverion of matrix

Changes to be committed:
	modified:   dspl/src/matrix.c
	new file:   examples/src/matrix_pinv_test.c
	modified:   include/dspl.c
	modified:   include/dspl.h
pull/6/merge
Dsplib 2020-10-16 23:30:55 +03:00
rodzic f848c84fac
commit b810cb1e22
4 zmienionych plików z 202 dodań i 31 usunięć

Wyświetl plik

@ -20,6 +20,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <float.h>
#include "dspl.h"
#include "dspl_internal.h"
#include "blas.h"
@ -235,50 +236,116 @@ int DSPL_API matrix_mul(double* a, int na, int ma,
#ifdef DOXYGEN_RUSSIAN
#endif
int DSPL_API matrix_svd(double* a, int n, int m,
double* u, double* s, double* vt, int* info)
int DSPL_API matrix_pinv(double* a, int n, int m, double* tol,
double* inv, int* info)
{
char jobz = 'A';
double* work = NULL;
int* iwork = NULL;
int lwork, mn, mx, err;
int pi;
int err, mn, i, j;
double eps;
double* u = NULL;
double* vt = NULL;
double* v = NULL;
double* ut = NULL;
double* s = NULL;
if(!a || !u || !s || !vt)
return ERROR_PTR;
if(n < 1 || m < 1)
return ERROR_SIZE;
if(n > m)
mn = (m > n) ? n : m;
u = (double*) malloc(n*n*sizeof(double));
if(!u)
{
mn = m;
mx = n;
err = ERROR_MALLOC;
goto exit_label;
}
vt = (double*) malloc(m*m*sizeof(double));
if(!vt)
{
err = ERROR_MALLOC;
goto exit_label;
}
s = (double*) malloc(mn*sizeof(double));
if(!s)
{
err = ERROR_MALLOC;
goto exit_label;
}
err = matrix_svd(a, n, m, u, s, vt, info);
if(err != RES_OK)
goto exit_label;
if(tol)
eps = *tol;
else
{
mn = n;
mx = m;
{
double smax;
double mx = (n > m) ? (double)n : (double)m;
err = minmax(s, mn, NULL, &smax);
eps = DBL_EPSILON * mx * smax;
}
err = RES_OK;
for(i = 0; i < mn; i++)
if(s[i] > eps)
s[i] = 1.0 / s[i];
else
s[i] = 0.0;
v = (double*) malloc(m*m*sizeof(double));
if(!v)
{
err = ERROR_MALLOC;
goto exit_label;
}
err = matrix_transpose(vt, m, m, v);
if(err != RES_OK)
goto exit_label;
lwork = 4 * mn * mn + 6 * mn + mx;
work = (double*) malloc(lwork*sizeof(double));
iwork = (int*) malloc(8*mn*sizeof(int));
dgesdd_(&jobz, &n, &m, a, &n, s, u, &n, vt, &m, work, &lwork, iwork, &pi);
ut = (double*) malloc(n*n*sizeof(double));
if(!ut)
{
err = ERROR_MALLOC;
goto exit_label;
}
err = matrix_transpose(u, n, n, ut);
if(err != RES_OK)
goto exit_label;
for(i = 0; i < mn; i++)
for(j = 0; j < m; j++)
v[j + i*m] *= s[i];
if(mn < m)
memset(v+ mn*m, 0, (m-mn)*sizeof(double));
err = matrix_mul(v, m, n, ut, n, n, inv);
exit_label:
if(u)
free(u);
if(vt)
free(vt);
if(s)
free(s);
if(v)
free(v);
if(ut)
free(ut);
if(info)
*info = pi;
if(pi)
err = ERROR_LAPACK;
if(work)
free(work);
if(iwork)
free(iwork);
return err;
}
#ifdef DOXYGEN_ENGLISH
#endif
@ -357,6 +424,55 @@ int DSPL_API matrix_print_cmplx(complex_t* a, int n, int m,
#ifdef DOXYGEN_ENGLISH
#endif
#ifdef DOXYGEN_RUSSIAN
#endif
int DSPL_API matrix_svd(double* a, int n, int m,
double* u, double* s, double* vt, int* info)
{
char jobz = 'A';
double* work = NULL;
int* iwork = NULL;
int lwork, mn, mx, err;
int pi;
if(!a || !u || !s || !vt)
return ERROR_PTR;
if(n < 1 || m < 1)
return ERROR_SIZE;
if(n > m)
{
mn = m;
mx = n;
}
else
{
mn = n;
mx = m;
}
err = RES_OK;
lwork = 4 * mn * mn + 6 * mn + mx;
work = (double*) malloc(lwork*sizeof(double));
iwork = (int*) malloc(8*mn*sizeof(int));
dgesdd_(&jobz, &n, &m, a, &n, s, u, &n, vt, &m, work, &lwork, iwork, &pi);
if(info)
*info = pi;
if(pi)
err = ERROR_LAPACK;
if(work)
free(work);
if(iwork)
free(iwork);
return err;
}
#ifdef DOXYGEN_ENGLISH

Wyświetl plik

@ -0,0 +1,46 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "dspl.h"
#define N 2
#define M 6
int main(int argc, char* argv[])
{
void* handle; /* DSPL handle */
handle = dspl_load(); /* Load DSPL function */
/* Matrix
A = [ 1 2 2;
2 4 5;
2 0 1;
2 -1 0;];
in array a by columns
*/
double a[N*M] = { 1, 2, 2, 5, 2, 4, 4, -1, 0, 0, 3, -2};
double inv[M*N]; /* left orthogonal matrix U */
int err, info, mn;
/* print input matrix */
matrix_print(a, N, M, "A", "%8.2f");
/* SVD decomposition A = U*S*V^T */
/*-----------------------------------------------------*/
err = matrix_pinv(a, N, M, NULL, inv, &info);
if(err != RES_OK)
printf("err = %.8x info = %d\n", err, info);
/* Print SVD decomposition */
matrix_print(inv, M, N, "inv(A)", "%8.8f");
dspl_free(handle); /* free dspl handle */
return 0;
}

Wyświetl plik

@ -137,6 +137,7 @@ p_matrix_eig_cmplx matrix_eig_cmplx ;
p_matrix_eye matrix_eye ;
p_matrix_eye_cmplx matrix_eye_cmplx ;
p_matrix_mul matrix_mul ;
p_matrix_pinv matrix_pinv ;
p_matrix_print matrix_print ;
p_matrix_print_cmplx matrix_print_cmplx ;
p_matrix_svd matrix_svd ;
@ -350,6 +351,7 @@ void* dspl_load()
LOAD_FUNC(matrix_eye);
LOAD_FUNC(matrix_eye_cmplx);
LOAD_FUNC(matrix_mul);
LOAD_FUNC(matrix_pinv);
LOAD_FUNC(matrix_print);
LOAD_FUNC(matrix_print_cmplx);
LOAD_FUNC(matrix_svd);

Wyświetl plik

@ -1216,6 +1216,13 @@ DECLARE_FUNC(int, matrix_mul, double* a
COMMA int mb
COMMA double* c);
/*----------------------------------------------------------------------------*/
DECLARE_FUNC(int, matrix_pinv, double* a
COMMA int n
COMMA int m
COMMA double* tol
COMMA double* inv
COMMA int* info);
/*----------------------------------------------------------------------------*/
DECLARE_FUNC(int, matrix_print, double* a
COMMA int n
COMMA int m