我正在尝试注释热图。所述matplotlib文档呈现的示例,其建议创建一个辅助函数来格式化注释。我觉得必须有一种更简单的方式来做我想做的事。我可以在热图的方框内进行注释,但是在编辑热图的范围时,这些文本会更改位置。我的问题是如何extent
在ax.imshow(...)
同时ax.text(...)
用于注释正确位置的同时使用。下面是一个示例:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
def get_manhattan_distance_matrix(coordinates):
shape = (coordinates.shape[0], 1, coordinates.shape[1])
ct = coordinates.reshape(shape)
displacement = coordinates - ct
return np.sum(np.abs(displacement), axis=-1)
x = np.arange(11)[::-1]
y = x.copy()
coordinates = np.array([x, y]).T
distance_matrix = get_manhattan_distance_matrix(coordinates)
# print("\n .. {} COORDINATES:\n{}\n".format(coordinates.shape, coordinates))
# print("\n .. {} DISTANCE MATRIX:\n{}\n".format(distance_matrix.shape, distance_matrix))
norm = Normalize(vmin=np.min(distance_matrix), vmax=np.max(distance_matrix))
这是修改的值的地方extent
。
extent = (np.min(x), np.max(x), np.min(y), np.max(y))
# extent = None
根据matplotlib文档,默认extent
值为None
。
fig, ax = plt.subplots()
handle = ax.imshow(distance_matrix, cmap='plasma', norm=norm, interpolation='nearest', origin='upper', extent=extent)
kws = dict(ha='center', va='center', color='gray', weight='semibold', fontsize=5)
for i in range(len(distance_matrix)):
for j in range(len(distance_matrix[i])):
if i == j:
ax.text(j, i, '', **kws)
else:
ax.text(j, i, distance_matrix[i, j], **kws)
plt.show()
plt.close(fig)
一个人可以通过修改来生成两个图形extent
-只需取消注释行,然后注释未注释行即可。这两个数字如下:
可以看到,通过设置extent
,像素位置发生了变化,进而改变了ax.text(...)
手柄的位置。是否有解决此问题的简单解决方案-也就是说,设置一个任意值,extent
并且仍然在每个框中居中放置文本句柄?
当extent=None
为时,x和y的有效范围均为-0.5至10.5。因此,中心位于整数位置。将范围设置为0到10不会与像素对齐。您必须乘以10/11才能使它们正确。
最好的方法是设置extent = (np.min(x)-0.5, np.max(x)+0.5, np.min(y)-0.5, np.max(y)+0.5)
使中心回到整数位置。
另请注意,默认情况下,图像从顶部开始显示,并且y轴反转。如果您更改范围,则需要使图像直立ax.imshow(..., origin='lower')
。(0,0像素应为示例图中的蓝色像素。)
要将文本放置在像素的中心,可以向水平索引添加0.5,除以像素的宽度,然后乘以x轴的差。并对y轴进行类似的计算。为了获得更好的可读性,可以使文本颜色取决于像素颜色。
# ...
extent = (np.min(x), np.max(x), np.min(y), np.max(y))
x0, x1, y0, y1 = extent
fig, ax = plt.subplots()
handle = ax.imshow(distance_matrix, cmap='plasma', norm=norm, interpolation='nearest', origin='lower', extent=extent)
kws = dict(ha='center', va='center', weight='semibold', fontsize=5)
height = len(distance_matrix)
width = len(distance_matrix[0])
for i in range(height):
for j in range(width):
if i != j:
val = distance_matrix[i, j]
ax.text(x0 + (j + 0.5) / width * (x1 - x0), y0 + (i + 0.5) / height * (y1 - y0),
f'{val}\n{i},{j}', color='white' if norm(val) < 0.6 else 'black', **kws)
plt.show()
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句