Snowman  0.1.0
matrix-wrapper.h
1 #pragma once
2 #include <cstdint>
3 #include <matrix-types.h>
4 #include <ostream>
5 #include <string>
6 #include <vector>
7 
8 namespace snowboy {
9  class VectorBase;
10  struct SubMatrix;
11  struct MatrixBase {
12  size_t m_rows{0};
13  size_t m_cols{0};
14  size_t m_stride{0};
15  float* m_data{nullptr};
16 
17  size_t rows() const noexcept { return m_rows; }
18  size_t cols() const noexcept { return m_cols; }
19  size_t stride() const noexcept { return m_stride; }
20  float* data() const noexcept { return m_data; }
21  float* data(size_t row) const noexcept { return m_data + (row * stride()); }
22  float& operator()(size_t row, size_t col) const noexcept { return m_data[row * m_stride + col]; }
23  bool empty() const noexcept { return rows() == 0 || cols() == 0; }
24 
25  void AddMat(float alpha, const MatrixBase& A, MatrixTransposeType transA);
26  void AddMatMat(float, const MatrixBase&, MatrixTransposeType, const MatrixBase&, MatrixTransposeType, float);
27  void AddVecToRows(float, const VectorBase&);
28  void AddVecVec(float, const VectorBase&, const VectorBase&);
29  void ApplyFloor(float);
30  SubMatrix ColRange(size_t, size_t) const;
31  void CopyColFromVec(const VectorBase&, size_t);
32  void CopyCols(const MatrixBase&, const std::vector<ssize_t>&);
33  void CopyColsFromVec(const VectorBase&);
34  void CopyDiagFromVec(const VectorBase&);
35  void CopyFromMat(const MatrixBase&, MatrixTransposeType transposeType);
36  void CopyRowFromVec(const VectorBase&, size_t);
37  void CopyRows(const MatrixBase&, const std::vector<ssize_t>&);
38  void CopyRowsFromVec(const VectorBase&);
39  bool IsDiagonal(float) const;
40  bool IsSymmetric(float) const;
41  bool IsUnit(float) const;
42  // Implementation does a Max and checks that against cutoff
43  // We can probably cancel early if the current value is above cutoff
44  bool IsZero(float cutoff = 0.00001) const;
45  void MulColsVec(const VectorBase&);
46  void MulRowsVec(const VectorBase&);
47  SubMatrix Range(size_t, size_t, size_t, size_t) const;
48  void Read(bool, bool, std::istream*);
49  void Read(bool, std::istream*); // Read(p1, false, p2);
50  SubMatrix RowRange(size_t, size_t) const;
51  void Scale(float factor);
52  void Set(float value);
53  void SetRandomGaussian();
54  void SetRandomUniform();
55  void SetUnit();
56  void Transpose();
57  void Write(bool, std::ostream*) const;
58  bool HasNan() const;
59  bool HasInfinity() const;
60  };
61  struct Matrix : MatrixBase {
62  Matrix() {}
63  Matrix(const Matrix& other) {
64  Resize(other.m_rows, other.m_cols, MatrixResizeType::kUndefined);
65  CopyFromMat(other, MatrixTransposeType::kNoTrans);
66  }
67  Matrix(const MatrixBase& other) {
68  Resize(other.m_rows, other.m_cols, MatrixResizeType::kUndefined);
69  CopyFromMat(other, MatrixTransposeType::kNoTrans);
70  }
71  Matrix(Matrix&& other) {
72  m_rows = other.m_rows;
73  m_cols = other.m_cols;
74  m_stride = other.m_stride;
75  m_data = other.m_data;
76  other.m_rows = 0;
77  other.m_data = nullptr;
78  other.m_stride = 0;
79  other.m_cols = 0;
80  }
81  void Resize(size_t rows, size_t cols, MatrixResizeType resize = MatrixResizeType::kSetZero);
82  void AllocateMatrixMemory(size_t rows, size_t cols);
83  void ReleaseMatrixMemory(); // NOTE: Called destroy in kaldi
84  ~Matrix() { ReleaseMatrixMemory(); }
85 
86  Matrix& operator=(const Matrix& other);
87  Matrix& operator=(const MatrixBase& other);
88  Matrix& operator=(Matrix&& other) {
89  Swap(&other);
90  return *this;
91  }
92 
93  void RemoveRow(size_t row);
94  void Read(bool, bool, std::istream*);
95  void Read(bool, std::istream*);
96  void Swap(Matrix* other);
97  void Transpose();
98 
99  static void PrintAllocStats(std::ostream&);
100  static void ResetAllocStats();
101  };
103  SubMatrix(const MatrixBase& parent, size_t rowoffset, size_t rows, size_t coloffset, size_t cols);
104  };
105 
106  std::ostream& operator<<(std::ostream&, const MatrixBase&);
107 } // namespace snowboy
Definition: vector-wrapper.h:11
Definition: matrix-wrapper.h:11
Definition: matrix-wrapper.h:61
Definition: matrix-wrapper.h:102