首页 文章

如何在seaborn heatmap子图中保持颜色条宽度一致

提问于
浏览
0

我试图在具有不同子图的单个图上绘制两个具有相同x轴的热图,每个图的比例都不同,因此它们需要不同的颜色条 . 每个热图也沿y轴有不同数量的观测值,见下图:

import numpy as np
import seaborn as sns

arr1 = np.random.rand(10, 10)*100
arr2 = np.random.rand(2, 10)

fig, axes = plt.subplots(2, 1, 
                     gridspec_kw={'height_ratios':[5, 1]}, 
                     sharex=True)

sns.heatmap(arr1, ax=axes[0]) 
sns.heatmap(arr2, ax=axes[1])

enter image description here

颜色条的宽度在图之间不一致 . 这是有道理的并且并不出乎意料,但是出于美学原因,我想强制每个的宽度保持一致 .

我尝试使用 cbar 参数并明确定义颜色条的大小和位置:

fig, axes = plt.subplots(2, 1, 
                     gridspec_kw={'height_ratios':[5, 1]}, 
                     sharex=True)
sns.heatmap(arr1, ax=axes[0])

cbar_ax = fig.add_axes([.85, .135, .019, .15])
sns.heatmap(arr2, ax=axes[1], cbar_ax=cbar_ax)

enter image description here

虽然在这样做时,情节不再是一致的宽度 . 这里最好的方法是为各个轴子图上的每个颜色条定义轴吗?似乎可能有一种我想念的简单方法

2 回答

  • 1

    首先

    解决方案

    如果可能的话,尝试使用 GridSpec 来组织你的4个子图(两个主图和两个颜色条在他们自己的_1765206中):

    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns
    
    arr1 = np.random.rand(10, 10)*100
    arr2 = np.random.rand(2, 10)
    
    fig, axes = plt.subplots(2, 2,
                             gridspec_kw={'height_ratios': [5, 1], 
                                          'width_ratios': [30, 1], 
                                          'wspace': 0.1}, 
                             sharex='col')
    
    sns.heatmap(arr1, ax=axes[0][0], cbar_ax=axes[0][1]) 
    sns.heatmap(arr2, ax=axes[1][0], cbar_ax=axes[1][1])
    

    enter image description here

    然后,对您的想法发表一些评论:

    使用 add_axes 时,手动添加子图( axes )并强制它在特定位置 . 这可能使新子图难以与其他子图协调 . 在你的第二个情节中,它发现第三个 axes 漂浮在其他子图上面 .

    第二种方法的问题是您将颜色条移出 axes arr2 . 因此,主要情节采用 axes 的全宽,而 arr1 的主要情节与其颜色条共享 axes 的全宽 . 认为两个主要情节似乎是"no longer consistent width",两个子图( axes 的宽度仍然一致) .

  • 3

    您可以将颜色条轴( cbar_ax )定义为子图网格的一部分,以便所有间距和宽度在两行之间同步 .

    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns
    
    arr1 = np.random.rand(10, 10)*100
    arr2 = np.random.rand(2, 10)
    
    fig, axes = plt.subplots(2, 2, 
                         gridspec_kw={'height_ratios':[5, 1],
                                      'width_ratios' :[1, 0.04]})
    
    sns.heatmap(arr1, ax=axes[0,0], cbar_ax=axes[0,1]) 
    sns.heatmap(arr2, ax=axes[1,0], cbar_ax=axes[1,1]) 
    
    
    plt.show()
    

    enter image description here

相关问题