Merge branch 'tinygrad:master' into triton

This commit is contained in:
Szymon Ożóg
2023-08-15 21:03:25 +02:00
committed by GitHub
22 changed files with 637 additions and 332 deletions

View File

@@ -267,3 +267,28 @@ jobs:
- name: Run pytest (cuda)
if: matrix.backend=='cuda'
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models
testunicorn:
name: ARM64 unicorn Test
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Cache pip
uses: actions/cache@v3
with:
path: '~/.cache/pip'
key: unicorn
- name: Install cross-assembler
run: |
sudo apt-get update -y && \
sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu
- name: Install dependencies
run: pip install -e '.[testing,arm]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test arm
run: CI=1 ARM64=1 CLANG=1 python -m pytest -n=auto test/ -k 'not (test_nn.py and (test_conv_transpose2d or test_conv2d))' --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_net_speed.py --ignore=test/test_specific_conv.py --ignore=test/unit/test_disk_tensor.py

View File

@@ -1,182 +0,0 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
import math
from collections import defaultdict
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
class Register(NamedTuple):
nm:str
dtype:DType
scalar:bool
off:Optional[int] = None
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
def subregs(self):
if self.dtype == dtypes._float4:
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
class AssemblyInstruction(NamedTuple):
op: UOps
out: Optional[Register]
vin: List[Union[Register, int, float]]
arg: Any = None
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyCodegen(Linearizer):
supports_load3: bool = False
sin_is_sin2pi: bool = False
no_div: bool = False
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
raise NotImplementedError("must be implemented")
# s registers are the addresses and non local indexes
def codegen(self):
self.process()
self.hand_coded_optimizations()
self.limit_global_dims(3) # all GPU asms have 3 (for now)
self.linearize()
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
tor: Dict[Any, Register] = {}
def newreg(tok, dtype=dtypes.float32, scalar=False):
nonlocal cnts, tor
if isinstance(tok, Token): dtype = tok.dtype # this
tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes._float4:
for off in range(4):
tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
cnts[(dtype, scalar)] += 1
return ret
def render_numnode(b):
key = ("num", b)
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return tor[key]
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in tor:
#if not isinstance(b, Register): b = render_numnode(b)
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return tor[key]
def render_cast(a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in tor:
ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a]))
return tor[key]
render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b),
MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def addr_w_offset(args):
idx = args.idx*self.bufs[args.i].dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
idx -= nums[0]
off = nums[0]
reg = idx.render(render_ops)
if self.supports_load3:
if reg.scalar:
new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype)
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return tor[f"buf{args.i}"], reg, off
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
return reg, None, off
ins = []
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
global_size, local_size = [], []
skipload_branch = 0
for uop,newvar,vin,args in self.uops:
if uop == UOps.CONST and newvar is not None:
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args))
elif uop == UOps.DEFINE_LOCAL:
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == UOps.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif uop == UOps.CAST and newvar is not None:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = newreg(newvar)
for i,sr in enumerate(out.subregs()):
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
elif uop == UOps.ALU and newvar is not None:
out = newreg(newvar) if newvar not in tor else tor[newvar]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and self.no_div:
tmp = newreg((newvar, "rcp"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and self.sin_is_sin2pi:
tmp = newreg((newvar, "2pi"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
else:
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args))
elif uop == UOps.LOAD and newvar is not None:
idx, treg, off = addr_w_offset(args)
reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
if args.valid.min == 0:
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(render_ops)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
if args.valid.min == 0 and args.valid.max == 1:
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, treg, off = addr_w_offset(args)
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
# define registers
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
if DEBUG >= 4:
for tins in ins: print(tins)
name, asm = self.specialize(ins)
return ASTRunner(name, asm,
global_size[::-1], local_size[::-1],
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True})

View File

@@ -24,6 +24,7 @@ setup(name='tinygrad',
extras_require={
'llvm': ["llvmlite"],
'cuda': ["pycuda"],
'arm': ["unicorn"],
'triton': ["triton>=2.0.0.dev20221202"],
'webgpu': ["wgpu"],
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],

View File

@@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod
from tinygrad.runtime.lib import RawBuffer
class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __init__(self, name:str, prg:str, binary:bool): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass
class RawFakeBuffer(RawBuffer):

View File

@@ -68,29 +68,5 @@ class TestLazyBuffer(unittest.TestCase):
assert GlobalCounters.cache[2][0].name.startswith("E_")
GlobalCounters.cache = None
class TestVariableBuffer(unittest.TestCase):
def test_get_variable_buffers_no_variable(self):
t = Tensor.rand(2, 3)
assert t.lazydata.get_variable_buffers() == {}
def test_get_variable_buffers_one_variable(self):
v = Variable("v", 1, 10)
t = Tensor.rand(2, 3).reshape(v, 3)
buffers = t.lazydata.get_variable_buffers()
assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 2
v = Variable("v", 1, 10)
t = Tensor.rand(2, 3).reshape(2, v)
buffers = t.lazydata.get_variable_buffers()
assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 3
def test_get_variable_buffers_cat(self):
v1 = Variable("v1", 1, 10)
v2 = Variable("v2", 1, 10)
t1 = Tensor.rand(2, 3).reshape(v1, 3)
t2 = Tensor.rand(6, 3).reshape(v2, 3)
t = t1.cat(t2)
buffers = t.lazydata.get_variable_buffers()
assert len(buffers) == 2 and buffers[v1].realize().realized.toCPU() == 2 and buffers[v2].realize().realized.toCPU() == 6
if __name__ == "__main__":
unittest.main()

View File

@@ -40,65 +40,91 @@ class TestSymbolic(unittest.TestCase):
class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_simple(self):
for i in range(1, 5):
vi = Variable("i", 1, 10)
assert Tensor.rand(i, 4).reshape(vi, 4).shape == (vi, 4)
assert vi.val == i
vi = Variable("i", 1, 10)
assert Tensor.rand(i, 6).reshape(vi, 2, 3).shape == (vi, 2, 3)
assert vi.val == i
vi = Variable("i", 1, 5)
for i in range(1, 6):
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
assert t.lazydata.st.var_vals[vi] == i
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
assert t.shape == (vi, 2, 3)
assert t.lazydata.st.var_vals[vi] == i
def test_reshape_symbols_reshape_ints(self):
for i in range(1, 5):
vi = Variable("i", 1, 10)
assert Tensor.rand(i, 4).reshape(vi, 4).reshape(i, 4).shape == (i, 4)
assert Tensor.rand(i, 4).reshape(vi, 4).reshape(i*4,).shape == (i*4,)
assert Tensor.rand(i, 6).reshape(vi, 6).reshape(i*2, 3).shape == (i*2, 3)
with self.assertRaises(AssertionError):
Tensor.rand(i, 6).reshape(vi, 6).reshape(1, 77).shape
vi = Variable("i", 1, 5)
for i in range(1, 6):
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
assert t.lazydata.st.var_vals == {vi: i}
t = t.reshape(i, 4)
assert t.shape == (i, 4)
assert t.lazydata.st.var_vals == {}
def test_reshape_reuse_var_same_value_ok(self):
for i in range(1, 5):
vi = Variable("i", 1, 10)
vi = Variable("i", 1, 5)
for i in range(1, 6):
a = Tensor.rand(i, 4).reshape(vi, 4)
b = Tensor.rand(i, 3).reshape(vi, 3)
assert vi.val == i
assert a.lazydata.st.var_vals[vi] == i
assert b.lazydata.st.var_vals[vi] == i
def test_reshape_reuse_var_different_value_fail(self):
for i in range(1, 5):
vi = Variable("i", 1, 10)
def test_reshape_reuse_var_different_value_ok(self):
vi = Variable("i", 1, 10)
for i in range(1, 6):
a = Tensor.rand(i, 4).reshape(vi, 2)
with self.assertRaises(AssertionError):
b = Tensor.rand(i, 3).reshape(vi, 3)
b = Tensor.rand(i, 3).reshape(vi, 3)
# a and b have different values of vi
assert a.lazydata.st.var_vals[vi] == 2 * i
assert b.lazydata.st.var_vals[vi] == i
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(vi, vj)
t = Tensor.rand(3, 4).reshape(vi, vj) # reshape into two variables
with self.assertRaises(AssertionError):
t = Tensor.rand(4, 4).reshape(vi, vi)
t = Tensor.rand(4, 4).reshape(vi, vi) # reshape into same variable in 2 dimensions
with self.assertRaises(AssertionError):
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4)
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4) # conflicted implied variable values
with self.assertRaises(AssertionError):
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
with self.assertRaises(AssertionError):
t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4)
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
def test_two_symbol_reshape(self):
vi = Variable("i", 1, 5)
vj = Variable("j", 1, 5)
for i in range(1, 6):
for j in range(1, 6):
t1 = Tensor.rand(i, 5).reshape(vi, 5)
t2 = Tensor.rand(5, j).reshape(5, vj)
t = t1@t2
assert t.shape == (vi, vj)
t = t.reshape(1, vi*vj)
assert t.shape == (1, vi*vj)
t = t.reshape(vj, vi)
assert t.shape == (vj, vi)
class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self):
vi = Variable("i", 1, 10)
vi = Variable("i", 1, 5)
vj = Variable("j", 1, 5)
a = Tensor([[1], [2], [3]]).expand((3, vi))
assert a.shape == (3, vi)
vj = Variable("j", 1, 10)
assert a.lazydata.st.var_vals == {}
a = a.reshape(3, vi, 1).expand((3, vi, vj))
assert a.shape == (3, vi, vj)
assert a.lazydata.st.var_vals == {}
def test_plus_expands_constant(self):
vi = Variable("i", 1, 10)
a = Tensor.rand(3, 4).reshape(3, vi)
a = a + 1
assert a.shape == (3, vi)
vi = Variable("i", 1, 5)
for i in range(1, 6):
a = Tensor.rand(3, i).reshape(3, vi)
a = a + 1
assert a.shape == (3, vi)
class TestSymbolicShapeExpr(unittest.TestCase):
def test_symbolic_expr_idxs(self):
@@ -114,5 +140,23 @@ class TestSymbolicShapeExpr(unittest.TestCase):
idx, valid = st.expr_idxs(idx)
assert idx.render() == "(((1+i)*1)+(lidx1*((i*4)+4))+gidx0)"
class TestShapeTrackerVarVals(unittest.TestCase):
def test_reshape_reshape_updates_var_vals(self):
vi = Variable("i", 1, 5)
vj = Variable("j", 1, 5)
t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj)
assert t.lazydata.st.var_vals == {vi: 4, vj: 3}
def test_lazy_check_var_vals(self):
vi = Variable("i", 1, 5)
a = Tensor.rand(3, 4).reshape(3, vi)
b = Tensor.rand(5, 6).reshape(vi, 6)
assert a.lazydata.st.var_vals == {vi: 4}
assert b.lazydata.st.var_vals == {vi: 5}
c = a@b
# shapetracker works with symbolic shape and doesn't check / propagate the underlying variable values
assert c.shape == (3, 6)
assert c.lazydata.st.var_vals == {}
if __name__ == '__main__':
unittest.main()

