mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
less lines, and oddly faster
This commit is contained in:
@@ -56,39 +56,31 @@ def clbuild(name, prg, options=tuple(), argdtypes=None):
|
||||
#print(prg)
|
||||
return CLProgram(name, prg, options, argdtypes)
|
||||
|
||||
code_for_op = {
|
||||
UnaryOps.RELU: 'max(A, (float)0.)', UnaryOps.EXP: 'exp(A)', UnaryOps.LOG: 'log(A)', UnaryOps.NEG: '-A', UnaryOps.SIGN: 'sign(A)',
|
||||
BinaryOps.ADD: "A+B", BinaryOps.SUB: "A-B", BinaryOps.MUL: "A*B", BinaryOps.DIV: "B/A", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)"
|
||||
}
|
||||
|
||||
def unary_op(op, x):
|
||||
ret = GPUBuffer(x.shape)
|
||||
if op == UnaryOps.RELU: code = 'max(a, (float)0.)'
|
||||
elif op == UnaryOps.EXP: code = 'exp(a)'
|
||||
elif op == UnaryOps.LOG: code = 'log(a)'
|
||||
elif op == UnaryOps.NEG: code = '-a'
|
||||
elif op == UnaryOps.SIGN: code = 'sign(a)'
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
unop = clbuild("unop", """
|
||||
__kernel void unop(__global const float4 *a_g, __global float4 *res_g) {
|
||||
int gid = get_global_id(0);
|
||||
float4 a = a_g[gid];
|
||||
res_g[gid] = """+code+""";
|
||||
float4 A = a_g[gid];
|
||||
res_g[gid] = convert_float4("""+code_for_op[op]+""");
|
||||
}""")
|
||||
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def binary_op(op, x, y):
|
||||
ret = GPUBuffer(x.shape)
|
||||
if op == BinaryOps.ADD: code = "a+b"
|
||||
elif op == BinaryOps.SUB: code = "a-b"
|
||||
elif op == BinaryOps.MUL: code = "a*b"
|
||||
elif op == BinaryOps.DIV: code = "b/a"
|
||||
elif op == BinaryOps.POW: code = "pow(a,b)"
|
||||
elif op == BinaryOps.CMPEQ: code = "(float4)(1.0f*(a.x==b.x), 1.0f*(a.y==b.y), 1.0f*(a.z==b.z), 1.0f*(a.w==b.w))"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
assert x.shape == ret.shape and y.shape == ret.shape
|
||||
binop = clbuild("binop", """
|
||||
__kernel void binop(__global const float4 *a_g, __global const float4 *b_g, __global float4 *res_g) {
|
||||
int gid = get_global_id(0);
|
||||
float4 a = a_g[gid];
|
||||
float4 b = b_g[gid];
|
||||
res_g[gid] = """+code+""";
|
||||
float4 A = a_g[gid];
|
||||
float4 B = b_g[gid];
|
||||
res_g[gid] = convert_float4("""+code_for_op[op]+""");
|
||||
}""")
|
||||
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
@@ -42,7 +42,8 @@ class Device:
|
||||
DEFAULT = i if os.environ.get(name, 0) == "1" else DEFAULT
|
||||
try:
|
||||
llops[i] = importlib.import_module('tinygrad.llops.'+op)
|
||||
buffers[i] = [cls for name, cls in inspect.getmembers(llops[i], inspect.isclass) if name.endswith("Buffer")][0]
|
||||
def find_buffer(llo, name): return [cls for cname, cls in inspect.getmembers(llo, inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
buffers[i] = find_buffer(llops[i], name)
|
||||
except ImportError as e:
|
||||
print(op, "not available", e)
|
||||
DEFAULT = CPU if DEFAULT is None else DEFAULT
|
||||
|
||||
Reference in New Issue
Block a user