From 9ff6c532eb8a97aa013f9222c5fd01323cdd8a72 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 11 Jan 2023 20:18:42 -0800 Subject: [PATCH] Prereqs for IMAGE=1 (#461) * contig * move ast, debug prog * add Token * cleanup reduce * exec_ast --- extra/thneed.py | 1 + tinygrad/lazy.py | 11 ++++++---- tinygrad/llops/ops_gpu.py | 45 ++++++++++++++++++++++++--------------- tinygrad/ops.py | 1 + 4 files changed, 37 insertions(+), 21 deletions(-) diff --git a/extra/thneed.py b/extra/thneed.py index 4a957dfc8a..f08fe9c3e3 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -19,6 +19,7 @@ class Thneed: self.gobj = 0 # build graph + # NOTE: if CLCACHE=1, this is wrong! nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []}) for _, args in self.cl_cache: # output is always the first parameter diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 7836b90d1e..5d6f24a5e6 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -27,7 +27,7 @@ class Device: vars()[name] = name # **** realize helpers **** -def realize_buffers(real_srcs, x): +def realize_buffers(real_srcs, x) -> LazyOp: if x in real_srcs: return realize_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x] return LazyOp(x.op, tuple(realize_buffers(real_srcs, y) for y in x.src), x.arg) @@ -41,7 +41,7 @@ def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], elif self.op.op == LoadOps.CONTIGUOUS: real_src = self.op.src[0].realize(self.device) ret = real_src.contiguous() - return ret, [real_src], LoadOps if id(ret) != id(real_src) else None + return ret, [real_src], LoadOps else: raise NotImplementedError(f"unknown LoadOp {self.op.op}") @@ -65,7 +65,8 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer return self.dbuffer.exec_ast(ast), list(real_srcs.values()), ReduceOps else: real_src = src.realize(self.device) - return real_src.reduce_op(self.op.op, self.op.arg), [real_src], ReduceOps + ast = LazyOp(self.op.op, (real_src,), self.op.arg) + return self.dbuffer.exec_ast(ast), [real_src], ReduceOps # this supports late merging an upstream Reduce op and even an Elementwise op above that def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: @@ -99,6 +100,8 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer if psrcs[0][0].shape != psrcs[0][1].shape: intermediate_shape = psrcs[0][1].shape assert psrcs[0][0].shape == self.shape, f"shape mismatch {psrcs[0][0].shape} != {self.shape}" + + # reshape all the late ops into the output shape # NOTE: these RESHAPEs will return self if they don't change the shape for x in real_srcs.keys(): if real_srcs[x] is None: @@ -111,7 +114,7 @@ _realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps: # **** lazy operations **** def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple(get_weakop(x) if isinstance(x, LazyOp) else weakref.ref(x) for x in op.src), op.arg) -def get_movementroot(root:LazyBuffer) -> LazyBuffer: return get_movementroot(root.op.src[0]) if root.optype == MovementOps and root.realized is None else root +def get_movementroot(root:LazyBuffer) -> LazyBuffer: return get_movementroot(root.op.src[0]) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and root.op.src[0].st.contiguous)) else root def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot(x) if x.optype == MovementOps and x.st.contiguous else x LAZY = int(os.getenv("LAZY", "1")) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index c997327c81..f3d22e131e 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -1,10 +1,11 @@ from __future__ import annotations import os, functools +from enum import Enum import numpy as np import pyopencl as cl # type: ignore from collections import defaultdict from typing import List, Tuple, Optional, Dict, Union, Set -from tinygrad.helpers import prod +from tinygrad.helpers import prod, all_same from tinygrad.ops import DEBUG, ASTKernel, UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExecAST, GlobalCounters from tinygrad.shapetracker import ShapeTracker @@ -55,7 +56,11 @@ class CLProgram: self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else ''}" if rename else name self.prg, self.options, self.argdtypes = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes self.clprogram = cl.Program(CL().cl_ctx, CL().cl_ctx.devices, [self.prg]) if binary else cl.Program(CL().cl_ctx, self.prg) # type: ignore - self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name) + try: + self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name) + except cl.RuntimeError as e: + print("FAILED TO BUILD", self.prg) + raise e if self.argdtypes is not None: self.clprg.set_scalar_arg_dtypes(self.argdtypes) CLProgram.kernel_cnt[name] += 1 @@ -79,10 +84,17 @@ class CLProgram: # **** end CL wrappers **** +Types = Enum("Types", ["FLOAT", "FLOAT4"]) +class Token: + def __init__(self, tok:str, typ:Types): + assert isinstance(tok, str) + self.tok = tok + self.typ = typ + def __repr__(self): return f"<{self.typ} {self.tok}>" + class CLASTKernel(ASTKernel): def __init__(self, ast:LazyOp): super().__init__(ast) - self.ast = ast def compute_buf_index(self, st, buf_index, offset=0): key = f"{buf_index}_{offset}" @@ -97,13 +109,13 @@ class CLASTKernel(ASTKernel): self.seen_idx.add(key) return key - def store(self, buf_index, value, offset=0): + def store(self, buf_index, value:Token, offset=0): st = self.bufs[buf_index].st if offset > 0: assert len(st.views) == 1 key = self.compute_buf_index(st, buf_index, offset) - self.kernel.append(f"data{buf_index}[bufi{key}] = {value};\n") + self.kernel.append(f"data{buf_index}[bufi{key}] = {value.tok};\n") - def load(self, buf_index, offset=0): + def load(self, buf_index, offset=0) -> Token: if buf_index not in self.loaded_keys: st = self.bufs[buf_index].st if offset > 0: assert len(st.views) == 1 @@ -118,18 +130,18 @@ class CLASTKernel(ASTKernel): ldr = f"data{buf_index}[bufi{key}]" if not constant_fold else constant_fold ldr = f"(bufvalid{key} ? {ldr} : 0.0)" if st.needs_valid() else ldr self.kernel.append(f"float val{key} = {ldr};\n") - self.loaded_keys[buf_index] = f"val{key}" + self.loaded_keys[buf_index] = Token(f"val{key}", Types.FLOAT) return self.loaded_keys[buf_index] - def ast_parse(self, x, reduce=False) -> str: + def ast_parse(self, x:Union[GPUBuffer, LazyOp], reduce:Optional[Token]=None) -> Token: if not isinstance(x, LazyOp): return self.load(self.bufs.index(x)) - if isinstance(x.op, ReduceOps) and not reduce: return "acc" + if isinstance(x.op, ReduceOps) and reduce is not None: return reduce values = [self.ast_parse(v, reduce) for v in x.src] code = GPUBuffer.code_for_op[x.op] # TODO: replace this with a function - if isinstance(x.op, ReduceOps): return code.replace("A", values[0]) - if len(values) >= 1: code = code.replace("A", values[0]) - if len(values) >= 2: code = code.replace("B", values[1]) - return code + assert all_same([x.typ for x in values]), f"type mismatch in {values}" + if len(values) >= 1: code = code.replace("A", values[0].tok) + if len(values) >= 2: code = code.replace("B", values[1].tok) + return Token(code, values[0].typ) def codegen(self): # TODO: fetch from quick cache before processing @@ -137,7 +149,7 @@ class CLASTKernel(ASTKernel): self.bufs_to_delete : Set[int] = set() self.seen_idx : Set[str] = set() - self.loaded_keys : Dict[int, str] = {} + self.loaded_keys : Dict[int, Token] = {} self.output_shape = self.shapes[0][:self.first_reduce] self.kernel : List[str] = [f"int idx{i} = get_global_id({min(3, len(self.output_shape))-1-i});\n" for i in range(min(3, len(self.output_shape)))] @@ -155,12 +167,11 @@ class CLASTKernel(ASTKernel): self.kernel.append(f"float acc = {GPUBuffer.start_for_op[self.reduceop.op]};\n") for i in range(self.first_reduce, self.last_reduce): self.kernel.append(f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n") - self.kernel.append(" acc = " + self.ast_parse(self.reduceop, reduce=True) + ";\n") + self.kernel.append(" acc = " + self.ast_parse(self.reduceop).tok + ";\n") self.kernel += ["}\n"] * (self.last_reduce - self.first_reduce) # late ast - process_ast = self.ast_parse(self.ast) - self.store(0, process_ast) + self.store(0, self.ast_parse(self.ast, Token("acc", Types.FLOAT))) self.kernel.append("}") # kernel function definition diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b571f63d11..dc88c381fa 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -104,6 +104,7 @@ class ASTKernel: assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" self.reduceop = reduceops[0] if reduceops else None self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else [] + self.ast = ast # create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer self.ret = type(self.bufs[0])(self.info.shape)