View File

@@ -1,14 +1,15 @@
import unittest, math
import numpy as np
from tinygrad.helpers import dtypes
from tinygrad.helpers import dtypes, getenv
from tinygrad.tensor import Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp
from tinygrad.shape.symbolic import Variable
def _uops_to_prg(uops):
src, global_size, local_size = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src, global_size, local_size).build(Device[Device.DEFAULT].runtime)
ret = Device[Device.DEFAULT].renderer("test", uops)
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime)
def _test_single_value(tc, tt, vals, op):
uops = [
@@ -36,36 +37,20 @@ def _test_single_value_const(tc, tt, vals, op):
prg([buf])
return buf.toCPU()[0]
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
class TestUOps(unittest.TestCase):
def _equal(self, v1, v2):
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5)
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 2.0]:
for a in [-2.0, 0.0, 1.0, 2.0]:
self._equal(f(Token('c', dt), [Token('a', dt)], [a], bop), fxn(a))
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('nan'))
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a)
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32):
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 2.0]:
for b in [-3.0, 3.0]:
for a in [-2.0, 0.0, 1.0, 2.0]:
for b in [-3.0, 1.0, 3.0] + ([] if no_b_zero else [0.0]):
self._equal(f(Token('c', dt), [Token('a', dt), Token('b', dt)], [a,b], bop), fxn(a,b))
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b)
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
# MOD isn't tested
# doesn't work in LLVM
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
for f in [_test_single_value, _test_single_value_const]:
@@ -73,8 +58,37 @@ class TestUOps(unittest.TestCase):
for b in [-3.0, 3.0]:
for c in [-4.0, 4.0]:
self._equal(f(Token('d', dt), [Token('a', dt), Token('b', dt), Token('c', dt)], [a,b,c], bop), fxn(a,b,c))
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
class TestFloatUOps(TestUOps):
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
# this is not on most backends
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf'))
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
# MOD isn't tested on floats
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)
def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c)
# TODO: fix this on all the backends
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM" or getenv('ARM64', False), "only test for compiled backends, broken on some")
class TestNonFloatUOps(TestUOps):
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
@unittest.skipIf(Device.DEFAULT == "CLANG", "broken in CLANG")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
import unittest
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_vars, sym_render
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_render
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):
@@ -261,14 +261,6 @@ class TestSymbolicVars(unittest.TestCase):
assert (a % 3 + b // 5).vars() == [a, b]
assert (a + b + c - a).vars() == [b, c]
def test_sym_vars(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
assert sym_vars(1) == []
assert sym_vars(a) == [a]
assert sym_vars(a+b) == [a, b]
assert sym_vars(a*3) == [a]
class TestSymbolicMinMax(unittest.TestCase):
def test_min_max_known(self):
a = Variable("a", 1, 8)

View File

@@ -0,0 +1,188 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp, UOp
from tinygrad.ops import BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
import math
from collections import defaultdict
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
class Register(NamedTuple):
nm:str
dtype:DType
scalar:bool
off:Optional[int] = None
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
def subregs(self):
if self.dtype == dtypes._float4:
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
class AssemblyInstruction(NamedTuple):
op: UOps
out: Optional[Register]
vin: List[Union[Register, int, float]]
arg: Any = None
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyLanguage:
supports_load3: bool = False
sin_is_sin2pi: bool = False
no_div: bool = False
#TODO: these should be global vars
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
tor: Dict[Any, Register] = {}
ins: List[AssemblyInstruction] = []
def newreg(self, tok, dtype=dtypes.float32, scalar=False):
if isinstance(tok, Token): dtype = tok.dtype # this
self.tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes._float4:
for off in range(4):
self.tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
self.cnts[(dtype, scalar)] += 1
return ret
def render_numnode(self, b):
key = ("num", b)
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return self.tor[key]
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in self.tor:
#if not isinstance(b, Register): b = render_numnode(b)
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return self.tor[key]
def render_cast(self, a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in self.tor:
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
return self.tor[key]
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def addr_w_offset(self, args):
assert isinstance(args, MemOp)
idx = args.idx*args.memory_dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
idx -= nums[0]
off = cast(int, nums[0])
reg = idx.render(self.render_ops, self)
if self.supports_load3:
if reg.scalar:
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return self.tor[args.name], reg, off
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
return reg, None, off
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
#TODO: Do not use clear()
lang.ins.clear()
lang.tor.clear()
buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
for uop,newvar,vin,args in uops:
if uop == UOps.DEFINE_LOCAL:
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == UOps.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) #FIXME: what should valid be here?
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif args[1] == "global+local":
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
elif uop == UOps.CAST and newvar is not None:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = lang.newreg(newvar)
for i,sr in enumerate(out.subregs()):
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
elif uop == UOps.ALU and newvar is not None:
out = lang.newreg(newvar) if newvar not in lang.tor else lang.tor[newvar]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = lang.newreg((newvar, 'pred'), dtype=dtypes.bool)
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and lang.no_div:
tmp = lang.newreg((newvar, "rcp"))
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
tmp = lang.newreg((newvar, "2pi"))
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
else:
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
elif uop == UOps.LOAD and newvar is not None:
if isinstance(args, ConstOp):
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.invalid_value))
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value))
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
else:
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value if args.valid.min == 1 else args.invalid_value))
else:
idx, treg, off = lang.addr_w_offset(args)
reg = lang.newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
if args.valid.min == 0:
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None)))
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, treg, off = lang.addr_w_offset(args)
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None)))
# define registers
lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins
if DEBUG >= 4:
for tins in lang.ins: print(tins)
return global_size, local_size

