From 526da93d40efffc22b4dbe74c7fea36ad5c4a824 Mon Sep 17 00:00:00 2001 From: "drCathieSo.eth" Date: Mon, 5 Feb 2024 01:33:50 +0800 Subject: [PATCH] Add MaxPooling2Dsame and MaxPooling2Dsame_stride test circuits and their dependencies --- circuits/MaxPooling2Dsame.circom | 73 +++ models/maxPooling2Dsame_input.json | 1 + models/maxPooling2Dsame_stride_input.json | 1 + models/maxPooling2d_same.ipynb | 619 ++++++++++++++++++ test/MaxPooling2Dsame.js | 37 ++ .../MaxPooling2Dsame_stride_test.circom | 5 + test/circuits/MaxPooling2Dsame_test.circom | 6 + 7 files changed, 742 insertions(+) create mode 100644 circuits/MaxPooling2Dsame.circom create mode 100644 models/maxPooling2Dsame_input.json create mode 100644 models/maxPooling2Dsame_stride_input.json create mode 100644 models/maxPooling2d_same.ipynb create mode 100644 test/MaxPooling2Dsame.js create mode 100644 test/circuits/MaxPooling2Dsame_stride_test.circom create mode 100644 test/circuits/MaxPooling2Dsame_test.circom diff --git a/circuits/MaxPooling2Dsame.circom b/circuits/MaxPooling2Dsame.circom new file mode 100644 index 0000000..b4f89b0 --- /dev/null +++ b/circuits/MaxPooling2Dsame.circom @@ -0,0 +1,73 @@ +pragma circom 2.0.0; + +include "./MaxPooling2D.circom"; + +template MaxPooling2Dsame (nRows, nCols, nChannels, poolSize, strides) { + signal input in[nRows][nCols][nChannels]; + + var rowPadding, colPadding; + + if (nRows % strides == 0) { + rowPadding = (poolSize - strides) > 0 ? (poolSize - strides) : 0; + } else { + rowPadding = (poolSize - (nRows % strides)) > 0 ? (poolSize - (nRows % strides)) : 0; + } + + if (nCols % strides == 0) { + colPadding = (poolSize - strides) > 0 ? (poolSize - strides) : 0; + } else { + colPadding = (poolSize - (nCols % strides)) > 0 ? (poolSize - (nCols % strides)) : 0; + } + + signal input out[(nRows+rowPadding-poolSize)\strides+1][(nCols+colPadding-poolSize)\strides+1][nChannels]; + + component max2d = MaxPooling2D(nRows+rowPadding, nCols+colPadding, nChannels, poolSize, strides); + + for (var i = rowPadding\2; i < rowPadding\2+nRows; i++) { + for (var j = colPadding\2; j < colPadding\2+nCols; j++) { + for (var k = 0; k < nChannels; k++) { + max2d.in[i][j][k] <== in[i-rowPadding\2][j-colPadding\2][k]; + } + } + } + + for (var i = 0; i< rowPadding\2; i++) { + for (var j = 0; j < nCols+colPadding; j++) { + for (var k = 0; k < nChannels; k++) { + max2d.in[i][j][k] <== 0; + } + } + } + + for (var i = nRows+rowPadding\2; i< nRows+rowPadding; i++) { + for (var j = 0; j < nCols+colPadding; j++) { + for (var k = 0; k < nChannels; k++) { + max2d.in[i][j][k] <== 0; + } + } + } + + for (var i = rowPadding\2; i < nRows+rowPadding\2; i++) { + for (var j = 0; j < colPadding\2; j++) { + for (var k = 0; k < nChannels; k++) { + max2d.in[i][j][k] <== 0; + } + } + } + + for (var i = rowPadding\2; i < nRows+rowPadding\2; i++) { + for (var j = nCols+colPadding\2; j < nCols+colPadding; j++) { + for (var k = 0; k < nChannels; k++) { + max2d.in[i][j][k] <== 0; + } + } + } + + for (var i = 0; i < (nRows+rowPadding-poolSize)\strides+1; i++) { + for (var j = 0; j < (nCols+colPadding-poolSize)\strides+1; j++) { + for (var k = 0; k < nChannels; k++) { + max2d.out[i][j][k] <== out[i][j][k]; + } + } + } +} \ No newline at end of file diff --git a/models/maxPooling2Dsame_input.json b/models/maxPooling2Dsame_input.json new file mode 100644 index 0000000..e8c41e3 --- /dev/null +++ b/models/maxPooling2Dsame_input.json @@ -0,0 +1 @@ +{"in": [[[985416895178787714183661703755988992, 805695864472562858264688082666127360, 632684451822347108628256674661007360], [962060468041561843488482131179470848, 8869073729764975861586898758664192, 677759670322146260067784271585083392], [467925439750069089712591214043201536, 813606982115283460844798537969434624, 347772122728895153321033327318138880], [908354388868282470618698257109352448, 691746563537225063033040168500068352, 981726193592448973649175667344932864], [107153282843697739853753056552288256, 916572289086665811955284394381410304, 816245550879711653823401560087986176]], [[570382642241468206642626555033944064, 902803940435826284139749932730941440, 777663362123877860598598645145665536], [634090669863853333267758182352027648, 825958123785832877484312556952616960, 381513410888697931311864158837800960], [795520454944760995904638766836350976, 216359595257833349149815253067890688, 874616717227600117488934168588976128], [746811470919424536534678233059164160, 475841014459437513096347703977181184, 930028665971741586694297106155307008], [114000116831430861403008392105033728, 974091109200621567612816335791194112, 874043654050301690820171495435665408]], [[539579607228014027803293647267430400, 508632843040824474037175176904310784, 890125946940297760199841005954924544], [278016685349965289070216885173223424, 552815374291668108366329210074038272, 500049722551334607063608137326002176], [313744264694903791749843753654812672, 418498494396639822766365224919367680, 817633569064628434285947918592507904], [680992631793690390926240128859897856, 217157254616318320755856276537212928, 602450814158750475602742768018915328], [230448677275229241444208872890826752, 894021844181003495121759207996522496, 738029169437394307130485433159385088]], [[656815180323522869339469446928924672, 454311340827568891467175719507853312, 111869508785107817930652937351069696], [632327323641954593527382646479388672, 99140057192692252460452964236001280, 469778003485259478255996251364392960], [539803270562814329659938110829494272, 9999223425718462488177430625255424, 323853684321289258856386212080386048], [511898330068137781770564935289405440, 235408855051092995844022897467195392, 488270711104327731196314985370222592], [973477114549060379932212627888930816, 163297737054907472179883527553155072, 734044941826309056667877810110464000]], [[20108031993913978534061672446820352, 732003881291962206032259471665790976, 752966593417693633098737987950215168], [470805829710577633964522992960536576, 899977578211573860664606514781093888, 777543729617406477245147183797239808], [450476427222540130797322504023048192, 608466014932603115839705401413074944, 220254487946003773078748363501862912], [813185240376988213714402965620523008, 864052363876363850452980196205133824, 441362157036501413461183215236546560], [537904308558723633624141672348647424, 459317359421049476909261905824055296, 9517425402804758462154346481057792]]], "out": [[["985416895178787714183661703755988992", "902803940435826284139749932730941440", "777663362123877860598598645145665536"], ["908354388868282470618698257109352448", "813606982115283460844798537969434624", "981726193592448973649175667344932864"], ["114000116831430861403008392105033728", "974091109200621567612816335791194112", "874043654050301690820171495435665408"]], [["656815180323522869339469446928924672", "552815374291668108366329210074038272", "890125946940297760199841005954924544"], ["680992631793690390926240128859897856", "418498494396639822766365224919367680", "817633569064628434285947918592507904"], ["973477114549060379932212627888930816", "894021844181003495121759207996522496", "738029169437394307130485433159385088"]], [["470805829710577633964522992960536576", "899977578211573860664606514781093888", "777543729617406477245147183797239808"], ["813185240376988213714402965620523008", "864052363876363850452980196205133824", "441362157036501413461183215236546560"], ["537904308558723633624141672348647424", "459317359421049476909261905824055296", "9517425402804758462154346481057792"]]]} \ No newline at end of file diff --git a/models/maxPooling2Dsame_stride_input.json b/models/maxPooling2Dsame_stride_input.json new file mode 100644 index 0000000..5c27ede --- /dev/null +++ b/models/maxPooling2Dsame_stride_input.json @@ -0,0 +1 @@ +{"in": [[[181060845852928909750665111192731648, 873335066240771590121766412651331584, 466591630903943415384862952693694464], [419512825789891392097314144843202560, 504687534234384880332925591705616384, 555108518031512182896083305607725056], [889727395338766582460901887049728000, 304715420607600730931356119511072768, 396868332630338490772573363885834240], [319583098696415993297833594590330880, 759947197661148233407130979270656, 895376897621437771937017335681908736], [186082694906830482286358809003687936, 589361183135800444011368821307736064, 229939006949938846437336028870082560], [377869456505531058270724751758458880, 593831349332373193997349940537327616, 617898741196265876307052239085633536], [208890055386047568155961513734045696, 45299963238380707641112856813895680, 48101503056849241382516478038769664], [968292067860118608616536728156504064, 424571806427637026929723097948880896, 41812473134057454089473646242299904], [33070335915057482728529046095790080, 613358806713430029637599279974449152, 962525235441581803388090406416678912], [880909150990983977881328708709515264, 31637788250919565139913459250495488, 407229563816807554881027269822775296]], [[722047037048510692294577429613117440, 287769872879727352864158742816489472, 633057959701190742109452742097371136], [587593114714360540327277724472180736, 995928671466629356695108803213918208, 94764680527825162313400310202105856], [673764008202154483917802187098423296, 441940517104381671888290741036777472, 205872675264919929820897759060819968], [274105248422619104019009535125487616, 341288944291860194012394329678544896, 703928155170246531697742298583400448], [253407875536760714785299235260596224, 954086618948581700566160227891675136, 695949832100202646216991075580510208], [772974380013021179531003942196477952, 694646688436906795548541068660178944, 239040667194620244838351059577470976], [761119703539180780317870240200392704, 841368944375512546814487235963912192, 912971104374087013640201200717529088], [319331760557976648592893712646799360, 236474092649413562963071165437837312, 311432210529661105396777421843202048], [618384476135186118143094127526412288, 223794997365045507832264506646462464, 733884991260016091068798783197282304], [100735392484192717783490475967119360, 742390579543848987849194909946871808, 743563277676321729157199064014520320]], [[182788443064937929682227014886490112, 969266407248505292190446922703568896, 67989640886783058282687539158450176], [547375743318019556181977609023782912, 700178164250420480306160355132833792, 16201034671487591644009635682189312], [270368171833540445776128922485784576, 953685968096406095226402728226848768, 462933776666994685541321886838816768], [480574858761667759505931818774822912, 155085614083040341586693710994735104, 238047469071987930182174431502139392], [486919139890581107026252744227815424, 19403335495558283277790977315569664, 180402947105016802412845905590878208], [396449293682511410849689631627149312, 495605584981020026007527055383592960, 226605072288882402861639187000459264], [17414102014369881983380961854750720, 3464874754555258300462930059067392, 898099349554742179221499835408449536], [785190855377658732611395349324496896, 305525060677295699367753044300136448, 204565390630897754560329821622108160], [823700963638153724209164115255492608, 297490709034777825263012262822543360, 65048292691036515903499949362380800], [126448774345507619385628942709817344, 810615034659410594133868903820951552, 509045959217838506309653855888998400]], [[713680159433190877367141066893950976, 845319479968869758289175934824087552, 911203285361282051254679884143788032], [827233800468269198331599375468331008, 412657299272379848615281280847708160, 697025850995712023646235328390365184], [281018305301285604451792003667066880, 648955743791868283505378771075072000, 891153605987557985551895812668129280], [629371477360976322692171102263705600, 775929100688812511565570103251566592, 779093686310633217204969814058074112], [123716197030944721057925228484624384, 310122093218366277778320631971971072, 837055315861977089490188905893330944], [168089474631037074473457381103632384, 954284277871745890387007263991660544, 318407199651891688655564050295422976], [362606007259281948457376818464292864, 778830848866666695382201257353543680, 399089019882525038220773462749741056], [85165376402537083399842660127604736, 534511940982250534195745403183300608, 604920400036749440823474048758448128], [271115350002795655205678953653075968, 436692189086875806315854535015268352, 696241628373556811398888259653206016], [865780466166194901147341462043623424, 322369175499225190546657061256560640, 743159547886228828110355377787240448]], [[92869070612180413927113198902706176, 394089567662198418381256095302680576, 561323421039284129171801027733618688], [590368082822858464161998638141145088, 974901899794306633179826070214410240, 947030261209948995604238219295588352], [630178576969414950281839283942719488, 513484750661627854365309594169769984, 118185923586877402106242064334192640], [438458544379405017992663956092616704, 563764553948950223392054624974274560, 766318217816117183898784970773102592], [718648242610498931670595372570378240, 170367042833493904212175913394110464, 770642982224225810805156492481658880], [222980408999043826610383226417446912, 374329609305188548268333017665110016, 480409175290119097140845377342144512], [945617766167062485510007332488085504, 881617846977144246296983483176714240, 119580345973383635510533722183041024], [915676460740849039909566342278152192, 970518691217475177799962810208223232, 243214091510184799886852557700595712], [241451378898586162348624120633098240, 488548632484219477011556801427013632, 370190636910089278397155469728677888], [459905504259236845042685491072204800, 380889549714929988856634904844173312, 385698952079474499128182038725656576]], [[272104416239653114624676168051195904, 798516947143056490327664879309160448, 272194052700581400886005094093422592], [948709870845387857346891868297756672, 266222379570141512580073653692006400, 942522429410427764176957286775783424], [733569990394794557752022913886715904, 347061476788642214734411677447487488, 533050355551517276272795668950548480], [611738758986895481018871938688221184, 86818661078505453401540813448544256, 375463634889197639840864653285523456], [214948078088181504774561244110651392, 399863393954484338773146229310226432, 190922351817263129077726923072733184], [510853284550473880828449351467532288, 850874634827709561170916730803847168, 381366846819539174365048847957229568], [839025984606462607442639918861385728, 698107125696239884550127377469931520, 844114496759040350837367869456515072], [699026648050104609901396739978428416, 691391860246058782304189411187752960, 595508900213276558954866653513383936], [375264961930589894681617213050322944, 21962591226620033534836703930351616, 965963232953063155561649061582864384], [873625069637210314527081271358652416, 17013011041365544509902786048032768, 428592791258254350241688002383314944]], [[530376524078391318576403350558867456, 519931918588238382668918224515497984, 879621802069332614144323070292131840], [593691399141314866582653531068563456, 943767390679553173919884888860786688, 385923386832176245175244870238863360], [26691066286083421505744468443136000, 936911794042926571187759755122704384, 102611786967474533380535374030307328], [359377041658319176282193392801153024, 267568221890592820056282999493230592, 535446422069610157460354392595103744], [551850226842807715717985301323317248, 635787754066812906470475510293987328, 198619411462123242320149559593926656], [849409850466932500582277765549522944, 672271595031144910633066386683854848, 239559394952509218495103850915037184], [17380971340718589103439522832580608, 690213328307762114218988812757893120, 468996473691080829608521180536373248], [298860826050814316023339962498809856, 97088769275785689925992414856609792, 759385182097158099930713639207567360], [257726817498370781111026414126104576, 957823990680256749684893824671809536, 718212290152370685872760641560772608], [701041664244132545037467114551640064, 778681301680931345113753214299144192, 283077119909484707357592139667079168]], [[726341250795347218817727679288049664, 98874600951474152578382574366228480, 811023617456629073825580183806017536], [517637503653494301215695653419614208, 122369588602427760258954234093895680, 626059217937769666947295689371025408], [889381041995761174871442113834254336, 469513890142709771759824808307589120, 441358856495420263909058468026253312], [131543858038160611299740240540860416, 25292383934178652847540505382223872, 459211802054529507024427175227949056], [297316028903018790176428814814412800, 374157506904283522668313022673453056, 146629093308910273096216904894251008], [542787120506082039348801047785111552, 483436833296610287432604322277359616, 648266144597272391722443174580322304], [587451855996849487170757966572814336, 662348937528706384079894366331600896, 909155419360394098803395006243536896], [419004870835690151466097954051325952, 18294586373576504601667852652511232, 663249101866330461651607742836637696], [232421672633227234834518971351302144, 630460530562004940297661931602313216, 69027371800090670656925201711759360], [100823603422431138134448128680525824, 638197200440921556662210295354097664, 154316173921419955725826895137210368]], [[395104674687803861494200200752791552, 954576105444982145091012933221089280, 718833792890275344548356355463839744], [655325672679351614686912948815790080, 209750597404820726483701800665874432, 508179296315473806639865544342241280], [835434923636889965803916675336110080, 311732328825750749493636615561543680, 253760178687312045038760979821756416], [601300571308724205782535401329655808, 174387890165552136605957440749436928, 128901151712453753777331793525997568], [360114136717724883763448597471821824, 603481824227854008112493084815654912, 642517615869771978807037827880058880], [474822394876300784538105805262553088, 40695369651322924660872558762000384, 806656676106104505229466148557291520], [235227782766514129862041539430055936, 186636675305325127674486323318620160, 292355799625837572085090278006849536], [757334531489503054541832608603439104, 302550198477650879608517885964582912, 878392400606874079114934534610616320], [989429374995405609003435779036282880, 917356224684922440232053906134269952, 994972891967103944551729525539471360], [338833763654035757967963503529558016, 239923972440443854373206883949346816, 24975367589743947952737807190982656]], [[53450472906456212489602249089089536, 260437559937745199304159808031031296, 475257720371555822213696431964815360], [510602045587970767184337129422454784, 724407671943644015539146665914007552, 127708924969769374954592752203988992], [889866378227432814713479757815611392, 768391514134683685944424328585019392, 810330839548143299962319792341254144], [904964699224751401721193296243458048, 907169014433129231084536890038157312, 330947299017325800841780638445993984], [325628890558827076203977452818530304, 362460672495233090485121388864602112, 609180260647294075284256963443032064], [397278799966624768096335987868696576, 287846814675991162983839467936153600, 951001737371019064659814308537958400], [394260450257245198773068562662686720, 300806280328724832256760381278519296, 851360451534335061119480218741899264], [362097670448151044151987973631508480, 224539340132971947925902626484387840, 761618704127653732576191351032381440], [250380406360501740927510858191339520, 756360962939184850464290598295699456, 116230048643782403360230233176276992], [901240322382283207026350889182953472, 373212986501062202469532290256994304, 736122528767077629789884872958935040]]], "out": [[["722047037048510692294577429613117440", "995928671466629356695108803213918208", "633057959701190742109452742097371136"], ["319583098696415993297833594590330880", "954086618948581700566160227891675136", "895376897621437771937017335681908736"], ["968292067860118608616536728156504064", "841368944375512546814487235963912192", "912971104374087013640201200717529088"], ["880909150990983977881328708709515264", "742390579543848987849194909946871808", "743563277676321729157199064014520320"]], [["827233800468269198331599375468331008", "974901899794306633179826070214410240", "947030261209948995604238219295588352"], ["718648242610498931670595372570378240", "775929100688812511565570103251566592", "837055315861977089490188905893330944"], ["945617766167062485510007332488085504", "970518691217475177799962810208223232", "604920400036749440823474048758448128"], ["865780466166194901147341462043623424", "380889549714929988856634904844173312", "743159547886228828110355377787240448"]], [["726341250795347218817727679288049664", "943767390679553173919884888860786688", "879621802069332614144323070292131840"], ["551850226842807715717985301323317248", "635787754066812906470475510293987328", "535446422069610157460354392595103744"], ["587451855996849487170757966572814336", "690213328307762114218988812757893120", "909155419360394098803395006243536896"], ["701041664244132545037467114551640064", "778681301680931345113753214299144192", "283077119909484707357592139667079168"]], [["510602045587970767184337129422454784", "724407671943644015539146665914007552", "475257720371555822213696431964815360"], ["904964699224751401721193296243458048", "907169014433129231084536890038157312", "609180260647294075284256963443032064"], ["394260450257245198773068562662686720", "300806280328724832256760381278519296", "851360451534335061119480218741899264"], ["901240322382283207026350889182953472", "373212986501062202469532290256994304", "736122528767077629789884872958935040"]]]} \ No newline at end of file diff --git a/models/maxPooling2d_same.ipynb b/models/maxPooling2d_same.ipynb new file mode 100644 index 0000000..3011eac --- /dev/null +++ b/models/maxPooling2d_same.ipynb @@ -0,0 +1,619 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow.keras.layers import Input, MaxPooling2D\n", + "from tensorflow.keras import Model\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = Input(shape=(5,5,3))\n", + "x = MaxPooling2D(pool_size=2, padding='same')(inputs)\n", + "model = Model(inputs, x)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n", + " \n", + " max_pooling2d (MaxPooling2D (None, 3, 3, 3) 0 \n", + " ) \n", + " \n", + "=================================================================\n", + "Total params: 0\n", + "Trainable params: 0\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[0.9854169 , 0.80569586, 0.63268445],\n", + " [0.96206047, 0.00886907, 0.67775967],\n", + " [0.46792544, 0.81360698, 0.34777212],\n", + " [0.90835439, 0.69174656, 0.98172619],\n", + " [0.10715328, 0.91657229, 0.81624555]],\n", + "\n", + " [[0.57038264, 0.90280394, 0.77766336],\n", + " [0.63409067, 0.82595812, 0.38151341],\n", + " [0.79552045, 0.2163596 , 0.87461672],\n", + " [0.74681147, 0.47584101, 0.93002867],\n", + " [0.11400012, 0.97409111, 0.87404365]],\n", + "\n", + " [[0.53957961, 0.50863284, 0.89012595],\n", + " [0.27801669, 0.55281537, 0.50004972],\n", + " [0.31374426, 0.41849849, 0.81763357],\n", + " [0.68099263, 0.21715725, 0.60245081],\n", + " [0.23044868, 0.89402184, 0.73802917]],\n", + "\n", + " [[0.65681518, 0.45431134, 0.11186951],\n", + " [0.63232732, 0.09914006, 0.469778 ],\n", + " [0.53980327, 0.00999922, 0.32385368],\n", + " [0.51189833, 0.23540886, 0.48827071],\n", + " [0.97347711, 0.16329774, 0.73404494]],\n", + "\n", + " [[0.02010803, 0.73200388, 0.75296659],\n", + " [0.47080583, 0.89997758, 0.77754373],\n", + " [0.45047643, 0.60846601, 0.22025449],\n", + " [0.81318524, 0.86405236, 0.44136216],\n", + " [0.53790431, 0.45931736, 0.00951743]]]])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.random.rand(1,5,5,3)\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/1 [==============================] - 0s 42ms/step\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-05 01:27:05.945781: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[[[0.9854169 , 0.90280396, 0.77766335],\n", + " [0.9083544 , 0.813607 , 0.98172617],\n", + " [0.11400012, 0.9740911 , 0.87404364]],\n", + "\n", + " [[0.6568152 , 0.5528154 , 0.89012593],\n", + " [0.6809926 , 0.4184985 , 0.81763357],\n", + " [0.9734771 , 0.89402187, 0.7380292 ]],\n", + "\n", + " [[0.47080582, 0.89997756, 0.7775437 ],\n", + " [0.8131852 , 0.86405236, 0.44136214],\n", + " [0.5379043 , 0.45931736, 0.00951743]]]], dtype=float32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y = model.predict(X)\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "X_in = [[[int(X[0][i][j][k]*1e36) for k in range(3)] for j in range(5)] for i in range(5)]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def MaxPooling2DInt(nRows, nCols, nChannels, poolSize, strides, input):\n", + " out = [[[str(max(int(input[i*strides + x][j*strides + y][k]) for x in range(poolSize) for y in range(poolSize))) for k in range(nChannels)] for j in range((nCols - poolSize) // strides + 1)] for i in range((nRows - poolSize) // strides + 1)]\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def MaxPooling2DsameInt(nRows, nCols, nChannels, poolSize, strides, input):\n", + " if nRows % strides == 0:\n", + " rowPadding = max(poolSize - strides, 0)\n", + " else:\n", + " rowPadding = max(poolSize - nRows % strides, 0)\n", + " if nCols % strides == 0:\n", + " colPadding = max(poolSize - strides, 0)\n", + " else:\n", + " colPadding = max(poolSize - nCols % strides, 0)\n", + " \n", + " _input = [[[0 for _ in range(nChannels)] for _ in range(nCols + colPadding)] for _ in range(nRows + rowPadding)]\n", + "\n", + " for i in range(nRows):\n", + " for j in range(nCols):\n", + " for k in range(nChannels):\n", + " _input[i+rowPadding//2][j+colPadding//2][k] = input[i][j][k]\n", + " \n", + " out = MaxPooling2DInt(nRows + rowPadding, nCols + colPadding, nChannels, poolSize, strides, _input)\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[['985416895178787714183661703755988992',\n", + " '902803940435826284139749932730941440',\n", + " '777663362123877860598598645145665536'],\n", + " ['908354388868282470618698257109352448',\n", + " '813606982115283460844798537969434624',\n", + " '981726193592448973649175667344932864'],\n", + " ['114000116831430861403008392105033728',\n", + " '974091109200621567612816335791194112',\n", + " '874043654050301690820171495435665408']],\n", + " [['656815180323522869339469446928924672',\n", + " '552815374291668108366329210074038272',\n", + " '890125946940297760199841005954924544'],\n", + " ['680992631793690390926240128859897856',\n", + " '418498494396639822766365224919367680',\n", + " '817633569064628434285947918592507904'],\n", + " ['973477114549060379932212627888930816',\n", + " '894021844181003495121759207996522496',\n", + " '738029169437394307130485433159385088']],\n", + " [['470805829710577633964522992960536576',\n", + " '899977578211573860664606514781093888',\n", + " '777543729617406477245147183797239808'],\n", + " ['813185240376988213714402965620523008',\n", + " '864052363876363850452980196205133824',\n", + " '441362157036501413461183215236546560'],\n", + " ['537904308558723633624141672348647424',\n", + " '459317359421049476909261905824055296',\n", + " '9517425402804758462154346481057792']]]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out = MaxPooling2DsameInt(5, 5, 3, 2, 2, X_in)\n", + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "in_json = {\n", + " \"in\": X_in,\n", + " \"out\": out\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"maxPooling2Dsame_input.json\", \"w\") as f:\n", + " json.dump(in_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = Input(shape=(10,10,3))\n", + "x = MaxPooling2D(pool_size=2, strides=3, padding='same')(inputs)\n", + "model = Model(inputs, x)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_1\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_2 (InputLayer) [(None, 10, 10, 3)] 0 \n", + " \n", + " max_pooling2d_1 (MaxPooling (None, 4, 4, 3) 0 \n", + " 2D) \n", + " \n", + "=================================================================\n", + "Total params: 0\n", + "Trainable params: 0\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[1.81060846e-01, 8.73335066e-01, 4.66591631e-01],\n", + " [4.19512826e-01, 5.04687534e-01, 5.55108518e-01],\n", + " [8.89727395e-01, 3.04715421e-01, 3.96868333e-01],\n", + " [3.19583099e-01, 7.59947198e-04, 8.95376898e-01],\n", + " [1.86082695e-01, 5.89361183e-01, 2.29939007e-01],\n", + " [3.77869457e-01, 5.93831349e-01, 6.17898741e-01],\n", + " [2.08890055e-01, 4.52999632e-02, 4.81015031e-02],\n", + " [9.68292068e-01, 4.24571806e-01, 4.18124731e-02],\n", + " [3.30703359e-02, 6.13358807e-01, 9.62525235e-01],\n", + " [8.80909151e-01, 3.16377883e-02, 4.07229564e-01]],\n", + "\n", + " [[7.22047037e-01, 2.87769873e-01, 6.33057960e-01],\n", + " [5.87593115e-01, 9.95928671e-01, 9.47646805e-02],\n", + " [6.73764008e-01, 4.41940517e-01, 2.05872675e-01],\n", + " [2.74105248e-01, 3.41288944e-01, 7.03928155e-01],\n", + " [2.53407876e-01, 9.54086619e-01, 6.95949832e-01],\n", + " [7.72974380e-01, 6.94646688e-01, 2.39040667e-01],\n", + " [7.61119704e-01, 8.41368944e-01, 9.12971104e-01],\n", + " [3.19331761e-01, 2.36474093e-01, 3.11432211e-01],\n", + " [6.18384476e-01, 2.23794997e-01, 7.33884991e-01],\n", + " [1.00735392e-01, 7.42390580e-01, 7.43563278e-01]],\n", + "\n", + " [[1.82788443e-01, 9.69266407e-01, 6.79896409e-02],\n", + " [5.47375743e-01, 7.00178164e-01, 1.62010347e-02],\n", + " [2.70368172e-01, 9.53685968e-01, 4.62933777e-01],\n", + " [4.80574859e-01, 1.55085614e-01, 2.38047469e-01],\n", + " [4.86919140e-01, 1.94033355e-02, 1.80402947e-01],\n", + " [3.96449294e-01, 4.95605585e-01, 2.26605072e-01],\n", + " [1.74141020e-02, 3.46487475e-03, 8.98099350e-01],\n", + " [7.85190855e-01, 3.05525061e-01, 2.04565391e-01],\n", + " [8.23700964e-01, 2.97490709e-01, 6.50482927e-02],\n", + " [1.26448774e-01, 8.10615035e-01, 5.09045959e-01]],\n", + "\n", + " [[7.13680159e-01, 8.45319480e-01, 9.11203285e-01],\n", + " [8.27233800e-01, 4.12657299e-01, 6.97025851e-01],\n", + " [2.81018305e-01, 6.48955744e-01, 8.91153606e-01],\n", + " [6.29371477e-01, 7.75929101e-01, 7.79093686e-01],\n", + " [1.23716197e-01, 3.10122093e-01, 8.37055316e-01],\n", + " [1.68089475e-01, 9.54284278e-01, 3.18407200e-01],\n", + " [3.62606007e-01, 7.78830849e-01, 3.99089020e-01],\n", + " [8.51653764e-02, 5.34511941e-01, 6.04920400e-01],\n", + " [2.71115350e-01, 4.36692189e-01, 6.96241628e-01],\n", + " [8.65780466e-01, 3.22369175e-01, 7.43159548e-01]],\n", + "\n", + " [[9.28690706e-02, 3.94089568e-01, 5.61323421e-01],\n", + " [5.90368083e-01, 9.74901900e-01, 9.47030261e-01],\n", + " [6.30178577e-01, 5.13484751e-01, 1.18185924e-01],\n", + " [4.38458544e-01, 5.63764554e-01, 7.66318218e-01],\n", + " [7.18648243e-01, 1.70367043e-01, 7.70642982e-01],\n", + " [2.22980409e-01, 3.74329609e-01, 4.80409175e-01],\n", + " [9.45617766e-01, 8.81617847e-01, 1.19580346e-01],\n", + " [9.15676461e-01, 9.70518691e-01, 2.43214092e-01],\n", + " [2.41451379e-01, 4.88548632e-01, 3.70190637e-01],\n", + " [4.59905504e-01, 3.80889550e-01, 3.85698952e-01]],\n", + "\n", + " [[2.72104416e-01, 7.98516947e-01, 2.72194053e-01],\n", + " [9.48709871e-01, 2.66222380e-01, 9.42522429e-01],\n", + " [7.33569990e-01, 3.47061477e-01, 5.33050356e-01],\n", + " [6.11738759e-01, 8.68186611e-02, 3.75463635e-01],\n", + " [2.14948078e-01, 3.99863394e-01, 1.90922352e-01],\n", + " [5.10853285e-01, 8.50874635e-01, 3.81366847e-01],\n", + " [8.39025985e-01, 6.98107126e-01, 8.44114497e-01],\n", + " [6.99026648e-01, 6.91391860e-01, 5.95508900e-01],\n", + " [3.75264962e-01, 2.19625912e-02, 9.65963233e-01],\n", + " [8.73625070e-01, 1.70130110e-02, 4.28592791e-01]],\n", + "\n", + " [[5.30376524e-01, 5.19931919e-01, 8.79621802e-01],\n", + " [5.93691399e-01, 9.43767391e-01, 3.85923387e-01],\n", + " [2.66910663e-02, 9.36911794e-01, 1.02611787e-01],\n", + " [3.59377042e-01, 2.67568222e-01, 5.35446422e-01],\n", + " [5.51850227e-01, 6.35787754e-01, 1.98619411e-01],\n", + " [8.49409850e-01, 6.72271595e-01, 2.39559395e-01],\n", + " [1.73809713e-02, 6.90213328e-01, 4.68996474e-01],\n", + " [2.98860826e-01, 9.70887693e-02, 7.59385182e-01],\n", + " [2.57726817e-01, 9.57823991e-01, 7.18212290e-01],\n", + " [7.01041664e-01, 7.78681302e-01, 2.83077120e-01]],\n", + "\n", + " [[7.26341251e-01, 9.88746010e-02, 8.11023617e-01],\n", + " [5.17637504e-01, 1.22369589e-01, 6.26059218e-01],\n", + " [8.89381042e-01, 4.69513890e-01, 4.41358856e-01],\n", + " [1.31543858e-01, 2.52923839e-02, 4.59211802e-01],\n", + " [2.97316029e-01, 3.74157507e-01, 1.46629093e-01],\n", + " [5.42787121e-01, 4.83436833e-01, 6.48266145e-01],\n", + " [5.87451856e-01, 6.62348938e-01, 9.09155419e-01],\n", + " [4.19004871e-01, 1.82945864e-02, 6.63249102e-01],\n", + " [2.32421673e-01, 6.30460531e-01, 6.90273718e-02],\n", + " [1.00823603e-01, 6.38197200e-01, 1.54316174e-01]],\n", + "\n", + " [[3.95104675e-01, 9.54576105e-01, 7.18833793e-01],\n", + " [6.55325673e-01, 2.09750597e-01, 5.08179296e-01],\n", + " [8.35434924e-01, 3.11732329e-01, 2.53760179e-01],\n", + " [6.01300571e-01, 1.74387890e-01, 1.28901152e-01],\n", + " [3.60114137e-01, 6.03481824e-01, 6.42517616e-01],\n", + " [4.74822395e-01, 4.06953697e-02, 8.06656676e-01],\n", + " [2.35227783e-01, 1.86636675e-01, 2.92355800e-01],\n", + " [7.57334531e-01, 3.02550198e-01, 8.78392401e-01],\n", + " [9.89429375e-01, 9.17356225e-01, 9.94972892e-01],\n", + " [3.38833764e-01, 2.39923972e-01, 2.49753676e-02]],\n", + "\n", + " [[5.34504729e-02, 2.60437560e-01, 4.75257720e-01],\n", + " [5.10602046e-01, 7.24407672e-01, 1.27708925e-01],\n", + " [8.89866378e-01, 7.68391514e-01, 8.10330840e-01],\n", + " [9.04964699e-01, 9.07169014e-01, 3.30947299e-01],\n", + " [3.25628891e-01, 3.62460672e-01, 6.09180261e-01],\n", + " [3.97278800e-01, 2.87846815e-01, 9.51001737e-01],\n", + " [3.94260450e-01, 3.00806280e-01, 8.51360452e-01],\n", + " [3.62097670e-01, 2.24539340e-01, 7.61618704e-01],\n", + " [2.50380406e-01, 7.56360963e-01, 1.16230049e-01],\n", + " [9.01240322e-01, 3.73212987e-01, 7.36122529e-01]]]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.random.rand(1,10,10,3)\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/1 [==============================] - 0s 30ms/step\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[[[0.72204703, 0.99592865, 0.63305795],\n", + " [0.3195831 , 0.9540866 , 0.8953769 ],\n", + " [0.96829206, 0.841369 , 0.9129711 ],\n", + " [0.88090914, 0.7423906 , 0.7435633 ]],\n", + "\n", + " [[0.8272338 , 0.9749019 , 0.94703025],\n", + " [0.71864825, 0.7759291 , 0.8370553 ],\n", + " [0.9456178 , 0.9705187 , 0.6049204 ],\n", + " [0.8657805 , 0.38088953, 0.74315953]],\n", + "\n", + " [[0.72634125, 0.94376737, 0.8796218 ],\n", + " [0.5518502 , 0.6357877 , 0.5354464 ],\n", + " [0.5874519 , 0.6902133 , 0.9091554 ],\n", + " [0.70104164, 0.7786813 , 0.28307712]],\n", + "\n", + " [[0.51060206, 0.7244077 , 0.47525772],\n", + " [0.9049647 , 0.90716904, 0.6091803 ],\n", + " [0.39426044, 0.30080628, 0.85136044],\n", + " [0.90124035, 0.373213 , 0.73612255]]]], dtype=float32)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y = model.predict(X)\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "X_in = [[[int(X[0][i][j][k]*1e36) for k in range(3)] for j in range(10)] for i in range(10)]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[['722047037048510692294577429613117440',\n", + " '995928671466629356695108803213918208',\n", + " '633057959701190742109452742097371136'],\n", + " ['319583098696415993297833594590330880',\n", + " '954086618948581700566160227891675136',\n", + " '895376897621437771937017335681908736'],\n", + " ['968292067860118608616536728156504064',\n", + " '841368944375512546814487235963912192',\n", + " '912971104374087013640201200717529088'],\n", + " ['880909150990983977881328708709515264',\n", + " '742390579543848987849194909946871808',\n", + " '743563277676321729157199064014520320']],\n", + " [['827233800468269198331599375468331008',\n", + " '974901899794306633179826070214410240',\n", + " '947030261209948995604238219295588352'],\n", + " ['718648242610498931670595372570378240',\n", + " '775929100688812511565570103251566592',\n", + " '837055315861977089490188905893330944'],\n", + " ['945617766167062485510007332488085504',\n", + " '970518691217475177799962810208223232',\n", + " '604920400036749440823474048758448128'],\n", + " ['865780466166194901147341462043623424',\n", + " '380889549714929988856634904844173312',\n", + " '743159547886228828110355377787240448']],\n", + " [['726341250795347218817727679288049664',\n", + " '943767390679553173919884888860786688',\n", + " '879621802069332614144323070292131840'],\n", + " ['551850226842807715717985301323317248',\n", + " '635787754066812906470475510293987328',\n", + " '535446422069610157460354392595103744'],\n", + " ['587451855996849487170757966572814336',\n", + " '690213328307762114218988812757893120',\n", + " '909155419360394098803395006243536896'],\n", + " ['701041664244132545037467114551640064',\n", + " '778681301680931345113753214299144192',\n", + " '283077119909484707357592139667079168']],\n", + " [['510602045587970767184337129422454784',\n", + " '724407671943644015539146665914007552',\n", + " '475257720371555822213696431964815360'],\n", + " ['904964699224751401721193296243458048',\n", + " '907169014433129231084536890038157312',\n", + " '609180260647294075284256963443032064'],\n", + " ['394260450257245198773068562662686720',\n", + " '300806280328724832256760381278519296',\n", + " '851360451534335061119480218741899264'],\n", + " ['901240322382283207026350889182953472',\n", + " '373212986501062202469532290256994304',\n", + " '736122528767077629789884872958935040']]]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out = MaxPooling2DsameInt(10, 10, 3, 2, 3, X_in)\n", + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "in_json = {\n", + " \"in\": X_in,\n", + " \"out\": out\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"maxPooling2Dsame_stride_input.json\", \"w\") as f:\n", + " json.dump(in_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sklearn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/MaxPooling2Dsame.js b/test/MaxPooling2Dsame.js new file mode 100644 index 0000000..3aacb20 --- /dev/null +++ b/test/MaxPooling2Dsame.js @@ -0,0 +1,37 @@ +const chai = require("chai"); +const path = require("path"); + +const wasm_tester = require("circom_tester").wasm; + +const F1Field = require("ffjavascript").F1Field; +const Scalar = require("ffjavascript").Scalar; +exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617"); +const Fr = new F1Field(exports.p); + +const assert = chai.assert; + +describe("MaxPooling2Dsame layer test", function () { + this.timeout(100000000); + + // MaxPooling with strides==poolSize + it("(5,5,3) -> (3,3,3)", async () => { + const INPUT = require("../models/maxPooling2Dsame_input.json"); + + const circuit = await wasm_tester(path.join(__dirname, "circuits", "MaxPooling2Dsame_test.circom")); + + const witness = await circuit.calculateWitness(INPUT, true); + + assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); + }); + + // MaxPooling with strides!=poolSize + it("(10,10,3) -> (4,4,3)", async () => { + const INPUT = require("../models/maxPooling2Dsame_stride_input.json"); + + const circuit = await wasm_tester(path.join(__dirname, "circuits", "MaxPooling2Dsame_stride_test.circom")); + + const witness = await circuit.calculateWitness(INPUT, true); + + assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); + }); +}); \ No newline at end of file diff --git a/test/circuits/MaxPooling2Dsame_stride_test.circom b/test/circuits/MaxPooling2Dsame_stride_test.circom new file mode 100644 index 0000000..1d2b2f6 --- /dev/null +++ b/test/circuits/MaxPooling2Dsame_stride_test.circom @@ -0,0 +1,5 @@ +pragma circom 2.0.0; + +include "../../circuits/MaxPooling2Dsame.circom"; + +component main = MaxPooling2Dsame(10, 10, 3, 2, 3); \ No newline at end of file diff --git a/test/circuits/MaxPooling2Dsame_test.circom b/test/circuits/MaxPooling2Dsame_test.circom new file mode 100644 index 0000000..8c42ae0 --- /dev/null +++ b/test/circuits/MaxPooling2Dsame_test.circom @@ -0,0 +1,6 @@ +pragma circom 2.0.0; + +include "../../circuits/MaxPooling2Dsame.circom"; + +// poolSize=strides - default Keras settings +component main = MaxPooling2Dsame(5, 5, 3, 2, 2); \ No newline at end of file