我正在使用scikit learn进行高斯过程回归(GPR)操作来预测数据 . 我的培训数据如下:
x_train = np.array([[0,0],[2,2],[3,3]]) #2-D cartesian coordinate points
y_train = np.array([[200,250, 155],[321,345,210],[417,445,851]]) #observed output from three different datasources at respective input data points (x_train)
需要预测平均值和方差/标准偏差的测试点(2-D)是:
xvalues = np.array([0,1,2,3])
yvalues = np.array([0,1,2,3])
x,y = np.meshgrid(xvalues,yvalues) #Total 16 locations (2-D)
positions = np.vstack([x.ravel(), y.ravel()])
x_test = (np.array(positions)).T
现在,在运行GPR( GausianProcessRegressor
)拟合之后(这里,ConstantKernel和RBF的乘积在 GaussianProcessRegressor
中用作内核),可以通过以下代码行预测均值和方差/标准差:
y_pred_test, sigma = gp.predict(x_test, return_std =True)
在打印预测平均值( y_pred_test
)和方差( sigma
)时,我在控制台中打印了以下输出:
在预测值(平均值)中,打印内部数组内有三个对象的“嵌套数组” . 可以假设内部阵列是每个2-D测试点位置处的每个数据源的预测平均值 . 但是,打印的方差只包含一个包含16个对象的数组(可能包含16个测试位置点) . 我知道方差提供了估计不确定性的指示 . 因此,我期待每个测试点的每个数据源的预测方差 . 我的期望是错的吗?如何在每个测试点获得每个数据源的预测方差?这是由于错误的代码?
谢谢!
2 回答
好吧,你无意中碰上了冰山......
作为前奏,让我们明确指出方差和标准差的概念仅适用于标量变量;对于矢量变量(比如你自己的3d输出),方差的概念不再有意义,而是使用协方差矩阵(Wikipedia,Wolfram) .
继续前奏,你的
sigma
的形状确实如预期的那样根据predict
方法的scikit-learn docs(即你的情况下没有编码错误):结合我之前关于协方差矩阵的评论,第一个选择是尝试使用参数
return_cov=True
的predict
函数(因为要求矢量变量的方差是没有意义的);但同样,这将导致16x16矩阵,而不是3x3矩阵(3个输出变量的协方差矩阵的预期形状)......澄清了这些细节之后,让我们继续讨论这个问题的本质 .
问题的核心在于实践和相关教程中很少提及(甚至暗示)的事情:具有多个输出的高斯过程回归是 highly non-trivial ,仍然是一个活跃的研究领域 . 可以说,scikit-learn无法真正处理这个案例,尽管事实上它表面上似乎没有发出至少一些相关的警告 .
让我们在最近的科学文献中寻找对这种主张的一些佐证:
Gaussian process regression with multiple response variables(2015) - 引用(强调我的):
Remarks on multi-output Gaussian process regression(2018) - 引用(强调原文):
Physics-Based Covariance Models for Gaussian Processes with Multiple Outputs(2013) - 引用:
因此,正如我所说,我的理解是,sckit-learn并不能真正处理这种情况,尽管事实上文档中没有提到或暗示过这样的事情(在以下问题上打开相关问题可能会很有趣)项目页面) . 这似乎也是this relevant SO thread以及this CrossValidated thread中有关GPML(Matlab)工具箱的结论 .
话虽如此,除了回复单独建模每个输出的选择(不是无效的选择,只要你记住你可能从你的3-D输出元素之间的相关性丢弃有用的信息),至少有一个Python工具箱似乎能够为多输出GP建模,即 runlmc (paper,code,documentation) .
首先,如果使用的参数是“sigma”,那指的是标准差,而不是方差(召回,方差只是标准偏差的平方) .
使用方差概念化更容易,因为方差被定义为从数据点到集合均值的欧几里德距离 .
在您的情况下,您有一组2D点 . 如果您将这些视为2D平面上的点,那么方差就是从每个点到平均值的距离 . 标准偏差将是方差的正根 .
在这种情况下,您有16个测试点和16个标准差值 . 这非常有意义,因为每个测试点都有自己与集合平均值的定义距离 .
如果要计算点的SET的方差,可以通过单独求和每个点的方差,将其除以点数,然后减去均方来实现 . 该数字的正根将产生该组的标准偏差 .
ASIDE:这也意味着如果您通过插入,删除或替换更改集合,则每个点的标准偏差将发生变化 . 这是因为将重新计算平均值以适应新数据 . 这个迭代过程是k-means聚类背后的基本力量 .