View File

@@ -0,0 +1,170 @@
import struct
from platform import system
from typing import Tuple, Dict, List, Optional
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes, CI
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def compute_offsets(total):
quotient, remainder = divmod(total, 4096)
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
#NOTE: Darwin needs names to start with a "_"
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
class ARM64Language(AssemblyLanguage): pass
def specialize_to_arm64(fn_nm, asm):
var_size = 16
prev_uop:Optional[UOps] = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
def mov_imm(value, reg):
# Manually move value into reg if value can't fit
if value.__class__ is not float and abs(value) > abs(65535):
ins.append(f"movz w15, #{value & 0xffff}")
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
ins.append(f"sxtw {reg}, w15")
elif reg[0] == 's':
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
ins.append("str x15, [sp, 16]")
ins.append(f"ldr {reg}, [sp, 16]")
else:
ins.append(f"mov {reg}, #{value}")
# Get variables intervals
live_range:Dict[str, List[int]] = {}
for i, (uop, out, vin, arg) in enumerate(asm):
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
mem_vars:Dict[str, int] = {}
rtor:Dict[str, str] = {}
def allocate_regs(mvars):
nonlocal var_size
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
if len(available_regs) == 0:
# ARM needs the stack 16-byte aligned
var_size += 16
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11')
mem_vars[v.nm] = var_size
rtor[v.nm] = available_regs.pop()
temp_floats = ['s0', 's1', 's2']
temp_ints = ['x11', 'x12', 'x13']
for i, (uop, out, vin, arg) in enumerate(asm):
# Clear regs out of interval
for var, reg in list(rtor.items()):
available_regs = s_regs if reg[0] == 's' else x_regs
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
available_regs.append(rtor.pop(var))
# Assign a registers to the variables using live ranges.
allocate_regs([out] + vin)
# Assign temp regs to vin and load them before direct use
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
ins.append(f"mov x15, {mem_vars[v.nm]}")
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
if uop == UOps.SPECIAL:
if arg.startswith('data'):
# data 8 to n into the stack
if int(arg[4:]) >= 8:
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]")
ins.append(f"mov {rtor[out.nm]}, x15")
else:
ins.append(f"mov {rtor[out.nm]}, #0")
ins.append(f"loop_{arg}:")
elif uop == UOps.CAST:
if arg == BinaryOps.CMPLT:
mov_imm(0.0, 's0')
mov_imm(1.0, 's1')
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
else:
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
elif uop == UOps.ALU:
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif arg == TernaryOps.WHERE:
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0")
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
else:
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
# Save the registers before they are cleared by func call
for i,k in enumerate(save_regs,1):
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
ins.append("stp x29, x30, [sp, #0]!")
ins.append("mov x29, sp")
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
ins.append(alu[arg])
ins.append(f"fmov {rtor[out.nm]}, s0")
ins.append("mov sp, x29")
ins.append("ldp x29, x30, [sp], #0")
for i,k in enumerate(save_regs,1):
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
ins.append(f"add sp, sp, #{len(save_regs)*16}")
elif arg == BinaryOps.CMPLT:
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
elif arg == BinaryOps.MOD:
ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15")
ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}")
else:
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif uop == UOps.LOAD:
if arg.__class__ in (int, float):
mov_imm(arg, rtor[out.nm])
else:
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == UOps.STORE:
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
elif uop == UOps.COND_BRANCH:
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == UOps.LOAD:
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
elif uop == UOps.LABEL:
ins.append(f"{arg[1:]}:")
elif uop == UOps.ENDLOOP:
mov_imm(arg[0], "x15")
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
ins.append(f"b.lt loop_{arg[1]}")
prev_uop = uop
# store regs into memory if needed
if out is not None and out.nm in mem_vars:
ins.append(f"mov x15, {mem_vars[out.nm]}")
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
lang = ARM64Language()
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True

