mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
* Add CDLL interface for metal * remove two unused functions * Cover most of the API methods * switch to cdll * directly call objc message in ops_metal * keep only obj interface * Use direct message sending for graph * may have found a solution to the memoryview on ctypes pointer * buf indexing bug fixed * fix c_int * fix c int to bytes * fix gpu time bug * line savings for cdll metal core * wip * c int bug * fix buf casting * dedup for c_void_p * dedup for c_void_p * linter fix * remove unused stuff * my py fix * more mypy error fix * line savings * line savings * rename send_message to msg; add __hash__ and __eq__ for dedup * wip * refactor * refactor * remove named import from ctypes * forgot to change variable name * file reorg, put support.py to ops_metal * refactor * hash error * remove to_ns_array * test oom exception, fix exception change * typevar for msg * add back dedup * test for compile error * move constant to graph * move header constant around * get label for icb buffer * check icb label using "in" * wip fixing mypy reported error * fixed mypy error * code formatting * all_resources dedup match previous * code formatting * code formatting; buffer set to objc_id * revert changes on buf for the manual release, seems like _free is not always called * skip unless on metal, for test_metal * fix premature mem release causing seg fault * test_metal check for device before importing * Buffer should only be released under _free explicitly * mypy fixes * change object ownership * test compile success * lint fixes * remove load_library * wrap sel_register in cache * simplify to_struct * swap lines * fix type error in to_struct * bump line to 9800 * remove pyobjc from setup.py * command buffer should be objc_instance and get released * stringWithUTF8String: returns objc_instance * Use constant for MTLPipelineOptionNone * better explanation for [MTLBuffer contents:] return * Use dyld_find in case the path differs * trailing whitespace * handle exception for methods that take error: * load /System/Library instead of /Library * Init c_void_p with None instead of zero for error objects --------- Co-authored-by: Mesozoic Egg <mesozoic.egg@proton.me> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
import unittest
|
|
from tinygrad.device import CompileError, Device
|
|
if Device.DEFAULT=="METAL":
|
|
from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
|
|
@unittest.skipIf(Device.DEFAULT!="METAL", "Metal support required")
|
|
class TestMetal(unittest.TestCase):
|
|
def test_alloc_oom(self):
|
|
device = MetalDevice("metal")
|
|
with self.assertRaises(MemoryError):
|
|
device.allocator.alloc(10000000000000000000)
|
|
|
|
def test_compile_error(self):
|
|
device = MetalDevice("metal")
|
|
compiler = MetalCompiler(device)
|
|
with self.assertRaises(CompileError):
|
|
compiler.compile("this is not valid metal")
|
|
|
|
def test_compile_success(self):
|
|
device = MetalDevice("metal")
|
|
compiler = MetalCompiler(device)
|
|
ret = compiler.compile("""
|
|
#include <metal_stdlib>
|
|
using namespace metal;
|
|
kernel void E_4n1(device int* data0, const device int* data1, const device int* data2,
|
|
uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
|
|
int val0 = *(data1+0);
|
|
int val1 = *(data1+1);
|
|
int val2 = *(data1+2);
|
|
int val3 = *(data1+3);
|
|
int val4 = *(data2+0);
|
|
int val5 = *(data2+1);
|
|
int val6 = *(data2+2);
|
|
int val7 = *(data2+3);
|
|
*(data0+0) = (val0+val4);
|
|
*(data0+1) = (val1+val5);
|
|
*(data0+2) = (val2+val6);
|
|
*(data0+3) = (val3+val7);
|
|
}
|
|
""")
|
|
assert ret is not None
|
|
|
|
def test_failed_newLibraryWithData(self):
|
|
device = MetalDevice("metal")
|
|
compiler = MetalCompiler(device)
|
|
compiled = compiler.compile("""
|
|
#include <metal_stdlib>
|
|
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){
|
|
data0[0] = 0;
|
|
}
|
|
""")
|
|
with self.assertRaises(RuntimeError):
|
|
compiled = compiled[:40] # corrupt the compiled program
|
|
MetalProgram(device, "r_5", compiled)
|
|
|