get_single_element [pr] (#8328)

This commit is contained in:
George Hotz
2024-12-18 22:23:45 -08:00
committed by GitHub
parent 423d823c50
commit 3a9ca62b9e
3 changed files with 7 additions and 7 deletions

View File

@@ -56,6 +56,9 @@ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]
def unwrap(x:Optional[T]) -> T:
assert x is not None
return x
def get_single_element(x:List[T]) -> T:
assert len(x) == 1, f"list {x} must only have 1 element"
return x[0]
def get_child(obj, key):
for k in key.split('.'):
if k.isnumeric(): obj = obj[int(k)]

View File

@@ -5,7 +5,7 @@ from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.helpers import prod, flatten
from tinygrad.helpers import prod, flatten, get_single_element
def render_val(x, dtype):
if dtypes.is_float(dtype):
@@ -172,8 +172,7 @@ class PTXRenderer(Renderer):
r[u] = [cast(str,r[x]) for x in u.src]
continue
if u.op is Ops.GEP:
assert len(u.arg) == 1
r[u] = r[u.src[0]][u.arg[0]]
r[u] = r[u.src[0]][get_single_element(u.arg)]
continue
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
r[u] = r[u.src[0]]

View File

@@ -6,7 +6,7 @@ import sys
from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
@@ -114,9 +114,7 @@ class PythonProgram:
elif uop is Ops.ASSIGN:
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
ul[i] = inp[0]
elif uop is Ops.GEP:
assert len(arg) == 1
ul[i] = inp[0][arg[0]]
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
elif uop is Ops.WMMA:
# here are the models for the WMMA instruction on the different hardware
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):