View File

@@ -9,9 +9,9 @@ from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten,
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction
from tinygrad.shape.symbolic import Variable, sym_vars
from tinygrad.shape.symbolic import Node
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps, OpType, LazyOp
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer, RawBufferTransfer
# lazy can recurse a lot
sys.setrecursionlimit(10000)
@@ -19,6 +19,7 @@ sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
LAZY = getenv("LAZY", 1)
LAZYCACHE = getenv("LAZYCACHE", 1)
P2P = getenv("P2P", 0)
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
@@ -214,7 +215,7 @@ class LazyBuffer:
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype)
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0:
return self.op.replace_with_movement_ops([(op, arg)])
ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype)
@@ -231,13 +232,13 @@ class LazyBuffer:
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.125: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.RESHAPE:
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
@@ -249,7 +250,7 @@ class LazyBuffer:
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
def expand(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
def expand(self: LazyBuffer, arg:Tuple[Union[Node,int], ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.EXPAND:
return self.op.src[0].expand(arg)
@@ -293,7 +294,6 @@ class LazyBuffer:
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[Any]: return []
def get_variable_buffers(self) -> Dict[Variable, LazyBuffer]: return {v:LazyBuffer.loadop(LoadOps.FROM, (1,), dtypes.int32, self.device, src=LazyBuffer.fromCPU(np.array([v.val], dtype=np.int32))) for s in self.shape for v in sym_vars(s)}
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
@@ -381,6 +381,8 @@ def _realize_from(buffer: LazyBuffer) -> None:
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
elif isinstance(rawbuf.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1:
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(rawbuf.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
else:
buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args())
@@ -413,4 +415,4 @@ MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.PERMUTE: LazyBuffer.permute,
MovementOps.PAD: LazyBuffer.pad,
MovementOps.STRIDE: LazyBuffer.stride,
}
}

View File

@@ -124,7 +124,7 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
class ASTRunner:
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg)
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
def build(self, runtime):
@@ -156,10 +156,12 @@ class Compiled:
def to_program(self, k):
k.linearize()
src, global_size, local_size = self.renderer(k.function_name, k.uops)
ret = self.renderer(k.function_name, k.uops)
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
#TODO: I need to find a better way to select ARM64
return ASTRunner(k.function_name, src, global_size, local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name).build(self.runtime)
display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)
def exec_ast(self, ast:LazyOp, output, **kwargs):
# all movementops do nothing in a Compiled buffer!

