首页 文章

如何使用一个数组的索引来定义另一个数组的__getitem__?

提问于
浏览
1

我有一个自定义类 Field 的对象,它基本上包裹着一个 numpy.ndarray 对象 . 该对象由两个输入定义:值数组( values )和切片对象( segment ),用于定义这些值应放置在某个较大数组( grid )中的位置 .

我希望能够使用 grid 的索引来访问 values 的项目 . 这应该可以通过定义自定义 Field.__getitem__ 方法来实现 .

import numpy as np

class Field:
    def __init__(self, values, segment, grid):
        if (not isinstance(segment, slice)) \\
        or (not isinstance(values, np.ndarray)) :
            raise TypeError
        if segment.step not in [1, -1]:
            raise ValueError('Segment must be continuous')
        if len(grid[segment]) != len(values):
            raise ValueError('values length must match segment')

        self.values = values
        self.segment = segment 
        self.grid = grid

    def __getitem__(self, key):
        new_key = ...  # <--- Code goes here
        return self.values[new_key]

grid = np.array([0.5, 1.5, 2.5, 3.5, 4.5])

values = np.array([42., 43., 44.])
segment = slice(2, 5)

my_field = Field(values, segment, grid)
print(grid[segment])  # output: [2.5, 3.5, 4.5]
print(my_field[2])  # Desired output: 42.
print(my_field[3])  # Desired output: 43.
print(my_field[0])  # Desired output: IndexError

关键是 segment 定义了 grid 中定义了 my_field 的位置集 . 我接近这个的方式已经证明是非常不优雅和笨拙的,并且基于定义一些布尔 index = np.zeros_like(grid, dtype=bool); index[segment] = True 的数组,然后涉及 np.cumsum(index) 的一些技巧...

How can I achieve this behavior in a simpler way?

1 回答

  • 1

    您可以使用显式步骤定义切片:

    segment = slice(2, 5, 1)
    

    这是为了确保 __init__ 中的 segment.step 返回 1 . 然后定义一个方法,检查您的输入 key 是否在相应的 range 中:

    def __getitem__(self, key):
        start, stop = self.segment.start, self.segment.stop
        new_key = key - start
        if new_key not in range(stop - start):
            raise IndexError(f'Key must be in range({start}, {stop})')
        return self.values[new_key]
    

    这给出了:

    my_field = Field(values, segment, grid)
    print(grid[segment])  # [2.5, 3.5, 4.5]
    print(my_field[2])    # 42.0
    print(my_field[3])    # 43.0
    print(my_field[0])    # IndexError: Key must be in range(2, 5)
    

相关问题