mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
get_single_element [pr] (#8328)
This commit is contained in:
@@ -56,6 +56,9 @@ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]
|
||||
def unwrap(x:Optional[T]) -> T:
|
||||
assert x is not None
|
||||
return x
|
||||
def get_single_element(x:List[T]) -> T:
|
||||
assert len(x) == 1, f"list {x} must only have 1 element"
|
||||
return x[0]
|
||||
def get_child(obj, key):
|
||||
for k in key.split('.'):
|
||||
if k.isnumeric(): obj = obj[int(k)]
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.helpers import prod, flatten
|
||||
from tinygrad.helpers import prod, flatten, get_single_element
|
||||
|
||||
def render_val(x, dtype):
|
||||
if dtypes.is_float(dtype):
|
||||
@@ -172,8 +172,7 @@ class PTXRenderer(Renderer):
|
||||
r[u] = [cast(str,r[x]) for x in u.src]
|
||||
continue
|
||||
if u.op is Ops.GEP:
|
||||
assert len(u.arg) == 1
|
||||
r[u] = r[u.src[0]][u.arg[0]]
|
||||
r[u] = r[u.src[0]][get_single_element(u.arg)]
|
||||
continue
|
||||
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
|
||||
r[u] = r[u.src[0]]
|
||||
|
||||
@@ -6,7 +6,7 @@ import sys
|
||||
from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -114,9 +114,7 @@ class PythonProgram:
|
||||
elif uop is Ops.ASSIGN:
|
||||
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
|
||||
ul[i] = inp[0]
|
||||
elif uop is Ops.GEP:
|
||||
assert len(arg) == 1
|
||||
ul[i] = inp[0][arg[0]]
|
||||
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
|
||||
elif uop is Ops.WMMA:
|
||||
# here are the models for the WMMA instruction on the different hardware
|
||||
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
||||
|
||||
Reference in New Issue
Block a user