Implementing NumPy style matrix slicing in Mojo🔥
NumPy 和 Pandas 等 Python 库之所以如此受欢迎,是因为它们提高了工程师和数据科学家的工作效率。当我第一次使用 NumPy 时,我被震撼的不是它支持的大量数学函数,而是我可以轻松操作数组和矩阵,特别是从大型表格数据集中切片行、列或子矩阵。

切片是数据科学中一项非常基本的操作。例如,以一个大型矩阵为例,其中的行代表单个客户,列代表客户属性,如人口统计、购买偏好等。若要计算特定行的基本汇总统计信息,需要对特定列进行切片并计算统计信息。要计算跨客户属性的相似性度量值(例如对它们进行聚类),您需要有效地对行进行切片并计算成对的相似性度量值。您可能还希望对特定列进行切片,以从数据集中删除不需要的列。
在这篇博文中,我将展示如何在 Mojo 中从头开始实现 NumPy 样式切片功能。通过此示例,你将演示如何:
- 创建支持 NumPy 样式切片操作的自定义 Matrix 数据结构
- 使用线性和跨步内存访问进行行和列切片
- 处理 NumPy 样式切片边缘情况(例如负索引 [-2:,:] 和开放索引 [:,4])
- 实现子矩阵的矢量化和并行化切片
这篇博文中使用的代码示例可在 GitHub 上的 Mojo 存储库中找到。
在进入实现细节之前,让我们先看一下最终结果。我们将构建一个可以执行切片操作的 Matrix 数据结构。
let mat = Matrix(8,5)
mat.print()
输出:
[[0.0850 0.8916 0.1896 0.3980 0.7435]
[0.5603 0.8095 0.5117 0.9950 0.9666]
[0.4260 0.6529 0.9615 0.8579 0.2940]
[0.4146 0.5148 0.7897 0.5442 0.0936]
[0.4322 0.8449 0.7728 0.1918 0.7803]
[0.1813 0.5791 0.3141 0.4119 0.9923]
[0.1639 0.3348 0.0762 0.1745 0.0372]
[0.4674 0.6741 0.0667 0.3897 0.1653]]
Matrix: 8 x 5 , DType: float32
我们可以对矩阵进行切片以提取第 2、3、4 行和最后 3 列的元素:
mat[2:4,-3:].print()
输出:
[[0.9615 0.8579 0.2940]
[0.7897 0.5442 0.0936]]
Matrix: 2 x 3 , DType: float32
验证切片子矩阵中的元素是否与原始矩阵中的元素匹配

现在,让我们尝试一些高级切片:
mat[1::2,::2].print()
输出:
[[0.5603 0.5117 0.9666]
[0.4146 0.7897 0.0936]
[0.1813 0.3141 0.9923]
[0.4674 0.0667 0.1653]]
Matrix: 4 x 3 , DType: float32
验证切片子矩阵中的元素是否与原始矩阵中的元素匹配

