首页 文章

Matplotlib 2子图,1个Colorbar

提问于
浏览
168

我花了很长时间研究如何让两个子图共享相同的y轴,并在Matplotlib中共享两个颜色条 .

发生了什么事情,当我在 subplot1subplot2 中调用 colorbar() 函数时,它会自动调整绘图,使得颜色条加上绘图将适合'subplot'边界框内,导致两个并排的图形为两个非常不同的尺寸 .

为了解决这个问题,我尝试创建了第三个子图,然后我将其修改为仅显示一个颜色条的情节 . 唯一的问题是,现在两个地块的高度和宽度是不均匀的,我无法弄清楚如何使它看起来没问题 .

这是我的代码:

from __future__ import division
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import patches
from matplotlib.ticker import NullFormatter

# SIS Functions
TE = 1 # Einstein radius
g1 = lambda x,y: (TE/2) * (y**2-x**2)/((x**2+y**2)**(3/2)) 
g2 = lambda x,y: -1*TE*x*y / ((x**2+y**2)**(3/2))
kappa = lambda x,y: TE / (2*np.sqrt(x**2+y**2))

coords = np.linspace(-2,2,400)
X,Y = np.meshgrid(coords,coords)
g1out = g1(X,Y)
g2out = g2(X,Y)
kappaout = kappa(X,Y)
for i in range(len(coords)):
    for j in range(len(coords)):
        if np.sqrt(coords[i]**2+coords[j]**2) <= TE:
            g1out[i][j]=0
            g2out[i][j]=0

fig = plt.figure()
fig.subplots_adjust(wspace=0,hspace=0)

# subplot number 1
ax1 = fig.add_subplot(1,2,1,aspect='equal',xlim=[-2,2],ylim=[-2,2])
plt.title(r"$\gamma_{1}$",fontsize="18")
plt.xlabel(r"x ($\theta_{E}$)",fontsize="15")
plt.ylabel(r"y ($\theta_{E}$)",rotation='horizontal',fontsize="15")
plt.xticks([-2.0,-1.5,-1.0,-0.5,0,0.5,1.0,1.5])
plt.xticks([-2.0,-1.5,-1.0,-0.5,0,0.5,1.0,1.5])
plt.imshow(g1out,extent=(-2,2,-2,2))
plt.axhline(y=0,linewidth=2,color='k',linestyle="--")
plt.axvline(x=0,linewidth=2,color='k',linestyle="--")
e1 = patches.Ellipse((0,0),2,2,color='white')
ax1.add_patch(e1)

# subplot number 2
ax2 = fig.add_subplot(1,2,2,sharey=ax1,xlim=[-2,2],ylim=[-2,2])
plt.title(r"$\gamma_{2}$",fontsize="18")
plt.xlabel(r"x ($\theta_{E}$)",fontsize="15")
ax2.yaxis.set_major_formatter( NullFormatter() )
plt.axhline(y=0,linewidth=2,color='k',linestyle="--")
plt.axvline(x=0,linewidth=2,color='k',linestyle="--")
plt.imshow(g2out,extent=(-2,2,-2,2))
e2 = patches.Ellipse((0,0),2,2,color='white')
ax2.add_patch(e2)

# subplot for colorbar
ax3 = fig.add_subplot(1,1,1)
ax3.axis('off')
cbar = plt.colorbar(ax=ax2)

plt.show()