View File

@@ -35,7 +35,7 @@ class CStyleLanguage(NamedTuple):
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
}

View File

@@ -32,6 +32,35 @@ code_for_op: Final[Dict[Op, Callable]] = {
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
}
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
def cast(bb, val, input_type, output_type):
if input_type == output_type: return val
if output_type == dtypes.float32:
if dtypes.is_int(input_type) or input_type == dtypes.bool:
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
elif input_type == dtypes.bfloat16:
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
return val
if input_type == dtypes.float32:
if dtypes.is_int(output_type) or output_type == dtypes.bool:
val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_unsigned(output_type) or output_type == dtypes.bool else bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
elif output_type == dtypes.bfloat16:
val = bb[-1].bitcast(val, ir.IntType(32))
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].trunc(val, ir.IntType(16))
else:
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
return val
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
@@ -41,7 +70,6 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name)
@@ -84,9 +112,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block)
if uop == UOps.LOAD:
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
assert newvar.dtype == dtypes.float, "newvar must be float"
valid = args.valid.render(render_llvm, bb[-1])
if isinstance(args, ConstOp):
assert newvar.dtype == dtypes.float, "newvar must be float"
if args.valid.min == 0 and args.valid.max == 1:
val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value))
else:
@@ -100,30 +128,12 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value))
else:
val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
if args.memory_dtype != newvar.dtype:
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
elif args.memory_dtype == dtypes.bfloat16:
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
val = cast(bb, val, args.memory_dtype, newvar.dtype)
lvars[newvar] = val
if uop == UOps.STORE:
assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory"
idx = args.idx.render(render_llvm, bb[-1])
element = lvars[vin[0]]
if args.memory_dtype != vin[0].dtype:
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype])
elif args.memory_dtype == dtypes.bfloat16:
element = bb[-1].bitcast(element, ir.IntType(32))
element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16))
element = bb[-1].trunc(element, ir.IntType(16))
else:
element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype])
element = cast(bb, lvars[vin[0]], vin[0].dtype, args.memory_dtype)
bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
if uop == UOps.ALU:
lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])

