首页 文章

在HDF5(PyTables)中存储numpy稀疏矩阵

提问于
浏览
21

我在使用PyTables存储numpy csr_matrix时遇到问题 . 我收到这个错误:

TypeError: objects of type ``csr_matrix`` are not supported in this context, sorry; supported objects are: NumPy array, record or scalar; homogeneous list or tuple, integer, float, complex or string

我的代码:

f = tables.openFile(path,'w')

atom = tables.Atom.from_dtype(self.count_vector.dtype)
ds = f.createCArray(f.root, 'count', atom, self.count_vector.shape)
ds[:] = self.count_vector
f.close()

有任何想法吗?

谢谢

3 回答

  • 34

    CSR矩阵可以从其 dataindicesindptr 属性完全重建 . 这些只是常规的numpy数组,因此将它们存储为pytables中的3个独立数组应该没有问题,然后将它们传递回 csr_matrix 的构造函数 . 见scipy docs .

    Edit: Pietro的回答指出 shape 成员也应该存储

  • 21

    DaveP的答案几乎是正确的......但是可能会导致非常稀疏的矩阵出现问题:如果最后一列或一行是空的,它们会被丢弃 . 因此,为了确保一切正常,"shape"属性也必须存储 .

    这是我经常使用的代码:

    import tables as tb
    from numpy import array
    from scipy import sparse
    
    def store_sparse_mat(m, name, store='store.h5'):
        msg = "This code only works for csr matrices"
        assert(m.__class__ == sparse.csr.csr_matrix), msg
        with tb.openFile(store,'a') as f:
            for par in ('data', 'indices', 'indptr', 'shape'):
                full_name = '%s_%s' % (name, par)
                try:
                    n = getattr(f.root, full_name)
                    n._f_remove()
                except AttributeError:
                    pass
    
                arr = array(getattr(m, par))
                atom = tb.Atom.from_dtype(arr.dtype)
                ds = f.createCArray(f.root, full_name, atom, arr.shape)
                ds[:] = arr
    
    def load_sparse_mat(name, store='store.h5'):
        with tb.openFile(store) as f:
            pars = []
            for par in ('data', 'indices', 'indptr', 'shape'):
                pars.append(getattr(f.root, '%s_%s' % (name, par)).read())
        m = sparse.csr_matrix(tuple(pars[:3]), shape=pars[3])
        return m
    

    将其适应csc矩阵是微不足道的 .

  • 6

    我更新了Pietro Battiston对Python 3.6和PyTables 3.x的优秀答案,因为一些PyTables函数名称在2.x升级中已经改变 .

    import numpy as np
    from scipy import sparse
    import tables
    
    def store_sparse_mat(M, name, filename='store.h5'):
        """
        Store a csr matrix in HDF5
    
        Parameters
        ----------
        M : scipy.sparse.csr.csr_matrix
            sparse matrix to be stored
    
        name: str
            node prefix in HDF5 hierarchy
    
        filename: str
            HDF5 filename
        """
        assert(M.__class__ == sparse.csr.csr_matrix), 'M must be a csr matrix'
        with tables.open_file(filename, 'a') as f:
            for attribute in ('data', 'indices', 'indptr', 'shape'):
                full_name = f'{name}_{attribute}'
    
                # remove existing nodes
                try:  
                    n = getattr(f.root, full_name)
                    n._f_remove()
                except AttributeError:
                    pass
    
                # add nodes
                arr = np.array(getattr(M, attribute))
                atom = tables.Atom.from_dtype(arr.dtype)
                ds = f.create_carray(f.root, full_name, atom, arr.shape)
                ds[:] = arr
    
    def load_sparse_mat(name, filename='store.h5'):
        """
        Load a csr matrix from HDF5
    
        Parameters
        ----------
        name: str
            node prefix in HDF5 hierarchy
    
        filename: str
            HDF5 filename
    
        Returns
        ----------
        M : scipy.sparse.csr.csr_matrix
            loaded sparse matrix
        """
        with tables.open_file(filename) as f:
    
            # get nodes
            attributes = []
            for attribute in ('data', 'indices', 'indptr', 'shape'):
                attributes.append(getattr(f.root, f'{name}_{attribute}').read())
    
        # construct sparse matrix
        M = sparse.csr_matrix(tuple(attributes[:3]), shape=attributes[3])
        return M
    

相关问题