创建支持切片的自定义 Matrix 数据结构
让我们从 Mojo 中的基本 Matrix 结构开始。Mojo 中的结构类似于 Python 中的类,但 Mojo 结构是静态的,并且受编译时限制。我们的矩阵结构具有以下功能:
初始化矩阵:init()
允许将切片矩阵/向量复制到新变量:copyinit()
处理切片边缘情况:adjust_slice()
4 x 重载 getitem() 方法,支持索引、仅行切片、仅列切片以及行列切片的组合
方便的 print() 函数来可视化矩阵
矩阵结构的骨架如下所示:
struct Matrix[dtype: DType = DType.float32]:
var dim0: Int
var dim1: Int
var _data: DTypePointer[dtype]
...
fn __init__(inout self, *dims: Int):
...
fn __copyinit__(inout self, other: Self):
...
fn _adjust_slice_(self, inout span: slice, dim: Int):
...
fn __getitem__(self, x: Int, y: Int) -> SIMD[dtype,1]:
...
fn __getitem__(self, owned row_slice: slice, col: Int) -> Self:
...
fn __getitem__(self, row: Int, owned col_slice: slice) -> Self:
...
fn __getitem__(self, owned row_slice: slice, owned col_slice: slice) -> Self:
...
fn print(self, prec: Int=4)->None:
...
我们将花时间讨论的三个概念是:
定义矩阵的结构变量:它的维度 dim0(行)和 dim1(列)以及我们使用 DTypePointer 分配和访问的数据
处理切片边缘情况的 adjust_slice() 函数
执行所有切片魔术的 getitem() 函数
调整切片表达式以处理边缘情况
在 Python 和 Mojo 中,getitem() 是一种特殊类型的函数,可以对对象进行索引。例如,当您想访问 row=0 和 column=3 处的元素时,您可以执行以下操作:matrix[0,3],这相当于:arr.getitem(0,3)。在我们的示例中,getitem() 被重载以支持整数和切片索引的各种组合,在这里我们将讨论以下支持切片变量的 getitem()
fn __getitem__(self, owned row_slice: slice, owned col_slice: slice) -> Self:
有两个参数 row_slice 和 col_slice 类型为 slice。当您使用切片表达式(如 [:,3:4])时,会生成一个切片对象并将其传递给 getitem(),这就是为什么我们有 row_slice 和 col_slice 类型为 slice 的变量。slice 对象允许您访问请求的切片项和切片的长度,但它不处理边缘情况。如果切片表达式具有负数或超过矩阵中的行数或列数,则必须调整切片以处理这些情况。这就是 adjust_slice() 函数的作用,也是我们对 row_slice 和 col_slice 变量调用的第一个函数。
下面是 getitem() 的屏幕截图,突出显示了对 adjust_slice() 函数的调用以处理切片边缘情况:

让我们仔细看看 adjust_slice()
fn _adjust_slice_(self, inout span: slice, dim: Int):
if span.start < 0:
span.start = dim + span.start
if not span._has_end():
span.end = dim
elif span.end < 0:
span.end = dim + span.end
if span.end > dim:
span.end = dim
if span.end < span.start:
span.start = 0
span.end = 0
adjust_slice() 函数支持行和列切片调整,工作原理如下。我们检查
- 如果起始值为负数,我们从尺寸中减去它
- 如果切片表达式没有结束,即使用 :,我们分配一个结束
- 如果端数为负数,则从尺寸大小中减去它
- 如果端部超过尺寸大小,则将其固定为尺寸大小
- 如果结束时间小于开始时间,则使切片无效
调整切片后,我们就可以开始对矩阵进行切片了。
使用线性和跨步内存访问对行和列进行切片
首先,让我们回顾一下我们的内存访问策略。矩阵结构包括存储在内存中并使用指针_data访问的原始数据,它包括存储在变量 dim0 中(行维度)和 dim1(列维度)中的维度信息。它们共同包含描述矩阵的完整信息。

访问一行中的所有元素

行访问相对简单。我们获取指向 (行数到切片) * (列数) 定义的第一个元素的指针。以下代码行获取指向第一个元素的指针:

一旦您拥有指向数据的指针并知道要加载的元素数(=矩阵中的列数),就可以使用 _data.simd_load 加载simd_width数据块。当您想要加载非顺序数据时,事情会变得更加复杂,我们将在下面看到。
访问列中的所有元素

列数据不按顺序存储在内存中,但它们的间隔相等。就像之前一样,我们将获取指向第一个元素的指针:

列中每个元素之间的间距也称为步幅长度,等于上述示例中的列数。要以跨步方式加载数据,可以使用 strided_load(src_ptr, stride)。
访问任意行和列切片

现在,事情开始变得更加有趣。在这里,我们有行和列切片的组合:[1::2,::2]。对于每一行,我们必须:
- 获取指向第一个元素的指针
- 使用列切片作为步幅的strided_load获取数据
在上面的示例中,要访问元素:[[4,6][12,14]],我们遵循两个步骤,如下图所示:

由于每一行都可以独立计算,因此我们可以跨行并行化以获得加速!
所有这些都汇集在 getitem() 函数中,如下面的注释所示:

结论
使用 Mojo 的好处之一是它不会受到两种语言问题的困扰。大多数高性能 Python 库都是复杂 C 和 C++ 实现的瘦 Python 包装器,NumPy 也不例外。这使得 Python 程序员很难理解引擎盖下的内容,如果您不熟悉这两种语言,则更难扩展。Mojo 的外观和读物都像 Python,并提供较低级别的功能来管理内存、添加类型等。您不必使用两种不同的语言来编写高性能代码,您可以在 Mojo 中完成所有操作!