diff --git a/accel/llvm/ops_llvm.py b/accel/llvm/ops_llvm.py index 1a2d44b922..62f72f1c77 100644 --- a/accel/llvm/ops_llvm.py +++ b/accel/llvm/ops_llvm.py @@ -4,8 +4,9 @@ import math import time from typing import Tuple, Union, Dict, Any, List from tinygrad.helpers import prod -from tinygrad.shapetracker import ShapeTracker, ZeroView -from tinygrad.ops import LazyOp, ASTKernel +from tinygrad.shape import ShapeTracker, ZeroView +from tinygrad.ops import LazyOp +from tinygrad.ast import ASTKernel import ctypes import numpy as np from ctypes import CFUNCTYPE diff --git a/accel/opencl/ops_opencl.py b/accel/opencl/ops_opencl.py index 582a052e1b..9b9966437d 100644 --- a/accel/opencl/ops_opencl.py +++ b/accel/opencl/ops_opencl.py @@ -139,6 +139,7 @@ class OpenCLBuffer(GPUBuffer): BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)", MovementOps.RESHAPE: "(A)" } + start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"} def __init__(self, shape, hostbuf:Optional[OpenCLBuffer]=None, backing:Optional[np.ndarray]=None): self._image = hostbuf._image if hostbuf is not None else None self.copied_backing = False diff --git a/extra/thneed.py b/extra/thneed.py index 21e686385f..c9f0761431 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -281,7 +281,7 @@ class Thneed: for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)): runtime = (e.profile.end - e.profile.start) print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(prg.op_estimate)/runtime:9.2f} GFLOPS {prg.options} -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}") - if DEBUGCL >= 2 and int(os.getenv("PRINT_KERNEL")) == i: + if DEBUGCL >= 2 and int(os.getenv("PRINT_KERNEL", "-1")) == i: print(prg.prg) total_runtime += runtime print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms") diff --git a/test/test_shapetracker.py b/test/test_shapetracker.py index c1f30371de..d7017db2c3 100644 --- a/test/test_shapetracker.py +++ b/test/test_shapetracker.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad.helpers import prod -from tinygrad.shapetracker import ShapeTracker +from tinygrad.shape import ShapeTracker class DumbShapeTracker: def __init__(self, shape): diff --git a/test/test_symbolic.py b/test/test_symbolic.py index 13bb30c623..1375a1afb8 100644 --- a/test/test_symbolic.py +++ b/test/test_symbolic.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import unittest -from tinygrad.symbolic import Variable +from tinygrad.shape.symbolic import Variable class TestSymbolic(unittest.TestCase): def test_mul_0(self): diff --git a/tinygrad/ast.py b/tinygrad/ast.py new file mode 100644 index 0000000000..16401668af --- /dev/null +++ b/tinygrad/ast.py @@ -0,0 +1,107 @@ +from tinygrad.helpers import dedup, all_same +from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops +from tinygrad.shape import ShapeTracker + +def get_first_reduce(shapes): + for i in range(len(shapes[0])): + if not all_same([x[i] for x in shapes]): + return i + return len(shapes[0]) # off the end + +# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops +class ASTKernel: + def __init__(self, ast:LazyOp): + # key for lookup in cache (can change, str might not be right) + self.key = str(ast) + + # if the AST ends with a RESHAPE, we remove it and create the buffer accordingly + if ast.op == MovementOps.RESHAPE: + output_shape = ast.arg + ast = ast.src[0] + else: + output_shape = None + + self.info = get_lazyop_info(ast) + self.bufs = dedup(get_buffers(ast)) + reduceops = [x for x in get_lazyops(ast) if x.op in ReduceOps] + assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" + self.reduceop = reduceops[0] if reduceops else None + self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else [] + self.ast = ast + + # create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer + self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape) + if hasattr(self.ret, "cl"): self.ret.cl # does the allocation of unbacked buffer, pylint: disable=W0104 + self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs + + # check valid AST kernel + assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape" + assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape" + assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size" + + def process(self): + # get shape, strides, and offset + # if it's a multiview buffer we take the final view + self.shapes = [x.shape for x in self.bufs] + self.strides = [x.st.views[-1].strides for x in self.bufs] + self.offsets = [x.st.views[-1].offset for x in self.bufs] # include the offsets (as is) + self.last_reduce = len(self.shapes[0]) + self.simplify_ones() + self.simplify_merge_adjacent() + + def simplify_ones(self): + # remove places where the shape is all ones + # TODO: this should be factored in to multi shape stride + all_ones = [all(s[i]==1 for s in self.shapes) for i in range(len(self.shapes[0]))] + # keep at least 1 one + if all(all_ones): + all_ones[-1] = False + self.shapes = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.shapes] + self.strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.strides] + self.last_reduce -= sum(all_ones) + # find first mismatch, don't reduce this + self.first_reduce = get_first_reduce(self.shapes) + + def simplify_merge_adjacent(self): + shapes, strides = self.shapes, self.strides + + # merge dimensions if we can, multi get_shape_strides + # TODO: does this always preserve the reduce dimension, NO + # TODO: move this into shapetracker, with tests! + rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] + for i in range(1, len(shapes[0])): + can_merge = [] + for j in range(len(shapes)): + # TODO: added the always mergability of 1s, is this right? if so, add to shapetracker in the 1 case + can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0)) + # more can merge than this + can_merge = all(can_merge) and i != self.first_reduce + if can_merge: + self.last_reduce -= 1 + for j in range(len(shapes)): + if can_merge: + rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) + else: + rets[j].append((shapes[j][i], strides[j][i])) + self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets] + self.first_reduce = get_first_reduce(self.shapes) + + @property + def shape_len(self): return len(self.shapes[0]) + + # this should be aware of the three parts to the shape + # * the input/output dimensions + # * the reduce dimensions + # * the size outputted by each kernel + def reshape_and_permute(self, new_shape_fxn, axis): + new_shapes, new_strides = [], [] + for shape, stride in zip(self.shapes, self.strides): + st = ShapeTracker(tuple(shape)) + st.strided(*zip(shape, stride)) + # TODO: handle reduced shape here + st.reshape(*new_shape_fxn(shape)) + if axis is not None: st.permute(*axis) + assert len(st.views) == 1 + new_shapes.append(st.shape) + new_strides.append(st.strides) + self.shapes, self.strides = new_shapes, new_strides \ No newline at end of file diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index bd143cc607..c79bb3de98 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Union, List, Dict from copy import copy import os, sys, weakref from tinygrad.helpers import ConvArgs, get_available_llops, prod -from tinygrad.shapetracker import ShapeTracker +from tinygrad.shape import ShapeTracker from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, DEBUG from tinygrad.graph import log_op diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 9098584006..c5b5a460a4 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -6,15 +6,18 @@ import pyopencl as cl # type: ignore from collections import defaultdict from typing import List, Tuple, Optional, Dict, Union, Set from tinygrad.helpers import prod, all_same -from tinygrad.ops import DEBUG, ASTKernel, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters +from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters +from tinygrad.ast import ASTKernel from tinygrad.lazy import IMAGE -from tinygrad.shapetracker import ShapeTracker, View, ZeroView -from tinygrad.symbolic import Variable, ModNode +from tinygrad.shape import ShapeTracker, View, ZeroView +from tinygrad.shape.symbolic import Variable, ModNode -VALIDHACKS = int(os.getenv("VALIDHACKS", "0")) +VALIDHACKS = int(os.getenv("VALIDHACKS", "0")) # TODO: remove the need for this NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", 0)) # this is needed as a switch for the tests to pass CLCACHE = int(os.getenv("CLCACHE", "1")) +FLOAT16 = int(os.getenv("FLOAT16", "0")) + class CLBuffer: def __init__(self, size): if len(CL.BUFFER_CACHE[size]) > 0: @@ -30,7 +33,6 @@ class CLBuffer: else: CL.mem_used -= self.cl.size -FLOAT16 = int(os.getenv("FLOAT16", "0")) class CLImage: fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT) @@ -109,6 +111,17 @@ class Token: def __repr__(self): return f"<{self.typ} {self.tok}>" class CLASTKernel(ASTKernel): + code_for_op : Dict[Op, str] = { + UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.SIGN: "sign(A)", + UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", + UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", + UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)", + BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", + BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", + ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)" + } + start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"} + def __init__(self, ast:LazyOp): super().__init__(ast) @@ -197,7 +210,7 @@ class CLASTKernel(ASTKernel): return self.load(buf_index, offset=(offset*self.strides[buf_index][-1] if offset != 0 else 0) + (alt_offset*self.strides[buf_index][-2] if alt_offset != 0 else 0)) if isinstance(x.op, ReduceOps) and reduce is not None: return reduce values = [self.ast_parse(v, offset, alt_offset, reduce) for v in x.src] - code = GPUBuffer.code_for_op[x.op] # TODO: replace this with a function + code = CLASTKernel.code_for_op[x.op] # TODO: replace this with a function if isinstance(x.op, ReduceOps) and values[0].typ != Types.FLOAT and not self.early_loads_are_non_reduce_float4: self.prekernel.add("float clsum(float4 x) { return x.x + x.y + x.z + x.w; }\n") return Token(code.replace("A", f"clsum({values[0].tok})").replace("acc", f"acc.s{offset}" if self.late_are_float4 else "acc"), Types.FLOAT) @@ -216,10 +229,6 @@ class CLASTKernel(ASTKernel): buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else "__global float *" for i,x in enumerate(self.bufs)] self.prekernel = set() - # promote to float4 if these hit - any_early_images = any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) - any_late_images = any(isinstance(buf._buf, CLImage) for buf in self.bufs if buf not in self.earlybufs) - # four toggles determine the kernel self.early_loads_are_non_reduce_float4 = False self.early_loads_are_float4 = False @@ -228,7 +237,7 @@ class CLASTKernel(ASTKernel): # if there's images in the earlybufs, we have to make an axis the 4 loading one # shove the axis to the end and remove - if any_early_images: + if any(isinstance(buf._buf, CLImage) for buf in self.earlybufs): eb_valids = [True] * len(self.shapes[0]) for i in range(len(self.bufs)): if isinstance(self.bufs[i]._buf, CLImage) and self.bufs[i] in self.earlybufs: @@ -250,7 +259,7 @@ class CLASTKernel(ASTKernel): self.early_loads_are_float4 = True # if there's images in the latebufs, we have to make an axis the 4 storing one. this affects the kernel shape - if any_late_images and not self.early_loads_are_non_reduce_float4: + if any(isinstance(buf._buf, CLImage) for buf in self.bufs if buf not in self.earlybufs) and not self.early_loads_are_non_reduce_float4: lb_valids = [True] * len(self.shapes[0]) for i in range(len(self.bufs)): #assert len(self.bufs[i].st.views) == 1 or not isinstance(self.bufs[i]._buf, CLImage) # images can't have views @@ -281,6 +290,8 @@ class CLASTKernel(ASTKernel): [i for i in range(self.shape_len) if i != xb_choice+1] + [xb_choice+1, self.shape_len]) # no change, we added a dimension self.four_float4 = True + + # first simplify self.simplify_ones() # use more opencl indexing @@ -289,7 +300,6 @@ class CLASTKernel(ASTKernel): if all([(base_shape[0]*base_shape[1])%x[0] == 0 for x in self.shapes]): #print("split here", base_shape, self.shapes[0]) self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None) - self.first_reduce += 1 self.last_reduce += 1 self.simplify_ones() @@ -320,11 +330,8 @@ class CLASTKernel(ASTKernel): full_shape = [x for x in self.shapes if x != self.shapes[0]] full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0] - for accumulator in accumulators: - self.kernel.append(f"{accumulator.decltype()} {accumulator.tok} = {GPUBuffer.start_for_op[self.reduceop.op]};\n") - - for i in range(self.first_reduce, self.last_reduce): - self.kernel.append(f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n") + self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators] + self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce, self.last_reduce)] tmp_kernel = [] for accnum, accumulator in enumerate(accumulators): @@ -332,6 +339,7 @@ class CLASTKernel(ASTKernel): tmp_kernel += [f" {accumulator.tok}.s{j} = " + self.ast_parse(self.reduceop, offset=j, alt_offset=accnum).tok.replace("acc", f"acc{accnum}") + ";\n" for j in range(4)] else: tmp_kernel.append(f" {accumulator.tok} = " + self.ast_parse(self.reduceop, alt_offset=accnum).tok.replace("acc", f"acc{accnum}") + ";\n") + self.kernel += tmp_kernel + ["}\n"] * (self.last_reduce - self.first_reduce) # late ast @@ -357,17 +365,6 @@ class CLASTKernel(ASTKernel): return runner class GPUBuffer(ExplicitExecAST): - code_for_op : Dict[Op, str] = { - UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.SIGN: "sign(A)", - UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", - UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", - UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)", - BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", - BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", - ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)" - } - start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"} - def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[GPUBuffer]=None, backing:Optional[np.ndarray]=None): super().__init__(shape, hostbuf) self._buf : Optional[CLBuffer] = hostbuf._buf if hostbuf is not None else None diff --git a/tinygrad/ops.py b/tinygrad/ops.py index bd3e9f9d71..ca30cb2a89 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -2,8 +2,8 @@ import os from enum import Enum from typing import Union, Type, NamedTuple, Tuple, Any, List import functools, operator -from tinygrad.helpers import prod, dedup, all_same -from tinygrad.shapetracker import ShapeTracker +from tinygrad.helpers import prod +from tinygrad.shape import ShapeTracker DEBUG = int(os.getenv("DEBUG", "0")) @@ -87,108 +87,4 @@ class ExplicitExecAST(DeviceBuffer): # TODO: creating a new object is making a copy, breaking the thneed compiler def contiguous(self): return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP) - #def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP) - -def get_first_reduce(shapes): - for i in range(len(shapes[0])): - if not all_same([x[i] for x in shapes]): - return i - return len(shapes[0]) # off the end - -# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops -class ASTKernel: - def __init__(self, ast:LazyOp): - # key for lookup in cache (can change, str might not be right) - self.key = str(ast) - - # if the AST ends with a RESHAPE, we remove it and create the buffer accordingly - if ast.op == MovementOps.RESHAPE: - output_shape = ast.arg - ast = ast.src[0] - else: - output_shape = None - - self.info = get_lazyop_info(ast) - self.bufs = dedup(get_buffers(ast)) - reduceops = [x for x in get_lazyops(ast) if x.op in ReduceOps] - assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" - self.reduceop = reduceops[0] if reduceops else None - self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else [] - self.ast = ast - - # create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer - self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape) - if hasattr(self.ret, "cl"): self.ret.cl # does the allocation of unbacked buffer, pylint: disable=W0104 - self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs - - # check valid AST kernel - assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape" - assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape" - assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size" - - def process(self): - # get shape, strides, and offset - # if it's a multiview buffer we take the final view - self.shapes = [x.shape for x in self.bufs] - self.strides = [x.st.views[-1].strides for x in self.bufs] - self.offsets = [x.st.views[-1].offset for x in self.bufs] # include the offsets (as is) - self.last_reduce = len(self.shapes[0]) - self.simplify_ones() - self.simplify_merge_adjacent() - - def simplify_ones(self): - # remove places where the shape is all ones - # TODO: this should be factored in to multi shape stride - all_ones = [all(s[i]==1 for s in self.shapes) for i in range(len(self.shapes[0]))] - # keep at least 1 one - if all(all_ones): - all_ones[-1] = False - self.shapes = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.shapes] - self.strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.strides] - self.last_reduce -= sum(all_ones) - # find first mismatch, don't reduce this - self.first_reduce = get_first_reduce(self.shapes) - - def simplify_merge_adjacent(self): - shapes, strides = self.shapes, self.strides - - # merge dimensions if we can, multi get_shape_strides - # TODO: does this always preserve the reduce dimension, NO - # TODO: move this into shapetracker, with tests! - rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] - for i in range(1, len(shapes[0])): - can_merge = [] - for j in range(len(shapes)): - # TODO: added the always mergability of 1s, is this right? if so, add to shapetracker in the 1 case - can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0)) - # more can merge than this - can_merge = all(can_merge) and i != self.first_reduce - if can_merge: - self.last_reduce -= 1 - for j in range(len(shapes)): - if can_merge: - rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) - else: - rets[j].append((shapes[j][i], strides[j][i])) - self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets] - self.first_reduce = get_first_reduce(self.shapes) - - @property - def shape_len(self): return len(self.shapes[0]) - - # this should be aware of the three parts to the shape - # * the input/output dimensions - # * the reduce dimensions - # * the size outputted by each kernel - def reshape_and_permute(self, new_shape_fxn, axis): - new_shapes, new_strides = [], [] - for shape, stride in zip(self.shapes, self.strides): - st = ShapeTracker(tuple(shape)) - st.strided(*zip(shape, stride)) - # TODO: handle reduced shape here - st.reshape(*new_shape_fxn(shape)) - if axis is not None: st.permute(*axis) - assert len(st.views) == 1 - new_shapes.append(st.shape) - new_strides.append(st.strides) - self.shapes, self.strides = new_shapes, new_strides + #def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP) \ No newline at end of file diff --git a/tinygrad/shapetracker.py b/tinygrad/shape/__init__.py similarity index 99% rename from tinygrad/shapetracker.py rename to tinygrad/shape/__init__.py index fdec2bce53..9e266ebcb5 100644 --- a/tinygrad/shapetracker.py +++ b/tinygrad/shape/__init__.py @@ -4,7 +4,7 @@ import os import functools from typing import Tuple, Union, List from tinygrad.helpers import prod -from tinygrad.symbolic import Variable +from tinygrad.shape.symbolic import Variable # TODO: fix DEBUG import DEBUG = int(os.getenv("DEBUG", "0")) diff --git a/tinygrad/symbolic.py b/tinygrad/shape/symbolic.py similarity index 94% rename from tinygrad/symbolic.py rename to tinygrad/shape/symbolic.py index 8b54f183b5..ec01cddd93 100644 --- a/tinygrad/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -44,7 +44,7 @@ class Node: def __mod__(self, b:int): if b == 1: return NumNode(0) if isinstance(self, SumNode): - a = Variable.sum([x for x in self.nodes if not (isinstance(x, MulNode) or isinstance(x, NumNode)) or (x.b%b != 0)]) + a = Variable.sum([(x if not isinstance(x, NumNode) else Variable.num(modn(x.b, b))) for x in self.nodes if not (isinstance(x, MulNode) or isinstance(x, NumNode)) or (x.b%b != 0)]) else: a = self if a.min >= 0 and a.max < b: return a @@ -108,10 +108,8 @@ class DivNode(Node): class ModNode(Node): def __init__(self, a:Node, b:int): - if isinstance(a, SumNode): - a = Variable.sum([(x if not isinstance(x, NumNode) else Variable.num(modn(x.b, b))) for x in a.nodes if not (isinstance(x, MulNode) or isinstance(x, NumNode)) or (x.b%b != 0)]) self.a, self.b = a, b - self.min, self.max = min(a.min, 0), max(a.max, b) + self.min, self.max = min(a.min, 0), max(a.max, b-1) @property def expr(self): assert self.a != self