View File

@@ -50,6 +50,15 @@ class RawBufferCopyInOut(RawBufferCopyIn):
self._copyout(x)
return x
class RawBufferTransfer(RawBuffer):
def _transfer(self, x) -> None: raise NotImplementedError("must be implemented")
@classmethod
def transfer(cls, x, shape, dtype, **kwargs):
ret = cls(prod(shape), dtype, **kwargs)
ret._transfer(x)
return ret
class RawConst(RawBuffer): # pylint: disable=abstract-method
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
@property

View File

@@ -1,8 +1,15 @@
import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools
from functools import partial, reduce
from tinygrad.ops import Compiled
from tinygrad.helpers import fromimport, getenv, DEBUG, CI
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
import struct
import numpy as np
ARM64 = getenv('ARM64', False)
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore
args = {
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
@@ -11,22 +18,64 @@ args = {
}[platform.system()]
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
ADDRESS = 0x10000
# Unicorn doesn't support external calls
def align(addr): return (addr+4095) & ~(4095)
mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
def emulate_ext_calls(fn, uc, address, size, user_data):
s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore
class ClangProgram:
def __init__(self, name:str, prg:str):
prg = CLANG_PROGRAM_HEADER + prg
def __init__(self, name:str, prg:str, binary:bool=False):
# TODO: is there a way to not write this to disk?
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
if binary and DEBUG >= 5: print(prg)
if not os.path.exists(fn):
_, tmp = tempfile.mkstemp()
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
os.rename(tmp, fn)
tmp = f"{fn}.{os.getpid()}.tmp"
if not binary:
prg = CLANG_PROGRAM_HEADER + prg
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
os.rename(tmp, fn)
else:
if CI and ARM64:
prg = prg.split('\n') # type: ignore
self.varsize = align(int(prg[0].split(" ")[1]))
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split())
self.prg = open(fn + '.bin', 'rb').read()
return
subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split())
self.lib = ctypes.CDLL(fn)
self.fxn = self.lib[name]
def __call__(self, global_size, local_size, *args, wait=False):
if wait: st = time.monotonic()
self.fxn(*[x._buf for x in args])
if CI and ARM64:
mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
mu.mem_map(ADDRESS, total_mem)
for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
addr = ADDRESS + len(self.prg)
for i, arg in enumerate(args):
if i<=7:
mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
else:
# NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
addr += arg.size * arg.dtype.itemsize
mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
else:
self.fxn(*[x._buf for x in args])
if wait: return time.monotonic()-st
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)

