From 4e17d27d093c77af1c19915d08d227e0fead2fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Henrik=20H=C3=B8iland?= Date: Sun, 16 Apr 2023 22:52:10 +0200 Subject: [PATCH] Fix cuda errors when running llama example (#749) --- tinygrad/codegen/cstyle.py | 2 ++ tinygrad/runtime/ops_cuda.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index f20afdbfb3..7f86192d36 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -116,6 +116,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan assert newvar is not None if args == -math.inf: kk(f"{newvar.render(True)} = -INFINITY;") + elif newvar.ltype == LocalTypes.float4: + kk(f"{newvar.render(True)} = {{ {args}f, {args}f, {args}f, {args}f }};") else: kk(f"{newvar.render(True)} = {args}f;") elif uop == UOps.ALU: diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index b14774ce52..8d09a6c0b4 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -32,7 +32,7 @@ class CUDAProgram: if wait: start, end = cuda.Event(), cuda.Event() start.record() - self.prg(*[x._cl for x in args], block=tuple(local_size), grid=tuple(global_size)) + self.prg(*[x._buf for x in args], block=tuple(local_size), grid=tuple(global_size)) if wait: end.record() end.synchronize()