diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 83f77a0dfe..83b514d645 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -152,7 +152,7 @@ jobs: name: Test GPU IMAGE ops run: | GPU=1 IMAGE=1 python -m pytest -n=auto test/test_ops.py - FORWARD_ONLY=1 GPU=1 IMAGE=2 python -m pytest -n=auto test/test_ops.py + GPU=1 IMAGE=2 python -m pytest -n=auto test/test_ops.py - if: ${{ matrix.task == 'openpilot' }} name: Test openpilot model compile and size run: | @@ -299,7 +299,7 @@ jobs: run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models - name: Run pytest (triton) if: matrix.backend=='triton' - run: python -m pytest -v -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models + run: python -m pytest -v -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models testunicorn: name: ARM64 unicorn Test diff --git a/openpilot/compile2.py b/openpilot/compile2.py index ba31a2ea7e..de8fc34cc1 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -13,9 +13,9 @@ import io from typing import Tuple, List from extra.utils import fetch from extra.onnx import get_run_onnx -from tinygrad.graph import print_tree +from tinygrad.graph import print_tree, log_schedule_item from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv, ImageDType +from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv, ImageDType, GRAPH from tinygrad.realize import run_schedule from tinygrad.ops import LoadOps, Device, ScheduleItem from tinygrad.features.image import fix_schedule_for_images @@ -64,15 +64,18 @@ def lb_to_numbers(schedule): if __name__ == "__main__": schedule, schedule_independent = get_schedule(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL) - run_schedule(schedule_independent, disable_logging=True) - schedule = fix_schedule_for_images(schedule) - - image_count = 0 - for si in schedule: - if isinstance(si.out.dtype, ImageDType): - image_count += 1 - + schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps) + print(f"{len(schedule_input)} inputs") + #schedule = fix_schedule_for_images(schedule) + image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule) print(f"**** running real kernels {image_count}/{len(schedule)} images ****") + + #if GRAPH: + # for si in schedule_input: log_schedule_item(si) + # for si in schedule: log_schedule_item(si) + + run_schedule(schedule_independent, disable_logging=True) + run_schedule(schedule_input) with Context(DEBUG=2, BEAM=getenv("LATEBEAM")): GlobalCounters.reset() run_schedule(schedule) diff --git a/test/test_schedule.py b/test/test_schedule.py index d5dc2b217c..3ddc17ee2c 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -175,7 +175,7 @@ class TestSchedule(unittest.TestCase): c1 = nn.Conv2d(3,16,3) # run - img = Tensor.ones(2,3,64,64) + img = Tensor.rand(2,3,64,64) out = c1(img).elu() check_schedule(out, 1, [c1.weight, c1.bias]) diff --git a/tinygrad/features/image.py b/tinygrad/features/image.py index c007375de9..7f75a124f0 100644 --- a/tinygrad/features/image.py +++ b/tinygrad/features/image.py @@ -1,5 +1,5 @@ from typing import List, Tuple, Dict, Any -from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG +from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, flatten # *** image Tensor function replacements *** @@ -51,14 +51,14 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape))) # packed (note: flipping bs and iy would make the auto-padding work) - x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4) + x = x.permute(0,2,3,1) cin_last = iy == 1 and ix == 1 - if cin == 1: w = w.reshape(cout//4,4,H*W).permute(0,2,1) - elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4) - else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4) + if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1) + elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3) + else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1) # contiguous creates the image, and early realize static weights (TODO: test for the static weight) - if IMAGE >= 2: x,w = x.cast(base_image_type(x.shape)), w.cast(base_image_type(w.shape)) + if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) x, w = x.contiguous(), w.contiguous() if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize() @@ -84,12 +84,9 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape) # the conv! (+ the bias) - ret = (x*w).cast(dtypes.float32).sum((-4, -3, -2, -1)) - - # reshape to image and cast back to image - ret = ret.reshape(bs*oy, ox*cout//4, 4) - if IMAGE >= 2: ret = ret.cast(base_image_type(ret.shape)) - if IMAGE >= 3: ret = ret.contiguous() + ret = x*w + if IMAGE >= 2: ret = ret.cast(base_image_type((bs*oy, ox*cout//4, 4))) + ret = ret.sum((-4, -3, -2, -1)) # undo hack for non multiples of 4 on C.rcout if added_output_channels != 0: @@ -108,26 +105,40 @@ from tinygrad.ops import ScheduleItem, BufferOps, LazyOp, UnaryOps, LoadOps, Mem def fix_schedule_for_images(schedule:List[ScheduleItem]): # this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?) + replace_inputs = {} for i, si in enumerate(schedule): if isinstance(si.out.dtype, ImageDType) and (prod(si.out.shape) != prod(si.out.dtype.shape) or not any(si.out.shape[x]%4 == 0 for x in si.out.st.unit_stride_axes())): if DEBUG >= 1: print(f"{i:3d}: rewrite output, output shape {prod(si.out.shape)}, image dtype {si.out.dtype} prod {prod(si.out.dtype.shape)}") si.out.dtype = dtypes.float32 for b in si.ast.get_lazyops(): if b.op != BufferOps.MEM: continue + # TODO: unit_stride axes will fail if there's a mask, even if the mask is divisble by four. this is too aggressive if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())): - if DEBUG >= 1: print(f"{i:3d}: rewrite input, image dtype {si.inputs[b.arg.idx-1].dtype}") - si.inputs[b.arg.idx-1].dtype = dtypes.float32 + if DEBUG >= 1: print(f"{i:3d}: rewrite input, image dtype {si.inputs[b.arg.idx-1].dtype}, {b.arg.st.views}") + if si.inputs[b.arg.idx-1].realized: + # have to copy it + replace_inputs[si.inputs[b.arg.idx-1]] = si.inputs[b.arg.idx-1].cast(dtypes.float32) + else: + # change it before it's created + si.inputs[b.arg.idx-1].dtype = dtypes.float32 # now fix up the schedule to reflect the new dtypes fixed_schedule:List[ScheduleItem] = [] for i,si in enumerate(schedule): ast = si.ast + inputs = si.inputs + + # replace inputs with casted versions + if any(x in replace_inputs for x in inputs): + fixed_schedule += flatten([replace_inputs[x].schedule() for x in inputs if x in replace_inputs]) + inputs = tuple(replace_inputs.get(x, x) for x in inputs) + # fix input dtypes to match what they actually are replacements = {} for b in si.ast.get_lazyops(): if b.op != BufferOps.MEM: continue - if b.arg.dtype != si.inputs[b.arg.idx-1].dtype: - replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, si.inputs[b.arg.idx-1].dtype, b.arg.st)) + if b.arg.dtype != inputs[b.arg.idx-1].dtype: + replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, inputs[b.arg.idx-1].dtype, b.arg.st)) if replacements: ast = ast.map_buffers(replacements) # fix the ops to create the output dtype @@ -138,7 +149,7 @@ def fix_schedule_for_images(schedule:List[ScheduleItem]): ast = LazyOp(UnaryOps.CAST, (ast,), (si.out.dtype, False)) # put this in the fixed schedule - fixed_schedule.append(dataclasses.replace(si, ast=ast)) + fixed_schedule.append(dataclasses.replace(si, ast=ast, inputs=inputs)) return fixed_schedule # *** images have weird indexing requirements *** @@ -158,6 +169,7 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tup assert isinstance(node, LtNode) node_flat, node_vars = node.a.flat_components if isinstance(node.a, SumNode) else [node.a], node.vars() same_sym = [i for (i, var) in idxy_flat_var if var in node_vars] + if len(same_sym) == 0: continue first, second = sorted(same_sym)[0], sorted(node_flat)[0] f_b = 1 if isinstance(first, Variable) else first.b s_b = 1 if isinstance(second, Variable) else second.b @@ -188,5 +200,5 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tup if idy_vars == node_vars or idy_vars & node_vars == set(): ones.append(node) valid = Variable.ands([i for i in nodes if i not in ones]) - if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy) + if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid) return (idx, idy), valid diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 22ded69ed9..195af841f4 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -6,7 +6,7 @@ except ImportError: from collections import defaultdict from typing import Dict, List from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp -from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv +from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup # **** debugging and graphing **** @@ -47,6 +47,7 @@ def str_dtype(dtyp): logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None def log_schedule_item(si: ScheduleItem): + global node_count if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n") show_graph = bool(GRAPH) if not DEBUG and not show_graph: return @@ -60,12 +61,27 @@ def log_schedule_item(si: ScheduleItem): if show_graph: assert si.out.base == si.out, "all outputs based" top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'} - for x in si.inputs: - assert x.base == x, "all inputs based" - #assert nm(x) in G.nodes, "all inputs seen" - G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060') + + # get inputs for shapetrackers + input_to_st = defaultdict(list) + for lo in si.ast.get_lazyops(): + if lo.op != BufferOps.MEM: continue + input_to_st[si.inputs[lo.arg.idx-1]].append(lo.arg.st) + + # add them to the graph, potentially with a movement op seperating them + for x in input_to_st: + for st in dedup(input_to_st[x]): + if st.contiguous: + G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060') + else: + inter_node = node_count + node_count += 1 + G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{st.real_offset()}" if st.real_offset() != 0 else "")) + G.add_edge(nm(x), inter_node, color='#00000060') + G.add_edge(inter_node, nm(si.out), label=get_sop(op), color='#00000060') if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype) + if nm(si.out) not in G.nodes: G.add_node(nm(si.out)) G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "") diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 5b6e4162f8..bd657b06a6 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -201,6 +201,9 @@ class LazyBuffer: def fromCPU(x: np.ndarray) -> LazyBuffer: return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x)) + def cast(self, dtype:DType, bitcast:bool=False): + return self.e(UnaryOps.CAST, arg=(dtype, bitcast)) + # *** elementwise ops *** def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer: @@ -279,7 +282,7 @@ class LazyBuffer: def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: if arg == tuple(range(len(self.shape))): return self if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg])) - if not self.realized: + if SHUFFLE_MOVEMENT_OPS and not self.realized: if PUSH_PERMUTES and self.optype == ReduceOps: # reduceops have one buffer input, permute it narg = tuple([self.op.arg[a] for a in arg]) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 6c24d2466e..ee0766fbe9 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -17,10 +17,10 @@ class ContiguousBackward(Function): class Cast(Function): def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer: self.input_dtype, self.bitcast = x.dtype, bitcast - return x.e(UnaryOps.CAST, arg=(dtype, bitcast)) + return x.cast(dtype, bitcast) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.e(UnaryOps.CAST, arg=(self.input_dtype, self.bitcast)) + return grad_output.cast(self.input_dtype, self.bitcast) # ************* unary ops *************