From 5713443fb124d3b4da1675abf3f05ef7882f5883 Mon Sep 17 00:00:00 2001 From: "drCathieSo.eth" Date: Sun, 4 Feb 2024 01:31:22 +0800 Subject: [PATCH] Add Conv2Dsame and Conv2Dsame_stride test circuits and Conv2Dsame.circom --- circuits/Conv2Dsame.circom | 91 +++ models/conv2Dsame_input.json | 1 + models/conv2Dsame_stride_input.json | 1 + models/conv2d_same.ipynb | 803 ++++++++++++++++++++ test/Conv2Dsame.js | 37 + test/circuits/Conv2Dsame_stride_test.circom | 5 + test/circuits/Conv2Dsame_test.circom | 5 + 7 files changed, 943 insertions(+) create mode 100644 circuits/Conv2Dsame.circom create mode 100644 models/conv2Dsame_input.json create mode 100644 models/conv2Dsame_stride_input.json create mode 100644 models/conv2d_same.ipynb create mode 100644 test/Conv2Dsame.js create mode 100644 test/circuits/Conv2Dsame_stride_test.circom create mode 100644 test/circuits/Conv2Dsame_test.circom diff --git a/circuits/Conv2Dsame.circom b/circuits/Conv2Dsame.circom new file mode 100644 index 0000000..c59a644 --- /dev/null +++ b/circuits/Conv2Dsame.circom @@ -0,0 +1,91 @@ +pragma circom 2.0.0; + +include "./Conv2D.circom"; + +template Conv2Dsame (nRows, nCols, nChannels, nFilters, kernelSize, strides, n) { + signal input in[nRows][nCols][nChannels]; + signal input weights[kernelSize][kernelSize][nChannels][nFilters]; + signal input bias[nFilters]; + + var rowPadding, colPadding; + + if (nRows % strides == 0) { + rowPadding = (kernelSize - strides) > 0 ? (kernelSize - strides) : 0; + } else { + rowPadding = (kernelSize - (nRows % strides)) > 0 ? (kernelSize - (nRows % strides)) : 0; + } + + if (nCols % strides == 0) { + colPadding = (kernelSize - strides) > 0 ? (kernelSize - strides) : 0; + } else { + colPadding = (kernelSize - (nCols % strides)) > 0 ? (kernelSize - (nCols % strides)) : 0; + } + + signal input out[(nRows+rowPadding-kernelSize)\strides+1][(nCols+colPadding-kernelSize)\strides+1][nFilters]; + signal input remainder[(nRows+rowPadding-kernelSize)\strides+1][(nCols+colPadding-kernelSize)\strides+1][nFilters]; + + component conv2d = Conv2D(nRows+rowPadding, nCols+colPadding, nChannels, nFilters, kernelSize, strides, n); + + for (var i = rowPadding\2; i < rowPadding\2+nRows; i++) { + for (var j = colPadding\2; j < colPadding\2+nCols; j++) { + for (var k = 0; k < nChannels; k++) { + conv2d.in[i][j][k] <== in[i-rowPadding\2][j-colPadding\2][k]; + } + } + } + + for (var i = 0; i< rowPadding\2; i++) { + for (var j = 0; j < nCols+colPadding; j++) { + for (var k = 0; k < nChannels; k++) { + conv2d.in[i][j][k] <== 0; + } + } + } + + for (var i = nRows+rowPadding\2; i < nRows+rowPadding; i++) { + for (var j = 0; j < nCols+colPadding; j++) { + for (var k = 0; k < nChannels; k++) { + conv2d.in[i][j][k] <== 0; + } + } + } + + for (var i = rowPadding\2; i < nRows+rowPadding\2; i++) { + for (var j = 0; j < colPadding\2; j++) { + for (var k = 0; k < nChannels; k++) { + conv2d.in[i][j][k] <== 0; + } + } + } + + for (var i = rowPadding\2; i < nRows+rowPadding\2; i++) { + for (var j = nCols+colPadding\2; j < nCols+colPadding; j++) { + for (var k = 0; k < nChannels; k++) { + conv2d.in[i][j][k] <== 0; + } + } + } + + for (var i = 0; i < kernelSize; i++) { + for (var j = 0; j < kernelSize; j++) { + for (var k = 0; k < nChannels; k++) { + for (var l = 0; l < nFilters; l++) { + conv2d.weights[i][j][k][l] <== weights[i][j][k][l]; + } + } + } + } + + for (var i = 0; i < nFilters; i++) { + conv2d.bias[i] <== bias[i]; + } + + for (var i = 0; i < (nRows+rowPadding-kernelSize)\strides+1; i++) { + for (var j = 0; j < (nCols+colPadding-kernelSize)\strides+1; j++) { + for (var k = 0; k < nFilters; k++) { + conv2d.out[i][j][k] <== out[i][j][k]; + conv2d.remainder[i][j][k] <== remainder[i][j][k]; + } + } + } +} \ No newline at end of file diff --git a/models/conv2Dsame_input.json b/models/conv2Dsame_input.json new file mode 100644 index 0000000..6bdf785 --- /dev/null +++ b/models/conv2Dsame_input.json @@ -0,0 +1 @@ +{"in": [[[573578318342437025844851866425360384, 796897075929102617762369594491666432, 544813853709568530836070618837811200], [950964808512243091948221705113042944, 63020991177631942013592578962751488, 534246864759300757335855534585151488], [145553273162883481564471617673756672, 168014069645222242011527365934448640, 918377427067427353453418389608333312], [140426938484365606594402473039888384, 264667746275250807784404837319311360, 75491098918846425947581327949692928], [193541506598040702303784498423136256, 897966980662074461603530364592062464, 180806663312596539664229464283283456]], [[848938644943171838323721255004405760, 946893746765754170995422226555076608, 768897933131184028070406971585462272], [912801056444623171008965891011379200, 812119034248563567653941176875089920, 506148778166511590361123171143254016], [726299972329744035058453094326075392, 96207737897536896527190301501030400, 50742547072808897124511004495970304], [301789259811556599819892346689945600, 485486502032339103829549531168505856, 464866287179084526270412742325174272], [875588607670153614680019680937115648, 932891851707191663462868731918548992, 339653173344804033759485646264598528]], [[237265950183872320398504585615900672, 92539353472303921531909154417934336, 85988770955032665647971390548082688], [244847082330976307871225285122719744, 8112524408893539994207368429699072, 236428844273952725283544611891970048], [659751238345296368251004385155350528, 636339039183294843043782526962237440, 825077716305557216991914246402998272], [531007310948199867694301878014705664, 724330536345950423826530223955902464, 663737513567450936766727588333748224], [439166246656338740466881160898674688, 743472567205232394663525661517283328, 677226299867745725736978696823111680]], [[874121934955878104176816691842383872, 538534223430679464327474931591806976, 347726279820595664077109546547412992], [773851100708451725315764067596500992, 975657286487747620319763447391715328, 944268611874408052782348044693667840], [215372314814578868225208595744030720, 657746228084012664949295522158477312, 894050164467754304522655814498910208], [741111797690731997653584571294285824, 687926089998044122195818245588516864, 327261898856574746857668546142404608], [448878343890703267942854989155663872, 924485996489381566675746159104884736, 482698411338606487321031874749399040]], [[139523373279942021389818047177424896, 796598027167266997218949056595755008, 976033348284549093920013784130256896], [660994593288569905303091763568705536, 69341432375643946894064073402482688, 998540588907763631143940943779987456], [316093678592279445721052906812080128, 495961038159453158592866196205338624, 937970692966970988133652679045939200], [49413224909801158568130981876727808, 247095544932269935825119489411776512, 583844162475215139440723570525732864], [715278042019727128024286358891659264, 489768637551234821572637878669803520, 987635686311730664114239575842357248]]], "weights": [[[[-223889976739883438974480414821842944, 182179450988769526980719757565624320], [-312540173530578612367366987190370304, 139334559440612791523818767365701632], [-144911438226699839338211669549514752, 11686950922012330194202493232087040]], [[-184420511126518259849192222901665792, 166687548160552971364293392512057344], [-150520324707031259246079775103516672, -127768903970718397153279061885190144], [-176604077219963084012262368328286208, -255772858858108525284918825947496448]], [[-327348411083221455632540069993119744, 65518230199813848675511947019419648], [23358047008514404274419529662070784, 305796265602111814573319260595027968], [50124526023864744999691541270233088, 51966935396194458503048216038080512]]], [[[-152963533997535716062692648819884032, -286536753177642827878558535144964096], [317600071430206318641788425301131264, 275960147380828863380341414965542912], [61409443616867069727730049321271296, 284595370292663576978075748665393152]], [[-271937012672424322146439224864276480, 302216708660125731526730788911448064], [347241044044494658554346293661007872, 350840210914611820177063825035493376], [-261212855577468885938913667830513664, -104776054620742800821305038629502976]], [[183119535446167002871030033657364480, 156187891960144058736559209685450752], [216080427169799798654929942775595008, 77634811401367193379228793098469376], [-176951438188552874579061318755549184, -299460560083389290857424393370337280]]], [[[218642294406890891751406281464741888, 238958597183227540631916579255222272], [-297664105892181404654557333611872256, -300561219453811641381297073246175232], [-357277274131774953904931896050057216, -363140106201171886485940135356530688]], [[-42309552431106571048104260356014080, -152798458933830258622761475474718720], [68972945213317871191629614314160128, -81471890211105347292681833950478336], [-303670346736907976036272459671928832, 212895214557647712730549686833774592]], [[-292076885700225858726145735778107392, -227672487497329711878256654156824576], [-290411472320556670589050515393347584, -240332305431365987041979825326653440], [-204085990786552442503615118510653440, -230043768882751482065503977736241152]]]], "bias": [0, 0], "out": [[["-738206413373029545006965393527124494", "-173320559228639935825945330543413025"], ["-1034698083357167138897799627355946893", "-434089078628623353676508551211157060"], ["-830734559502817036444813835112148949", "-686578771866623630758607927275546159"], ["-285424960673894669036410085848276078", "77035899885765173133597720803628368"], ["-41293201078409047199874417933080796", "28777494655737863382696013669386951"]], [["-606713676308333020774846244444307246", "388466467207809267549922375129065936"], ["-1102737485952039787132334772316590877", "711672842283067459049106389539428334"], ["-1279246107747464951673577203827572661", "-104057682619248993879676863654015882"], ["-1137502411882990109129978808062308003", "-285737283216816113761016693211231047"], ["-703798773277911056745141956845225462", "355407338114374779173330249858051758"]], [["-1546951103062535866431703941066730304", "-509754977574520651977576289008629976"], ["-2007190678249673958423804906616390874", "-255124530399913653884258580354194679"], ["-1987955686599915815112948301772925117", "92785353760923841311068732708486389"], ["-1501354303867896914042548115076362343", "-140005761885967356612636419407991713"], ["-758589315783580773280487949446211355", "437520880361906870599644364227465106"]], [["-761417923444182653301008173638043967", "91771720385437306243354293205033102"], ["-1688165103880073226811051249397561256", "-253090840416308272240635021911178871"], ["-868411975706764838263900420928401111", "427029536451544469306217379046517378"], ["-1632207770832291707418877818442334628", "313343146963309014974813849350304619"], ["-1123607486465218134583708408995464790", "277056825693143901704758542110909098"]], [["-543826042689234949496060048498817249", "415061953419939109346123158429653654"], ["-1006898403524974693098527575437710832", "646952095929760474345606121011841547"], ["-1339156364320875137533821288468963984", "424307047679243230200796457433761030"], ["-633848388019879384336947814123500320", "638133921526842615943818078787826079"], ["-911192101070039979166490878519403558", "572640728975516234503728848776059996"]]], "remainder": [[["732636700329184288172821299452706816", "163633766656393366871975061931163648"], ["106376739450670706774435272610283520", "224263678919132314456428282264944640"], ["302563808180207575103565765357338624", "724302115124451087131827274648649728"], ["414653571797565705485911652130881536", "793412538503882240528593488489480192"], ["802744020509610387937821108140507136", "304480836101726291197711385865748480"]], [["859770729700736535878340696623546368", "892305185639885971968166660350672896"], ["159017773966938546227144022541991936", "243286973980475632014665402361577472"], ["45707991996386548824583584683655168", "305057405020119480420035216424304640"], ["513655874550133048911800632902418432", "966709976332017469863264280209522688"], ["198125756383956075595551934970331136", "988945351929415067553212891893071872"]], [["787267116688540528237625132699353088", "820149329025927193528840575260819456"], ["472980868351632170335228269959315456", "763189020330031063201267188022378496"], ["644197050565810210754859371753111552", "989731482915520051484045079570546688"], ["948204822800878014118539586611183616", "748247265754602884949081086979211264"], ["716169397100995271731495802811449344", "485403242675801265137341039288254464"]], [["43352138962814326789218918981435392", "405166142631707980272441068569493504"], ["810091229170266512829629243259355136", "269653910662555626980942633563586560"], ["676839398629154231820514533721505792", "419805522694742694500594416384737280"], ["745383827341964610018052023030644736", "256049836598014041054468595098058752"], ["675940914936064095590762595072606208", "261569026192941931124614732445646848"]], [["945584600952117166014909825389953024", "419123852798410471873927944123449344"], ["610870231548665514822003410375540736", "444338394948450192884055088832708608"], ["392811864299264176482313886100357120", "482398655357498430014877178936688640"], ["865948833234167064622566980785799168", "886966130974971444248685772183437312"], ["195220905237827872879130999353507840", "356250510761353797733544351900893184"]]]} \ No newline at end of file diff --git a/models/conv2Dsame_stride_input.json b/models/conv2Dsame_stride_input.json new file mode 100644 index 0000000..1fbe20f --- /dev/null +++ b/models/conv2Dsame_stride_input.json @@ -0,0 +1 @@ +{"in": [[[159171310144660808849580245813035008, 773132349533750995129683302059868160, 707323469667019518938139499522162688], [852102363732118002231621956635983872, 777470553936418897668135401853812736, 912675696321925965137123394623373312], [934790562565718672076246643409682432, 819227003366986408699695570832326656, 701561885411571628821689409296400384], [791437938233053393796499443462176768, 878015911921594657747068833174126592, 829049758023987011959747747652829184], [686192480082645202759568915251068928, 529391849605564071577133900208537600, 978137403467643333173402204660301824], [607195729884374240887191922974654464, 899253143512219833045747655321321472, 769459611003872686456829926268993536], [875525571302137082460949151326666752, 905349315078055472853261931732008960, 878899508021791928128005251081961472], [390668195900968457628571893929869312, 955806155875142623557665841800544256, 317897360852441809853589055401361408], [764772770358354077326876658708250624, 817542251768019748791085333628321792, 698396679207362476625572468208173056], [830918087477103471651109848534220800, 89290468797192929331595843706290176, 416900244375856390386138320165928960]], [[21564038529804373679738784862699520, 124834200298711062016947601604083712, 101019143307703019436928302103658496], [898013986088237561613604320628965376, 267533414530880335665019949661814784, 74417362171673606182845737018064896], [22083112884501379488001495478042624, 639575094022533000396282831445164032, 897414753683917650316929906337906688], [340537839125285361591543938290810880, 266943577771092040721646423596072960, 757156053400529010228232096084656128], [537806377277787614683130425904201728, 655654358648736392403875180224446464, 915044555322720651288679744498827264], [259388949621496418498438216479145984, 931645919439548083830207673505153024, 895087706828045931556592301526482944], [596732492154831625943755905558380544, 835791249437404265312162866603753472, 437188744507928882882969449122496512], [151597991101908310079829963626774528, 745592633963059161117082720290209792, 758908391516642904228756601385779200], [149609639851604149416690222781956096, 729812486195530257578258311986282496, 373825804678432385839397749540257792], [776094384318571624485778405969625088, 611521452399336332790010979760996352, 267437272791195900907136051381796864]], [[863862767916225052353724300925075456, 88544579529258436157439732000751616, 188118101124087351290782324680032256], [446059005695225165907537529243959296, 536203199834733414416414436419960832, 359935226501319065405942353459412992], [301536013319989483031781601422868480, 205269198076443416467102059907252224, 946866616612165410069994608649043968], [192584278660269718395576300139970560, 295457430611532317957440637465788416, 671233220110806728196724913604657152], [640595754071868803627299787939774464, 256410938657877796940612789040316416, 119539900696204239084647625921658880], [384411467201420039389480265367683072, 121478834971778701604687264855621632, 396612766212281618100259637910568960], [580686779548411450479678600201109504, 991963370723973300747577055137234944, 809429714965597730205260671661113344], [428545066666959259662110155864014848, 580751051395154140581414143019450368, 178981822176542671348342480898621440], [598358610693843031133127027447562240, 539167349035998519813866215096451072, 582694232161625227416499626199482368], [734422977030743513595009069609910272, 501038076695012147750987349589229568, 556188133153858056660925366493249536]], [[702200348743891556346431078362251264, 652997838617521173858993502498586624, 67209058429937072775758631386742784], [984395860996606161260820222214930432, 492179284708902035797813867640258560, 291107573267516256557980242174541824], [403627013783817373797669717097316352, 312145515851412930021869923847897088, 338772566094994630799169102813331456], [630588554016737296289539144655306752, 357604237434804951770146663032684544, 469401530992174285529297582827765760], [138041373084244255960706896025878528, 60588519870016682601848027628961792, 550805087401687595763844395378409472], [760853085445330442842943058471813120, 656032328313485905607989846863773696, 768557894265301485249897698619817984], [821093321019202119723545920654016512, 286252984195685255392690265103794176, 519872203895423426518966012645212160], [529348909548683653879549703728660480, 986768232460898305415437734390005760, 451723514223614441254787041567178752], [267262092856670156971402065367531520, 856575647503531993546154138069368832, 959552364650536590420913003090149376], [191107279732280492349793229829308416, 502061473520169749726015337560801280, 50494109948060365746026987468619776]], [[18808717983980206895196932841406464, 254287514038177636506092848881336320, 327725135154048551146988188467200000], [130366296016952274873916723105366016, 758043671818217346423988798095360000, 532688950706481809635431743265701888], [377119572966067932922363569974018048, 722742328676682520778605144775327744, 950323990740689614438198775816126464], [75638362164763833168142373710987264, 933593105751863046072814640939663360, 157961248884251352768311150115815424], [620451461806050305723568341412478976, 645809227407024634304977328274407424, 155771399662181468838141989576769536], [789056082206593931641883359273025536, 854326120412576705713617568327008256, 874259036097889242656297718747496448], [803166263932878542153840428914311168, 519244233175028819680988230020235264, 458010544377194781973471921329668096], [643719027521253884633857475838214144, 556730767949949636894773862949978112, 854552732852505205188726040960172032], [425224506827976449106749294666317824, 757961774619209190824436328566882304, 737943324614926858727277962740105216], [904213200705842527835521253780750336, 944325262588247072676177835391451136, 841800716485828935966639421481025536]], [[773463261740755853127912758991388672, 848115493423481102740754561979383808, 978974572529443839055880731106476032], [922913900875078572447311588170399744, 874559503928025891318344541248946176, 811512752285126468564185300594589696], [314666833229342854731036598470705152, 229687882301993354339637203180716032, 329379464831743437379118898905874432], [322652551036862740972065841420959744, 608666667742671552200498560641269760, 749665505533932829126297826923380736], [896906814466416550109120195475800064, 226520978274841327888299645391077376, 387904766529772378907547272510177280], [733309623485673413959338398601707520, 607140742447717433160663575914610688, 535811878316594995055434527284396032], [492161996367537150177929024268926976, 424094296988999149850962826866196480, 778365695565463747773853572999938048], [18685501489869694408816769978335232, 697528483877688107260945765163335680, 529086277433752370617291023226765312], [842730958769206212992348160421527552, 306583962181422716210299244476628992, 704657794095505719489299528441397248], [868489499794575228872368425707503616, 69292785658408905417793649007132672, 735603567186071645740498485227028480]], [[411835245401874146288689286993674240, 63093876957170394765242288360652800, 30125474068855687334847013508349952], [47898551210630023534675299724689408, 395286843987227286738352815980150784, 142389642809073746311824653934395392], [378925871041476690058896884415594496, 736187443624994741164799438824144896, 704101958222426155856918732871827456], [544057061423966935500290576806313984, 76722441097137107247433472813301760, 107983530547514303558693274960003072], [105305932636806433899618069198667776, 463457679608919028779684429356335104, 952313571557659357320751329342128128], [146407454211864751557052352140673024, 477926458677983782144977297519673344, 695505596198805907915560532089765888], [396335018951781689426907135553306624, 682474926483070610463192138817470464, 715400495541909583375446260412579840], [893540672225323490376299871163383808, 928896690781682494075139466775756800, 185249828253696531963515773846028288], [856847364053297769527550630835519488, 405465737389540321506823852116672512, 89021309521015052445769604825874432], [456947385499195124497525841290330112, 577207394497325956869286953452830720, 105786208243838112178934871027089408]], [[367855942936817053921362907381104640, 176672262085555306737880348055568384, 464071887893671477287367896794660864], [116402231320876472626503162108313600, 375875361115548664981136574028709888, 933145224485535540819219628465061888], [38401408078140412640617005829849088, 731180100213754321171207347108839424, 236030652926759744069562680737792000], [541170323502511751457995511073603584, 857146589619748529838995352934416384, 49566101561443918519889524655390720], [84947664129841894843931294309548032, 879664379592867153257508802396160000, 376501641192181031519011817020981248], [950114461962523956189408245854502912, 233337706159432955866156639226691584, 268898205073523805060371020785909760], [861363170585036940315451261829775360, 160721383842972743559881882242908160, 33232756717677226032412655471296512], [316582152649938060777866434604695552, 256750170883927701266108415251316736, 592403915852235345024158264423612416], [78671279274685316792652259475324928, 731613373249635939032367289824444416, 960396190648122212981943454599741440], [954769362582709934851177303265771520, 683300056439137029178775913709436928, 995818754886710095093520475143274496]], [[735411549866786756187555437492568064, 304825947217761348503602789926043648, 111670079167642491513239918302199808], [203158059766762154373258229982429184, 459390575652981811753132273369612288, 908830045767803924875862479525117952], [620374479504696214033598465068498944, 77815084176558135136827457126006784, 859933737393714903517783116656672768], [166250199689153464356176867828432896, 961408049131634435847760364921946112, 474148500157170171022042972555313152], [106829368479589347217193506430779392, 361840275096540704991460749580697600, 522630436013246029512953889720107008], [556620282938957167644055746868412416, 18364737317944967742377462847242240, 907875425950277228530340490089857024], [799307069158986995452920090939359232, 760487543584898339036066641206050816, 121069023123069663480436974598225920], [454913287839619843212190048166346752, 940615784615227265816534405900926976, 487578147212192383989562821012619264], [129182046507286487548703627683037184, 188069728234344095226226224947789824, 588129250066928883926851255103651840], [804310000030451229043176488868773888, 972997973676414779862508689946574848, 56322229310840301412348130804891648]], [[339113685680049939956886658992832512, 994905304820387191672371611634237440, 614615191900327976029962497572208640], [783846667327171751349981000513880064, 733168640758945697847499404055937024, 859133505808499923873612735329599488], [384921658881441263030006217340616704, 563174319474954280737024679836909568, 596377152786360889140039145240920064], [360797302064305871977033076425359360, 53327510861837405562767438260469760, 590087793922963068549105774137180160], [802257019978236918332755683531816960, 889270177009243942034290310945701888, 876531146240051695778026776745213952], [621738668867005515896695114937925632, 621842831084094367828774914824863744, 907954772908133753374247235828580352], [158948136012292747176688172077678592, 546618121181380033665544842135994368, 831199928763794091567729364711768064], [448613764984414906653857167209136128, 281226788055348949658305905213571072, 371934764268284183797255211050860544], [407477858563138916961821337799622656, 244978883579953117621809365869658112, 551744085472183724850412658205130752], [318850102262657320675197938885984256, 477131739769685265614951836597354496, 236148569978969846770138073817677824]]], "weights": [[[[89226394891738890234911678693113856, -12069195508956909805469919556403200], [95124244689941404601266147262726144, 79700708389282227598662739918061568], [-106933251023292538179271866657538048, 11905372142791749578016705977778176]], [[55349647998809820200052223899598848, -271235316991806036467764210327093248], [-238274306058883681980267326334828544, 169450312852859514254862271951929344], [-20704418420791625709047262469947392, 140111625194549568484037170722504704]], [[-263449639081954977017005765203853312, -147867843508720407110009087022071808], [200233995914459233480026503151353856, 161291241645812997886514343272513536], [-149971544742584236450439666270732288, -130827650427818305114516209891540992]], [[-52112340927124021121563694836219904, 244466841220855721196137403973107712], [-186103269457817095661514717466198016, -43945550918579104427517743785312256], [-265386432409286512976910198385934336, 45939624309539797436489419773181952]]], [[[257476031780242924029161813183561728, -32398566603660583017571197273505792], [-170099467039108269311056468154253312, 189800709486007701861330391150886912], [151705473661422733134109989247385600, -238412275910377518524011487522979840]], [[137879759073257460334671033840697344, -91348052024841318713015880931868672], [1190751791000366328049568246136832, -7973134517669678528123717449416704], [-247020125389099141035953763018342400, -138846039772033690452985462384492544]], [[108750581741333008843753623014866944, -150537505745887763703861208831492096], [-24231642484664917393664667916697600, -259205669164657620977990293757362176], [-232033744454383871339490748072460288, 61730682849884034139412772318871552]], [[57480067014694213873351683819438080, -265266835689544675693196851118342144], [-212788581848144545712612755636224, -32541170716285705180684522054746112], [-162627667188644413163585500860121088, -204299226403236406463255315503644672]]], [[[164101541042327882549652201981083648, -177933335304260261850911328062930944], [195380657911300683379160585536012288, 183013111352920553784555019331174400], [-132887929677963254361824189963304960, 11586606502532959939431539016728576]], [[-101354628801345823451981722907836416, 25569826364517212868326762638999552], [14838755130767823797406468795269120, -228026866912841809950725657554059264], [-170267224311828613114478839069671424, -205271184444427482458750124522733568]], [[30023515224456786306858612091256832, -259240210056304931055460476105261056], [-37529796361923219296861458570149888, 269733667373657233818009734511329280], [18460065126419068849695306035494912, 202887982130050653568544556760170496]], [[28933763504028319292346184807481344, 154086738824844362218073231273754624], [-189923465251922623429434688065765376, 121185868978500370751786176259031040], [255532920360565182606987333287804928, -162540823221206671424084376215879680]]], [[[-124473601579666134921001625045696512, 76049178838729864185487815664992256], [-3660440444946289093542198898589696, 23682653903961180701787908336517120], [229330360889434828521677277626368000, 256998419761657743361176908213518336]], [[-261349916458129892707963116732809216, -253867387771606472659783285448114176], [-79969659447669982856380253042900992, 239941000938415552394053731777249280], [-145047560334205626656667162447446016, -71907922625541691614536510669324288]], [[-268330514430999779867853655185031168, 57022869586944586913915234666151936], [-117865175008773801753807872529006592, -186005204916000378815711281185554432], [161419004201889052964298258124898304, 268087506294250485039740796614475776]], [[-181928724050521859497024987187904512, -171299442648887629281524525649887232], [-110691815614700329811484264611446784, 157881140708923341042396618492477440], [-113486051559448249076086034197905408, -204881250858306899115212947186515968]]]], "bias": [0, 0], "out": [[["-811051029307312151898080091259013408", "-1472109697883872855085104222700990006"], ["-528539552873451841296887387802887946", "-479910753097361964472240108640655029"], ["-859536355090285293603582494722226955", "-670518221637687591501448018045711492"], ["-77675673238639598916850751879067826", "-132028915984577730731287477184728367"]], [["-973791923027617402137438972479106354", "-529895757470201620182206632643141082"], ["-1068758558577939368141318995900977305", "-330698196577790148366400564712521445"], ["-968582927650604267027127767242160672", "-615391447297339813950840270560399175"], ["-349524047948521414307381813385449474", "-394712802987869376369174007174918150"]], [["-1195008336948108945640918772722080227", "-391453675833271485890705066425298906"], ["-1035539474474336745928350434338370235", "375449144451453718974060701074107917"], ["-468574096206261610672918401549766744", "-461492881731804944422419395904964084"], ["-157987822858060281896825874562638157", "-145297327207434479738146166840718303"]], [["-717832718354692660291274590189907309", "-640793387586819749749055661801718428"], ["-818234976163909363095633653303124629", "-486529556567947364827511906932334658"], ["-488975320390965478682141686514332960", "-593549900211690891912136829366038418"], ["-88817120440234533520269026413841619", "-188924963686509380494214483287084738"]]], "remainder": [[["57546413369690632436021828010377216", "844338365610010970536479881582084096"], ["713605335486457688719381102919155712", "375215677051519710413966252668092416"], ["755914587195177990836477246220271616", "818027579922420637827386176052396032"], ["66541310804045710582849369230278656", "347409709477862532129959056406216704"]], [["289886146631823524117043901901045760", "747073207457187946361744580709187584"], ["596964332285599153985153944246026240", "55475850457661407668048488535425024"], ["170834380408647470600498462609899520", "509441462772270674298218142377181184"], ["34071275330636135911467390225874944", "138225886497714127955769149616029696"]], [["799025578005122329965199425850572800", "638105259408657768502010079451545600"], ["646664651333832093639422680485593088", "675100193761157594649844145046159360"], ["445458083786331341876221205003894784", "502163711096538860296377604167958528"], ["47026304856349756806446693973229568", "831593504386261794099030491582693376"]], [["939137107020160348118700928203227136", "113237806886600667494639994056736768"], ["659213932223174550761216604282814464", "197300411875572149344585017353830400"], ["421437083594697043343248459305058304", "120269150227163283149044634810843136"], ["198117022230724190125402300398698496", "310521369112965179461249563143176192"]]]} \ No newline at end of file diff --git a/models/conv2d_same.ipynb b/models/conv2d_same.ipynb new file mode 100644 index 0000000..36caac5 --- /dev/null +++ b/models/conv2d_same.ipynb @@ -0,0 +1,803 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow.keras.layers import Input, Conv2D\n", + "from tensorflow.keras import Model\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = Input(shape=(5,5,3))\n", + "x = Conv2D(2, 3, padding=\"same\")(inputs)\n", + "model = Model(inputs, x)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n", + " \n", + " conv2d (Conv2D) (None, 5, 5, 2) 56 \n", + " \n", + "=================================================================\n", + "Total params: 56\n", + "Trainable params: 56\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[,\n", + " ]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.weights" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[0.57357832, 0.79689708, 0.54481385],\n", + " [0.95096481, 0.06302099, 0.53424686],\n", + " [0.14555327, 0.16801407, 0.91837743],\n", + " [0.14042694, 0.26466775, 0.0754911 ],\n", + " [0.19354151, 0.89796698, 0.18080666]],\n", + "\n", + " [[0.84893864, 0.94689375, 0.76889793],\n", + " [0.91280106, 0.81211903, 0.50614878],\n", + " [0.72629997, 0.09620774, 0.05074255],\n", + " [0.30178926, 0.4854865 , 0.46486629],\n", + " [0.87558861, 0.93289185, 0.33965317]],\n", + "\n", + " [[0.23726595, 0.09253935, 0.08598877],\n", + " [0.24484708, 0.00811252, 0.23642884],\n", + " [0.65975124, 0.63633904, 0.82507772],\n", + " [0.53100731, 0.72433054, 0.66373751],\n", + " [0.43916625, 0.74347257, 0.6772263 ]],\n", + "\n", + " [[0.87412193, 0.53853422, 0.34772628],\n", + " [0.7738511 , 0.97565729, 0.94426861],\n", + " [0.21537231, 0.65774623, 0.89405016],\n", + " [0.7411118 , 0.68792609, 0.3272619 ],\n", + " [0.44887834, 0.924486 , 0.48269841]],\n", + "\n", + " [[0.13952337, 0.79659803, 0.97603335],\n", + " [0.66099459, 0.06934143, 0.99854059],\n", + " [0.31609368, 0.49596104, 0.93797069],\n", + " [0.04941322, 0.24709554, 0.58384416],\n", + " [0.71527804, 0.48976864, 0.98763569]]]])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.random.rand(1,5,5,3)\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/1 [==============================] - 0s 51ms/step\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-04 01:25:53.532541: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[[[-0.7382064 , -0.17332056],\n", + " [-1.034698 , -0.43408912],\n", + " [-0.83073455, -0.6865788 ],\n", + " [-0.28542495, 0.0770359 ],\n", + " [-0.04129319, 0.02877748]],\n", + "\n", + " [[-0.60671365, 0.38846642],\n", + " [-1.1027374 , 0.71167284],\n", + " [-1.2792461 , -0.10405768],\n", + " [-1.1375024 , -0.28573734],\n", + " [-0.7037988 , 0.3554073 ]],\n", + "\n", + " [[-1.5469512 , -0.50975496],\n", + " [-2.0071907 , -0.25512454],\n", + " [-1.9879558 , 0.09278526],\n", + " [-1.5013543 , -0.14000578],\n", + " [-0.7585893 , 0.43752092]],\n", + "\n", + " [[-0.7614179 , 0.09177176],\n", + " [-1.688165 , -0.25309083],\n", + " [-0.8684121 , 0.42702955],\n", + " [-1.6322078 , 0.3133433 ],\n", + " [-1.1236074 , 0.2770568 ]],\n", + "\n", + " [[-0.54382604, 0.415062 ],\n", + " [-1.0068984 , 0.6469521 ],\n", + " [-1.3391564 , 0.42430705],\n", + " [-0.6338484 , 0.63813394],\n", + " [-0.9111921 , 0.5726407 ]]]], dtype=float32)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y = model.predict(X)\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "X_in = [[[int(X[0][i][j][k]*1e36) for k in range(3)] for j in range(5)] for i in range(5)]\n", + "weights = [[[[int(model.weights[0].numpy()[i][j][k][l]*1e36) for l in range(2)] for k in range(3)] for j in range(3)] for i in range(3)]\n", + "bias = [int(model.weights[1].numpy()[i]*1e72) for i in range(2)]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def Conv2DInt(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):\n", + " out = [[[0 for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]\n", + " remainder = [[[None for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]\n", + " for i in range((nRows - kernelSize)//strides + 1):\n", + " for j in range((nCols - kernelSize)//strides + 1):\n", + " for m in range(nFilters):\n", + " for k in range(nChannels):\n", + " for x in range(kernelSize):\n", + " for y in range(kernelSize):\n", + " out[i][j][m] += int(input[i*strides+x][j*strides+y][k])*int(weights[x][y][k][m])\n", + " out[i][j][m] += int(bias[m])\n", + " remainder[i][j][m] = str(out[i][j][m] % n)\n", + " out[i][j][m] = str(out[i][j][m] // n)\n", + " return out, remainder" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def Conv2DsameInt(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):\n", + " if nRows % strides == 0:\n", + " rowPadding = max(kernelSize - strides, 0)\n", + " else:\n", + " rowPadding = max(kernelSize - nRows % strides, 0)\n", + " if nCols % strides == 0:\n", + " colPadding = max(kernelSize - strides, 0)\n", + " else:\n", + " colPadding = max(kernelSize - nCols % strides, 0)\n", + " \n", + " _input = [[[0 for _ in range(nChannels)] for _ in range(nCols + colPadding)] for _ in range(nRows + rowPadding)]\n", + "\n", + " for i in range(nRows):\n", + " for j in range(nCols):\n", + " for k in range(nChannels):\n", + " _input[i+rowPadding//2][j+colPadding//2][k] = input[i][j][k]\n", + " \n", + " out, remainder = Conv2DInt(nRows + rowPadding, nCols + colPadding, nChannels, nFilters, kernelSize, strides, n, _input, weights, bias)\n", + " return out, remainder" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([[['-738206413373029545006965393527124494',\n", + " '-173320559228639935825945330543413025'],\n", + " ['-1034698083357167138897799627355946893',\n", + " '-434089078628623353676508551211157060'],\n", + " ['-830734559502817036444813835112148949',\n", + " '-686578771866623630758607927275546159'],\n", + " ['-285424960673894669036410085848276078',\n", + " '77035899885765173133597720803628368'],\n", + " ['-41293201078409047199874417933080796',\n", + " '28777494655737863382696013669386951']],\n", + " [['-606713676308333020774846244444307246',\n", + " '388466467207809267549922375129065936'],\n", + " ['-1102737485952039787132334772316590877',\n", + " '711672842283067459049106389539428334'],\n", + " ['-1279246107747464951673577203827572661',\n", + " '-104057682619248993879676863654015882'],\n", + " ['-1137502411882990109129978808062308003',\n", + " '-285737283216816113761016693211231047'],\n", + " ['-703798773277911056745141956845225462',\n", + " '355407338114374779173330249858051758']],\n", + " [['-1546951103062535866431703941066730304',\n", + " '-509754977574520651977576289008629976'],\n", + " ['-2007190678249673958423804906616390874',\n", + " '-255124530399913653884258580354194679'],\n", + " ['-1987955686599915815112948301772925117',\n", + " '92785353760923841311068732708486389'],\n", + " ['-1501354303867896914042548115076362343',\n", + " '-140005761885967356612636419407991713'],\n", + " ['-758589315783580773280487949446211355',\n", + " '437520880361906870599644364227465106']],\n", + " [['-761417923444182653301008173638043967',\n", + " '91771720385437306243354293205033102'],\n", + " ['-1688165103880073226811051249397561256',\n", + " '-253090840416308272240635021911178871'],\n", + " ['-868411975706764838263900420928401111',\n", + " '427029536451544469306217379046517378'],\n", + " ['-1632207770832291707418877818442334628',\n", + " '313343146963309014974813849350304619'],\n", + " ['-1123607486465218134583708408995464790',\n", + " '277056825693143901704758542110909098']],\n", + " [['-543826042689234949496060048498817249',\n", + " '415061953419939109346123158429653654'],\n", + " ['-1006898403524974693098527575437710832',\n", + " '646952095929760474345606121011841547'],\n", + " ['-1339156364320875137533821288468963984',\n", + " '424307047679243230200796457433761030'],\n", + " ['-633848388019879384336947814123500320',\n", + " '638133921526842615943818078787826079'],\n", + " ['-911192101070039979166490878519403558',\n", + " '572640728975516234503728848776059996']]],\n", + " [[['732636700329184288172821299452706816',\n", + " '163633766656393366871975061931163648'],\n", + " ['106376739450670706774435272610283520',\n", + " '224263678919132314456428282264944640'],\n", + " ['302563808180207575103565765357338624',\n", + " '724302115124451087131827274648649728'],\n", + " ['414653571797565705485911652130881536',\n", + " '793412538503882240528593488489480192'],\n", + " ['802744020509610387937821108140507136',\n", + " '304480836101726291197711385865748480']],\n", + " [['859770729700736535878340696623546368',\n", + " '892305185639885971968166660350672896'],\n", + " ['159017773966938546227144022541991936',\n", + " '243286973980475632014665402361577472'],\n", + " ['45707991996386548824583584683655168',\n", + " '305057405020119480420035216424304640'],\n", + " ['513655874550133048911800632902418432',\n", + " '966709976332017469863264280209522688'],\n", + " ['198125756383956075595551934970331136',\n", + " '988945351929415067553212891893071872']],\n", + " [['787267116688540528237625132699353088',\n", + " '820149329025927193528840575260819456'],\n", + " ['472980868351632170335228269959315456',\n", + " '763189020330031063201267188022378496'],\n", + " ['644197050565810210754859371753111552',\n", + " '989731482915520051484045079570546688'],\n", + " ['948204822800878014118539586611183616',\n", + " '748247265754602884949081086979211264'],\n", + " ['716169397100995271731495802811449344',\n", + " '485403242675801265137341039288254464']],\n", + " [['43352138962814326789218918981435392',\n", + " '405166142631707980272441068569493504'],\n", + " ['810091229170266512829629243259355136',\n", + " '269653910662555626980942633563586560'],\n", + " ['676839398629154231820514533721505792',\n", + " '419805522694742694500594416384737280'],\n", + " ['745383827341964610018052023030644736',\n", + " '256049836598014041054468595098058752'],\n", + " ['675940914936064095590762595072606208',\n", + " '261569026192941931124614732445646848']],\n", + " [['945584600952117166014909825389953024',\n", + " '419123852798410471873927944123449344'],\n", + " ['610870231548665514822003410375540736',\n", + " '444338394948450192884055088832708608'],\n", + " ['392811864299264176482313886100357120',\n", + " '482398655357498430014877178936688640'],\n", + " ['865948833234167064622566980785799168',\n", + " '886966130974971444248685772183437312'],\n", + " ['195220905237827872879130999353507840',\n", + " '356250510761353797733544351900893184']]])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out, remainder = Conv2DsameInt(5, 5, 3, 2, 3, 1, 10**36, X_in, weights, bias)\n", + "out, remainder" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "in_json = {\n", + " \"in\": X_in,\n", + " \"weights\": weights,\n", + " \"bias\": bias,\n", + " \"out\": out,\n", + " \"remainder\": remainder\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"conv2Dsame_input.json\", \"w\") as f:\n", + " json.dump(in_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = Input(shape=(10,10,3))\n", + "x = Conv2D(2, 4, 3, padding=\"same\")(inputs)\n", + "model = Model(inputs, x)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_1\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_2 (InputLayer) [(None, 10, 10, 3)] 0 \n", + " \n", + " conv2d_1 (Conv2D) (None, 4, 4, 2) 98 \n", + " \n", + "=================================================================\n", + "Total params: 98\n", + "Trainable params: 98\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[0.15917131, 0.77313235, 0.70732347],\n", + " [0.85210236, 0.77747055, 0.9126757 ],\n", + " [0.93479056, 0.819227 , 0.70156189],\n", + " [0.79143794, 0.87801591, 0.82904976],\n", + " [0.68619248, 0.52939185, 0.9781374 ],\n", + " [0.60719573, 0.89925314, 0.76945961],\n", + " [0.87552557, 0.90534932, 0.87889951],\n", + " [0.3906682 , 0.95580616, 0.31789736],\n", + " [0.76477277, 0.81754225, 0.69839668],\n", + " [0.83091809, 0.08929047, 0.41690024]],\n", + "\n", + " [[0.02156404, 0.1248342 , 0.10101914],\n", + " [0.89801399, 0.26753341, 0.07441736],\n", + " [0.02208311, 0.63957509, 0.89741475],\n", + " [0.34053784, 0.26694358, 0.75715605],\n", + " [0.53780638, 0.65565436, 0.91504456],\n", + " [0.25938895, 0.93164592, 0.89508771],\n", + " [0.59673249, 0.83579125, 0.43718874],\n", + " [0.15159799, 0.74559263, 0.75890839],\n", + " [0.14960964, 0.72981249, 0.3738258 ],\n", + " [0.77609438, 0.61152145, 0.26743727]],\n", + "\n", + " [[0.86386277, 0.08854458, 0.1881181 ],\n", + " [0.44605901, 0.5362032 , 0.35993523],\n", + " [0.30153601, 0.2052692 , 0.94686662],\n", + " [0.19258428, 0.29545743, 0.67123322],\n", + " [0.64059575, 0.25641094, 0.1195399 ],\n", + " [0.38441147, 0.12147883, 0.39661277],\n", + " [0.58068678, 0.99196337, 0.80942971],\n", + " [0.42854507, 0.58075105, 0.17898182],\n", + " [0.59835861, 0.53916735, 0.58269423],\n", + " [0.73442298, 0.50103808, 0.55618813]],\n", + "\n", + " [[0.70220035, 0.65299784, 0.06720906],\n", + " [0.98439586, 0.49217928, 0.29110757],\n", + " [0.40362701, 0.31214552, 0.33877257],\n", + " [0.63058855, 0.35760424, 0.46940153],\n", + " [0.13804137, 0.06058852, 0.55080509],\n", + " [0.76085309, 0.65603233, 0.76855789],\n", + " [0.82109332, 0.28625298, 0.5198722 ],\n", + " [0.52934891, 0.98676823, 0.45172351],\n", + " [0.26726209, 0.85657565, 0.95955236],\n", + " [0.19110728, 0.50206147, 0.05049411]],\n", + "\n", + " [[0.01880872, 0.25428751, 0.32772514],\n", + " [0.1303663 , 0.75804367, 0.53268895],\n", + " [0.37711957, 0.72274233, 0.95032399],\n", + " [0.07563836, 0.93359311, 0.15796125],\n", + " [0.62045146, 0.64580923, 0.1557714 ],\n", + " [0.78905608, 0.85432612, 0.87425904],\n", + " [0.80316626, 0.51924423, 0.45801054],\n", + " [0.64371903, 0.55673077, 0.85455273],\n", + " [0.42522451, 0.75796177, 0.73794332],\n", + " [0.9042132 , 0.94432526, 0.84180072]],\n", + "\n", + " [[0.77346326, 0.84811549, 0.97897457],\n", + " [0.9229139 , 0.8745595 , 0.81151275],\n", + " [0.31466683, 0.22968788, 0.32937946],\n", + " [0.32265255, 0.60866667, 0.74966551],\n", + " [0.89690681, 0.22652098, 0.38790477],\n", + " [0.73330962, 0.60714074, 0.53581188],\n", + " [0.492162 , 0.4240943 , 0.7783657 ],\n", + " [0.0186855 , 0.69752848, 0.52908628],\n", + " [0.84273096, 0.30658396, 0.70465779],\n", + " [0.8684895 , 0.06929279, 0.73560357]],\n", + "\n", + " [[0.41183525, 0.06309388, 0.03012547],\n", + " [0.04789855, 0.39528684, 0.14238964],\n", + " [0.37892587, 0.73618744, 0.70410196],\n", + " [0.54405706, 0.07672244, 0.10798353],\n", + " [0.10530593, 0.46345768, 0.95231357],\n", + " [0.14640745, 0.47792646, 0.6955056 ],\n", + " [0.39633502, 0.68247493, 0.7154005 ],\n", + " [0.89354067, 0.92889669, 0.18524983],\n", + " [0.85684736, 0.40546574, 0.08902131],\n", + " [0.45694739, 0.57720739, 0.10578621]],\n", + "\n", + " [[0.36785594, 0.17667226, 0.46407189],\n", + " [0.11640223, 0.37587536, 0.93314522],\n", + " [0.03840141, 0.7311801 , 0.23603065],\n", + " [0.54117032, 0.85714659, 0.0495661 ],\n", + " [0.08494766, 0.87966438, 0.37650164],\n", + " [0.95011446, 0.23333771, 0.26889821],\n", + " [0.86136317, 0.16072138, 0.03323276],\n", + " [0.31658215, 0.25675017, 0.59240392],\n", + " [0.07867128, 0.73161337, 0.96039619],\n", + " [0.95476936, 0.68330006, 0.99581875]],\n", + "\n", + " [[0.73541155, 0.30482595, 0.11167008],\n", + " [0.20315806, 0.45939058, 0.90883005],\n", + " [0.62037448, 0.07781508, 0.85993374],\n", + " [0.1662502 , 0.96140805, 0.4741485 ],\n", + " [0.10682937, 0.36184028, 0.52263044],\n", + " [0.55662028, 0.01836474, 0.90787543],\n", + " [0.79930707, 0.76048754, 0.12106902],\n", + " [0.45491329, 0.94061578, 0.48757815],\n", + " [0.12918205, 0.18806973, 0.58812925],\n", + " [0.80431 , 0.97299797, 0.05632223]],\n", + "\n", + " [[0.33911369, 0.9949053 , 0.61461519],\n", + " [0.78384667, 0.73316864, 0.85913351],\n", + " [0.38492166, 0.56317432, 0.59637715],\n", + " [0.3607973 , 0.05332751, 0.59008779],\n", + " [0.80225702, 0.88927018, 0.87653115],\n", + " [0.62173867, 0.62184283, 0.90795477],\n", + " [0.15894814, 0.54661812, 0.83119993],\n", + " [0.44861376, 0.28122679, 0.37193476],\n", + " [0.40747786, 0.24497888, 0.55174409],\n", + " [0.3188501 , 0.47713174, 0.23614857]]]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.random.rand(1,10,10,3)\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/1 [==============================] - 0s 28ms/step\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[[[-0.811051 , -1.4721096 ],\n", + " [-0.52853954, -0.47991082],\n", + " [-0.8595364 , -0.6705183 ],\n", + " [-0.07767564, -0.1320289 ]],\n", + "\n", + " [[-0.9737919 , -0.5298958 ],\n", + " [-1.0687585 , -0.33069813],\n", + " [-0.96858287, -0.61539143],\n", + " [-0.34952405, -0.3947128 ]],\n", + "\n", + " [[-1.1950083 , -0.39145365],\n", + " [-1.0355395 , 0.37544912],\n", + " [-0.46857417, -0.46149284],\n", + " [-0.1579878 , -0.14529735]],\n", + "\n", + " [[-0.7178327 , -0.6407934 ],\n", + " [-0.81823504, -0.48652953],\n", + " [-0.48897535, -0.5935499 ],\n", + " [-0.0888171 , -0.18892495]]]], dtype=float32)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y = model.predict(X)\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "X_in = [[[int(X[0][i][j][k]*1e36) for k in range(3)] for j in range(10)] for i in range(10)]\n", + "weights = [[[[int(model.weights[0].numpy()[i][j][k][l]*1e36) for l in range(2)] for k in range(3)] for j in range(4)] for i in range(4)]\n", + "bias = [int(model.weights[1].numpy()[i]*1e72) for i in range(2)]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([[['-811051029307312151898080091259013408',\n", + " '-1472109697883872855085104222700990006'],\n", + " ['-528539552873451841296887387802887946',\n", + " '-479910753097361964472240108640655029'],\n", + " ['-859536355090285293603582494722226955',\n", + " '-670518221637687591501448018045711492'],\n", + " ['-77675673238639598916850751879067826',\n", + " '-132028915984577730731287477184728367']],\n", + " [['-973791923027617402137438972479106354',\n", + " '-529895757470201620182206632643141082'],\n", + " ['-1068758558577939368141318995900977305',\n", + " '-330698196577790148366400564712521445'],\n", + " ['-968582927650604267027127767242160672',\n", + " '-615391447297339813950840270560399175'],\n", + " ['-349524047948521414307381813385449474',\n", + " '-394712802987869376369174007174918150']],\n", + " [['-1195008336948108945640918772722080227',\n", + " '-391453675833271485890705066425298906'],\n", + " ['-1035539474474336745928350434338370235',\n", + " '375449144451453718974060701074107917'],\n", + " ['-468574096206261610672918401549766744',\n", + " '-461492881731804944422419395904964084'],\n", + " ['-157987822858060281896825874562638157',\n", + " '-145297327207434479738146166840718303']],\n", + " [['-717832718354692660291274590189907309',\n", + " '-640793387586819749749055661801718428'],\n", + " ['-818234976163909363095633653303124629',\n", + " '-486529556567947364827511906932334658'],\n", + " ['-488975320390965478682141686514332960',\n", + " '-593549900211690891912136829366038418'],\n", + " ['-88817120440234533520269026413841619',\n", + " '-188924963686509380494214483287084738']]],\n", + " [[['57546413369690632436021828010377216',\n", + " '844338365610010970536479881582084096'],\n", + " ['713605335486457688719381102919155712',\n", + " '375215677051519710413966252668092416'],\n", + " ['755914587195177990836477246220271616',\n", + " '818027579922420637827386176052396032'],\n", + " ['66541310804045710582849369230278656',\n", + " '347409709477862532129959056406216704']],\n", + " [['289886146631823524117043901901045760',\n", + " '747073207457187946361744580709187584'],\n", + " ['596964332285599153985153944246026240',\n", + " '55475850457661407668048488535425024'],\n", + " ['170834380408647470600498462609899520',\n", + " '509441462772270674298218142377181184'],\n", + " ['34071275330636135911467390225874944',\n", + " '138225886497714127955769149616029696']],\n", + " [['799025578005122329965199425850572800',\n", + " '638105259408657768502010079451545600'],\n", + " ['646664651333832093639422680485593088',\n", + " '675100193761157594649844145046159360'],\n", + " ['445458083786331341876221205003894784',\n", + " '502163711096538860296377604167958528'],\n", + " ['47026304856349756806446693973229568',\n", + " '831593504386261794099030491582693376']],\n", + " [['939137107020160348118700928203227136',\n", + " '113237806886600667494639994056736768'],\n", + " ['659213932223174550761216604282814464',\n", + " '197300411875572149344585017353830400'],\n", + " ['421437083594697043343248459305058304',\n", + " '120269150227163283149044634810843136'],\n", + " ['198117022230724190125402300398698496',\n", + " '310521369112965179461249563143176192']]])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out, remainder = Conv2DsameInt(10, 10, 3, 2, 4, 3, 10**36, X_in, weights, bias)\n", + "out, remainder" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "in_json = {\n", + " \"in\": X_in,\n", + " \"weights\": weights,\n", + " \"bias\": bias,\n", + " \"out\": out,\n", + " \"remainder\": remainder\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"conv2Dsame_stride_input.json\", \"w\") as f:\n", + " json.dump(in_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sklearn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/Conv2Dsame.js b/test/Conv2Dsame.js new file mode 100644 index 0000000..3b2867d --- /dev/null +++ b/test/Conv2Dsame.js @@ -0,0 +1,37 @@ +const chai = require("chai"); +const path = require("path"); + +const wasm_tester = require("circom_tester").wasm; + +const F1Field = require("ffjavascript").F1Field; +const Scalar = require("ffjavascript").Scalar; +exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617"); +const Fr = new F1Field(exports.p); + +const assert = chai.assert; + + + +describe("Conv2Dsame layer test", function () { + this.timeout(100000000); + + it("(5,5,3) -> (5,5,2)", async () => { + const INPUT = require("../models/Conv2Dsame_input.json"); + + const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2Dsame_test.circom")); + + const witness = await circuit.calculateWitness(INPUT, true); + + assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); + }); + + it("(10,10,3) -> (4,4,2)", async () => { + const INPUT = require("../models/Conv2Dsame_stride_input.json"); + + const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2Dsame_stride_test.circom")); + + const witness = await circuit.calculateWitness(INPUT, true); + + assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); + }); +}); \ No newline at end of file diff --git a/test/circuits/Conv2Dsame_stride_test.circom b/test/circuits/Conv2Dsame_stride_test.circom new file mode 100644 index 0000000..86a1f43 --- /dev/null +++ b/test/circuits/Conv2Dsame_stride_test.circom @@ -0,0 +1,5 @@ +pragma circom 2.0.0; + +include "../../circuits/Conv2Dsame.circom"; + +component main = Conv2Dsame(10, 10, 3, 2, 4, 3, 10**36); \ No newline at end of file diff --git a/test/circuits/Conv2Dsame_test.circom b/test/circuits/Conv2Dsame_test.circom new file mode 100644 index 0000000..76725f1 --- /dev/null +++ b/test/circuits/Conv2Dsame_test.circom @@ -0,0 +1,5 @@ +pragma circom 2.0.0; + +include "../../circuits/Conv2Dsame.circom"; + +component main = Conv2Dsame(5, 5, 3, 2, 3, 1, 10**36); \ No newline at end of file