mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
ane struct
This commit is contained in:
@@ -20,6 +20,54 @@ libane.ANE_TensorData.restype = POINTER(c_uint16)
|
||||
libane.ANE_Run.argtypes = [c_void_p]*3
|
||||
libane.ANE_Run.restype = c_int
|
||||
|
||||
ANE_Struct = [
|
||||
# section @ 0x2C len 0xF4
|
||||
|
||||
# section @ 0x128 len 0x3C (conv)
|
||||
("u16", 0x128, "InputWidth"),
|
||||
("u16", 0x12A, "InputHeight"),
|
||||
("u16", 0x12C, "InputDepth"),
|
||||
|
||||
("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
|
||||
|
||||
("u32", 0x134, "InputChannels"),
|
||||
("u32", 0x138, "OutputChannels"),
|
||||
|
||||
("u16", 0x13C, "OutputWidth"),
|
||||
("u16", 0x13E, "OutputHeight"),
|
||||
("u16", 0x140, "OutputDepth"),
|
||||
|
||||
("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
|
||||
("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
|
||||
|
||||
("u16", 0x14C, "BatchSize"),
|
||||
|
||||
# section @ 0x16C len 0x6C (input)
|
||||
# reloc 0x16c-0x174 = image
|
||||
("u32", 0x178, "InputRowStride"),
|
||||
("u32", 0x17C, "InputPlaneStride"),
|
||||
("u32", 0x180, "InputDepthStride"),
|
||||
("u32", 0x184, "InputBatchStride"),
|
||||
|
||||
("u8", 0x1A7, "InputInterleave"),
|
||||
|
||||
# section @ 0x1E0 len 0x44
|
||||
|
||||
# section @ 0x22c len 0xC (scaling)
|
||||
("u16", 0x230, "BiasScalar"),
|
||||
("u16", 0x232, "ScaleScalar"),
|
||||
|
||||
# section @ 0x240 len 0x10
|
||||
|
||||
# section @ 0x258 len 0x18
|
||||
("u32", 0x260, "OutputRowStride"),
|
||||
("u32", 0x264, "OutputPlaneStride"),
|
||||
("u32", 0x268, "OutputDepthStride"),
|
||||
("u32", 0x26C, "OutputBatchStride"),
|
||||
|
||||
("u8", 0x273, "OutputInterleave")]
|
||||
|
||||
|
||||
class ANETensor:
|
||||
def __init__(self, *shape):
|
||||
self.shape = shape
|
||||
@@ -63,11 +111,12 @@ if __name__ == "__main__":
|
||||
toutd = tout.data()
|
||||
|
||||
tind[0:4] = [-1,1,-2,2]
|
||||
print("** before **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
|
||||
comp = ane.compile(open("../2_compile/model.hwx", "rb").read())
|
||||
comp = ane.compile(open("../ops/relu.hwx", "rb").read())
|
||||
ret = ane.run(comp, tin, tout)
|
||||
|
||||
print("** after **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .tensor import Tensor, Function, register
|
||||
from .tensor import Device, Function, register
|
||||
from functools import lru_cache
|
||||
import struct
|
||||
|
||||
@@ -33,4 +33,4 @@ class ReLU(Function):
|
||||
ret = ctx.ane.tensor(input.shape)
|
||||
ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret)
|
||||
return ret
|
||||
register('relu', ReLU, device=Tensor.ANE)
|
||||
register('relu', ReLU, device=Device.ANE)
|
||||
|
||||
Reference in New Issue
Block a user