作者
Ferdinand Schenck
发布于
2024 年 5 月 4 日
原文地址:https://fnands.com/blog/2024/mojo-png-parsing/
因此,在过去,虽然我一直在努力跟踪 Mojo 的开发,但到目前为止,我主要只是跟踪变更日志并编写了一些非常琐碎的代码。在我的上一篇文章中,我说过我想尝试一些更实质性的东西,所以就开始吧。
我正在研究Basalt项目,该项目试图在纯 Mojo 中构建机器学习框架,并意识到迄今为止使用的唯一图像是 MNIST,无论如何它都采用奇怪的二进制格式。为什么没有其他的呢?由于 Mojo 尚不支持加速器(如 GPU),Imagenet 可能不切实际,但现在在 CPU 上使用 CIFAR-10 等工具训练 CNN 应该相当快。 CIFAR-10 数据集可从原始来源以 pickle 存档或某种自定义二进制格式获得。我考虑过为这些编写解码器,但在 Mojo 中编写一个 PNG 解析器可能更有用,然后使用Kaggle和其他地方托管的数据集版本,或者只是使用这个包将原始数据转换为 PNG 。这样,代码就可以用来打开一般的 PNG 图像。
(纯)Mojo 中的 PNG 解析器
不要将这篇文章误认为是教程:将其视为某人在学习新语言时发现 PNG 标准的血淋淋的细节。如果您想了解有关 PNG 格式的更多信息,维基百科页面作为概述非常有帮助,W3C 页面提供了很多详细信息。
作为参考,这是用 Mojo 编写的24.3.0
,并且由于 Mojo 仍在快速变化,下面所做的很多操作可能已经过时。实际上,我基本上完成了整个帖子,24.2.1
但24.3
在发布之前就发布了,但只需要进行一些小的更改即可使其在新版本中工作。
这里的目标不是构建一个显示 PNG 的工具,而只是将它们读入可用于 ML 目的的数组(或张量)中,因此我将跳过许多更多面向显示的细节。
读入数据
首先,让我们拍摄一张测试图像。这是来自 PIL 库的测试图像,是 OG 程序员Grace Hopper的图像:这是一个相对简单的 PNG,所以它应该是一个很好的起点。

