Actual source code: mattransposematmult.c
2: /*
3: Defines matrix-matrix product routines for
4: C = A^T * B and C = A * B^t
5: with A SeqAIJ and B SeqDense
6: */
8: #include <../src/mat/impls/aij/seq/aij.h>
9: #include <../src/mat/impls/dense/seq/dense.h>
11: PetscErrorCode MatDestroy_SeqDense_MatTransMatMult(void *data)
12: {
13: Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;
15: MatDestroy(&atb->mA);
16: VecDestroy(&atb->bt);
17: VecDestroy(&atb->ct);
18: PetscFree(atb);
19: return 0;
20: }
22: static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat,Mat,Mat);
24: PETSC_INTERN PetscErrorCode MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A,Mat B,PetscReal fill,Mat C)
25: {
26: Mat_MatTransMatMult *atb;
27: PetscBool cisdense;
28: PetscInt dofm;
30: MatCheckProduct(C,4);
34: /* create output dense matrix C */
35: if (C->product->type == MATPRODUCT_AtB) {
36: MatSetSizes(C,A->cmap->n,B->cmap->N,A->cmap->n,B->cmap->N);
37: dofm = B->cmap->n;
38: } else {
39: MatSetSizes(C,A->rmap->n,B->rmap->N,A->rmap->n,B->rmap->N);
40: dofm = B->rmap->n;
41: }
42: PetscObjectTypeCompareAny((PetscObject)C,&cisdense,MATSEQDENSE,MATSEQDENSECUDA,"");
43: if (!cisdense) {
44: MatSetType(C,((PetscObject)B)->type_name);
45: }
46: MatSetUp(C);
48: /* create additional data structure for the product */
49: PetscNew(&atb);
50: MatCreateMAIJ(A,dofm,&atb->mA);
51: MatCreateVecs(atb->mA,&atb->ct,&atb->bt);
52: C->product->data = atb;
53: C->product->destroy = MatDestroy_SeqDense_MatTransMatMult;
55: if (C->product->type == MATPRODUCT_AtB) {
56: C->ops->transposematmultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense;
57: } else {
58: C->ops->mattransposemultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense;
59: }
60: return 0;
61: }
63: PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A,Mat B,Mat C)
64: {
65: PetscInt i,j,m=A->rmap->n,n=A->cmap->n,blda,clda;
66: PetscInt mdof = C->cmap->N;
67: const PetscScalar *Barray;
68: PetscScalar *Carray;
69: Mat_MatTransMatMult *atb;
70: Vec bt,ct;
72: MatCheckProduct(C,3);
74: atb = (Mat_MatTransMatMult *)C->product->data;
76: bt = atb->bt;
77: ct = atb->ct;
79: MatDenseGetArrayRead(B,&Barray);
80: MatDenseGetLDA(B,&blda);
81: MatDenseGetArrayWrite(C,&Carray);
82: MatDenseGetLDA(C,&clda);
83: if (C->product->type == MATPRODUCT_AtB) { /* transpose local array of B, then copy it to vector bt */
84: const PetscScalar *ctarray;
85: PetscScalar *btarray;
87: VecGetArrayWrite(bt,&btarray);
88: for (j=0; j<mdof; j++) {
89: for (i=0; i<m; i++) btarray[i*mdof + j] = Barray[j*blda + i];
90: }
91: VecRestoreArrayWrite(bt,&btarray);
93: /* compute ct = mA^T * cb */
94: MatMultTranspose(atb->mA,bt,ct);
96: /* transpose local array of ct to matrix C */
97: VecGetArrayRead(ct,&ctarray);
98: for (j=0; j<mdof; j++) {
99: for (i=0; i<n; i++) Carray[j*clda + i] = ctarray[i*mdof + j];
100: }
101: VecRestoreArrayRead(ct,&ctarray);
102: } else {
103: const PetscScalar *btarray;
104: PetscScalar *ctarray;
106: if (blda == B->rmap->n) {
107: VecPlaceArray(ct,Barray);
108: } else {
109: PetscInt bn = B->cmap->n;
110: PetscInt bm = B->rmap->n;
112: VecGetArrayWrite(ct,&ctarray);
113: for (j=0; j<bn; j++) {
114: for (i=0; i<bm; i++) ctarray[j*bm + i] = Barray[j*blda + i];
115: }
116: VecRestoreArrayWrite(ct,&ctarray);
117: }
119: MatMult(atb->mA,ct,bt);
120: if (blda == B->rmap->n) {
121: VecResetArray(ct);
122: }
123: VecGetArrayRead(bt,&btarray);
124: for (j=0; j<mdof; j++) {
125: for (i=0; i<m; i++) Carray[j*clda + i] = btarray[i*mdof + j];
126: }
127: VecRestoreArrayRead(bt,&btarray);
128: }
129: MatDenseRestoreArrayRead(B,&Barray);
130: MatDenseRestoreArray(C,&Carray);
131: return 0;
132: }