Skip to content

Commit

Permalink
Merge pull request #66 from jacobmerson/feature-scalarMatMult
Browse files Browse the repository at this point in the history
Create ScalarMatMultiply
  • Loading branch information
jacobmerson authored Jan 18, 2019
2 parents 8a0cbd8 + 06d5642 commit 0868fc3
Show file tree
Hide file tree
Showing 14 changed files with 520 additions and 50 deletions.
24 changes: 24 additions & 0 deletions core/lasCSRCore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,34 @@ namespace las
}
}
};
class IdentityCSR : public CSRBuilder
{
protected:
int ndofs;
public:
IdentityCSR(int ndofs)
: CSRBuilder(ndofs,ndofs)
, ndofs(ndofs)
{
}
void run()
{
for(int i=0; i<ndofs;++i)
{
add(i,i);
}
}
};
Sparsity * createCSR(apf::Numbering * num, int ndofs)
{
CSRFromNumbering bldr(num,ndofs);
bldr.run();
return bldr.finalize();
}
Sparsity * createIdentityCSR(int ndofs)
{
IdentityCSR bldr(ndofs);
bldr.run();
return bldr.finalize();
}
}
1 change: 1 addition & 0 deletions core/lasCSRCore.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ namespace las
* values of the numbering, no mesh partitioning is considered.
*/
Sparsity * createCSR(apf::Numbering * num, int ndofs);
Sparsity * createIdentityCSR(int ndofs);
}
#endif
58 changes: 58 additions & 0 deletions src/las.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ namespace las
{
static_cast<T*>(this)->_set(v,cnt,rws,vls);
}
void set(Vec * v, scalar * vls)
{
static_cast<T*>(this)->_set(v,vls);
}
void set(Mat * m, int cntr, int * rws, int cntc, int * cls, scalar * vls)
{
static_cast<T*>(this)->_set(m,cntr,rws,cntc,cls,vls);
Expand Down Expand Up @@ -198,6 +202,8 @@ namespace las
virtual void solve(Mat * k, Vec * u, Vec * f) = 0;
virtual ~Solve() {}
};
template <typename T>
Solve * getSolve(int id);
/**
* Interface for Matrix-Vector multiplication
* @todo Retrieve backend-specific solvers using
Expand All @@ -210,6 +216,8 @@ namespace las
virtual void exec(Mat * x, Vec * a, Vec * b) = 0;
virtual ~MatVecMult() {}
};
template <class T>
MatVecMult * getMatVecMult();
/**
* Interface for Matrix-Matrix multiplication
* @todo Retrieve backend-specific solvers using
Expand All @@ -222,6 +230,56 @@ namespace las
virtual void exec(Mat * a, Mat * b, Mat ** c) = 0;
virtual ~MatMatMult() {}
};
template <class T>
MatMatMult * getMatMatMult();
/**
* Interface for Scalar-Matrix multiplication
* If c is NULL performs an in place multiplication
* @todo Retrieve backend-specific solvers using
* backend id classes to do template
* specialization, as above.
*/
class ScalarMatMult
{
public:
virtual void exec(scalar s, Mat * a, Mat ** c) = 0;
virtual ~ScalarMatMult() {}
};
template <class T>
ScalarMatMult * getScalarMatMult();
/*
* interface for C = alpha_1*A+alpha_2*B
*/
class MatMatAdd
{
public:
virtual void exec(scalar s1, Mat * a, scalar s2, Mat * b, Mat ** c) = 0;
virtual ~MatMatAdd() {}
};
template <class T>
MatMatAdd * getMatMatAdd();
/*
* interface for vector vector addition
*/
class VecVecAdd
{
public:
virtual void exec(scalar s1, Vec * v1, scalar s2, Vec * v2, Vec *& v3) = 0;
virtual ~VecVecAdd() {}
};
template <class T>
VecVecAdd * getVecVecAdd();
/*
* Interface for scalar-vector multiplication
*/
class ScalarVecMult
{
public:
virtual void exec(scalar s, Vec * x, Vec ** y) = 0;
virtual ~ScalarVecMult() {}
};
template <class T>
ScalarVecMult * getScalarVecMult();
/*
* Finalize routines which must be called on a matrix when switching from
* add mode to set mode
Expand Down
30 changes: 20 additions & 10 deletions src/lasCSR.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "lasSys.h"
#include <vector>
#include <iostream>
#include <algorithm>
#include <cassert>
namespace las
{
class Sparsity;
Expand Down Expand Up @@ -35,21 +37,29 @@ namespace las
* they are converted to use 1-indexing (in a debug build this generates a warning).
*/
CSR(int r, int c, int nnz, int * rs, int * cs);
CSR(int r, int c, int nnz, std::vector<int> const & rs, std::vector<int> const & cs);
int getNumRows() const { return nr; }
int getNumCols() const { return nc; }
int getNumNonzero() const { return nnz; }
// return the index into the values array
// if the location is not stored then return -1
// note rw and cl start at zero
int operator()(int rw, int cl) const
{
int result = -1;
int fst = rws[rw] - 1;
while((fst < rws[rw+1] - 2) && (cls[fst] - 1 < cl))
++fst;
// the column is correct at offset and the row isn't empty
if(cls[fst] - 1 == cl && rws[rw] - 1 <= rws[rw+1] - 2)
result = fst;
else
result = -1;
return result;
assert(rw < nr && rw>=0);
assert(cl < nc && cl>=0);
// the row is empty
if(rws[rw+1]-rws[rw] == 0)
return -1;
// this approach finds the correct index in log(n) time where
// n is the number of elements on the row
typedef std::vector<int>::const_iterator vit_t;
vit_t bgn = cls.begin()+rws[rw]-1;
vit_t end = cls.begin()+rws[rw+1]-1;
std::pair<vit_t, vit_t> bounds = equal_range(bgn, end, cl+1);
if(bounds.first == bounds.second)
return -1;
return bounds.first-cls.begin();
}
int * getRows() { return &rws[0]; }
int * getCols() { return &cls[0]; }
Expand Down
10 changes: 6 additions & 4 deletions src/lasCSRBuilder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ namespace las
auto it = std::unique(coords.begin(), coords.end(), unique_comps<int,int>);
coords.resize(std::distance(coords.begin(), it));
assert(nnz >= coords.size());
if(coords.size() < nnz) {
std::cerr<<"Warning: ignored "<<nnz-coords.size() << " duplicate entries\n";
}
// this is useful for debugging, but doesn't make sense when the builder is needed
// for things like the add function...
//if(coords.size() < nnz) {
// std::cerr<<"Warning: ignored "<<nnz-coords.size() << " duplicate entries\n";
//}
nnz = coords.size();
cls.resize(nnz);
for(std::size_t i=0; i<coords.size(); ++i) {
Expand Down Expand Up @@ -61,7 +63,7 @@ namespace las
result = false;
if(rw>=rw_cnt || cl>=cl_cnt)
{
std::cerr<<"Warning: inserting a row outside the matrix bounds. Skipping it.\n";
std::cerr<<"Warning: inserting a row/column outside the matrix bounds ("<<rw<<","<<cl<<"). Skipping it.\n";
result=false;
}
return result;
Expand Down
8 changes: 8 additions & 0 deletions src/lasCSR_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ namespace las
cls[cl]++;
}
}
LAS_INLINE CSR::CSR(int r, int c, int nnz, std::vector<int> const & rs, std::vector<int> const & cs)
: nr(r)
, nc(c)
, nnz(nnz)
, rws(rs)
, cls(cs)
{
}
LAS_INLINE Sparsity * csrFromArray(int rws, int cls, int nnz, int * row_arr, int * col_arr)
{
return reinterpret_cast<Sparsity*>(new CSR(rws,cls,nnz,row_arr,col_arr));
Expand Down
1 change: 1 addition & 0 deletions src/lasPETSc.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace las
Solve * createPetscQNSolve(void * a);
MatVecMult * createPetscMatVecMult();
MatMatMult * createPetscMatMatMult();
ScalarMatMult * createPetscScalarMatMult();
template <>
void finalizeMatrix<petsc>(Mat * mat);
template <>
Expand Down
22 changes: 22 additions & 0 deletions src/lasPETSc_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,28 @@ namespace las
{
return new PetscMatMatMult;
}
class PetscScalarMatMult : public ScalarMatMult
{
public:
virtual void exec(scalar s, Mat * a, Mat ** c)
{
if (c == nullptr)
{
PetscErrorCode ierr = ::MatScale(*getPetscMat(a), s);
}
else
{
std::cerr << "Out of place matrix scalar multiplication not "
"implemented in petsc"
<< std::endl;
std::abort();
}
}
};
LAS_INLINE ScalarMatMult * createPetscScalarMatMult()
{
return new PetscScalarMatMult;
}
template <>
LAS_INLINE void finalizeMatrix<petsc>(Mat * mat)
{
Expand Down
Loading

0 comments on commit 0868fc3

Please sign in to comment.