View File

@@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore
from typing import Optional, List
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut
from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
@@ -29,7 +29,7 @@ class _CL:
CL = _CL()
CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None
class CLBuffer(RawBufferCopyInOut):
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
def __init__(self, size, dtype, device='0'):
if isinstance(dtype, ImageDType):
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
@@ -49,6 +49,10 @@ class CLBuffer(RawBufferCopyInOut):
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
def _transfer(self, x):
if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name:
cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd")
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):

View File

@@ -55,7 +55,8 @@ class LLVMProgram:
LLVM.engine.finalize_object()
self.fxn = LLVM.engine.get_function_address(name)
def __del__(self): LLVM.engine.remove_module(self.mod)
def __del__(self):
if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod)
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)

View File

@@ -37,7 +37,7 @@ def unwrap(x):
return ret
class MetalProgram:
def __init__(self, name:str, prg:str):
def __init__(self, name:str, prg:str, binary:bool=False):
if METAL_XCODE:
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode

View File

@@ -12,7 +12,7 @@ import wgpu # type: ignore
device = get_default_device()
class WebGPUProgram:
def __init__(self, name: str, prg: str): self.name,self.prg = name,device.create_shader_module(code=prg)
def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,device.create_shader_module(code=prg)
def __call__(self, global_size, local_size, *bufs, wait=False):
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from enum import Enum, auto
import functools
from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple
from tinygrad.helpers import prod, DEBUG
from tinygrad.helpers import prod, DEBUG, partition
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int
# these ops live here
@@ -129,17 +129,18 @@ def get_unsafe_resize_offset(strides, arg):
return sum([s * x[0] for s, x in zip(strides,arg)])
class ShapeTracker:
__slots__ = "views"
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None):
__slots__ = "views", "var_vals"
def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None):
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [View(shape)])
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
self.var_vals: Dict[Variable, int] = shape.var_vals if isinstance(shape, ShapeTracker) else {}
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views}, var_vals={self.var_vals})"
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@property
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
@property
def key(self) -> Tuple[View, ...]: return tuple(self.views)
@@ -231,15 +232,16 @@ class ShapeTracker:
return self
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
# reshape into symbolic shape, update the variable value
if all(isinstance(s, int) for s in self.shape) and len(new_vars:=list(s for s in new_shape if isinstance(s, Variable))) > 0:
assert len(new_vars) == 1, "only one variable is supported in a shape"
new_var, new_val = new_vars[0], prod(self.shape) // prod(s for s in new_shape if isinstance(s, int))
if new_var.val is None:
new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int))
if new_nodes and all(isinstance(s, int) for s in self.shape):
# reshape from all int shape into shape with a variable, update the variable value
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
if new_var not in self.var_vals:
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
new_var.val = new_val
else: assert new_var.val == new_val, f"value conflicts, was {new_var.val}, set to {new_val}"
self.var_vals[new_var] = new_val
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
elif not new_nodes: self.var_vals = {}
if self.views[-1].shape == new_shape: return self
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done

View File

@@ -9,7 +9,6 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
def sym_vars(x: Union[Node, int]) -> List[Variable]: return [] if isinstance(x, int) else x.vars()
class Node:
b: Union[Node, int]
@@ -141,7 +140,6 @@ class Variable(Node):
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
self.val: Optional[int] = None
def vars(self): return [self]
class NumNode(Node):