ane struct

This commit is contained in:
George Hotz
2020-12-18 09:06:25 -08:00
parent 56d44637f3
commit fbcd1912cf
2 changed files with 53 additions and 4 deletions

View File

@@ -20,6 +20,54 @@ libane.ANE_TensorData.restype = POINTER(c_uint16)
libane.ANE_Run.argtypes = [c_void_p]*3
libane.ANE_Run.restype = c_int
ANE_Struct = [
# section @ 0x2C len 0xF4
# section @ 0x128 len 0x3C (conv)
("u16", 0x128, "InputWidth"),
("u16", 0x12A, "InputHeight"),
("u16", 0x12C, "InputDepth"),
("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
("u32", 0x134, "InputChannels"),
("u32", 0x138, "OutputChannels"),
("u16", 0x13C, "OutputWidth"),
("u16", 0x13E, "OutputHeight"),
("u16", 0x140, "OutputDepth"),
("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
("u16", 0x14C, "BatchSize"),
# section @ 0x16C len 0x6C (input)
# reloc 0x16c-0x174 = image
("u32", 0x178, "InputRowStride"),
("u32", 0x17C, "InputPlaneStride"),
("u32", 0x180, "InputDepthStride"),
("u32", 0x184, "InputBatchStride"),
("u8", 0x1A7, "InputInterleave"),
# section @ 0x1E0 len 0x44
# section @ 0x22c len 0xC (scaling)
("u16", 0x230, "BiasScalar"),
("u16", 0x232, "ScaleScalar"),
# section @ 0x240 len 0x10
# section @ 0x258 len 0x18
("u32", 0x260, "OutputRowStride"),
("u32", 0x264, "OutputPlaneStride"),
("u32", 0x268, "OutputDepthStride"),
("u32", 0x26C, "OutputBatchStride"),
("u8", 0x273, "OutputInterleave")]
class ANETensor:
def __init__(self, *shape):
self.shape = shape
@@ -63,11 +111,12 @@ if __name__ == "__main__":
toutd = tout.data()
tind[0:4] = [-1,1,-2,2]
print("** before **")
print(tind)
print(toutd)
comp = ane.compile(open("../2_compile/model.hwx", "rb").read())
comp = ane.compile(open("../ops/relu.hwx", "rb").read())
ret = ane.run(comp, tin, tout)
print("** after **")
print(tind)
print(toutd)

View File

@@ -1,4 +1,4 @@
from .tensor import Tensor, Function, register
from .tensor import Device, Function, register
from functools import lru_cache
import struct
@@ -33,4 +33,4 @@ class ReLU(Function):
ret = ctx.ane.tensor(input.shape)
ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret)
return ret
register('relu', ReLU, device=Tensor.ANE)
register('relu', ReLU, device=Device.ANE)