From 3a9ca62b9eead39adff16c00eb35e0da555e74d1 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 18 Dec 2024 22:23:45 -0800 Subject: [PATCH] get_single_element [pr] (#8328) --- tinygrad/helpers.py | 3 +++ tinygrad/renderer/ptx.py | 5 ++--- tinygrad/runtime/ops_python.py | 6 ++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 03f7c653c7..519669ff77 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -56,6 +56,9 @@ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T] def unwrap(x:Optional[T]) -> T: assert x is not None return x +def get_single_element(x:List[T]) -> T: + assert len(x) == 1, f"list {x} must only have 1 element" + return x[0] def get_child(obj, key): for k in key.split('.'): if k.isnumeric(): obj = obj[int(k)] diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 5be5c1c1f3..b8c77374be 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -5,7 +5,7 @@ from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer -from tinygrad.helpers import prod, flatten +from tinygrad.helpers import prod, flatten, get_single_element def render_val(x, dtype): if dtypes.is_float(dtype): @@ -172,8 +172,7 @@ class PTXRenderer(Renderer): r[u] = [cast(str,r[x]) for x in u.src] continue if u.op is Ops.GEP: - assert len(u.arg) == 1 - r[u] = r[u.src[0]][u.arg[0]] + r[u] = r[u.src[0]][get_single_element(u.arg)] continue if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)): r[u] = r[u.src[0]] diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 3076cd747f..2ca46c7f1d 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -6,7 +6,7 @@ import sys from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate -from tinygrad.helpers import all_same, getenv, flatten +from tinygrad.helpers import all_same, getenv, flatten, get_single_element from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.ops import exec_alu, Ops, UOp, GroupOp from tinygrad.renderer import Renderer @@ -114,9 +114,7 @@ class PythonProgram: elif uop is Ops.ASSIGN: for j in range(len(inp[0])): inp[0][j] = inp[1][j] ul[i] = inp[0] - elif uop is Ops.GEP: - assert len(arg) == 1 - ul[i] = inp[0][arg[0]] + elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)] elif uop is Ops.WMMA: # here are the models for the WMMA instruction on the different hardware def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):