料斗
现在 Mojo 已经在 stdlib 中实现了它的 pathlib 版本,我们实际上可以检查该文件是否存在:
from pathlib import Path
test_image = Path('hopper.png')
print(test_image.exists())
True
我们还将通过 Python 导入图像,以便我们可以比较得到的输出是否与 Python 情况匹配。
from python import Python
var Image = Python.import_module('PIL.Image')
var np = Python.import_module('numpy')
py_array = np.array(Image.open("hopper.png"))
我们将读取原始字节。我本以为数据是无符号8 位整数,但 Mojo 将它们读取为有符号8 位整数。然而,有人提议改变这一点,因此这可能很快就会改变。
with open(test_image, "r") as f:
file_contents = f.read_bytes()
print(len(file_contents))
30605
检查文件头
PNG 文件在前 8 个字节中定义了一个签名,其中一部分是 ASCII 中的字母 PNG。我们定义一个小辅助函数来将字节转换为字符串:
fn bytes_to_string(list: List[Int8]) -> String:
var word = String("")
for letter in list:
word += chr(int(letter[].cast[DType.uint8]()))
return word
为了确保我们确实处理的是 PNG,我们可以检查位 1 到 3:
png_signature = file_contents[0:8]
print(bytes_to_string(png_signature[1:4]))
PNG
是的,它告诉我们这是一个 PNG 文件。
读取块
现在我们读取第一个“块”,它应该是标题。每个块由四个部分组成:块长度(4 字节)、块类型(4 字节)、块数据(无论前 4 个字节表示的长度)以及根据数据计算出的校验和(称为 CRC) (4字节)。
| 长度 | 块类型 | 块数据 | CRC |
| --- | --- | --- | --- |
| 4字节 | 4字节 | 长度字节 | 4字节 |
当使用 读取数据时read_bytes
,数据以有符号 8 位整数列表的形式出现,但我们希望将数据解释为 32 位无符号整数。以下是执行此操作的辅助函数(感谢Michael Kowalski)的帮助。
from math.bit import bswap, bitreverse
from testing import assert_true
fn bytes_to_uint32_be(owned list: List[Int8]) raises -> List[UInt32]:
assert_true(len(list) % 4 == 0, "List[Int8] length must be a multiple of 4 to convert to List[Int32]")
var result_length = len(list) // 4
# get the data pointer with ownership.
# This avoids copying and makes sure only one List owns a pointer to the underlying address.
var ptr_to_int8 = list.steal_data()
var ptr_to_uint32 = ptr_to_int8.bitcast[UInt32]()
var result = List[UInt32]()
result.data = ptr_to_uint32
result.capacity = result_length
result.size = result_length
# swap the bytes in each UInt32 to convert from big-endian to little-endian
for i in range(result_length):
result[i] = bswap(result[i])
return result
读取图像标题
文件头之后的第一个块应该始终是图像头,所以让我们看一下它:
让我们看看第一个块有多长:
read_head = 8
chunk_length = bytes_to_uint32_be(file_contents[read_head:read_head+4])[0]
print(chunk_length)
13
所以第一个块的长度是 13 个字节。我们来看看它是什么类型:
chunk_type = file_contents[read_head+4:read_head+8]
print(bytes_to_string(chunk_type))
IHDR
IHDR,确认该块是图像头。我们现在可以解析接下来的 13 个字节的标头数据来获取有关图像的信息:
start_header = int(read_head+8)
end_header = int(read_head+8+chunk_length)
header_data = file_contents[start_header:end_header]
前两个块分别告诉我们图像的宽度和高度:
print("Image width: ", bytes_to_uint32_be(header_data[0:4])[0])
print("Image height: ", bytes_to_uint32_be(header_data[4:8])[0])
Image width: 128
Image height: 128
所以我们的图像大小是 128x128 像素。
接下来的字节告诉我们每个像素的位深度、颜色类型、压缩方法、过滤方法以及图像是否隔行扫描。
print("Bit depth: ", int(header_data[8]))
print("Color type: ", int(header_data[9]))
print("Compression method: ", int(header_data[10]))
print("Filter method: ", int(header_data[11]))
print("Interlaced: ", int(header_data[12]))
Bit depth: 8
Color type: 2
Compression method: 0
Filter method: 0
Interlaced: 0
因此颜色类型为Truecolor
RGB,位深度为 8。
有趣的旁注:在PIL PngImagePlugin中有一个变更日志:
# history:
# 1996-05-06 fl Created (couldn't resist it)
# 1996-12-14 fl Upgraded, added read and verify support (0.2)
# 1996-12-15 fl Separate PNG stream parser
# 1996-12-29 fl Added write support, added getchunks
# 1996-12-30 fl Eliminated circular references in decoder (0.3)
# 1998-07-12 fl Read/write 16-bit images as mode I (0.4)
# 2001-02-08 fl Added transparency support (from Zircon) (0.5)
# 2001-04-16 fl Don't close data source in "open" method (0.6)
# 2004-02-24 fl Don't even pretend to support interlaced files (0.7)
# 2004-08-31 fl Do basic sanity check on chunk identifiers (0.8)
# 2004-09-20 fl Added PngInfo chunk container
# 2004-12-18 fl Added DPI read support (based on code by Niki Spahiev)
# 2008-08-13 fl Added tRNS support for RGB images
# 2009-03-06 fl Support for preserving ICC profiles (by Florian Hoech)
# 2009-03-08 fl Added zTXT support (from Lowell Alleman)
# 2009-03-29 fl Read interlaced PNG files (from Conrado Porto Lopes Gouvua)
我喜欢 2004 年的评论:Don't even pretend to support interlaced files
在 PNG 读取添加到 PIL 大约 13 年后,隔行扫描 PNG 才得到支持。我有一种感觉,我不会在这篇文章中处理交错文件......
该块的最后部分是 CRC32 值,即 32 位循环冗余校验。我不会透露太多细节,但它基本上是添加一个错误检测代码来检测块数据是否损坏。通过检查提供的 CRC32 值与我们自己计算的值,我们可以确保我们正在读取的数据没有损坏。
start_crc = int(read_head+8+chunk_length)
end_crc = int(start_crc+4)
header_crc = bytes_to_uint32_be(file_contents[start_crc:end_crc])[0]
print("CRC: ", hex(header_crc))
CRC: 0x4c5cf69c
我们需要一些代码来计算 CRC32 值。
这不是最有效的实现,但很简单。
我可能会写一篇后续文章,更详细地解释它的作用。
fn CRC32(owned data: List[SIMD[DType.int8, 1]]) -> SIMD[DType.uint32, 1]:
var crc32: UInt32 = 0xffffffff
for byte in data:
crc32 = (bitreverse(byte[]).cast[DType.uint32]() << 24) ^ crc32
for i in range(8):
if crc32 & 0x80000000 != 0:
crc32 = (crc32 << 1) ^ 0x04c11db7
else:
crc32 = crc32 << 1
return bitreverse(crc32^0xffffffff)
print(hex(CRC32(file_contents[read_head+4:end_header])))
0x4c5cf69c
太好了,CRC 十六进制匹配,所以我们知道 IHDR 块中的数据是好的。
读取更多块
现在,读取每个块的部分将会重复,因此让我们定义一个名为 的结构体Chunk
来保存块中包含的信息,以及一个为我们解析块并返回组成部分的函数:
struct Chunk(Movable, Copyable):
var length: UInt32
var type: String
var data: List[Int8]
var crc: UInt32
var end: Int
fn __init__(inout self, length: UInt32, chunk_type: String, data : List[Int8], crc: UInt32, end: Int):
self.length = length
self.type = chunk_type
self.data = data
self.crc = crc
self.end = end
fn __moveinit__(inout self, owned existing: Chunk):
self.length = existing.length
self.type = existing.type
self.data = existing.data
self.crc = existing.crc
self.end = existing.end
fn __copyinit__(inout self, existing: Chunk):
self.length = existing.length
self.type = existing.type
self.data = existing.data
self.crc = existing.crc
self.end = existing.end
def parse_next_chunk(owned data: List[Int8], read_head: Int) -> Chunk:
chunk_length = bytes_to_uint32_be(data[read_head:read_head+4])[0]
chunk_type = bytes_to_string(data[read_head+4:read_head+8])
start_data = int(read_head+8)
end_data = int(start_data+chunk_length)
chunk_data = data[start_data:end_data]
start_crc = int(end_data)
end_crc = int(start_crc+4)
chunk_crc = bytes_to_uint32_be(data[start_crc:end_crc])[0]
# Check CRC
assert_true(CRC32(data[read_head+4:end_data]) == chunk_crc, "CRC32 does not match")
return Chunk(length=chunk_length, chunk_type=chunk_type, data=chunk_data, crc=chunk_crc, end=end_crc)
在块创建期间,会计算块数据的 CRC32 值,如果它与预期不同,则会引发问题。
让我们测试一下它是否解析 IHDR 块:
var header_chunk = parse_next_chunk(file_contents, 8)
print(header_chunk.type)
read_head = header_chunk.end
IHDR
接下来的几个块称为“辅助块”,并不是绝对必要的。它们包含可用于渲染图像的图像属性(如gamma ):
var gamma_chunk = parse_next_chunk(file_contents, read_head)
print(gamma_chunk.type)
read_head = gamma_chunk.end
gAMA
var chromacity_chunk = parse_next_chunk(file_contents, read_head)
print(chromacity_chunk.type)
read_head = chromacity_chunk.end
cHRM
var background_chunk = parse_next_chunk(file_contents, read_head)
print(background_chunk.type)
read_head = background_chunk.end
bKGD
var pixel_size_chunk = parse_next_chunk(file_contents, read_head)
print(pixel_size_chunk.type)
read_head = pixel_size_chunk.end
pHYs
图像数据块
IDAT 块(实际上每个图像可以有多个)包含实际的图像数据。
var image_data_chunk = parse_next_chunk(file_contents, read_head)
print(image_data_chunk.type)
read_head = image_data_chunk.end
IDAT
减压
PNG 使用DEFLATE压缩算法进行压缩(无损)。
PNG 首先被过滤,然后被压缩,但是当我们解码时,我们需要首先解压缩数据并撤消过滤器。
下一节就是为什么我在“纯粹的”Mojo 中说:我考虑过实现它,但这将是相当大量的工作,所以我希望其他人这样做,或者我可以在未来。
因此,目前我通过 Mojo 的外部函数接口 (FFI) 使用该算法的zlib版本。
以下内容是我根据 Ilya Lubenets 和 Jack Clayton 之间的一条 Mojo 不和谐内容稍微改编而来的:
from sys import ffi
alias Bytef = Scalar[DType.int8]
alias uLong = UInt64
alias zlib_type = fn(
_out: Pointer[Bytef],
_out_len: Pointer[UInt64],
_in: Pointer[Bytef],
_in_len: uLong
) -> Int
fn log_zlib_result(Z_RES: Int, compressing: Bool = True) raises -> NoneType:
var prefix: String = ''
if not compressing:
prefix = "un"
if Z_RES == 0:
print('OK ' + prefix.upper() + 'COMPRESSING: Everything ' + prefix + 'compressed fine')
elif Z_RES == -4:
raise Error('ERROR ' + prefix.upper() + 'COMPRESSING: Not enought memory')
elif Z_RES == -5:
raise Error('ERROR ' + prefix.upper() + 'COMPRESSING: Buffer have not enough memory')
else:
raise Error('ERROR ' + prefix.upper() + 'COMPRESSING: Unhandled exception')
fn uncompress(data: List[Int8], quiet: Bool = True) raises -> List[UInt8]:
var data_memory_amount: Int = len(data)*10 # This can be done better.
var handle = ffi.DLHandle('')
var zlib_uncompress = handle.get_function[zlib_type]('uncompress')
var uncompressed = Pointer[Bytef].alloc(data_memory_amount)
var compressed = Pointer[Bytef].alloc(len(data))
var uncompressed_len = Pointer[uLong].alloc(1)
memset_zero(uncompressed, data_memory_amount)
memset_zero(uncompressed_len, 1)
uncompressed_len[0] = data_memory_amount
for i in range(len(data)):
compressed.store(i, data[i])
var Z_RES = zlib_uncompress(
uncompressed,
uncompressed_len,
compressed,
len(data),
)
if not quiet:
log_zlib_result(Z_RES, compressing=False)
print('Uncompressed length: ' + str(uncompressed_len[0]))
# Can probably do something more efficient here with pointers, but eh.
var res = List[UInt8]()
for i in range(uncompressed_len[0]):
res.append(uncompressed[i].cast[DType.uint8]())
return res
鼓声......让我们看看这是否有效:
uncompressed_data = uncompress(image_data_chunk.data, quiet=False)
OK UNCOMPRESSING: Everything uncompressed fine
Uncompressed length: 49280
现在我们有一个未压缩字节的列表。然而,这些还不是像素值。未压缩的数据长度为 49280 字节。我们知道我们有一个 8 位颜色深度的 RGB 图像,所以期望128*128*3\=49152像素数据的字节数。请注意49280-49152\=128,并且我们的图像具有 的形状(128, 128)
。
这些额外的 128 字节是为了让我们知道使用什么过滤器将每行像素(称为扫描线)的字节值转换为可以有效压缩的内容。
取消过滤
PNG 规范指定的可能的过滤器类型有:
Type Name
0 None
1 Sub
2 Up
3 Average
4 Paeth
规范中有一些微妙的地方需要注意,例如这些过滤器是按字节应用的,而不是按像素值应用的。对于 8 位颜色深度,这并不重要,但对于 16 位,这意味着像素的第一个字节(MSB 或最高有效字节)将与第二个字节(LSB 或最低有效字节)分开计算。我不会在这里太深入地讨论所有细节,但您可以在此处阅读规范的详细信息。
我将简要解释每个过滤器背后的基本思想:
- 0:无
- 1:子
- 2:向上
- 3:平均:
- 4:帕斯:
- 三个相邻像素(左、上、左上)用于计算从该像素中减去的值。它比其他三个更复杂一些。
因此,在对过滤后的数据进行解码时,我们需要反转上述操作来重新获得像素值。
现在我们知道了,让我们看看第一个字节值:
print(uncompressed_data[0])
1
所以我们在这里处理过滤器类型 1。让我们解码第一行:
var filter_type = uncompressed_data[0]
var scanline = uncompressed_data[1:128*3+1]
# Decoded image data
var result = List[UInt8](capacity=128*3)
# take the first pixels as 0
var left: UInt8 = 0
var pixel_size: Int = 3
var offset: Int = 1
for i in range(len(scanline)):
if i >= pixel_size:
left = result[i-pixel_size]
# The specification specifies that the result is modulo 256
# Silimar to the C implementation, we can just add the left pixel to the current pixel,
# and the result will be modulo 256 due to overflow
result.append((uncompressed_data[i + offset] + left))
让我们确认我们解码的行与 PIL 所做的相同:
for i in range(128):
for j in range(3):
assert_true(result[i*3+j] == py_array[0][i][j].__int__(), "Pixel values do not match")
现在我们已经有了大致的了解,让我们更笼统地写一下这个,并做其他过滤器。
有关如何选择过滤器的想法,请阅读此 stackoverflow 帖子及其指向的资源:PNG 编码器如何选择要使用的过滤器?
我将这些作为采用 16 位有符号整数的函数来完成。这对于 Paeth 过滤器来说非常重要,其中标准规定:
PaethPredictor 函数内的计算必须准确执行,不得溢出。算术模 256 仅用于从目标字节值中减去函数结果的最后一步。
所以基本上我们需要保持更高的精度,然后在最后转换回字节。
我的实现基于用 C 编写的png 解码器的iPXE实现。
from math import abs
fn undo_trivial(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
return current
fn undo_sub(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
return current + left
fn undo_up(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
return current + above
fn undo_average(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
return current + ((above + left) >> 1) # Bitshift is equivalent to division by 2
fn undo_paeth(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
var peath: Int16 = left + above - above_left
var peath_a: Int16 = abs(peath - left)
var peath_b: Int16 = abs(peath - above)
var peath_c: Int16 = abs(peath - above_left)
if ( peath_a <= peath_b ) and ( peath_a <= peath_c ):
return (current + left)
elif ( peath_b <= peath_c ):
return (current + above)
else:
return (current + above_left)
fn undo_filter(filter_type: UInt8, current: UInt8, left: UInt8 = 0, above: UInt8 = 0, above_left: UInt8 = 0) raises -> UInt8:
var current_int = current.cast[DType.int16]()
var left_int = left.cast[DType.int16]()
var above_int = above.cast[DType.int16]()
var above_left_int = above_left.cast[DType.int16]()
var result_int: Int16 = 0
if filter_type == 0:
result_int= undo_trivial(current_int, left_int, above_int, above_left_int)
elif filter_type == 1:
result_int = undo_sub(current_int, left_int, above_int, above_left_int)
elif filter_type == 2:
result_int = undo_up(current_int, left_int, above_int, above_left_int)
elif filter_type == 3:
result_int = undo_average(current_int, left_int, above_int, above_left_int)
elif filter_type == 4:
result_int = undo_paeth(current_int, left_int, above_int, above_left_int)
else:
raise Error("Unknown filter type")
return result_int.cast[DType.uint8]()
对于这个undo_filter
函数,我试图将单独的过滤器添加到某种元组或列表中,这样我就可以对它们进行索引(因此是统一的签名),但还无法弄清楚如何在 Mojo 中执行此操作。
因此,让我们将这些应用到整个图像,并确认我们得到的结果与从 Python 得到的结果相同:
# Decoded image data
# take the first pixels as 0
var pixel_size: Int = 3
# Initialize the previous scanline to 0
var previous_result = List[UInt8](0*128)
for line in range(128):
var offset = 1 + 1*line + line * 128 * 3
var left: UInt8 = 0
var above_left: UInt8 = 0
#var left: UInt8 = 0
var result = List[UInt8](capacity=128*3)
var scanline = uncompressed_data[offset:offset+128*3]
var filter_type = uncompressed_data[offset - 1]
for i in range(len(scanline)):
if i >= pixel_size:
left = result[i-pixel_size]
above_left = previous_result[i-pixel_size]
result.append(undo_filter(filter_type, uncompressed_data[i + offset], left, previous_result[i], above_left))
previous_result = result
for i in range(128):
for j in range(3):
assert_true(result[i*3+j] == py_array[line][i][j].__int__(), "Pixel values do not match")
就是这样。如果上面的代码运行,则意味着我们已经成功解析了 PNG 文件,并且至少获得了与使用 Pillow 相同的数据。
创建张量
现在理想情况下我们希望将上述内容转换为张量。
让我们编写一个函数来解析图像数据并返回一个张量。
from tensor import Tensor, TensorSpec, TensorShape
from utils.index import Index
from random import rand
var height = 128
var width = 128
var channels = 3
# Declare the grayscale image.
var spec = TensorSpec(DType.uint8, height, width, channels)
var tensor_image = Tensor[DType.uint8](spec)
# Decoded image data
# take the first pixels as 0
var pixel_size: Int = 3
# Initialize the previous scanline to 0
var previous_result = List[UInt8](0*128)
for line in range(128):
var offset = 1 + 1*line + line * 128 * 3
var left: UInt8 = 0
var above_left: UInt8 = 0
#var left: UInt8 = 0
var result = List[UInt8](capacity=128*3)
var scanline = uncompressed_data[offset:offset+128*3]
var filter_type = uncompressed_data[offset - 1]
for i in range(len(scanline)):
if i >= pixel_size:
left = result[i-pixel_size]
above_left = previous_result[i-pixel_size]
result.append(undo_filter(filter_type, uncompressed_data[i + offset], left, previous_result[i], above_left))
previous_result = result
for i in range(128):
for j in range(3):
tensor_image[Index(line, i, j)] = result[i*3+j]
我不完全确定为什么我需要Index
在设置项目时使用,但是在获取时我可以只提供索引:
print(tensor_image[0,1,2])
print(py_array[0][1][2])
62
62
我们终于得到它了。我很快就会把它们放在一起,但让我们快速完成文件的解析。
最后的块
此时还有一些块:包含一些注释的文本块,以及表示文件末尾的结束块:
var text_chunk_1 = parse_next_chunk(file_contents, read_head)
print(text_chunk_1.type)
read_head = text_chunk_1.end
print(bytes_to_string(text_chunk_1.data))
tEXt
comment
var text_chunk_2 = parse_next_chunk(file_contents, read_head)
print(text_chunk_2.type)
read_head = text_chunk_2.end
print(bytes_to_string(text_chunk_2.data))
tEXt
date:create
var text_chunk_3 = parse_next_chunk(file_contents, read_head)
print(text_chunk_3.type)
read_head = text_chunk_3.end
print(bytes_to_string(text_chunk_3.data))
tEXt
date:modify
上面的文本块实际上有更多信息,但似乎是 UTF-8 编码的,而 Mojo 似乎只处理 ASCII?
var end_chunk = parse_next_chunk(file_contents, read_head)
print(end_chunk.type)
read_head = end_chunk.end
IEND
把它们放在一起。
让我们更好地封装上面的逻辑。我正在考虑类似于 PIL 的东西。
我们从一个名为的结构开始PNGImage
fn bytes_to_hex_string(list: List[Int8]) -> String:
var word = String("")
for letter in list:
word += hex(int(letter[].cast[DType.uint8]()))
return word
fn determine_file_type(data: List[Int8]) -> String:
# Is there a better way? Probably
if bytes_to_hex_string(data[0:8]) == String("0x890x500x4e0x470xd0xa0x1a0xa"):
return "PNG"
else:
return "Unknown"
struct PNGImage:
var image_path: Path
var raw_data: List[Int8]
var width: Int
var height: Int
var channels: Int
var bit_depth: Int
var color_type: Int
var compression_method: UInt8
var filter_method: UInt8
var interlaced: UInt8
var data: List[UInt8]
var data_type: DType
fn __init__(inout self, file_name: Path) raises:
self.image_path = file_name
assert_true(self.image_path.exists(), "File does not exist")
with open(self.image_path , "r") as image_file:
self.raw_data = image_file.read_bytes()
assert_true(determine_file_type(self.raw_data) == "PNG", "File is not a PNG. Only PNGs are supported")
var read_head = 8
var header_chunk = parse_next_chunk(self.raw_data, read_head)
read_head = header_chunk.end
self.width = int(bytes_to_uint32_be(header_chunk.data[0:4])[0])
self.height = int(bytes_to_uint32_be(header_chunk.data[4:8])[0])
self.bit_depth = int(header_chunk.data[8].cast[DType.uint32]())
self.color_type = int(header_chunk.data[9])
self.compression_method = header_chunk.data[10].cast[DType.uint8]()
self.filter_method = header_chunk.data[11].cast[DType.uint8]()
self.interlaced = header_chunk.data[12].cast[DType.uint8]()
# There must be a better way to do this
var color_type_dict = Dict[Int, Int]()
color_type_dict[0] = 1
color_type_dict[2] = 3
color_type_dict[3] = 1
color_type_dict[4] = 2
color_type_dict[6] = 4
self.channels = color_type_dict[self.color_type]
if self.bit_depth == 8:
self.data_type = DType.uint8
elif self.bit_depth == 16:
self.data_type = DType.uint16
else:
raise Error("Unknown bit depth")
# Check color_type and bit_depth
assert_true(self.color_type == 2, "Only RGB images are supported")
assert_true(self.bit_depth == 8, "Only 8-bit images are supported")
# Check if the image is interlaced
assert_true(self.interlaced == 0, "Interlaced images are not supported")
# Chack compression method
assert_true(self.compression_method == 0, "Compression method not supported")
# Scan over chunks until end found
var ended = False
var data_found = False
var uncompressd_data = List[Int8]()
while read_head < len(self.raw_data) and not ended:
var chunk = parse_next_chunk(self.raw_data, read_head)
read_head = chunk.end
if chunk.type == "IDAT":
uncompressd_data.extend(chunk.data)
data_found = True
elif chunk.type == "IEND":
ended = True
assert_true(ended, "IEND chunk not found")
assert_true(data_found, "IDAT chunk not found")
self.data = uncompress(uncompressd_data)
# In case the filename is passed as a string
fn __init__(inout self, file_name: String) raises:
self.__init__(Path(file_name))
fn to_tensor(self) raises -> Tensor[DType.uint8]:
var spec = TensorSpec(DType.uint8, self.height, self.width, self.channels)
var tensor_image = Tensor[DType.uint8](spec)
var pixel_size: Int = self.channels * (self.bit_depth // 8)
# Initialize the previous scanline to 0
var previous_result = List[UInt8](0*self.width)
for line in range(self.height):
var offset = 1 + 1*line + line * self.width * pixel_size
var left: UInt8 = 0
var above_left: UInt8 = 0
var result = List[UInt8](capacity=self.width*pixel_size)
var scanline = self.data[offset:offset+self.width*pixel_size]
var filter_type = self.data[offset - 1]
for i in range(len(scanline)):
if i >= pixel_size:
left = result[i-pixel_size]
above_left = previous_result[i-pixel_size]
result.append(undo_filter(filter_type, self.data[i + offset], left, previous_result[i], above_left))
previous_result = result
for i in range(self.width):
for j in range(self.channels):
tensor_image[Index(line, i, j)] = result[i*self.channels+j]
return tensor_image
好吧,这不是最漂亮的,但让我们看看它是否有效:
var hopper = PNGImage(test_image)
hopper_tensor = hopper.to_tensor()
for i in range(hopper.height):
for j in range(hopper.width):
for c in range(hopper.channels):
assert_true(hopper_tensor[i, j, c] == py_array[i][j][c].__int__(), "Pixel values do not match")
如果上面的代码运行了,那么就意味着我们正确地读取了图像!
让我们尝试一下 CIFAR-10 数据集中的 PNG 图像:
cifar_image = Path('114_automobile.png')
var cifar = PNGImage(cifar_image)
cifar_tensor = cifar.to_tensor()
py_cifar = np.array(Image.open('114_automobile.png'))
for i in range(cifar.height):
for j in range(cifar.width):
for c in range(cifar.channels):
assert_true(cifar_tensor[i, j, c] == py_cifar[i][j][c].__int__(), "Pixel values do not match")
这也有效!现在我们应该能够读取 CIFAR-10 数据集!
我对上面的实现有几个问题,即:
- 如何处理16位图像?我们需要一个单独的函数吗
to_tensor_16
?
- 结构体是执行此操作的正确方法吗?
这是我觉得我还不知道在 Mojo 中执行此操作的惯用方法是什么的地方之一。你可能会说,🪄Mojical🪄方式。 Mojo 还很年轻,我不确定是否已经出现了一种惯用的方式。
结论
对于博客文章来说,读取 PNG 是一个非常有趣的主题。这让我真正接触到了 Mojo 中一些较低层次的概念,我觉得我以前没有完全掌握这些概念。
我承认,这最终比我预期的工作量要多一些。套用肯尼迪的一句话:

jfk_meme
Mojo 在短短几个月内取得的进步令人印象深刻:去年 9 月我试图编写一些 Mojo 时,感觉很难做任何实际的事情,而现在该语言似乎非常有用。
还有一些事情我需要习惯。一件事是我总是觉得我“需要”编写fn
函数,而不是def
函数。在编写库等时,这是一个很好的做法,但它让我想知道:什么时候编写def
样式函数合适,因为fn
总是更安全、更高效?
附录
我对这篇博客文章中的代码进行了一些重构,并将其写入一个我称为Mimage 的库中。目标是能够在纯 Mojo 中读取和写入常见的图像格式,而不必调用 Python 或 C。目前 Mimage 仍然需要一些 C 库来进行解压缩步骤,但我希望这些可以在纯 Mojo 中使用。莫乔很快。
接下来的步骤可能是添加对 Mimage 对 16 位 PNG 和 JPEG 的支持。长期目标是能够读取和写入与 Python 的Pillow相同的所有图像格式,但这可能需要很长时间才能实现。由于我从事 ML 方面的工作,因此我将尝试重点关注 ML 所需的格式和功能,例如能够读取 Imagenet 数据集中的所有图像。