Files
tinygrad/test/test_metal.py
mesozoic-egg 992cde05d7 Metal with CDLL instead of py-objc (#6545)
* Add CDLL interface for metal

* remove two unused functions

* Cover most of the API methods

* switch to cdll

* directly call objc message in ops_metal

* keep only obj interface

* Use direct message sending for graph

* may have found a solution to the memoryview on ctypes pointer

* buf indexing bug fixed

* fix c_int

* fix c int to bytes

* fix gpu time bug

* line savings for cdll metal core

* wip

* c int bug

* fix buf casting

* dedup for c_void_p

* dedup for c_void_p

* linter fix

* remove unused stuff

* my py fix

* more mypy error fix

* line savings

* line savings

* rename send_message to msg; add __hash__ and __eq__ for dedup

* wip

* refactor

* refactor

* remove named import from ctypes

* forgot to change variable name

* file reorg, put support.py to ops_metal

* refactor

* hash error

* remove to_ns_array

* test oom exception, fix exception change

* typevar for msg

* add back dedup

* test for compile error

* move constant to graph

* move header constant around

* get label for icb buffer

* check icb label using "in"

* wip fixing mypy reported error

* fixed mypy error

* code formatting

* all_resources dedup match previous

* code formatting

* code formatting; buffer set to objc_id

* revert changes on buf for the manual release, seems like _free is not always called

* skip unless on metal, for test_metal

* fix premature mem release causing seg fault

* test_metal check for device before importing

* Buffer should only be released under _free explicitly

* mypy fixes

* change object ownership

* test compile success

* lint fixes

* remove load_library

* wrap sel_register in cache

* simplify to_struct

* swap lines

* fix type error in to_struct

* bump line to 9800

* remove pyobjc from setup.py

* command buffer should be objc_instance and get released

* stringWithUTF8String: returns objc_instance

* Use constant for MTLPipelineOptionNone

* better explanation for [MTLBuffer contents:] return

* Use dyld_find in case the path differs

* trailing whitespace

* handle exception for methods that take error:

* load /System/Library instead of /Library

* Init c_void_p with None instead of zero for error objects

---------

Co-authored-by: Mesozoic Egg <mesozoic.egg@proton.me>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-09-25 17:43:01 +08:00

55 lines
1.8 KiB
Python

import unittest
from tinygrad.device import CompileError, Device
if Device.DEFAULT=="METAL":
from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
@unittest.skipIf(Device.DEFAULT!="METAL", "Metal support required")
class TestMetal(unittest.TestCase):
def test_alloc_oom(self):
device = MetalDevice("metal")
with self.assertRaises(MemoryError):
device.allocator.alloc(10000000000000000000)
def test_compile_error(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
with self.assertRaises(CompileError):
compiler.compile("this is not valid metal")
def test_compile_success(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
ret = compiler.compile("""
#include <metal_stdlib>
using namespace metal;
kernel void E_4n1(device int* data0, const device int* data1, const device int* data2,
uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int val0 = *(data1+0);
int val1 = *(data1+1);
int val2 = *(data1+2);
int val3 = *(data1+3);
int val4 = *(data2+0);
int val5 = *(data2+1);
int val6 = *(data2+2);
int val7 = *(data2+3);
*(data0+0) = (val0+val4);
*(data0+1) = (val1+val5);
*(data0+2) = (val2+val6);
*(data0+3) = (val3+val7);
}
""")
assert ret is not None
def test_failed_newLibraryWithData(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
compiled = compiler.compile("""
#include <metal_stdlib>
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){
data0[0] = 0;
}
""")
with self.assertRaises(RuntimeError):
compiled = compiled[:40] # corrupt the compiled program
MetalProgram(device, "r_5", compiled)