mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'tinygrad:master' into triton
This commit is contained in:
25
.github/workflows/test.yml
vendored
25
.github/workflows/test.yml
vendored
@@ -267,3 +267,28 @@ jobs:
|
||||
- name: Run pytest (cuda)
|
||||
if: matrix.backend=='cuda'
|
||||
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models
|
||||
|
||||
testunicorn:
|
||||
name: ARM64 unicorn Test
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.8
|
||||
- name: Cache pip
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: '~/.cache/pip'
|
||||
key: unicorn
|
||||
- name: Install cross-assembler
|
||||
run: |
|
||||
sudo apt-get update -y && \
|
||||
sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu
|
||||
- name: Install dependencies
|
||||
run: pip install -e '.[testing,arm]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Test arm
|
||||
run: CI=1 ARM64=1 CLANG=1 python -m pytest -n=auto test/ -k 'not (test_nn.py and (test_conv_transpose2d or test_conv2d))' --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_net_speed.py --ignore=test/test_specific_conv.py --ignore=test/unit/test_disk_tensor.py
|
||||
@@ -1,182 +0,0 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
|
||||
from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps
|
||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
|
||||
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm:str
|
||||
dtype:DType
|
||||
scalar:bool
|
||||
off:Optional[int] = None
|
||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes._float4:
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: UOps
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
|
||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||
class AssemblyCodegen(Linearizer):
|
||||
supports_load3: bool = False
|
||||
sin_is_sin2pi: bool = False
|
||||
no_div: bool = False
|
||||
|
||||
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
|
||||
raise NotImplementedError("must be implemented")
|
||||
|
||||
# s registers are the addresses and non local indexes
|
||||
def codegen(self):
|
||||
self.process()
|
||||
self.hand_coded_optimizations()
|
||||
self.limit_global_dims(3) # all GPU asms have 3 (for now)
|
||||
self.linearize()
|
||||
|
||||
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||||
tor: Dict[Any, Register] = {}
|
||||
def newreg(tok, dtype=dtypes.float32, scalar=False):
|
||||
nonlocal cnts, tor
|
||||
if isinstance(tok, Token): dtype = tok.dtype # this
|
||||
tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
if dtype == dtypes._float4:
|
||||
for off in range(4):
|
||||
tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
cnts[(dtype, scalar)] += 1
|
||||
return ret
|
||||
|
||||
def render_numnode(b):
|
||||
key = ("num", b)
|
||||
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
return tor[key]
|
||||
|
||||
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return tor[key]
|
||||
|
||||
def render_cast(a:Register, new_dtype:DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
key = (a, new_dtype)
|
||||
if key not in tor:
|
||||
ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a]))
|
||||
return tor[key]
|
||||
|
||||
render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||||
DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||||
ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||||
LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def addr_w_offset(args):
|
||||
idx = args.idx*self.bufs[args.i].dtype.itemsize
|
||||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode):
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||||
idx -= nums[0]
|
||||
off = nums[0]
|
||||
reg = idx.render(render_ops)
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return tor[f"buf{args.i}"], reg, off
|
||||
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
|
||||
return reg, None, off
|
||||
|
||||
ins = []
|
||||
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
for uop,newvar,vin,args in self.uops:
|
||||
if uop == UOps.CONST and newvar is not None:
|
||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args))
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == UOps.LOOP:
|
||||
if args[1] == "global":
|
||||
for i,var in enumerate(args[0]):
|
||||
global_size.append(var.max+1)
|
||||
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
elif args[1] == "local":
|
||||
for i,var in enumerate(args[0]):
|
||||
local_size.append(var.max+1)
|
||||
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
|
||||
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool)
|
||||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif uop == UOps.CAST and newvar is not None:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = newreg(newvar)
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == UOps.ALU and newvar is not None:
|
||||
out = newreg(newvar) if newvar not in tor else tor[newvar]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
|
||||
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and self.no_div:
|
||||
tmp = newreg((newvar, "rcp"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
elif args == UnaryOps.SIN and self.sin_is_sin2pi:
|
||||
tmp = newreg((newvar, "2pi"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||
else:
|
||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args))
|
||||
elif uop == UOps.LOAD and newvar is not None:
|
||||
idx, treg, off = addr_w_offset(args)
|
||||
reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
|
||||
if args.valid.min == 0:
|
||||
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(render_ops)
|
||||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
idx, treg, off = addr_w_offset(args)
|
||||
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
|
||||
|
||||
# define registers
|
||||
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in ins: print(tins)
|
||||
name, asm = self.specialize(ins)
|
||||
|
||||
return ASTRunner(name, asm,
|
||||
global_size[::-1], local_size[::-1],
|
||||
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True})
|
||||
1
setup.py
1
setup.py
@@ -24,6 +24,7 @@ setup(name='tinygrad',
|
||||
extras_require={
|
||||
'llvm': ["llvmlite"],
|
||||
'cuda': ["pycuda"],
|
||||
'arm': ["unicorn"],
|
||||
'triton': ["triton>=2.0.0.dev20221202"],
|
||||
'webgpu': ["wgpu"],
|
||||
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],
|
||||
|
||||
2
test/external/external_test_speed_llama.py
vendored
2
test/external/external_test_speed_llama.py
vendored
@@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
class FakeProgram:
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
def __init__(self, name:str, prg:str, binary:bool): pass
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False): pass
|
||||
|
||||
class RawFakeBuffer(RawBuffer):
|
||||
|
||||
@@ -68,29 +68,5 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
assert GlobalCounters.cache[2][0].name.startswith("E_")
|
||||
GlobalCounters.cache = None
|
||||
|
||||
class TestVariableBuffer(unittest.TestCase):
|
||||
def test_get_variable_buffers_no_variable(self):
|
||||
t = Tensor.rand(2, 3)
|
||||
assert t.lazydata.get_variable_buffers() == {}
|
||||
|
||||
def test_get_variable_buffers_one_variable(self):
|
||||
v = Variable("v", 1, 10)
|
||||
t = Tensor.rand(2, 3).reshape(v, 3)
|
||||
buffers = t.lazydata.get_variable_buffers()
|
||||
assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 2
|
||||
v = Variable("v", 1, 10)
|
||||
t = Tensor.rand(2, 3).reshape(2, v)
|
||||
buffers = t.lazydata.get_variable_buffers()
|
||||
assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 3
|
||||
|
||||
def test_get_variable_buffers_cat(self):
|
||||
v1 = Variable("v1", 1, 10)
|
||||
v2 = Variable("v2", 1, 10)
|
||||
t1 = Tensor.rand(2, 3).reshape(v1, 3)
|
||||
t2 = Tensor.rand(6, 3).reshape(v2, 3)
|
||||
t = t1.cat(t2)
|
||||
buffers = t.lazydata.get_variable_buffers()
|
||||
assert len(buffers) == 2 and buffers[v1].realize().realized.toCPU() == 2 and buffers[v2].realize().realized.toCPU() == 6
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -40,65 +40,91 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
def test_reshape_into_symbols_simple(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
assert Tensor.rand(i, 4).reshape(vi, 4).shape == (vi, 4)
|
||||
assert vi.val == i
|
||||
vi = Variable("i", 1, 10)
|
||||
assert Tensor.rand(i, 6).reshape(vi, 2, 3).shape == (vi, 2, 3)
|
||||
assert vi.val == i
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
t = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
assert t.shape == (vi, 4)
|
||||
assert t.lazydata.st.var_vals[vi] == i
|
||||
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
|
||||
assert t.shape == (vi, 2, 3)
|
||||
assert t.lazydata.st.var_vals[vi] == i
|
||||
|
||||
def test_reshape_symbols_reshape_ints(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
assert Tensor.rand(i, 4).reshape(vi, 4).reshape(i, 4).shape == (i, 4)
|
||||
assert Tensor.rand(i, 4).reshape(vi, 4).reshape(i*4,).shape == (i*4,)
|
||||
assert Tensor.rand(i, 6).reshape(vi, 6).reshape(i*2, 3).shape == (i*2, 3)
|
||||
with self.assertRaises(AssertionError):
|
||||
Tensor.rand(i, 6).reshape(vi, 6).reshape(1, 77).shape
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
t = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
assert t.shape == (vi, 4)
|
||||
assert t.lazydata.st.var_vals == {vi: i}
|
||||
t = t.reshape(i, 4)
|
||||
assert t.shape == (i, 4)
|
||||
assert t.lazydata.st.var_vals == {}
|
||||
|
||||
def test_reshape_reuse_var_same_value_ok(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
a = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
b = Tensor.rand(i, 3).reshape(vi, 3)
|
||||
assert vi.val == i
|
||||
assert a.lazydata.st.var_vals[vi] == i
|
||||
assert b.lazydata.st.var_vals[vi] == i
|
||||
|
||||
def test_reshape_reuse_var_different_value_fail(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
def test_reshape_reuse_var_different_value_ok(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 6):
|
||||
a = Tensor.rand(i, 4).reshape(vi, 2)
|
||||
with self.assertRaises(AssertionError):
|
||||
b = Tensor.rand(i, 3).reshape(vi, 3)
|
||||
b = Tensor.rand(i, 3).reshape(vi, 3)
|
||||
# a and b have different values of vi
|
||||
assert a.lazydata.st.var_vals[vi] == 2 * i
|
||||
assert b.lazydata.st.var_vals[vi] == i
|
||||
|
||||
def test_reshape_into_symbols_bad_shape(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(vi, vj)
|
||||
t = Tensor.rand(3, 4).reshape(vi, vj) # reshape into two variables
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(4, 4).reshape(vi, vi)
|
||||
t = Tensor.rand(4, 4).reshape(vi, vi) # reshape into same variable in 2 dimensions
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4)
|
||||
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4) # conflicted implied variable values
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4)
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
|
||||
|
||||
def test_two_symbol_reshape(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
vj = Variable("j", 1, 5)
|
||||
for i in range(1, 6):
|
||||
for j in range(1, 6):
|
||||
t1 = Tensor.rand(i, 5).reshape(vi, 5)
|
||||
t2 = Tensor.rand(5, j).reshape(5, vj)
|
||||
t = t1@t2
|
||||
assert t.shape == (vi, vj)
|
||||
t = t.reshape(1, vi*vj)
|
||||
assert t.shape == (1, vi*vj)
|
||||
t = t.reshape(vj, vi)
|
||||
assert t.shape == (vj, vi)
|
||||
|
||||
class TestSymbolicExpand(unittest.TestCase):
|
||||
def test_expand_into_symbols(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
vi = Variable("i", 1, 5)
|
||||
vj = Variable("j", 1, 5)
|
||||
a = Tensor([[1], [2], [3]]).expand((3, vi))
|
||||
assert a.shape == (3, vi)
|
||||
vj = Variable("j", 1, 10)
|
||||
assert a.lazydata.st.var_vals == {}
|
||||
a = a.reshape(3, vi, 1).expand((3, vi, vj))
|
||||
assert a.shape == (3, vi, vj)
|
||||
assert a.lazydata.st.var_vals == {}
|
||||
|
||||
def test_plus_expands_constant(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, 4).reshape(3, vi)
|
||||
a = a + 1
|
||||
assert a.shape == (3, vi)
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
a = Tensor.rand(3, i).reshape(3, vi)
|
||||
a = a + 1
|
||||
assert a.shape == (3, vi)
|
||||
|
||||
class TestSymbolicShapeExpr(unittest.TestCase):
|
||||
def test_symbolic_expr_idxs(self):
|
||||
@@ -114,5 +140,23 @@ class TestSymbolicShapeExpr(unittest.TestCase):
|
||||
idx, valid = st.expr_idxs(idx)
|
||||
assert idx.render() == "(((1+i)*1)+(lidx1*((i*4)+4))+gidx0)"
|
||||
|
||||
class TestShapeTrackerVarVals(unittest.TestCase):
|
||||
def test_reshape_reshape_updates_var_vals(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
vj = Variable("j", 1, 5)
|
||||
t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj)
|
||||
assert t.lazydata.st.var_vals == {vi: 4, vj: 3}
|
||||
|
||||
def test_lazy_check_var_vals(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
a = Tensor.rand(3, 4).reshape(3, vi)
|
||||
b = Tensor.rand(5, 6).reshape(vi, 6)
|
||||
assert a.lazydata.st.var_vals == {vi: 4}
|
||||
assert b.lazydata.st.var_vals == {vi: 5}
|
||||
c = a@b
|
||||
# shapetracker works with symbolic shape and doesn't check / propagate the underlying variable values
|
||||
assert c.shape == (3, 6)
|
||||
assert c.lazydata.st.var_vals == {}
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,14 +1,15 @@
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.helpers import dtypes, getenv
|
||||
from tinygrad.tensor import Device
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
|
||||
from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
src, global_size, local_size = Device[Device.DEFAULT].renderer("test", uops)
|
||||
return ASTRunner("test", src, global_size, local_size).build(Device[Device.DEFAULT].runtime)
|
||||
ret = Device[Device.DEFAULT].renderer("test", uops)
|
||||
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
|
||||
return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime)
|
||||
|
||||
def _test_single_value(tc, tt, vals, op):
|
||||
uops = [
|
||||
@@ -36,36 +37,20 @@ def _test_single_value_const(tc, tt, vals, op):
|
||||
prg([buf])
|
||||
return buf.toCPU()[0]
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
|
||||
class TestUOps(unittest.TestCase):
|
||||
def _equal(self, v1, v2):
|
||||
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5)
|
||||
|
||||
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 2.0]:
|
||||
for a in [-2.0, 0.0, 1.0, 2.0]:
|
||||
self._equal(f(Token('c', dt), [Token('a', dt)], [a], bop), fxn(a))
|
||||
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
|
||||
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('nan'))
|
||||
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
|
||||
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a)
|
||||
|
||||
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 2.0]:
|
||||
for b in [-3.0, 3.0]:
|
||||
for a in [-2.0, 0.0, 1.0, 2.0]:
|
||||
for b in [-3.0, 1.0, 3.0] + ([] if no_b_zero else [0.0]):
|
||||
self._equal(f(Token('c', dt), [Token('a', dt), Token('b', dt)], [a,b], bop), fxn(a,b))
|
||||
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
|
||||
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
|
||||
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
||||
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b)
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
|
||||
# MOD isn't tested
|
||||
|
||||
# doesn't work in LLVM
|
||||
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)
|
||||
|
||||
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
@@ -73,8 +58,37 @@ class TestUOps(unittest.TestCase):
|
||||
for b in [-3.0, 3.0]:
|
||||
for c in [-4.0, 4.0]:
|
||||
self._equal(f(Token('d', dt), [Token('a', dt), Token('b', dt), Token('c', dt)], [a,b,c], bop), fxn(a,b,c))
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
|
||||
class TestFloatUOps(TestUOps):
|
||||
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
|
||||
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
||||
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
|
||||
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
# this is not on most backends
|
||||
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf'))
|
||||
|
||||
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
|
||||
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
|
||||
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
||||
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
|
||||
# MOD isn't tested on floats
|
||||
|
||||
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)
|
||||
def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c)
|
||||
|
||||
# TODO: fix this on all the backends
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM" or getenv('ARM64', False), "only test for compiled backends, broken on some")
|
||||
class TestNonFloatUOps(TestUOps):
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
|
||||
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
|
||||
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "broken in CLANG")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_vars, sym_render
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_render
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
@@ -261,14 +261,6 @@ class TestSymbolicVars(unittest.TestCase):
|
||||
assert (a % 3 + b // 5).vars() == [a, b]
|
||||
assert (a + b + c - a).vars() == [b, c]
|
||||
|
||||
def test_sym_vars(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
assert sym_vars(1) == []
|
||||
assert sym_vars(a) == [a]
|
||||
assert sym_vars(a+b) == [a, b]
|
||||
assert sym_vars(a*3) == [a]
|
||||
|
||||
class TestSymbolicMinMax(unittest.TestCase):
|
||||
def test_min_max_known(self):
|
||||
a = Variable("a", 1, 8)
|
||||
|
||||
188
tinygrad/codegen/assembly.py
Normal file
188
tinygrad/codegen/assembly.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||||
from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
|
||||
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm:str
|
||||
dtype:DType
|
||||
scalar:bool
|
||||
off:Optional[int] = None
|
||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes._float4:
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: UOps
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
|
||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||
class AssemblyLanguage:
|
||||
supports_load3: bool = False
|
||||
sin_is_sin2pi: bool = False
|
||||
no_div: bool = False
|
||||
#TODO: these should be global vars
|
||||
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||||
tor: Dict[Any, Register] = {}
|
||||
ins: List[AssemblyInstruction] = []
|
||||
|
||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False):
|
||||
if isinstance(tok, Token): dtype = tok.dtype # this
|
||||
self.tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
if dtype == dtypes._float4:
|
||||
for off in range(4):
|
||||
self.tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
self.cnts[(dtype, scalar)] += 1
|
||||
return ret
|
||||
|
||||
def render_numnode(self, b):
|
||||
key = ("num", b)
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
return self.tor[key]
|
||||
|
||||
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in self.tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return self.tor[key]
|
||||
|
||||
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
key = (a, new_dtype)
|
||||
if key not in self.tor:
|
||||
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
return self.tor[key]
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||||
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||||
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||||
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def addr_w_offset(self, args):
|
||||
assert isinstance(args, MemOp)
|
||||
idx = args.idx*args.memory_dtype.itemsize
|
||||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode):
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||||
idx -= nums[0]
|
||||
off = cast(int, nums[0])
|
||||
reg = idx.render(self.render_ops, self)
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return self.tor[args.name], reg, off
|
||||
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||||
return reg, None, off
|
||||
|
||||
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
#TODO: Do not use clear()
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
for uop,newvar,vin,args in uops:
|
||||
if uop == UOps.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == UOps.LOOP:
|
||||
if args[1] == "global":
|
||||
for i,var in enumerate(args[0]):
|
||||
global_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
elif args[1] == "local":
|
||||
for i,var in enumerate(args[0]):
|
||||
local_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) #FIXME: what should valid be here?
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
|
||||
elif uop == UOps.CAST and newvar is not None:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(newvar)
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == UOps.ALU and newvar is not None:
|
||||
out = lang.newreg(newvar) if newvar not in lang.tor else lang.tor[newvar]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = lang.newreg((newvar, 'pred'), dtype=dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and lang.no_div:
|
||||
tmp = lang.newreg((newvar, "rcp"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||
tmp = lang.newreg((newvar, "2pi"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||
else:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
elif uop == UOps.LOAD and newvar is not None:
|
||||
if isinstance(args, ConstOp):
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.invalid_value))
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
else:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value if args.valid.min == 1 else args.invalid_value))
|
||||
else:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
reg = lang.newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
|
||||
if args.valid.min == 0:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None)))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None)))
|
||||
# define registers
|
||||
lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins: print(tins)
|
||||
return global_size, local_size
|
||||
170
tinygrad/codegen/assembly_arm64.py
Normal file
170
tinygrad/codegen/assembly_arm64.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import struct
|
||||
from platform import system
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.helpers import dtypes, CI
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
def compute_offsets(total):
|
||||
quotient, remainder = divmod(total, 4096)
|
||||
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
|
||||
|
||||
#NOTE: Darwin needs names to start with a "_"
|
||||
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
||||
|
||||
class ARM64Language(AssemblyLanguage): pass
|
||||
|
||||
def specialize_to_arm64(fn_nm, asm):
|
||||
var_size = 16
|
||||
prev_uop:Optional[UOps] = None
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
|
||||
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
|
||||
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
||||
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
||||
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
||||
|
||||
def mov_imm(value, reg):
|
||||
# Manually move value into reg if value can't fit
|
||||
if value.__class__ is not float and abs(value) > abs(65535):
|
||||
ins.append(f"movz w15, #{value & 0xffff}")
|
||||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||
ins.append(f"sxtw {reg}, w15")
|
||||
elif reg[0] == 's':
|
||||
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
||||
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
||||
ins.append("str x15, [sp, 16]")
|
||||
ins.append(f"ldr {reg}, [sp, 16]")
|
||||
else:
|
||||
ins.append(f"mov {reg}, #{value}")
|
||||
|
||||
# Get variables intervals
|
||||
live_range:Dict[str, List[int]] = {}
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
||||
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||
|
||||
mem_vars:Dict[str, int] = {}
|
||||
rtor:Dict[str, str] = {}
|
||||
def allocate_regs(mvars):
|
||||
nonlocal var_size
|
||||
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
||||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||
if len(available_regs) == 0:
|
||||
# ARM needs the stack 16-byte aligned
|
||||
var_size += 16
|
||||
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11')
|
||||
mem_vars[v.nm] = var_size
|
||||
rtor[v.nm] = available_regs.pop()
|
||||
|
||||
temp_floats = ['s0', 's1', 's2']
|
||||
temp_ints = ['x11', 'x12', 'x13']
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
# Clear regs out of interval
|
||||
for var, reg in list(rtor.items()):
|
||||
available_regs = s_regs if reg[0] == 's' else x_regs
|
||||
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
||||
available_regs.append(rtor.pop(var))
|
||||
# Assign a registers to the variables using live ranges.
|
||||
allocate_regs([out] + vin)
|
||||
# Assign temp regs to vin and load them before direct use
|
||||
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
||||
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
||||
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||
|
||||
if uop == UOps.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
# data 8 to n into the stack
|
||||
if int(arg[4:]) >= 8:
|
||||
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]")
|
||||
ins.append(f"mov {rtor[out.nm]}, x15")
|
||||
else:
|
||||
ins.append(f"mov {rtor[out.nm]}, #0")
|
||||
ins.append(f"loop_{arg}:")
|
||||
elif uop == UOps.CAST:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
mov_imm(0.0, 's0')
|
||||
mov_imm(1.0, 's1')
|
||||
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
||||
else:
|
||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||
elif uop == UOps.ALU:
|
||||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif arg == TernaryOps.WHERE:
|
||||
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0")
|
||||
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
||||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
||||
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||
else:
|
||||
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
||||
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
||||
# Save the registers before they are cleared by func call
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append("stp x29, x30, [sp, #0]!")
|
||||
ins.append("mov x29, sp")
|
||||
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
||||
ins.append(alu[arg])
|
||||
ins.append(f"fmov {rtor[out.nm]}, s0")
|
||||
ins.append("mov sp, x29")
|
||||
ins.append("ldp x29, x30, [sp], #0")
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||
elif arg == BinaryOps.CMPLT:
|
||||
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
||||
elif arg == BinaryOps.MOD:
|
||||
ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}")
|
||||
else:
|
||||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
elif uop == UOps.STORE:
|
||||
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == UOps.LOAD:
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg[1:]}:")
|
||||
elif uop == UOps.ENDLOOP:
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"b.lt loop_{arg[1]}")
|
||||
prev_uop = uop
|
||||
# store regs into memory if needed
|
||||
if out is not None and out.nm in mem_vars:
|
||||
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
||||
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
||||
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
||||
|
||||
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
||||
lang = ARM64Language()
|
||||
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
||||
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|
||||
@@ -9,9 +9,9 @@ from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten,
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sym_vars
|
||||
from tinygrad.shape.symbolic import Node
|
||||
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps, OpType, LazyOp
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer, RawBufferTransfer
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
@@ -19,6 +19,7 @@ sys.setrecursionlimit(10000)
|
||||
OPT = getenv("OPT", 2)
|
||||
LAZY = getenv("LAZY", 1)
|
||||
LAZYCACHE = getenv("LAZYCACHE", 1)
|
||||
P2P = getenv("P2P", 0)
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
||||
@@ -214,7 +215,7 @@ class LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0:
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype)
|
||||
@@ -231,13 +232,13 @@ class LazyBuffer:
|
||||
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
|
||||
if divisor < 16 or heuristic < 0.125: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
|
||||
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
@@ -249,7 +250,7 @@ class LazyBuffer:
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
|
||||
|
||||
def expand(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
def expand(self: LazyBuffer, arg:Tuple[Union[Node,int], ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].expand(arg)
|
||||
@@ -293,7 +294,6 @@ class LazyBuffer:
|
||||
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
|
||||
def get_lazyops(self) -> List[Any]: return []
|
||||
def get_variable_buffers(self) -> Dict[Variable, LazyBuffer]: return {v:LazyBuffer.loadop(LoadOps.FROM, (1,), dtypes.int32, self.device, src=LazyBuffer.fromCPU(np.array([v.val], dtype=np.int32))) for s in self.shape for v in sym_vars(s)}
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
|
||||
@@ -381,6 +381,8 @@ def _realize_from(buffer: LazyBuffer) -> None:
|
||||
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
|
||||
elif isinstance(rawbuf.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1:
|
||||
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(rawbuf.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
|
||||
else:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args())
|
||||
|
||||
@@ -413,4 +415,4 @@ MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
|
||||
MovementOps.PERMUTE: LazyBuffer.permute,
|
||||
MovementOps.PAD: LazyBuffer.pad,
|
||||
MovementOps.STRIDE: LazyBuffer.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg)
|
||||
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
|
||||
def build(self, runtime):
|
||||
@@ -156,10 +156,12 @@ class Compiled:
|
||||
|
||||
def to_program(self, k):
|
||||
k.linearize()
|
||||
src, global_size, local_size = self.renderer(k.function_name, k.uops)
|
||||
ret = self.renderer(k.function_name, k.uops)
|
||||
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
|
||||
#TODO: I need to find a better way to select ARM64
|
||||
return ASTRunner(k.function_name, src, global_size, local_size,
|
||||
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
||||
display_name=k.display_name).build(self.runtime)
|
||||
display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output, **kwargs):
|
||||
# all movementops do nothing in a Compiled buffer!
|
||||
|
||||
@@ -35,7 +35,7 @@ class CStyleLanguage(NamedTuple):
|
||||
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
|
||||
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
|
||||
}
|
||||
|
||||
@@ -32,6 +32,35 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
|
||||
}
|
||||
|
||||
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
||||
|
||||
def cast(bb, val, input_type, output_type):
|
||||
if input_type == output_type: return val
|
||||
|
||||
if output_type == dtypes.float32:
|
||||
if dtypes.is_int(input_type) or input_type == dtypes.bool:
|
||||
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
|
||||
elif input_type == dtypes.bfloat16:
|
||||
val = bb[-1].sext(val, ir.IntType(32))
|
||||
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].bitcast(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
return val
|
||||
|
||||
if input_type == dtypes.float32:
|
||||
if dtypes.is_int(output_type) or output_type == dtypes.bool:
|
||||
val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_unsigned(output_type) or output_type == dtypes.bool else bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
|
||||
elif output_type == dtypes.bfloat16:
|
||||
val = bb[-1].bitcast(val, ir.IntType(32))
|
||||
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].trunc(val, ir.IntType(16))
|
||||
else:
|
||||
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
|
||||
return val
|
||||
|
||||
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
||||
|
||||
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
@@ -41,7 +70,6 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
|
||||
# create llvm function
|
||||
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
||||
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name)
|
||||
|
||||
@@ -84,9 +112,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block)
|
||||
if uop == UOps.LOAD:
|
||||
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
|
||||
assert newvar.dtype == dtypes.float, "newvar must be float"
|
||||
valid = args.valid.render(render_llvm, bb[-1])
|
||||
if isinstance(args, ConstOp):
|
||||
assert newvar.dtype == dtypes.float, "newvar must be float"
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value))
|
||||
else:
|
||||
@@ -100,30 +128,12 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
||||
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value))
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
|
||||
|
||||
if args.memory_dtype != newvar.dtype:
|
||||
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
|
||||
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
|
||||
elif args.memory_dtype == dtypes.bfloat16:
|
||||
val = bb[-1].sext(val, ir.IntType(32))
|
||||
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].bitcast(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
val = cast(bb, val, args.memory_dtype, newvar.dtype)
|
||||
lvars[newvar] = val
|
||||
if uop == UOps.STORE:
|
||||
assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory"
|
||||
idx = args.idx.render(render_llvm, bb[-1])
|
||||
element = lvars[vin[0]]
|
||||
if args.memory_dtype != vin[0].dtype:
|
||||
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
|
||||
element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype])
|
||||
elif args.memory_dtype == dtypes.bfloat16:
|
||||
element = bb[-1].bitcast(element, ir.IntType(32))
|
||||
element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16))
|
||||
element = bb[-1].trunc(element, ir.IntType(16))
|
||||
else:
|
||||
element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype])
|
||||
element = cast(bb, lvars[vin[0]], vin[0].dtype, args.memory_dtype)
|
||||
bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
|
||||
if uop == UOps.ALU:
|
||||
lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
|
||||
|
||||
@@ -50,6 +50,15 @@ class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
self._copyout(x)
|
||||
return x
|
||||
|
||||
class RawBufferTransfer(RawBuffer):
|
||||
def _transfer(self, x) -> None: raise NotImplementedError("must be implemented")
|
||||
|
||||
@classmethod
|
||||
def transfer(cls, x, shape, dtype, **kwargs):
|
||||
ret = cls(prod(shape), dtype, **kwargs)
|
||||
ret._transfer(x)
|
||||
return ret
|
||||
|
||||
class RawConst(RawBuffer): # pylint: disable=abstract-method
|
||||
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
|
||||
@property
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools
|
||||
from functools import partial, reduce
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import fromimport, getenv, DEBUG, CI
|
||||
from tinygrad.runtime.lib import RawMallocBuffer
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
ARM64 = getenv('ARM64', False)
|
||||
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore
|
||||
|
||||
args = {
|
||||
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
|
||||
@@ -11,22 +18,64 @@ args = {
|
||||
}[platform.system()]
|
||||
|
||||
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
|
||||
ADDRESS = 0x10000
|
||||
|
||||
# Unicorn doesn't support external calls
|
||||
def align(addr): return (addr+4095) & ~(4095)
|
||||
mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
|
||||
def emulate_ext_calls(fn, uc, address, size, user_data):
|
||||
s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
|
||||
uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore
|
||||
|
||||
class ClangProgram:
|
||||
def __init__(self, name:str, prg:str):
|
||||
prg = CLANG_PROGRAM_HEADER + prg
|
||||
def __init__(self, name:str, prg:str, binary:bool=False):
|
||||
# TODO: is there a way to not write this to disk?
|
||||
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
|
||||
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
|
||||
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
|
||||
if binary and DEBUG >= 5: print(prg)
|
||||
if not os.path.exists(fn):
|
||||
_, tmp = tempfile.mkstemp()
|
||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
|
||||
os.rename(tmp, fn)
|
||||
tmp = f"{fn}.{os.getpid()}.tmp"
|
||||
if not binary:
|
||||
prg = CLANG_PROGRAM_HEADER + prg
|
||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
|
||||
os.rename(tmp, fn)
|
||||
else:
|
||||
if CI and ARM64:
|
||||
prg = prg.split('\n') # type: ignore
|
||||
self.varsize = align(int(prg[0].split(" ")[1]))
|
||||
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
|
||||
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
|
||||
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split())
|
||||
self.prg = open(fn + '.bin', 'rb').read()
|
||||
return
|
||||
subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split())
|
||||
self.lib = ctypes.CDLL(fn)
|
||||
self.fxn = self.lib[name]
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
if wait: st = time.monotonic()
|
||||
self.fxn(*[x._buf for x in args])
|
||||
if CI and ARM64:
|
||||
mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
|
||||
total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
|
||||
mu.mem_map(ADDRESS, total_mem)
|
||||
for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
|
||||
mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
|
||||
addr = ADDRESS + len(self.prg)
|
||||
for i, arg in enumerate(args):
|
||||
if i<=7:
|
||||
mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
|
||||
else:
|
||||
# NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
|
||||
mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
|
||||
addr += arg.size * arg.dtype.itemsize
|
||||
mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
|
||||
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
|
||||
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
|
||||
else:
|
||||
self.fxn(*[x._buf for x in args])
|
||||
if wait: return time.monotonic()-st
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
|
||||
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
|
||||
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
|
||||
|
||||
@@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore
|
||||
from typing import Optional, List
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
@@ -29,7 +29,7 @@ class _CL:
|
||||
CL = _CL()
|
||||
CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None
|
||||
|
||||
class CLBuffer(RawBufferCopyInOut):
|
||||
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
if isinstance(dtype, ImageDType):
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
@@ -49,6 +49,10 @@ class CLBuffer(RawBufferCopyInOut):
|
||||
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
|
||||
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
|
||||
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
|
||||
def _transfer(self, x):
|
||||
if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name:
|
||||
cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
|
||||
else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd")
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
|
||||
|
||||
@@ -55,7 +55,8 @@ class LLVMProgram:
|
||||
LLVM.engine.finalize_object()
|
||||
self.fxn = LLVM.engine.get_function_address(name)
|
||||
|
||||
def __del__(self): LLVM.engine.remove_module(self.mod)
|
||||
def __del__(self):
|
||||
if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod)
|
||||
|
||||
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
|
||||
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
|
||||
|
||||
@@ -37,7 +37,7 @@ def unwrap(x):
|
||||
return ret
|
||||
|
||||
class MetalProgram:
|
||||
def __init__(self, name:str, prg:str):
|
||||
def __init__(self, name:str, prg:str, binary:bool=False):
|
||||
if METAL_XCODE:
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
|
||||
@@ -12,7 +12,7 @@ import wgpu # type: ignore
|
||||
device = get_default_device()
|
||||
|
||||
class WebGPUProgram:
|
||||
def __init__(self, name: str, prg: str): self.name,self.prg = name,device.create_shader_module(code=prg)
|
||||
def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,device.create_shader_module(code=prg)
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False):
|
||||
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
|
||||
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from enum import Enum, auto
|
||||
import functools
|
||||
from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.helpers import prod, DEBUG, partition
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int
|
||||
|
||||
# these ops live here
|
||||
@@ -129,17 +129,18 @@ def get_unsafe_resize_offset(strides, arg):
|
||||
return sum([s * x[0] for s, x in zip(strides,arg)])
|
||||
|
||||
class ShapeTracker:
|
||||
__slots__ = "views"
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None):
|
||||
__slots__ = "views", "var_vals"
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None):
|
||||
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [View(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
|
||||
self.var_vals: Dict[Variable, int] = shape.var_vals if isinstance(shape, ShapeTracker) else {}
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views}, var_vals={self.var_vals})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
||||
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape
|
||||
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
|
||||
|
||||
@property
|
||||
def key(self) -> Tuple[View, ...]: return tuple(self.views)
|
||||
@@ -231,15 +232,16 @@ class ShapeTracker:
|
||||
return self
|
||||
|
||||
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
|
||||
# reshape into symbolic shape, update the variable value
|
||||
if all(isinstance(s, int) for s in self.shape) and len(new_vars:=list(s for s in new_shape if isinstance(s, Variable))) > 0:
|
||||
assert len(new_vars) == 1, "only one variable is supported in a shape"
|
||||
new_var, new_val = new_vars[0], prod(self.shape) // prod(s for s in new_shape if isinstance(s, int))
|
||||
if new_var.val is None:
|
||||
new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int))
|
||||
if new_nodes and all(isinstance(s, int) for s in self.shape):
|
||||
# reshape from all int shape into shape with a variable, update the variable value
|
||||
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
|
||||
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
|
||||
if new_var not in self.var_vals:
|
||||
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
|
||||
new_var.val = new_val
|
||||
else: assert new_var.val == new_val, f"value conflicts, was {new_var.val}, set to {new_val}"
|
||||
|
||||
self.var_vals[new_var] = new_val
|
||||
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
|
||||
elif not new_nodes: self.var_vals = {}
|
||||
if self.views[-1].shape == new_shape: return self
|
||||
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
||||
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
|
||||
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
|
||||
def sym_vars(x: Union[Node, int]) -> List[Variable]: return [] if isinstance(x, int) else x.vars()
|
||||
|
||||
class Node:
|
||||
b: Union[Node, int]
|
||||
@@ -141,7 +140,6 @@ class Variable(Node):
|
||||
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
self.val: Optional[int] = None
|
||||
def vars(self): return [self]
|
||||
|
||||
class NumNode(Node):
|
||||
|
||||
Reference in New Issue
Block a user