8 回答

  • 0

    只需将颜色条放在自己的轴上,然后使用 subplots_adjust 为它腾出空间 .

    作为一个简单的例子:

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(nrows=2, ncols=2)
    for ax in axes.flat:
        im = ax.imshow(np.random.random((10,10)), vmin=0, vmax=1)
    
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    
    plt.show()
    

    enter image description here

  • 31

    按照 abevieiramota 使用轴列表的解决方案非常有效,直到您只使用一行图像,如评论中所指出的那样 . 使用合理的纵横比 figsize 会有所帮助,但仍然远非完美 . 例如:

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(9.75, 3))
    for ax in axes.flat:
        im = ax.imshow(np.random.random((10,10)), vmin=0, vmax=1)
    
    fig.colorbar(im, ax=axes.ravel().tolist())
    
    plt.show()
    

    1 x 3 image array

    colorbar function提供 shrink 参数,该参数是颜色条轴大小的缩放系数 . 它确实需要一些手动试验和错误 . 例如:

    fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.75)
    

    1 x 3 image array with shrunk colorbar

  • 7

    使用 make_axes 更容易,并提供更好的结果 . 它还提供了自定义颜色条定位的可能性 . 另请注意 subplots 选项以共享x和y轴 .

    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    
    fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
    for ax in axes.flat:
        im = ax.imshow(np.random.random((10,10)), vmin=0, vmax=1)
    
    cax,kw = mpl.colorbar.make_axes([ax for ax in axes.flat])
    plt.colorbar(im, cax=cax, **kw)
    
    plt.show()
    

  • 35

    作为一个偶然发现这个线程的初学者,我想添加一个python-for-dummies改编的 abevieiramota 's very neat answer (because I' m,在我必须查找的级别'ravel'来计算出他们的代码在做什么):

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig, ((ax1,ax2,ax3),(ax4,ax5,ax6)) = plt.subplots(2,3)
    
    axlist = [ax1,ax2,ax3,ax4,ax5,ax6]
    
    first = ax1.imshow(np.random.random((10,10)), vmin=0, vmax=1)
    third = ax3.imshow(np.random.random((12,12)), vmin=0, vmax=1)
    
    fig.colorbar(first, ax=axlist)
    
    plt.show()
    

    更像pythonic,更容易让像我这样的新手看到这里真正发生的事情 .

  • 91

    此解决方案不需要手动调整轴位置或颜色条大小,可以使用多行和单行布局,并且可以使用 tight_layout() . 它改编自gallery example,使用来自matplotlib的AxesGrid ToolboxImageGrid .

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import ImageGrid
    
    # Set up figure and image grid
    fig = plt.figure(figsize=(9.75, 3))
    
    grid = ImageGrid(fig, 111,          # as in plt.subplot(111)
                     nrows_ncols=(1,3),
                     axes_pad=0.15,
                     share_all=True,
                     cbar_location="right",
                     cbar_mode="single",
                     cbar_size="7%",
                     cbar_pad=0.15,
                     )
    
    # Add data to image grid
    for ax in grid:
        im = ax.imshow(np.random.random((10,10)), vmin=0, vmax=1)
    
    # Colorbar
    ax.cax.colorbar(im)
    ax.cax.toggle_label(True)
    
    #plt.tight_layout()    # Works, but may still require rect paramater to keep colorbar labels visible
    plt.show()
    

    image grid

  • 244

    我注意到几乎所有发布的解决方案都涉及 ax.imshow(im, ...) 并没有规范化显示在多个子图的颜色条上的颜色 . im mappable取自最后一个实例,但如果多个 im -s的值不同,该怎么办? (遗憾的是,由于3D轴,我发现了使用 plt.subplots(...) 的方法 .

    Example Plot

    如果我能以更好的方式定位彩条...(可能有更好的方法来做到这一点,但至少应该不太难以遵循 . )

    import matplotlib
    from matplotlib import cm
    import matplotlib.pyplot as plt
    import numpy as np
    from mpl_toolkits.mplot3d import Axes3D
    
    cmap = 'plasma'
    ncontours = 5
    
    def get_data(row, col):
        """ get X, Y, Z, and plot number of subplot
            Z > 0 for top row, Z < 0 for bottom row """
        if row == 0:
            x = np.linspace(1, 10, 10, dtype=int)
            X, Y = np.meshgrid(x, x)
            Z = np.sqrt(X**2 + Y**2)
            if col == 0:
                pnum = 1
            else:
                pnum = 2
        elif row == 1:
            x = np.linspace(1, 10, 10, dtype=int)
            X, Y = np.meshgrid(x, x)
            Z = -np.sqrt(X**2 + Y**2)
            if col == 0:
                pnum = 3
            else:
                pnum = 4
        print("\nPNUM: {}, Zmin = {}, Zmax = {}\n".format(pnum, np.min(Z), np.max(Z)))
        return X, Y, Z, pnum
    
    fig = plt.figure()
    nrows, ncols = 2, 2
    zz = []
    axes = []
    for row in range(nrows):
        for col in range(ncols):
            X, Y, Z, pnum = get_data(row, col)
            ax = fig.add_subplot(nrows, ncols, pnum, projection='3d')
            ax.set_title('row = {}, col = {}'.format(row, col))
            fhandle = ax.plot_surface(X, Y, Z, cmap=cmap)
            zz.append(Z)
            axes.append(ax)
    
    ## get full range of Z data as flat list for top and bottom rows
    zz_top = zz[0].reshape(-1).tolist() + zz[1].reshape(-1).tolist()
    zz_btm = zz[2].reshape(-1).tolist() + zz[3].reshape(-1).tolist()
    ## get top and bottom axes
    ax_top = [axes[0], axes[1]]
    ax_btm = [axes[2], axes[3]]
    ## normalize colors to minimum and maximum values of dataset
    norm_top = matplotlib.colors.Normalize(vmin=min(zz_top), vmax=max(zz_top))
    norm_btm = matplotlib.colors.Normalize(vmin=min(zz_btm), vmax=max(zz_btm))
    cmap = cm.get_cmap(cmap, ncontours) # number of colors on colorbar
    mtop = cm.ScalarMappable(cmap=cmap, norm=norm_top)
    mbtm = cm.ScalarMappable(cmap=cmap, norm=norm_btm)
    for m in (mtop, mbtm):
        m.set_array([])
    
    # ## create cax to draw colorbar in
    # cax_top = fig.add_axes([0.9, 0.55, 0.05, 0.4])
    # cax_btm = fig.add_axes([0.9, 0.05, 0.05, 0.4])
    cbar_top = fig.colorbar(mtop, ax=ax_top, orientation='vertical', shrink=0.75, pad=0.2) #, cax=cax_top)
    cbar_top.set_ticks(np.linspace(min(zz_top), max(zz_top), ncontours))
    cbar_btm = fig.colorbar(mbtm, ax=ax_btm, orientation='vertical', shrink=0.75, pad=0.2) #, cax=cax_btm)
    cbar_btm.set_ticks(np.linspace(min(zz_btm), max(zz_btm), ncontours))
    
    plt.show()
    plt.close(fig)
    ## orientation of colorbar = 'horizontal' if done by column
    
  • 12

    正如其他答案中所指出的,通常的想法是定义颜色条所在的轴 . 有多种方法可以做到这一点 . 尚未提及的一个是用 plt.subplots() 在子图创建时直接指定颜色条轴 . 优点是轴位置不需要手动设置,并且在所有情况下都具有自动方面,颜色条将与子图完全相同 . 即使在使用图像的许多情况下,结果也将令人满意,如下所示 .

    使用 plt.subplots() 时,使用 gridspec_kw 参数可以使色条轴比其他轴小得多 .

    fig, (ax, ax2, cax) = plt.subplots(ncols=3,figsize=(5.5,3), 
                      gridspec_kw={"width_ratios":[1,1, 0.05]})
    

    例:

    import matplotlib.pyplot as plt
    import numpy as np; np.random.seed(1)
    
    fig, (ax, ax2, cax) = plt.subplots(ncols=3,figsize=(5.5,3), 
                      gridspec_kw={"width_ratios":[1,1, 0.05]})
    fig.subplots_adjust(wspace=0.3)
    im  = ax.imshow(np.random.rand(11,8), vmin=0, vmax=1)
    im2 = ax2.imshow(np.random.rand(11,8), vmin=0, vmax=1)
    ax.set_ylabel("y label")
    
    fig.colorbar(im, cax=cax)
    
    plt.show()
    

    enter image description here

    如果绘图的方面是自动缩放的,或者图像由于它们在宽度方向上的方面而缩小(如上所述),则效果很好 . 但是,如果图像比较宽,那么结果看起来如下,这可能是不希望的 .

    enter image description here

    fix the colorbar height to the subplot height 的解决方案是使用 mpl_toolkits.axes_grid1.inset_locator.InsetPosition 来设置相对于图像子图轴的颜色条轴 .

    import matplotlib.pyplot as plt
    import numpy as np; np.random.seed(1)
    from mpl_toolkits.axes_grid1.inset_locator import InsetPosition
    
    fig, (ax, ax2, cax) = plt.subplots(ncols=3,figsize=(7,3), 
                      gridspec_kw={"width_ratios":[1,1, 0.05]})
    fig.subplots_adjust(wspace=0.3)
    im  = ax.imshow(np.random.rand(11,16), vmin=0, vmax=1)
    im2 = ax2.imshow(np.random.rand(11,16), vmin=0, vmax=1)
    ax.set_ylabel("y label")
    
    ip = InsetPosition(ax2, [1.05,0,0.05,1]) 
    cax.set_axes_locator(ip)
    
    fig.colorbar(im, cax=cax, ax=[ax,ax2])
    
    plt.show()
    

    enter image description here

  • 8

    您可以使用 figure.colorbar()ax 参数和轴列表简化Joe Kington的代码 . 来自the documentation

    ax无|父轴轴对象,新的颜色条轴的空间将被盗取 . 如果给出了一个轴列表,则将调整所有轴的大小,以便为色条轴腾出空间 .

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(nrows=2, ncols=2)
    for ax in axes.flat:
        im = ax.imshow(np.random.random((10,10)), vmin=0, vmax=1)
    
    fig.colorbar(im, ax=axes.ravel().tolist())
    
    plt.show()
    

    1

相关问题