mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
rename ops to have unique names (#7522)
This commit is contained in:
2
test/external/fuzz_schedule.py
vendored
2
test/external/fuzz_schedule.py
vendored
@@ -50,7 +50,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
rawbufs: Dict[LazyBuffer, Buffer] = {}
|
||||
for lsi in ts:
|
||||
for out in lsi.outputs:
|
||||
base = rawbufs[lsi.inputs[0]].base if out.op is MetaOps.VIEW else None
|
||||
base = rawbufs[lsi.inputs[0]].base if out.op is MetaOps.BUFFER_VIEW else None
|
||||
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
|
||||
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
||||
for x in lsi.inputs:
|
||||
|
||||
@@ -144,13 +144,15 @@ class TestMultiTensor(unittest.TestCase):
|
||||
O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
|
||||
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
|
||||
|
||||
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)), strat.sampled_from((ReduceOps.SUM, ReduceOps.PROD, ReduceOps.MAX)),
|
||||
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
|
||||
strat.sampled_from((ReduceOps.SUM, ReduceOps.PROD, ReduceOps.REDUCE_MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
||||
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
|
||||
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
|
||||
n = X.numpy()
|
||||
X.shard_(devices, shard_axis)
|
||||
f = {ReduceOps.SUM: lambda x: x.sum(reduce_axis), ReduceOps.PROD: lambda x: x.prod(reduce_axis), ReduceOps.MAX: lambda x: x.max(reduce_axis)}[rop]
|
||||
f = {ReduceOps.SUM: lambda x: x.sum(reduce_axis), ReduceOps.PROD: lambda x: x.prod(reduce_axis),
|
||||
ReduceOps.REDUCE_MAX: lambda x: x.max(reduce_axis)}[rop]
|
||||
fX = f(X)
|
||||
fn = f(n)
|
||||
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
||||
|
||||
@@ -1495,7 +1495,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_arange_view_op(self):
|
||||
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous()
|
||||
assert isinstance(a.lazydata, LazyBuffer)
|
||||
self.assertIs(a.lazydata.base.op, MetaOps.VIEW)
|
||||
self.assertIs(a.lazydata.base.op, MetaOps.BUFFER_VIEW)
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), [[4, 5]])
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
def test_no_implicit_broadcasting(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
||||
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
|
||||
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.REDUCE_MAX, (1,)))
|
||||
st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
|
||||
for r in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
|
||||
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
|
||||
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.VIEW}}
|
||||
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group:
|
||||
del realizes[tr]
|
||||
|
||||
@@ -34,7 +34,7 @@ class LazyBuffer(MathTrait):
|
||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
|
||||
assert self.op is not MetaOps.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
|
||||
|
||||
if self.op is MetaOps.VIEW:
|
||||
if self.op is MetaOps.BUFFER_VIEW:
|
||||
# some LazyBuffers can be processed with only a view, no AST required
|
||||
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
||||
else:
|
||||
@@ -89,7 +89,7 @@ class LazyBuffer(MathTrait):
|
||||
|
||||
def contiguous(self, allow_buffer_view=True):
|
||||
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
||||
ret = self.alu(MetaOps.VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS)
|
||||
ret = self.alu(MetaOps.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS)
|
||||
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
||||
return ret
|
||||
self.base.forced_realize = True
|
||||
@@ -111,7 +111,8 @@ class LazyBuffer(MathTrait):
|
||||
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
||||
# TODO: applying this makes gpt2 slower
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
||||
cast_op: Union[MetaOps, UnaryOps] = \
|
||||
(MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
||||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp)
|
||||
@@ -188,7 +189,7 @@ class LazyBuffer(MathTrait):
|
||||
if self.is_unrealized_unmasked_const() and all_int(self.shape):
|
||||
if op is ReduceOps.SUM: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
|
||||
if op is ReduceOps.PROD: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
|
||||
if op is ReduceOps.MAX: return self.const_with_shape(self.base.arg, new_shape)
|
||||
if op is ReduceOps.REDUCE_MAX: return self.const_with_shape(self.base.arg, new_shape)
|
||||
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
||||
|
||||
@@ -15,7 +15,7 @@ from tinygrad.device import Buffer
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
BUF_LIMIT = {"METAL":32}
|
||||
METAOPS = {MetaOps.COPY:Ops.COPY, MetaOps.EMPTY:Ops.EMPTY, MetaOps.VIEW:Ops.BUFFER_VIEW}
|
||||
METAOPS = {MetaOps.COPY:Ops.COPY, MetaOps.EMPTY:Ops.EMPTY, MetaOps.BUFFER_VIEW:Ops.BUFFER_VIEW}
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ class Prod(Function):
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
|
||||
self.x, self.ret, self.axis = x, x.r(ReduceOps.REDUCE_MAX, axis), axis
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
||||
@@ -31,9 +31,9 @@ class TernaryOps(FastEnum):
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
class ReduceOps(FastEnum):
|
||||
"""A -> B (reduce)"""
|
||||
SUM = auto(); PROD = auto(); MAX = auto() # noqa: E702
|
||||
SUM = auto(); PROD = auto(); REDUCE_MAX = auto() # noqa: E702
|
||||
class MetaOps(FastEnum):
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); ASSIGN = auto(); BUFFER_VIEW = auto() # noqa: E702
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
|
||||
|
||||
class SimpleMathTrait:
|
||||
@@ -118,7 +118,7 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
# do not preserve f(0) = 0
|
||||
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
||||
|
||||
REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}
|
||||
REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.REDUCE_MAX:BinaryOps.MAX}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
Reference in New Issue
Block a user