From 563bf2d8e8cfddb5884d2927b3262635abcbb22a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 8 Jul 2022 07:40:30 -0700 Subject: [PATCH] force input/weight to be contiguous (uncached) --- tinygrad/llops/ops_gpu.py | 4 ++-- tinygrad/ops.py | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 41c02350fe..de3abb86f1 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -154,9 +154,9 @@ class GPUBuffer: global_size = [C.bs*C.cout, C.oy, C.ox] # now input and weight can be anywhere in bufs - convbufs_contig = tuple(x[1].st.contiguous for x in bufs if x[0] in ["input", "weight"]) - assert convbufs_contig == (True, True), "input and weight missing or not contiguous" + bufs = [(x[0], x[1].contiguous_op()) if x[0] in ["input", "weight"] else x for x in bufs] ewbufs = [x for x in bufs if x[0] not in ["input", "weight"]] + assert len(bufs) == len(ewbufs)+2, "input or weight missing" kernel_name = "conv" conv_src = """ diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 10be064668..93d1dc5bb4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -25,8 +25,8 @@ GRAPH = int(os.getenv("GRAPH", "0")) OPT = int(os.getenv("OPT", "1")) MERGE_MOVEMENT_OPS, REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS = OPT>=1, OPT>=1, OPT>=1 -MERGE_ELEMENTWISE_OPS = OPT>=2 -SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_INTO_CONV_OUTPUTS = OPT>=3, OPT>=3 +MERGE_ELEMENTWISE_OPS, MERGE_ONE_CONV_INTO_ELEMENTWISE = OPT>=2, OPT>=2 +SHUFFLE_MOVEMENT_OPS = OPT>=3 SHUFFLE_SLICE_OPS = OPT>=4 # NOTE: 0/0 is NaN if you slice, so this can change the output # **** enumerate supported devices **** @@ -113,7 +113,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer # NOTE: if it references the same conv multiple times, they should already be merged by the dictionary conv_args : Optional[ConvArgs] = None psrcs = [x for x in real_srcs.keys() if x.optype == ProcessingOps and x.realized is None and len(x.children) <= 1] - if len(psrcs) == 1 and MERGE_ELEMENTWISE_INTO_CONV_OUTPUTS: + if len(psrcs) == 1 and MERGE_ONE_CONV_INTO_ELEMENTWISE: conv_args = psrcs[0].op.arg del real_srcs[psrcs[0]] real_srcs[psrcs[0].op.src[0]], real_srcs[psrcs[0].op.src[1]] = None, None @@ -144,8 +144,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer return ast_eval(self.op), list(real_srcs.values()), BinaryOps def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: - real_src_x = self.op.src[0].realize(self.device) - real_src_w = self.op.src[1].realize(self.device) + real_src_x, real_src_w = [x.realize(self.device) for x in self.op.src] return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w], ProcessingOps _realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}