Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03122023

This commit is contained in:
Rohit Santhanam
2023-03-13 18:09:12 +00:00
91 changed files with 3492 additions and 3373 deletions

View File

@@ -42,7 +42,7 @@ jobs:
- name: Clear cache
run: |
rm -rf ~/.triton/cache/
rm -rf ~/.triton/
- name: Update path
run: |
@@ -93,7 +93,6 @@ jobs:
cd python/test/unit/
pytest
- name: Run CXX unittests
run: |
cd python/
@@ -107,4 +106,4 @@ jobs:
sudo nvidia-smi -i 0 -pm 1
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
pytest -vs .
sudo nvidia-smi -i 0 -rgc
sudo nvidia-smi -i 0 -rgc

View File

@@ -1,37 +1,37 @@
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio
cuda,AlbertForMaskedLM,4,1.5637,196.1976,36.7633,1.2637
cuda,AlbertForQuestionAnswering,4,1.5697,193.9925,22.3314,1.3133
cuda,BartForCausalLM,4,1.4704,88.2529,36.2264,0.9750
cuda,BertForMaskedLM,16,1.5146,83.3326,42.1532,1.0497
cuda,BertForQuestionAnswering,16,1.6353,65.9640,30.3219,1.1693
cuda,BlenderbotSmallForCausalLM,64,1.1694,62.6029,28.2282,0.9127
cuda,BlenderbotSmallForConditionalGeneration,64,1.3153,116.4360,51.6582,0.9804
cuda,CamemBert,16,1.4672,93.1047,33.8399,1.0462
cuda,DebertaForMaskedLM,4,0.8610,94.2568,41.9968,1.0406
cuda,DebertaForQuestionAnswering,8,0.9550,94.5806,42.3991,1.1506
cuda,DebertaV2ForMaskedLM,1,0.7829,214.8802,64.2595,0.9772
cuda,DistilBertForMaskedLM,128,1.2346,80.7496,30.1995,0.9625
cuda,DistilBertForQuestionAnswering,256,1.3784,86.2125,32.3333,1.1469
cuda,DistillGPT2,16,1.6558,72.8982,21.3507,1.0639
cuda,ElectraForCausalLM,32,1.5242,67.9513,57.7771,0.9719
cuda,ElectraForQuestionAnswering,64,1.9230,67.2341,45.5223,1.1624
cuda,GPT2ForSequenceClassification,4,2.0511,53.7014,31.2262,1.2305
cuda,LayoutLMForMaskedLM,16,1.5055,84.8592,40.6248,1.0491
cuda,LayoutLMForSequenceClassification,16,1.6464,66.9082,35.1418,1.1401
cuda,MBartForCausalLM,4,1.4704,88.7321,29.1925,0.9831
cuda,MegatronBertForCausalLM,4,1.1061,136.3322,79.4501,1.0946
cuda,MegatronBertForQuestionAnswering,8,1.2551,133.9124,75.9488,1.1147
cuda,MobileBertForMaskedLM,64,0.9569,333.3552,130.7601,1.0135
cuda,MobileBertForQuestionAnswering,128,0.9634,331.8111,126.0109,0.8400
cuda,PLBartForCausalLM,8,1.5155,83.1783,24.3849,0.9886
cuda,PLBartForConditionalGeneration,4,1.4414,93.6038,52.3630,1.0496
cuda,PegasusForCausalLM,32,1.1225,79.1829,36.5687,0.9737
cuda,PegasusForConditionalGeneration,32,1.1506,175.4371,59.1006,1.0686
cuda,RobertaForCausalLM,16,1.5780,83.6565,33.7543,1.0491
cuda,RobertaForQuestionAnswering,16,1.6336,66.3454,29.9597,1.1698
cuda,Speech2Text2ForCausalLM,256,1.5464,41.4059,25.6908,0.8768
cuda,T5ForConditionalGeneration,4,1.2736,96.9787,54.8479,1.1802
cuda,T5Small,4,1.2861,98.4766,32.1507,1.1802
cuda,TrOCRForCausalLM,32,1.2573,127.5731,36.5153,0.9584
cuda,XLNetLMHeadModel,8,1.6924,177.0149,83.7423,1.1026
cuda,YituTechConvBert,16,1.4142,107.2519,68.8073,1.0363
cuda,AlbertForMaskedLM,4,1.5511,164.3373,26.8523,1.2647
cuda,AlbertForQuestionAnswering,4,1.5501,163.5580,25.7983,1.3145
cuda,BartForCausalLM,4,1.5080,71.7230,32.8907,0.9749
cuda,BertForMaskedLM,16,1.5350,67.9451,35.3286,1.0494
cuda,BertForQuestionAnswering,16,1.6735,53.2963,34.3754,1.1710
cuda,BlenderbotSmallForCausalLM,64,1.2106,46.6466,23.8058,0.9120
cuda,BlenderbotSmallForConditionalGeneration,64,1.3616,77.3013,55.3546,0.9803
cuda,CamemBert,16,1.4779,76.1809,35.3883,1.0469
cuda,DebertaForMaskedLM,4,0.8415,62.3395,35.9657,1.0418
cuda,DebertaForQuestionAnswering,8,1.0609,67.5151,35.7728,1.1528
cuda,DebertaV2ForMaskedLM,1,0.6026,134.6517,66.1783,0.9773
cuda,DistilBertForMaskedLM,128,1.2460,66.9382,18.3089,0.9624
cuda,DistilBertForQuestionAnswering,256,1.3997,72.4126,18.1956,1.1486
cuda,DistillGPT2,16,1.6656,60.5455,17.2280,1.0641
cuda,ElectraForCausalLM,32,1.8299,45.4841,37.0944,0.9717
cuda,ElectraForQuestionAnswering,64,2.0289,52.6890,35.9632,1.1928
cuda,GPT2ForSequenceClassification,4,2.2567,38.2969,30.0527,1.2323
cuda,LayoutLMForMaskedLM,16,1.5423,68.8018,36.5562,1.0495
cuda,LayoutLMForSequenceClassification,16,1.7058,53.9355,35.2225,1.1659
cuda,MBartForCausalLM,4,1.4945,71.4649,32.8653,0.9830
cuda,MegatronBertForCausalLM,4,1.4328,58.4404,70.6226,1.0951
cuda,MegatronBertForQuestionAnswering,8,1.5886,85.2533,69.1219,1.1152
cuda,MobileBertForMaskedLM,64,0.9007,131.7379,107.5275,1.0136
cuda,MobileBertForQuestionAnswering,128,0.8435,167.9066,106.7049,0.8579
cuda,PLBartForCausalLM,8,1.5261,68.9224,19.5826,0.9887
cuda,PLBartForConditionalGeneration,4,1.5298,71.2811,45.6902,1.0495
cuda,PegasusForCausalLM,32,1.2212,57.5436,33.3863,0.9736
cuda,PegasusForConditionalGeneration,32,1.2822,106.4678,69.8825,1.0689
cuda,RobertaForCausalLM,16,1.6128,67.5706,34.7355,1.0496
cuda,RobertaForQuestionAnswering,16,1.6800,53.6267,33.8527,1.1704
cuda,Speech2Text2ForCausalLM,256,1.8230,32.9145,18.7201,0.8760
cuda,T5ForConditionalGeneration,4,1.6592,59.5324,39.4406,1.1814
cuda,T5Small,4,1.6581,59.5930,37.0471,1.1814
cuda,TrOCRForCausalLM,32,1.2586,106.2633,32.5330,0.9583
cuda,XLNetLMHeadModel,8,1.8108,142.8795,84.8197,1.1240
cuda,YituTechConvBert,16,1.5207,81.4595,53.1565,1.0362
1 dev name batch_size speedup abs_latency compilation_latency compression_ratio
2 cuda AlbertForMaskedLM 4 1.5637 1.5511 196.1976 164.3373 36.7633 26.8523 1.2637 1.2647
3 cuda AlbertForQuestionAnswering 4 1.5697 1.5501 193.9925 163.5580 22.3314 25.7983 1.3133 1.3145
4 cuda BartForCausalLM 4 1.4704 1.5080 88.2529 71.7230 36.2264 32.8907 0.9750 0.9749
5 cuda BertForMaskedLM 16 1.5146 1.5350 83.3326 67.9451 42.1532 35.3286 1.0497 1.0494
6 cuda BertForQuestionAnswering 16 1.6353 1.6735 65.9640 53.2963 30.3219 34.3754 1.1693 1.1710
7 cuda BlenderbotSmallForCausalLM 64 1.1694 1.2106 62.6029 46.6466 28.2282 23.8058 0.9127 0.9120
8 cuda BlenderbotSmallForConditionalGeneration 64 1.3153 1.3616 116.4360 77.3013 51.6582 55.3546 0.9804 0.9803
9 cuda CamemBert 16 1.4672 1.4779 93.1047 76.1809 33.8399 35.3883 1.0462 1.0469
10 cuda DebertaForMaskedLM 4 0.8610 0.8415 94.2568 62.3395 41.9968 35.9657 1.0406 1.0418
11 cuda DebertaForQuestionAnswering 8 0.9550 1.0609 94.5806 67.5151 42.3991 35.7728 1.1506 1.1528
12 cuda DebertaV2ForMaskedLM 1 0.7829 0.6026 214.8802 134.6517 64.2595 66.1783 0.9772 0.9773
13 cuda DistilBertForMaskedLM 128 1.2346 1.2460 80.7496 66.9382 30.1995 18.3089 0.9625 0.9624
14 cuda DistilBertForQuestionAnswering 256 1.3784 1.3997 86.2125 72.4126 32.3333 18.1956 1.1469 1.1486
15 cuda DistillGPT2 16 1.6558 1.6656 72.8982 60.5455 21.3507 17.2280 1.0639 1.0641
16 cuda ElectraForCausalLM 32 1.5242 1.8299 67.9513 45.4841 57.7771 37.0944 0.9719 0.9717
17 cuda ElectraForQuestionAnswering 64 1.9230 2.0289 67.2341 52.6890 45.5223 35.9632 1.1624 1.1928
18 cuda GPT2ForSequenceClassification 4 2.0511 2.2567 53.7014 38.2969 31.2262 30.0527 1.2305 1.2323
19 cuda LayoutLMForMaskedLM 16 1.5055 1.5423 84.8592 68.8018 40.6248 36.5562 1.0491 1.0495
20 cuda LayoutLMForSequenceClassification 16 1.6464 1.7058 66.9082 53.9355 35.1418 35.2225 1.1401 1.1659
21 cuda MBartForCausalLM 4 1.4704 1.4945 88.7321 71.4649 29.1925 32.8653 0.9831 0.9830
22 cuda MegatronBertForCausalLM 4 1.1061 1.4328 136.3322 58.4404 79.4501 70.6226 1.0946 1.0951
23 cuda MegatronBertForQuestionAnswering 8 1.2551 1.5886 133.9124 85.2533 75.9488 69.1219 1.1147 1.1152
24 cuda MobileBertForMaskedLM 64 0.9569 0.9007 333.3552 131.7379 130.7601 107.5275 1.0135 1.0136
25 cuda MobileBertForQuestionAnswering 128 0.9634 0.8435 331.8111 167.9066 126.0109 106.7049 0.8400 0.8579
26 cuda PLBartForCausalLM 8 1.5155 1.5261 83.1783 68.9224 24.3849 19.5826 0.9886 0.9887
27 cuda PLBartForConditionalGeneration 4 1.4414 1.5298 93.6038 71.2811 52.3630 45.6902 1.0496 1.0495
28 cuda PegasusForCausalLM 32 1.1225 1.2212 79.1829 57.5436 36.5687 33.3863 0.9737 0.9736
29 cuda PegasusForConditionalGeneration 32 1.1506 1.2822 175.4371 106.4678 59.1006 69.8825 1.0686 1.0689
30 cuda RobertaForCausalLM 16 1.5780 1.6128 83.6565 67.5706 33.7543 34.7355 1.0491 1.0496
31 cuda RobertaForQuestionAnswering 16 1.6336 1.6800 66.3454 53.6267 29.9597 33.8527 1.1698 1.1704
32 cuda Speech2Text2ForCausalLM 256 1.5464 1.8230 41.4059 32.9145 25.6908 18.7201 0.8768 0.8760
33 cuda T5ForConditionalGeneration 4 1.2736 1.6592 96.9787 59.5324 54.8479 39.4406 1.1802 1.1814
34 cuda T5Small 4 1.2861 1.6581 98.4766 59.5930 32.1507 37.0471 1.1802 1.1814
35 cuda TrOCRForCausalLM 32 1.2573 1.2586 127.5731 106.2633 36.5153 32.5330 0.9584 0.9583
36 cuda XLNetLMHeadModel 8 1.6924 1.8108 177.0149 142.8795 83.7423 84.8197 1.1026 1.1240
37 cuda YituTechConvBert 16 1.4142 1.5207 107.2519 81.4595 68.8073 53.1565 1.0363 1.0362

View File

@@ -1,54 +1,54 @@
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio
cuda,adv_inception_v3,128,1.5315,126.8003,157.5622,1.0179
cuda,beit_base_patch16_224,64,1.3308,91.0230,35.2707,0.9891
cuda,coat_lite_mini,128,2.0396,63.6890,87.8844,1.0199
cuda,convmixer_768_32,32,1.0443,336.7657,32.3711,0.9999
cuda,convnext_base,64,1.4920,97.2042,76.9539,1.0388
cuda,crossvit_9_240,128,1.2019,87.2138,86.0260,1.0032
cuda,cspdarknet53,64,1.4388,90.4760,88.7691,1.0140
cuda,deit_base_distilled_patch16_224,64,1.2471,81.1739,45.5359,0.9527
cuda,dla102,128,1.5027,147.5660,82.0541,1.0331
cuda,dm_nfnet_f0,128,1.4283,98.4167,67.5993,1.0713
cuda,dpn107,32,1.2356,100.9084,91.1994,0.9651
cuda,eca_botnext26ts_256,128,1.5012,86.6883,79.4803,1.0043
cuda,ese_vovnet19b_dw,128,1.3848,52.7837,47.5720,0.9915
cuda,fbnetc_100,128,1.5364,64.6275,121.7965,0.9548
cuda,gernet_l,128,1.1577,76.1686,56.8719,0.9712
cuda,ghostnet_100,128,1.7220,65.4748,209.2949,1.0223
cuda,gluon_inception_v3,128,1.5343,126.3718,52.4958,1.0466
cuda,gmixer_24_224,128,1.6468,84.2051,51.3789,1.1584
cuda,gmlp_s16_224,128,1.5873,94.6600,58.2645,1.2023
cuda,hrnet_w18,128,1.3546,263.2673,198.3349,0.9923
cuda,inception_v3,128,1.5238,126.7710,51.6400,1.0466
cuda,jx_nest_base,32,1.2384,103.6627,78.7322,0.9607
cuda,lcnet_050,128,1.7771,21.2319,34.1647,0.9458
cuda,mixer_b16_224,128,1.2902,108.2678,27.6445,0.9948
cuda,mixnet_l,128,1.2122,182.1990,118.8907,0.9908
cuda,mnasnet_100,128,1.6199,47.7919,48.7633,0.9408
cuda,mobilenetv2_100,128,1.5661,50.4129,42.9730,1.1166
cuda,mobilenetv3_large_100,128,1.5888,46.3295,51.9647,0.9704
cuda,mobilevit_s,64,1.3031,82.1195,111.3226,1.0065
cuda,nfnet_l0,128,1.4895,81.4742,58.5768,0.9691
cuda,pit_b_224,64,1.4120,97.3046,43.7878,1.0241
cuda,pnasnet5large,16,1.0523,239.1102,145.2293,1.2797
cuda,poolformer_m36,64,1.2154,138.0360,93.0292,1.1927
cuda,regnety_002,128,1.2659,38.8745,68.1799,0.8660
cuda,repvgg_a2,128,1.2185,73.6415,32.3085,0.9735
cuda,res2net101_26w_4s,64,1.0443,116.1409,144.6286,0.9491
cuda,res2net50_14w_8s,128,1.3212,130.1624,102.7642,0.9609
cuda,res2next50,128,1.2159,157.3657,46.9827,0.9756
cuda,resmlp_12_224,128,1.2970,54.6888,40.0312,1.0342
cuda,resnest101e,64,1.4079,134.5610,119.5467,1.0831
cuda,rexnet_100,128,1.4427,65.3909,222.1865,1.0439
cuda,selecsls42b,128,1.4015,53.4159,31.3161,0.9731
cuda,spnasnet_100,128,1.5507,54.3208,34.4102,1.0045
cuda,swin_base_patch4_window7_224,64,1.5038,115.4018,104.8326,0.9043
cuda,swsl_resnext101_32x16d,32,0.9981,136.5238,49.8939,0.9833
cuda,tf_efficientnet_b0,128,1.4894,67.2972,57.0583,1.0725
cuda,tf_mixnet_l,128,1.2179,189.8781,68.3717,1.0676
cuda,tinynet_a,128,1.3548,64.3571,117.1189,1.0718
cuda,tnt_s_patch16_224,128,3.0069,126.5317,67.8712,1.0505
cuda,twins_pcpvt_base,64,1.2016,154.8390,144.5083,1.0541
cuda,visformer_small,128,1.1935,87.3201,42.3853,1.0220
cuda,vit_base_patch16_224,64,1.2207,85.2031,39.3641,0.9551
cuda,volo_d1_224,0,0.0000
cuda,adv_inception_v3,128,1.5923,102.5292,51.6032,1.0472
cuda,beit_base_patch16_224,64,1.3390,75.3027,29.7471,1.0156
cuda,coat_lite_mini,128,2.0579,53.3689,37.1856,1.0437
cuda,convmixer_768_32,32,1.0470,275.5328,23.8037,0.9999
cuda,convnext_base,64,1.5084,80.1811,42.5659,1.0373
cuda,crossvit_9_240,128,1.5392,37.1806,44.9986,0.9193
cuda,cspdarknet53,64,1.4721,75.0403,35.2882,1.0547
cuda,deit_base_distilled_patch16_224,64,1.1432,55.9737,23.4038,0.9816
cuda,dla102,128,1.5282,123.7284,49.3612,1.0430
cuda,dm_nfnet_f0,128,1.4354,79.7518,34.8994,1.1038
cuda,dpn107,32,1.2412,83.8921,58.9111,0.9952
cuda,eca_botnext26ts_256,128,1.5425,71.2406,28.8920,1.0270
cuda,ese_vovnet19b_dw,128,1.4647,42.4837,18.0285,1.0135
cuda,fbnetc_100,128,1.5795,53.8033,33.0222,1.0082
cuda,gernet_l,128,1.1684,63.4230,26.8687,1.0053
cuda,ghostnet_100,128,1.7812,54.4211,47.6168,1.0484
cuda,gluon_inception_v3,128,1.5952,102.5018,50.0857,1.0469
cuda,gmixer_24_224,128,1.6749,69.2430,42.0841,1.1921
cuda,gmlp_s16_224,128,1.5886,79.2132,43.0142,1.2343
cuda,hrnet_w18,128,1.3743,221.5304,134.2573,1.0100
cuda,inception_v3,128,1.5847,102.8333,49.7648,1.0472
cuda,jx_nest_base,32,1.3747,71.4190,61.4053,0.9905
cuda,lcnet_050,128,1.8159,18.0047,18.8249,1.0005
cuda,mixer_b16_224,128,1.2795,90.9229,21.0438,1.0133
cuda,mixnet_l,128,1.2273,149.9722,47.7482,1.0129
cuda,mnasnet_100,128,1.6594,40.0512,26.5165,1.0047
cuda,mobilenetv2_100,128,1.6085,41.1217,27.4450,1.1731
cuda,mobilenetv3_large_100,128,1.6610,37.9995,29.8185,1.0052
cuda,mobilevit_s,64,1.5212,55.4152,53.6475,1.0258
cuda,nfnet_l0,128,1.4927,65.7078,32.4067,0.9980
cuda,pit_b_224,64,1.2286,57.9484,26.5321,0.9606
cuda,pnasnet5large,16,1.0000,198.2494,93.4641,1.3184
cuda,poolformer_m36,64,1.3486,103.9235,62.3196,1.1942
cuda,regnety_002,128,1.3030,32.4968,27.2439,1.0014
cuda,repvgg_a2,128,1.2485,59.7729,26.9209,1.0185
cuda,res2net101_26w_4s,64,1.0813,94.1773,86.6520,0.9655
cuda,res2net50_14w_8s,128,1.3251,109.5258,79.9578,0.9830
cuda,res2next50,128,1.2518,125.5008,43.9754,0.9756
cuda,resmlp_12_224,128,1.3060,45.2373,19.3709,1.1048
cuda,resnest101e,64,1.4346,108.1945,78.1993,1.1037
cuda,rexnet_100,128,1.4637,55.0121,41.2075,1.0862
cuda,selecsls42b,128,1.4284,44.6645,23.3892,1.0139
cuda,spnasnet_100,128,1.5908,45.3189,32.0148,1.0048
cuda,swin_base_patch4_window7_224,64,1.6164,89.5854,75.5848,0.9299
cuda,swsl_resnext101_32x16d,32,1.0175,110.0041,45.7853,1.0003
cuda,tf_efficientnet_b0,128,1.5271,55.7361,34.5551,1.1079
cuda,tf_mixnet_l,128,1.2369,155.9027,48.6695,1.0921
cuda,tinynet_a,128,1.3792,53.0640,40.6346,1.1108
cuda,tnt_s_patch16_224,128,3.1078,104.8486,59.6028,1.0660
cuda,twins_pcpvt_base,64,1.5921,67.4600,84.4977,1.0909
cuda,visformer_small,128,1.1952,72.8705,23.7303,1.0410
cuda,vit_base_patch16_224,64,1.1309,56.4866,22.0208,0.9804
cuda,volo_d1_224,64,1.6868,72.0957,65.3011,0.9729
1 dev name batch_size speedup abs_latency compilation_latency compression_ratio
2 cuda adv_inception_v3 128 1.5315 1.5923 126.8003 102.5292 157.5622 51.6032 1.0179 1.0472
3 cuda beit_base_patch16_224 64 1.3308 1.3390 91.0230 75.3027 35.2707 29.7471 0.9891 1.0156
4 cuda coat_lite_mini 128 2.0396 2.0579 63.6890 53.3689 87.8844 37.1856 1.0199 1.0437
5 cuda convmixer_768_32 32 1.0443 1.0470 336.7657 275.5328 32.3711 23.8037 0.9999
6 cuda convnext_base 64 1.4920 1.5084 97.2042 80.1811 76.9539 42.5659 1.0388 1.0373
7 cuda crossvit_9_240 128 1.2019 1.5392 87.2138 37.1806 86.0260 44.9986 1.0032 0.9193
8 cuda cspdarknet53 64 1.4388 1.4721 90.4760 75.0403 88.7691 35.2882 1.0140 1.0547
9 cuda deit_base_distilled_patch16_224 64 1.2471 1.1432 81.1739 55.9737 45.5359 23.4038 0.9527 0.9816
10 cuda dla102 128 1.5027 1.5282 147.5660 123.7284 82.0541 49.3612 1.0331 1.0430
11 cuda dm_nfnet_f0 128 1.4283 1.4354 98.4167 79.7518 67.5993 34.8994 1.0713 1.1038
12 cuda dpn107 32 1.2356 1.2412 100.9084 83.8921 91.1994 58.9111 0.9651 0.9952
13 cuda eca_botnext26ts_256 128 1.5012 1.5425 86.6883 71.2406 79.4803 28.8920 1.0043 1.0270
14 cuda ese_vovnet19b_dw 128 1.3848 1.4647 52.7837 42.4837 47.5720 18.0285 0.9915 1.0135
15 cuda fbnetc_100 128 1.5364 1.5795 64.6275 53.8033 121.7965 33.0222 0.9548 1.0082
16 cuda gernet_l 128 1.1577 1.1684 76.1686 63.4230 56.8719 26.8687 0.9712 1.0053
17 cuda ghostnet_100 128 1.7220 1.7812 65.4748 54.4211 209.2949 47.6168 1.0223 1.0484
18 cuda gluon_inception_v3 128 1.5343 1.5952 126.3718 102.5018 52.4958 50.0857 1.0466 1.0469
19 cuda gmixer_24_224 128 1.6468 1.6749 84.2051 69.2430 51.3789 42.0841 1.1584 1.1921
20 cuda gmlp_s16_224 128 1.5873 1.5886 94.6600 79.2132 58.2645 43.0142 1.2023 1.2343
21 cuda hrnet_w18 128 1.3546 1.3743 263.2673 221.5304 198.3349 134.2573 0.9923 1.0100
22 cuda inception_v3 128 1.5238 1.5847 126.7710 102.8333 51.6400 49.7648 1.0466 1.0472
23 cuda jx_nest_base 32 1.2384 1.3747 103.6627 71.4190 78.7322 61.4053 0.9607 0.9905
24 cuda lcnet_050 128 1.7771 1.8159 21.2319 18.0047 34.1647 18.8249 0.9458 1.0005
25 cuda mixer_b16_224 128 1.2902 1.2795 108.2678 90.9229 27.6445 21.0438 0.9948 1.0133
26 cuda mixnet_l 128 1.2122 1.2273 182.1990 149.9722 118.8907 47.7482 0.9908 1.0129
27 cuda mnasnet_100 128 1.6199 1.6594 47.7919 40.0512 48.7633 26.5165 0.9408 1.0047
28 cuda mobilenetv2_100 128 1.5661 1.6085 50.4129 41.1217 42.9730 27.4450 1.1166 1.1731
29 cuda mobilenetv3_large_100 128 1.5888 1.6610 46.3295 37.9995 51.9647 29.8185 0.9704 1.0052
30 cuda mobilevit_s 64 1.3031 1.5212 82.1195 55.4152 111.3226 53.6475 1.0065 1.0258
31 cuda nfnet_l0 128 1.4895 1.4927 81.4742 65.7078 58.5768 32.4067 0.9691 0.9980
32 cuda pit_b_224 64 1.4120 1.2286 97.3046 57.9484 43.7878 26.5321 1.0241 0.9606
33 cuda pnasnet5large 16 1.0523 1.0000 239.1102 198.2494 145.2293 93.4641 1.2797 1.3184
34 cuda poolformer_m36 64 1.2154 1.3486 138.0360 103.9235 93.0292 62.3196 1.1927 1.1942
35 cuda regnety_002 128 1.2659 1.3030 38.8745 32.4968 68.1799 27.2439 0.8660 1.0014
36 cuda repvgg_a2 128 1.2185 1.2485 73.6415 59.7729 32.3085 26.9209 0.9735 1.0185
37 cuda res2net101_26w_4s 64 1.0443 1.0813 116.1409 94.1773 144.6286 86.6520 0.9491 0.9655
38 cuda res2net50_14w_8s 128 1.3212 1.3251 130.1624 109.5258 102.7642 79.9578 0.9609 0.9830
39 cuda res2next50 128 1.2159 1.2518 157.3657 125.5008 46.9827 43.9754 0.9756
40 cuda resmlp_12_224 128 1.2970 1.3060 54.6888 45.2373 40.0312 19.3709 1.0342 1.1048
41 cuda resnest101e 64 1.4079 1.4346 134.5610 108.1945 119.5467 78.1993 1.0831 1.1037
42 cuda rexnet_100 128 1.4427 1.4637 65.3909 55.0121 222.1865 41.2075 1.0439 1.0862
43 cuda selecsls42b 128 1.4015 1.4284 53.4159 44.6645 31.3161 23.3892 0.9731 1.0139
44 cuda spnasnet_100 128 1.5507 1.5908 54.3208 45.3189 34.4102 32.0148 1.0045 1.0048
45 cuda swin_base_patch4_window7_224 64 1.5038 1.6164 115.4018 89.5854 104.8326 75.5848 0.9043 0.9299
46 cuda swsl_resnext101_32x16d 32 0.9981 1.0175 136.5238 110.0041 49.8939 45.7853 0.9833 1.0003
47 cuda tf_efficientnet_b0 128 1.4894 1.5271 67.2972 55.7361 57.0583 34.5551 1.0725 1.1079
48 cuda tf_mixnet_l 128 1.2179 1.2369 189.8781 155.9027 68.3717 48.6695 1.0676 1.0921
49 cuda tinynet_a 128 1.3548 1.3792 64.3571 53.0640 117.1189 40.6346 1.0718 1.1108
50 cuda tnt_s_patch16_224 128 3.0069 3.1078 126.5317 104.8486 67.8712 59.6028 1.0505 1.0660
51 cuda twins_pcpvt_base 64 1.2016 1.5921 154.8390 67.4600 144.5083 84.4977 1.0541 1.0909
52 cuda visformer_small 128 1.1935 1.1952 87.3201 72.8705 42.3853 23.7303 1.0220 1.0410
53 cuda vit_base_patch16_224 64 1.2207 1.1309 85.2031 56.4866 39.3641 22.0208 0.9551 0.9804
54 cuda volo_d1_224 0 64 0.0000 1.6868 72.0957 65.3011 0.9729

View File

@@ -1,51 +1,53 @@
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio
cuda,BERT_pytorch,16,1.3446,65.5196,59.4176,1.1679
cuda,LearningToPaint,96,1.0376,12.7398,35.9770,0.7613
cuda,Super_SloMo,6,1.3132,73.6570,39.9019,1.2390
cuda,alexnet,128,1.1653,10.1614,10.5925,0.9408
cuda,attention_is_all_you_need_pytorch,256,1.2514,82.7756,66.7768,1.1459
cuda,dcgan,32,0.8947,2.6217,5.9728,1.0082
cuda,densenet121,4,0.8777,65.7564,123.8633,0.8292
cuda,drq,1,1.0291,4.9372,7.8125,0.9849
cuda,fastNLP_Bert,6,1.5073,73.2808,42.6538,1.1547
cuda,functorch_dp_cifar10,64,1.4043,9.3820,51.7668,0.4986
cuda,functorch_maml_omniglot,1,1.0998,3.0214,12.6407,0.2181
cuda,hf_Albert,8,1.3968,56.3755,38.3806,1.2603
cuda,hf_Bart,4,1.1020,91.5587,53.6117,1.0087
cuda,hf_Bert,4,1.1458,65.0555,43.0593,1.0261
cuda,hf_Bert_large,4,1.1683,131.1134,59.7228,1.0909
cuda,hf_DistilBert,8,1.2007,34.2069,22.8917,1.0228
cuda,hf_GPT2,4,1.2689,52.0180,31.7956,1.1540
cuda,BERT_pytorch,16,1.7111,24.2741,35.7065,1.3212
cuda,LearningToPaint,96,1.0513,10.7557,11.1879,0.9896
cuda,Super_SloMo,6,1.3267,60.4328,28.2097,1.2392
cuda,alexnet,128,1.1754,8.3246,5.3319,1.0003
cuda,attention_is_all_you_need_pytorch,256,1.3416,36.4401,39.5927,1.1774
cuda,dcgan,32,0.9151,2.6249,3.2964,1.0082
cuda,densenet121,4,0.9225,51.3747,68.5841,0.9930
cuda,doctr_det_predictor,0,0.0000
cuda,doctr_reco_predictor,0,0.0000
cuda,drq,1,0.9500,3.4884,4.8028,0.9687
cuda,fastNLP_Bert,6,1.4328,34.7753,35.4863,1.2368
cuda,functorch_dp_cifar10,64,1.2015,8.1625,12.9040,1.0609
cuda,functorch_maml_omniglot,1,0.9322,2.5844,3.8640,1.0000
cuda,hf_Albert,8,2.1228,30.3377,26.8282,1.2676
cuda,hf_Bart,4,1.2899,39.1935,47.2373,1.0080
cuda,hf_Bert,4,1.3262,26.1063,35.0281,1.0656
cuda,hf_Bert_large,4,1.4163,55.1021,67.2825,1.0915
cuda,hf_DistilBert,8,1.4051,21.7191,18.0399,1.0242
cuda,hf_GPT2,4,1.6661,26.9039,29.9473,1.1555
cuda,hf_Longformer,0,0.0000
cuda,hf_Reformer,4,1.0903,82.9519,28.0343,0.9289
cuda,hf_T5_large,2,1.3534,332.3302,172.6140,1.1666
cuda,lennard_jones,1000,0.9952,3.8690,4.8521,1.0000
cuda,maml_omniglot,32,1.0328,3.3367,8.2772,0.2181
cuda,mnasnet1_0,32,1.0162,25.4638,69.8684,0.8356
cuda,mobilenet_v2,96,1.5212,38.4276,100.4918,1.1011
cuda,nvidia_deeprecommender,256,1.0517,11.1245,7.3804,0.9715
cuda,phlippe_densenet,128,1.0043,33.4096,108.0736,0.8774
cuda,phlippe_resnet,128,1.0229,14.0998,21.7420,0.4147
cuda,pytorch_CycleGAN_and_pix2pix,1,1.3815,9.3944,32.3602,0.6135
cuda,pytorch_stargan,16,1.1625,14.4103,41.3705,0.8893
cuda,pytorch_unet,1,1.3638,35.7120,51.2342,0.9525
cuda,resnet152,32,0.9568,76.3876,70.2073,0.9997
cuda,resnet18,16,0.9193,12.1360,23.4287,0.6492
cuda,resnet50,32,1.0230,29.6914,26.1574,1.0010
cuda,resnext50_32x4d,8,0.8679,25.7775,39.3170,0.8524
cuda,shufflenet_v2_x1_0,128,1.1374,31.2127,62.0057,0.9590
cuda,soft_actor_critic,256,0.9754,3.1737,5.5626,0.9998
cuda,speech_transformer,32,1.1390,94.3465,74.5561,0.8732
cuda,squeezenet1_1,32,1.1572,9.2585,19.0393,0.9243
cuda,timm_efficientdet,1,1.3338,95.3918,255.9148,1.0310
cuda,timm_efficientnet,32,1.1237,34.3466,80.1230,0.9445
cuda,timm_nfnet,128,1.4441,95.5148,36.3090,1.1050
cuda,timm_regnet,32,1.0374,65.3419,57.6930,0.9528
cuda,timm_resnest,32,1.5878,18.2585,54.0304,0.9636
cuda,timm_vision_transformer,8,1.0850,51.7360,50.3927,0.7429
cuda,timm_vision_transformer_large,0,0.0000
cuda,timm_vovnet,32,1.1318,27.3068,27.8668,0.8884
cuda,hf_Reformer,4,1.1709,64.6979,15.7035,0.9267
cuda,hf_T5_large,2,1.7215,107.0798,148.8805,1.1684
cuda,lennard_jones,1000,0.8428,1.8488,3.0609,1.0001
cuda,maml_omniglot,32,0.9648,2.6869,3.9775,0.9999
cuda,mnasnet1_0,32,1.0469,21.6251,25.8232,0.9996
cuda,mobilenet_v2,96,1.5604,31.9572,27.0225,1.1734
cuda,nvidia_deeprecommender,256,1.0605,9.2080,4.1318,0.9711
cuda,phlippe_densenet,128,1.0237,27.5988,28.0400,1.0023
cuda,phlippe_resnet,128,1.0493,10.9751,10.2485,1.0092
cuda,pytorch_CycleGAN_and_pix2pix,1,1.3724,8.2225,11.9561,1.0219
cuda,pytorch_stargan,16,1.1835,11.9178,10.0507,1.0868
cuda,pytorch_unet,1,1.3787,29.7543,13.7711,1.0100
cuda,resnet152,32,0.9834,63.2446,67.7935,0.9991
cuda,resnet18,16,0.9451,9.4977,11.7663,0.9948
cuda,resnet50,32,1.0513,24.5141,24.6629,1.0021
cuda,resnext50_32x4d,8,0.9216,22.2460,24.3420,0.9984
cuda,shufflenet_v2_x1_0,128,1.1943,25.4520,28.8611,1.0951
cuda,soft_actor_critic,256,0.8691,1.9637,3.3716,0.9996
cuda,speech_transformer,32,1.2718,35.2922,46.9957,1.0897
cuda,squeezenet1_1,32,1.1302,8.4540,7.9625,1.0771
cuda,timm_efficientdet,1,1.3370,80.0377,120.1814,1.2713
cuda,timm_efficientnet,32,1.1874,27.6302,33.9059,1.0971
cuda,timm_nfnet,128,1.4525,77.3461,34.3270,1.1056
cuda,timm_regnet,32,1.0644,50.6953,35.7562,1.0000
cuda,timm_resnest,32,1.6200,14.7763,17.2245,1.0906
cuda,timm_vision_transformer,32,1.0800,19.4188,22.0255,0.9966
cuda,timm_vision_transformer_large,32,1.0081,393.1742,127.8083,0.9735
cuda,timm_vovnet,32,1.1472,22.4727,22.7328,1.0120
cuda,torchrec_dlrm,0,0.0000
cuda,tts_angular,64,0.8185,10.2896,5.1774,1.0015
cuda,vgg16,64,1.2931,61.1714,10.9558,0.9828
cuda,yolov3,16,1.2202,68.8346,86.5149,1.0437
cuda,tts_angular,64,0.8974,6.5057,2.5555,0.9973
cuda,vgg16,64,1.2909,50.7405,6.1510,0.9828
cuda,yolov3,16,1.2930,54.8069,41.9269,1.0563
1 dev dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio name batch_size speedup abs_latency compilation_latency compression_ratio
2 cuda cuda,BERT_pytorch,16,1.7111,24.2741,35.7065,1.3212 BERT_pytorch 16 1.3446 65.5196 59.4176 1.1679
3 cuda cuda,LearningToPaint,96,1.0513,10.7557,11.1879,0.9896 LearningToPaint 96 1.0376 12.7398 35.9770 0.7613
4 cuda cuda,Super_SloMo,6,1.3267,60.4328,28.2097,1.2392 Super_SloMo 6 1.3132 73.6570 39.9019 1.2390
5 cuda cuda,alexnet,128,1.1754,8.3246,5.3319,1.0003 alexnet 128 1.1653 10.1614 10.5925 0.9408
6 cuda cuda,attention_is_all_you_need_pytorch,256,1.3416,36.4401,39.5927,1.1774 attention_is_all_you_need_pytorch 256 1.2514 82.7756 66.7768 1.1459
7 cuda cuda,dcgan,32,0.9151,2.6249,3.2964,1.0082 dcgan 32 0.8947 2.6217 5.9728 1.0082
8 cuda cuda,densenet121,4,0.9225,51.3747,68.5841,0.9930 densenet121 4 0.8777 65.7564 123.8633 0.8292
9 cuda cuda,doctr_det_predictor,0,0.0000 drq 1 1.0291 4.9372 7.8125 0.9849
10 cuda cuda,doctr_reco_predictor,0,0.0000 fastNLP_Bert 6 1.5073 73.2808 42.6538 1.1547
11 cuda cuda,drq,1,0.9500,3.4884,4.8028,0.9687 functorch_dp_cifar10 64 1.4043 9.3820 51.7668 0.4986
12 cuda cuda,fastNLP_Bert,6,1.4328,34.7753,35.4863,1.2368 functorch_maml_omniglot 1 1.0998 3.0214 12.6407 0.2181
13 cuda cuda,functorch_dp_cifar10,64,1.2015,8.1625,12.9040,1.0609 hf_Albert 8 1.3968 56.3755 38.3806 1.2603
14 cuda cuda,functorch_maml_omniglot,1,0.9322,2.5844,3.8640,1.0000 hf_Bart 4 1.1020 91.5587 53.6117 1.0087
15 cuda cuda,hf_Albert,8,2.1228,30.3377,26.8282,1.2676 hf_Bert 4 1.1458 65.0555 43.0593 1.0261
16 cuda cuda,hf_Bart,4,1.2899,39.1935,47.2373,1.0080 hf_Bert_large 4 1.1683 131.1134 59.7228 1.0909
17 cuda cuda,hf_Bert,4,1.3262,26.1063,35.0281,1.0656 hf_DistilBert 8 1.2007 34.2069 22.8917 1.0228
18 cuda cuda,hf_Bert_large,4,1.4163,55.1021,67.2825,1.0915 hf_GPT2 4 1.2689 52.0180 31.7956 1.1540
19 cuda,hf_DistilBert,8,1.4051,21.7191,18.0399,1.0242
20 cuda,hf_GPT2,4,1.6661,26.9039,29.9473,1.1555
21 cuda cuda,hf_Longformer,0,0.0000 hf_Longformer 0 0.0000
22 cuda cuda,hf_Reformer,4,1.1709,64.6979,15.7035,0.9267 hf_Reformer 4 1.0903 82.9519 28.0343 0.9289
23 cuda cuda,hf_T5_large,2,1.7215,107.0798,148.8805,1.1684 hf_T5_large 2 1.3534 332.3302 172.6140 1.1666
24 cuda cuda,lennard_jones,1000,0.8428,1.8488,3.0609,1.0001 lennard_jones 1000 0.9952 3.8690 4.8521 1.0000
25 cuda cuda,maml_omniglot,32,0.9648,2.6869,3.9775,0.9999 maml_omniglot 32 1.0328 3.3367 8.2772 0.2181
26 cuda cuda,mnasnet1_0,32,1.0469,21.6251,25.8232,0.9996 mnasnet1_0 32 1.0162 25.4638 69.8684 0.8356
27 cuda cuda,mobilenet_v2,96,1.5604,31.9572,27.0225,1.1734 mobilenet_v2 96 1.5212 38.4276 100.4918 1.1011
28 cuda cuda,nvidia_deeprecommender,256,1.0605,9.2080,4.1318,0.9711 nvidia_deeprecommender 256 1.0517 11.1245 7.3804 0.9715
29 cuda cuda,phlippe_densenet,128,1.0237,27.5988,28.0400,1.0023 phlippe_densenet 128 1.0043 33.4096 108.0736 0.8774
30 cuda cuda,phlippe_resnet,128,1.0493,10.9751,10.2485,1.0092 phlippe_resnet 128 1.0229 14.0998 21.7420 0.4147
31 cuda cuda,pytorch_CycleGAN_and_pix2pix,1,1.3724,8.2225,11.9561,1.0219 pytorch_CycleGAN_and_pix2pix 1 1.3815 9.3944 32.3602 0.6135
32 cuda cuda,pytorch_stargan,16,1.1835,11.9178,10.0507,1.0868 pytorch_stargan 16 1.1625 14.4103 41.3705 0.8893
33 cuda cuda,pytorch_unet,1,1.3787,29.7543,13.7711,1.0100 pytorch_unet 1 1.3638 35.7120 51.2342 0.9525
34 cuda cuda,resnet152,32,0.9834,63.2446,67.7935,0.9991 resnet152 32 0.9568 76.3876 70.2073 0.9997
35 cuda cuda,resnet18,16,0.9451,9.4977,11.7663,0.9948 resnet18 16 0.9193 12.1360 23.4287 0.6492
36 cuda cuda,resnet50,32,1.0513,24.5141,24.6629,1.0021 resnet50 32 1.0230 29.6914 26.1574 1.0010
37 cuda cuda,resnext50_32x4d,8,0.9216,22.2460,24.3420,0.9984 resnext50_32x4d 8 0.8679 25.7775 39.3170 0.8524
38 cuda cuda,shufflenet_v2_x1_0,128,1.1943,25.4520,28.8611,1.0951 shufflenet_v2_x1_0 128 1.1374 31.2127 62.0057 0.9590
39 cuda cuda,soft_actor_critic,256,0.8691,1.9637,3.3716,0.9996 soft_actor_critic 256 0.9754 3.1737 5.5626 0.9998
40 cuda cuda,speech_transformer,32,1.2718,35.2922,46.9957,1.0897 speech_transformer 32 1.1390 94.3465 74.5561 0.8732
41 cuda cuda,squeezenet1_1,32,1.1302,8.4540,7.9625,1.0771 squeezenet1_1 32 1.1572 9.2585 19.0393 0.9243
42 cuda cuda,timm_efficientdet,1,1.3370,80.0377,120.1814,1.2713 timm_efficientdet 1 1.3338 95.3918 255.9148 1.0310
43 cuda cuda,timm_efficientnet,32,1.1874,27.6302,33.9059,1.0971 timm_efficientnet 32 1.1237 34.3466 80.1230 0.9445
44 cuda cuda,timm_nfnet,128,1.4525,77.3461,34.3270,1.1056 timm_nfnet 128 1.4441 95.5148 36.3090 1.1050
45 cuda cuda,timm_regnet,32,1.0644,50.6953,35.7562,1.0000 timm_regnet 32 1.0374 65.3419 57.6930 0.9528
46 cuda cuda,timm_resnest,32,1.6200,14.7763,17.2245,1.0906 timm_resnest 32 1.5878 18.2585 54.0304 0.9636
47 cuda cuda,timm_vision_transformer,32,1.0800,19.4188,22.0255,0.9966 timm_vision_transformer 8 1.0850 51.7360 50.3927 0.7429
48 cuda cuda,timm_vision_transformer_large,32,1.0081,393.1742,127.8083,0.9735 timm_vision_transformer_large 0 0.0000
49 cuda cuda,timm_vovnet,32,1.1472,22.4727,22.7328,1.0120 timm_vovnet 32 1.1318 27.3068 27.8668 0.8884
50 cuda cuda,torchrec_dlrm,0,0.0000 torchrec_dlrm 0 0.0000
51 cuda cuda,tts_angular,64,0.8974,6.5057,2.5555,0.9973 tts_angular 64 0.8185 10.2896 5.1774 1.0015
52 cuda cuda,vgg16,64,1.2909,50.7405,6.1510,0.9828 vgg16 64 1.2931 61.1714 10.9558 0.9828
53 cuda cuda,yolov3,16,1.2930,54.8069,41.9269,1.0563 yolov3 16 1.2202 68.8346 86.5149 1.0437

View File

@@ -4,7 +4,7 @@ from collections import namedtuple
# Create a named tuple for the output of the benchmark
BenchmarkOutput = namedtuple(
'BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup'])
'BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency'])
def parse_output(file_path: str) -> dict:
@@ -12,39 +12,52 @@ def parse_output(file_path: str) -> dict:
with open(file_path) as f:
reader = csv.reader(f)
for i, row in enumerate(reader):
if i == 0:
if i == 0 or len(row) < 5:
continue
dev = row[0]
name = row[1]
batch_size = row[2]
speedup = float(row[3])
entries[name] = BenchmarkOutput(dev, name, batch_size, speedup)
latency = float(row[4])
entries[name] = BenchmarkOutput(dev, name, batch_size, speedup, latency)
return entries
def compare(baseline: dict, new: dict, threshold: float) -> bool:
def compare(baseline: dict, new: dict, threshold: float, geomean_threshold: float) -> bool:
baseline_geomean = 1.0
new_geomean = 1.0
for key in new:
if key not in baseline:
print(f"New benchmark {key} not found in baseline")
baseline_speedup = baseline[key].speedup
new_speedup = new[key].speedup
if new_speedup < baseline_speedup * (1 - threshold):
baseline_latency = baseline[key].latency
new_latency = new[key].latency
if new_latency < baseline_latency * (1 - threshold):
print(
f"New benchmark {key} is slower than baseline: {new_speedup} vs {baseline_speedup}")
elif new_speedup > baseline_speedup * (1 + threshold):
f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}")
elif new_latency > baseline_latency * (1 + threshold):
print(
f"New benchmark {key} is faster than baseline: {new_speedup} vs {baseline_speedup}")
f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}")
baseline_geomean *= baseline[key].speedup
new_geomean *= new[key].speedup
baseline_geomean = baseline_geomean ** (1 / len(baseline))
new_geomean = new_geomean ** (1 / len(new))
print(f"Baseline geomean: {baseline_geomean}")
print(f"New geomean: {new_geomean}")
assert new_geomean > baseline_geomean * (1 - geomean_threshold), \
f"New geomean is slower than baseline: {new_geomean} vs {baseline_geomean}"
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--baseline', required=True)
parser.add_argument('--new', required=True)
parser.add_argument('--threshold', type=float, default=0.02)
parser.add_argument('--threshold', type=float, default=0.1)
parser.add_argument('--geomean-threshold', type=float, default=0.02)
args = parser.parse_args()
baseline = parse_output(args.baseline)
new = parse_output(args.new)
compare(baseline, new, args.threshold)
compare(baseline, new, args.threshold, args.geomean_threshold)
main()

View File

@@ -27,6 +27,11 @@ cd "$ROOT" || exit
for model in "${MODELS[@]}"; do
echo "Checking performance test for $model"
python3 "$INDUCTOR"/scripts/check_perf.py --new "$TEST_REPORTS_DIR"/"$model".csv --baseline "$INDUCTOR"/data/"$model".csv
EXIT_STATUS=$?
if [ "$EXIT_STATUS" -ne 0 ]; then
echo "Performance test for $model failed"
exit "$EXIT_STATUS"
fi
done
# unlock GPU clocks

View File

@@ -1,6 +1,8 @@
cmake_minimum_required(VERSION 3.6)
cmake_minimum_required(VERSION 3.18)
if(POLICY CMP0116)
# Introduced in cmake 3.20
# https://cmake.org/cmake/help/latest/policy/CMP0116.html
cmake_policy(SET CMP0116 OLD)
endif()
@@ -12,6 +14,7 @@ set(CMAKE_INCLUDE_CURRENT_DIR ON)
project(triton)
include(CTest)
if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()
@@ -24,7 +27,7 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
set(TRITON_USE_ROCM ON)
# Ensure Python3 vars are set correctly
# used conditionally in this file and by lit tests
# used conditionally in this file and by lit tests
# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
@@ -37,7 +40,7 @@ if(NOT CMAKE_BUILD_TYPE)
endif()
if(NOT WIN32)
find_library(TERMINFO_LIBRARY tinfo)
find_library(TERMINFO_LIBRARY tinfo)
endif()
# Compiler flags
@@ -47,9 +50,9 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${PYBIND11_INCLUDE_DIR})
if(WIN32)
SET(BUILD_SHARED_LIBS OFF)
find_package(dlfcn-win32 REQUIRED)
set(CMAKE_DL_LIBS dlfcn-win32::dl)
SET(BUILD_SHARED_LIBS OFF)
find_package(dlfcn-win32 REQUIRED)
set(CMAKE_DL_LIBS dlfcn-win32::dl)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
@@ -62,12 +65,10 @@ if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
endif()
##########
# #########
# LLVM
##########
if (NOT MLIR_DIR)
# #########
if(NOT MLIR_DIR)
if(NOT LLVM_LIBRARY_DIR)
if(WIN32)
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
@@ -83,12 +84,16 @@ if (NOT MLIR_DIR)
else()
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
endif()
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
# FindLLVM outputs LLVM_LIBRARY_DIRS but we expect LLVM_LIBRARY_DIR here
set(LLVM_LIBRARY_DIR ${LLVM_LIBRARY_DIRS})
if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
endif()
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
@@ -148,37 +153,38 @@ if (NOT MLIR_DIR)
libLLVMAnalysis.a
)
endif()
set (MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
set(MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
endif()
# Python module
if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
include_directories("." ${PYTHON_SRC_PATH})
if (PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
else()
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS})
endif()
message(STATUS "Adding Python module")
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
include_directories("." ${PYTHON_SRC_PATH})
if(PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
else()
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS})
endif()
endif()
# # Triton
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# else()
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# endif()
# MLIR
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})
@@ -196,14 +202,13 @@ include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
# link_directories(${LLVM_LIBRARY_DIR})
# link_directories(${LLVM_LIBRARY_DIR})
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(bin)
# find_package(PythonLibs REQUIRED)
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
@@ -221,6 +226,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
TritonHSACO
${dialect_libs}
${conversion_libs}
# optimizations
MLIRPass
MLIRTransforms
@@ -233,6 +239,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
MLIRROCDLToLLVMIRTranslation
MLIRIR
)
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
${TRITON_LIBRARIES}
@@ -246,21 +253,23 @@ if(TRITON_BUILD_PYTHON_MODULE)
${TRITON_LIBRARIES}
)
endif()
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
endif()
if (UNIX AND NOT APPLE)
if(UNIX AND NOT APPLE)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL")
endif()
if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
# Check if the platform is MacOS
if(APPLE)
set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto")
endif()
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
# Check if the platform is MacOS
if(APPLE)
set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto")
endif()
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
endif()
add_subdirectory(test)

View File

@@ -25,7 +25,7 @@ You can install the latest stable release of Triton from pip:
```bash
pip install triton
```
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
Binary wheels are available for CPython 3.6-3.11 and PyPy 3.7-3.9.
And the latest nightly release:
@@ -66,12 +66,11 @@ pip install -e .
# Changelog
Version 1.1 is out! New features include:
Version 2.0 is out! New features include:
- Many, many bugfixes
- More documentation
- Automatic on-disk caching of compiled binary objects
- Random Number Generation
- Faster (up to 2x on A100), cleaner blocksparse ops
- Performance improvements
- Backend rewritten to use MLIR
- Support for kernels that contain back-to-back matmuls (e.g., flash attention)
# Contributing
@@ -88,7 +87,3 @@ Supported Platforms:
Supported Hardware:
* NVIDIA GPUs (Compute Capability 7.0+)
* Under development: AMD GPUs, CPUs
# Disclaimer
Triton is a fairly recent project, and it is under active development. We expect it to be pretty useful in a wide variety of cases, but don't be surprised if it's a bit rough around the edges :)

View File

@@ -1,7 +1,3 @@
add_subdirectory(FileCheck)
# add_llvm_executable(FileCheck FileCheck/FileCheck.cpp)
# target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)

View File

@@ -1,2 +0,0 @@
add_llvm_executable(FileCheck FileCheck.cpp)
target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)

View File

@@ -1,885 +0,0 @@
//===- FileCheck.cpp - Check that File's Contents match what is expected --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// FileCheck does a line-by line check of a file that validates whether it
// contains the expected content. This is useful for regression tests etc.
//
// This program exits with an exit status of 2 on error, exit status of 0 if
// the file matched the expected contents, and exit status of 1 if it did not
// contain the expected contents.
//
//===----------------------------------------------------------------------===//
#include "llvm/FileCheck/FileCheck.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
#include <map>
using namespace llvm;
static cl::extrahelp FileCheckOptsEnv(
"\nOptions are parsed from the environment variable FILECHECK_OPTS and\n"
"from the command line.\n");
static cl::opt<std::string>
CheckFilename(cl::Positional, cl::desc("<check-file>"), cl::Optional);
static cl::opt<std::string>
InputFilename("input-file", cl::desc("File to check (defaults to stdin)"),
cl::init("-"), cl::value_desc("filename"));
static cl::list<std::string> CheckPrefixes(
"check-prefix",
cl::desc("Prefix to use from check file (defaults to 'CHECK')"));
static cl::alias CheckPrefixesAlias(
"check-prefixes", cl::aliasopt(CheckPrefixes), cl::CommaSeparated,
cl::NotHidden,
cl::desc(
"Alias for -check-prefix permitting multiple comma separated values"));
static cl::list<std::string> CommentPrefixes(
"comment-prefixes", cl::CommaSeparated, cl::Hidden,
cl::desc("Comma-separated list of comment prefixes to use from check file\n"
"(defaults to 'COM,RUN'). Please avoid using this feature in\n"
"LLVM's LIT-based test suites, which should be easier to\n"
"maintain if they all follow a consistent comment style. This\n"
"feature is meant for non-LIT test suites using FileCheck."));
static cl::opt<bool> NoCanonicalizeWhiteSpace(
"strict-whitespace",
cl::desc("Do not treat all horizontal whitespace as equivalent"));
static cl::opt<bool> IgnoreCase("ignore-case",
cl::desc("Use case-insensitive matching"));
static cl::list<std::string> ImplicitCheckNot(
"implicit-check-not",
cl::desc("Add an implicit negative check with this pattern to every\n"
"positive check. This can be used to ensure that no instances of\n"
"this pattern occur which are not matched by a positive pattern"),
cl::value_desc("pattern"));
static cl::list<std::string>
GlobalDefines("D", cl::AlwaysPrefix,
cl::desc("Define a variable to be used in capture patterns."),
cl::value_desc("VAR=VALUE"));
static cl::opt<bool> AllowEmptyInput(
"allow-empty", cl::init(false),
cl::desc("Allow the input file to be empty. This is useful when making\n"
"checks that some error message does not occur, for example."));
static cl::opt<bool> AllowUnusedPrefixes(
"allow-unused-prefixes", cl::init(false), cl::ZeroOrMore,
cl::desc("Allow prefixes to be specified but not appear in the test."));
static cl::opt<bool> MatchFullLines(
"match-full-lines", cl::init(false),
cl::desc("Require all positive matches to cover an entire input line.\n"
"Allows leading and trailing whitespace if --strict-whitespace\n"
"is not also passed."));
static cl::opt<bool> EnableVarScope(
"enable-var-scope", cl::init(false),
cl::desc("Enables scope for regex variables. Variables with names that\n"
"do not start with '$' will be reset at the beginning of\n"
"each CHECK-LABEL block."));
static cl::opt<bool> AllowDeprecatedDagOverlap(
"allow-deprecated-dag-overlap", cl::init(false),
cl::desc("Enable overlapping among matches in a group of consecutive\n"
"CHECK-DAG directives. This option is deprecated and is only\n"
"provided for convenience as old tests are migrated to the new\n"
"non-overlapping CHECK-DAG implementation.\n"));
static cl::opt<bool> Verbose(
"v", cl::init(false), cl::ZeroOrMore,
cl::desc("Print directive pattern matches, or add them to the input dump\n"
"if enabled.\n"));
static cl::opt<bool> VerboseVerbose(
"vv", cl::init(false), cl::ZeroOrMore,
cl::desc("Print information helpful in diagnosing internal FileCheck\n"
"issues, or add it to the input dump if enabled. Implies\n"
"-v.\n"));
// The order of DumpInputValue members affects their precedence, as documented
// for -dump-input below.
enum DumpInputValue {
DumpInputNever,
DumpInputFail,
DumpInputAlways,
DumpInputHelp
};
static cl::list<DumpInputValue> DumpInputs(
"dump-input",
cl::desc("Dump input to stderr, adding annotations representing\n"
"currently enabled diagnostics. When there are multiple\n"
"occurrences of this option, the <value> that appears earliest\n"
"in the list below has precedence. The default is 'fail'.\n"),
cl::value_desc("mode"),
cl::values(clEnumValN(DumpInputHelp, "help", "Explain input dump and quit"),
clEnumValN(DumpInputAlways, "always", "Always dump input"),
clEnumValN(DumpInputFail, "fail", "Dump input on failure"),
clEnumValN(DumpInputNever, "never", "Never dump input")));
// The order of DumpInputFilterValue members affects their precedence, as
// documented for -dump-input-filter below.
enum DumpInputFilterValue {
DumpInputFilterError,
DumpInputFilterAnnotation,
DumpInputFilterAnnotationFull,
DumpInputFilterAll
};
static cl::list<DumpInputFilterValue> DumpInputFilters(
"dump-input-filter",
cl::desc("In the dump requested by -dump-input, print only input lines of\n"
"kind <value> plus any context specified by -dump-input-context.\n"
"When there are multiple occurrences of this option, the <value>\n"
"that appears earliest in the list below has precedence. The\n"
"default is 'error' when -dump-input=fail, and it's 'all' when\n"
"-dump-input=always.\n"),
cl::values(clEnumValN(DumpInputFilterAll, "all", "All input lines"),
clEnumValN(DumpInputFilterAnnotationFull, "annotation-full",
"Input lines with annotations"),
clEnumValN(DumpInputFilterAnnotation, "annotation",
"Input lines with starting points of annotations"),
clEnumValN(DumpInputFilterError, "error",
"Input lines with starting points of error "
"annotations")));
static cl::list<unsigned> DumpInputContexts(
"dump-input-context", cl::value_desc("N"),
cl::desc("In the dump requested by -dump-input, print <N> input lines\n"
"before and <N> input lines after any lines specified by\n"
"-dump-input-filter. When there are multiple occurrences of\n"
"this option, the largest specified <N> has precedence. The\n"
"default is 5.\n"));
typedef cl::list<std::string>::const_iterator prefix_iterator;
static void DumpCommandLine(int argc, char **argv) {
errs() << "FileCheck command line: ";
for (int I = 0; I < argc; I++)
errs() << " " << argv[I];
errs() << "\n";
}
struct MarkerStyle {
/// The starting char (before tildes) for marking the line.
char Lead;
/// What color to use for this annotation.
raw_ostream::Colors Color;
/// A note to follow the marker, or empty string if none.
std::string Note;
/// Does this marker indicate inclusion by -dump-input-filter=error?
bool FiltersAsError;
MarkerStyle() {}
MarkerStyle(char Lead, raw_ostream::Colors Color,
const std::string &Note = "", bool FiltersAsError = false)
: Lead(Lead), Color(Color), Note(Note), FiltersAsError(FiltersAsError) {
assert((!FiltersAsError || !Note.empty()) &&
"expected error diagnostic to have note");
}
};
static MarkerStyle GetMarker(FileCheckDiag::MatchType MatchTy) {
switch (MatchTy) {
case FileCheckDiag::MatchFoundAndExpected:
return MarkerStyle('^', raw_ostream::GREEN);
case FileCheckDiag::MatchFoundButExcluded:
return MarkerStyle('!', raw_ostream::RED, "error: no match expected",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFoundButWrongLine:
return MarkerStyle('!', raw_ostream::RED, "error: match on wrong line",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFoundButDiscarded:
return MarkerStyle('!', raw_ostream::CYAN,
"discard: overlaps earlier match");
case FileCheckDiag::MatchFoundErrorNote:
// Note should always be overridden within the FileCheckDiag.
return MarkerStyle('!', raw_ostream::RED,
"error: unknown error after match",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchNoneAndExcluded:
return MarkerStyle('X', raw_ostream::GREEN);
case FileCheckDiag::MatchNoneButExpected:
return MarkerStyle('X', raw_ostream::RED, "error: no match found",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchNoneForInvalidPattern:
return MarkerStyle('X', raw_ostream::RED,
"error: match failed for invalid pattern",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFuzzy:
return MarkerStyle('?', raw_ostream::MAGENTA, "possible intended match",
/*FiltersAsError=*/true);
}
llvm_unreachable_internal("unexpected match type");
}
static void DumpInputAnnotationHelp(raw_ostream &OS) {
OS << "The following description was requested by -dump-input=help to\n"
<< "explain the input dump printed by FileCheck.\n"
<< "\n"
<< "Related command-line options:\n"
<< "\n"
<< " - -dump-input=<value> enables or disables the input dump\n"
<< " - -dump-input-filter=<value> filters the input lines\n"
<< " - -dump-input-context=<N> adjusts the context of filtered lines\n"
<< " - -v and -vv add more annotations\n"
<< " - -color forces colors to be enabled both in the dump and below\n"
<< " - -help documents the above options in more detail\n"
<< "\n"
<< "These options can also be set via FILECHECK_OPTS. For example, for\n"
<< "maximum debugging output on failures:\n"
<< "\n"
<< " $ FILECHECK_OPTS='-dump-input-filter=all -vv -color' ninja check\n"
<< "\n"
<< "Input dump annotation format:\n"
<< "\n";
// Labels for input lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "L:";
OS << " labels line number L of the input file\n"
<< " An extra space is added after each input line to represent"
<< " the\n"
<< " newline character\n";
// Labels for annotation lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L";
OS << " labels the only match result for either (1) a pattern of type T"
<< " from\n"
<< " line L of the check file if L is an integer or (2) the"
<< " I-th implicit\n"
<< " pattern if L is \"imp\" followed by an integer "
<< "I (index origin one)\n";
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L'N";
OS << " labels the Nth match result for such a pattern\n";
// Markers on annotation lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "^~~";
OS << " marks good match (reported if -v)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "!~~";
OS << " marks bad match, such as:\n"
<< " - CHECK-NEXT on same line as previous match (error)\n"
<< " - CHECK-NOT found (error)\n"
<< " - CHECK-DAG overlapping match (discarded, reported if "
<< "-vv)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "X~~";
OS << " marks search range when no match is found, such as:\n"
<< " - CHECK-NEXT not found (error)\n"
<< " - CHECK-NOT not found (success, reported if -vv)\n"
<< " - CHECK-DAG not found after discarded matches (error)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "?";
OS << " marks fuzzy match when no match is found\n";
// Elided lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "...";
OS << " indicates elided input lines and annotations, as specified by\n"
<< " -dump-input-filter and -dump-input-context\n";
// Colors.
OS << " - colors ";
WithColor(OS, raw_ostream::GREEN, true) << "success";
OS << ", ";
WithColor(OS, raw_ostream::RED, true) << "error";
OS << ", ";
WithColor(OS, raw_ostream::MAGENTA, true) << "fuzzy match";
OS << ", ";
WithColor(OS, raw_ostream::CYAN, true, false) << "discarded match";
OS << ", ";
WithColor(OS, raw_ostream::CYAN, true, true) << "unmatched input";
OS << "\n";
}
/// An annotation for a single input line.
struct InputAnnotation {
/// The index of the match result across all checks
unsigned DiagIndex;
/// The label for this annotation.
std::string Label;
/// Is this the initial fragment of a diagnostic that has been broken across
/// multiple lines?
bool IsFirstLine;
/// What input line (one-origin indexing) this annotation marks. This might
/// be different from the starting line of the original diagnostic if
/// !IsFirstLine.
unsigned InputLine;
/// The column range (one-origin indexing, open end) in which to mark the
/// input line. If InputEndCol is UINT_MAX, treat it as the last column
/// before the newline.
unsigned InputStartCol, InputEndCol;
/// The marker to use.
MarkerStyle Marker;
/// Whether this annotation represents a good match for an expected pattern.
bool FoundAndExpectedMatch;
};
/// Get an abbreviation for the check type.
static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
switch (Ty) {
case Check::CheckPlain:
if (Ty.getCount() > 1)
return "count";
return "check";
case Check::CheckNext:
return "next";
case Check::CheckSame:
return "same";
case Check::CheckNot:
return "not";
case Check::CheckDAG:
return "dag";
case Check::CheckLabel:
return "label";
case Check::CheckEmpty:
return "empty";
case Check::CheckComment:
return "com";
case Check::CheckEOF:
return "eof";
case Check::CheckBadNot:
return "bad-not";
case Check::CheckBadCount:
return "bad-count";
case Check::CheckMisspelled:
return "misspelled";
case Check::CheckNone:
llvm_unreachable("invalid FileCheckType");
}
llvm_unreachable("unknown FileCheckType");
}
static void
BuildInputAnnotations(const SourceMgr &SM, unsigned CheckFileBufferID,
const std::pair<unsigned, unsigned> &ImpPatBufferIDRange,
const std::vector<FileCheckDiag> &Diags,
std::vector<InputAnnotation> &Annotations,
unsigned &LabelWidth) {
struct CompareSMLoc {
bool operator()(const SMLoc &LHS, const SMLoc &RHS) const {
return LHS.getPointer() < RHS.getPointer();
}
};
// How many diagnostics does each pattern have?
std::map<SMLoc, unsigned, CompareSMLoc> DiagCountPerPattern;
for (auto Diag : Diags)
++DiagCountPerPattern[Diag.CheckLoc];
// How many diagnostics have we seen so far per pattern?
std::map<SMLoc, unsigned, CompareSMLoc> DiagIndexPerPattern;
// How many total diagnostics have we seen so far?
unsigned DiagIndex = 0;
// What's the widest label?
LabelWidth = 0;
for (auto DiagItr = Diags.begin(), DiagEnd = Diags.end(); DiagItr != DiagEnd;
++DiagItr) {
InputAnnotation A;
A.DiagIndex = DiagIndex++;
// Build label, which uniquely identifies this check result.
unsigned CheckBufferID = SM.FindBufferContainingLoc(DiagItr->CheckLoc);
auto CheckLineAndCol =
SM.getLineAndColumn(DiagItr->CheckLoc, CheckBufferID);
llvm::raw_string_ostream Label(A.Label);
Label << GetCheckTypeAbbreviation(DiagItr->CheckTy) << ":";
if (CheckBufferID == CheckFileBufferID)
Label << CheckLineAndCol.first;
else if (ImpPatBufferIDRange.first <= CheckBufferID &&
CheckBufferID < ImpPatBufferIDRange.second)
Label << "imp" << (CheckBufferID - ImpPatBufferIDRange.first + 1);
else
llvm_unreachable("expected diagnostic's check location to be either in "
"the check file or for an implicit pattern");
if (DiagCountPerPattern[DiagItr->CheckLoc] > 1)
Label << "'" << DiagIndexPerPattern[DiagItr->CheckLoc]++;
LabelWidth = std::max((std::string::size_type)LabelWidth, A.Label.size());
A.Marker = GetMarker(DiagItr->MatchTy);
if (!DiagItr->Note.empty()) {
A.Marker.Note = DiagItr->Note;
// It's less confusing if notes that don't actually have ranges don't have
// markers. For example, a marker for 'with "VAR" equal to "5"' would
// seem to indicate where "VAR" matches, but the location we actually have
// for the marker simply points to the start of the match/search range for
// the full pattern of which the substitution is potentially just one
// component.
if (DiagItr->InputStartLine == DiagItr->InputEndLine &&
DiagItr->InputStartCol == DiagItr->InputEndCol)
A.Marker.Lead = ' ';
}
if (DiagItr->MatchTy == FileCheckDiag::MatchFoundErrorNote) {
assert(!DiagItr->Note.empty() &&
"expected custom note for MatchFoundErrorNote");
A.Marker.Note = "error: " + A.Marker.Note;
}
A.FoundAndExpectedMatch =
DiagItr->MatchTy == FileCheckDiag::MatchFoundAndExpected;
// Compute the mark location, and break annotation into multiple
// annotations if it spans multiple lines.
A.IsFirstLine = true;
A.InputLine = DiagItr->InputStartLine;
A.InputStartCol = DiagItr->InputStartCol;
if (DiagItr->InputStartLine == DiagItr->InputEndLine) {
// Sometimes ranges are empty in order to indicate a specific point, but
// that would mean nothing would be marked, so adjust the range to
// include the following character.
A.InputEndCol =
std::max(DiagItr->InputStartCol + 1, DiagItr->InputEndCol);
Annotations.push_back(A);
} else {
assert(DiagItr->InputStartLine < DiagItr->InputEndLine &&
"expected input range not to be inverted");
A.InputEndCol = UINT_MAX;
Annotations.push_back(A);
for (unsigned L = DiagItr->InputStartLine + 1, E = DiagItr->InputEndLine;
L <= E; ++L) {
// If a range ends before the first column on a line, then it has no
// characters on that line, so there's nothing to render.
if (DiagItr->InputEndCol == 1 && L == E)
break;
InputAnnotation B;
B.DiagIndex = A.DiagIndex;
B.Label = A.Label;
B.IsFirstLine = false;
B.InputLine = L;
B.Marker = A.Marker;
B.Marker.Lead = '~';
B.Marker.Note = "";
B.InputStartCol = 1;
if (L != E)
B.InputEndCol = UINT_MAX;
else
B.InputEndCol = DiagItr->InputEndCol;
B.FoundAndExpectedMatch = A.FoundAndExpectedMatch;
Annotations.push_back(B);
}
}
}
}
static unsigned FindInputLineInFilter(
DumpInputFilterValue DumpInputFilter, unsigned CurInputLine,
const std::vector<InputAnnotation>::iterator &AnnotationBeg,
const std::vector<InputAnnotation>::iterator &AnnotationEnd) {
if (DumpInputFilter == DumpInputFilterAll)
return CurInputLine;
for (auto AnnotationItr = AnnotationBeg; AnnotationItr != AnnotationEnd;
++AnnotationItr) {
switch (DumpInputFilter) {
case DumpInputFilterAll:
llvm_unreachable("unexpected DumpInputFilterAll");
break;
case DumpInputFilterAnnotationFull:
return AnnotationItr->InputLine;
case DumpInputFilterAnnotation:
if (AnnotationItr->IsFirstLine)
return AnnotationItr->InputLine;
break;
case DumpInputFilterError:
if (AnnotationItr->IsFirstLine && AnnotationItr->Marker.FiltersAsError)
return AnnotationItr->InputLine;
break;
}
}
return UINT_MAX;
}
/// To OS, print a vertical ellipsis (right-justified at LabelWidth) if it would
/// occupy less lines than ElidedLines, but print ElidedLines otherwise. Either
/// way, clear ElidedLines. Thus, if ElidedLines is empty, do nothing.
static void DumpEllipsisOrElidedLines(raw_ostream &OS, std::string &ElidedLines,
unsigned LabelWidth) {
if (ElidedLines.empty())
return;
unsigned EllipsisLines = 3;
if (EllipsisLines < StringRef(ElidedLines).count('\n')) {
for (unsigned i = 0; i < EllipsisLines; ++i) {
WithColor(OS, raw_ostream::BLACK, /*Bold=*/true)
<< right_justify(".", LabelWidth);
OS << '\n';
}
} else
OS << ElidedLines;
ElidedLines.clear();
}
static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
DumpInputFilterValue DumpInputFilter,
unsigned DumpInputContext,
StringRef InputFileText,
std::vector<InputAnnotation> &Annotations,
unsigned LabelWidth) {
OS << "Input was:\n<<<<<<\n";
// Sort annotations.
llvm::sort(Annotations,
[](const InputAnnotation &A, const InputAnnotation &B) {
// 1. Sort annotations in the order of the input lines.
//
// This makes it easier to find relevant annotations while
// iterating input lines in the implementation below. FileCheck
// does not always produce diagnostics in the order of input
// lines due to, for example, CHECK-DAG and CHECK-NOT.
if (A.InputLine != B.InputLine)
return A.InputLine < B.InputLine;
// 2. Sort annotations in the temporal order FileCheck produced
// their associated diagnostics.
//
// This sort offers several benefits:
//
// A. On a single input line, the order of annotations reflects
// the FileCheck logic for processing directives/patterns.
// This can be helpful in understanding cases in which the
// order of the associated directives/patterns in the check
// file or on the command line either (i) does not match the
// temporal order in which FileCheck looks for matches for the
// directives/patterns (due to, for example, CHECK-LABEL,
// CHECK-NOT, or `--implicit-check-not`) or (ii) does match
// that order but does not match the order of those
// diagnostics along an input line (due to, for example,
// CHECK-DAG).
//
// On the other hand, because our presentation format presents
// input lines in order, there's no clear way to offer the
// same benefit across input lines. For consistency, it might
// then seem worthwhile to have annotations on a single line
// also sorted in input order (that is, by input column).
// However, in practice, this appears to be more confusing
// than helpful. Perhaps it's intuitive to expect annotations
// to be listed in the temporal order in which they were
// produced except in cases the presentation format obviously
// and inherently cannot support it (that is, across input
// lines).
//
// B. When diagnostics' annotations are split among multiple
// input lines, the user must track them from one input line
// to the next. One property of the sort chosen here is that
// it facilitates the user in this regard by ensuring the
// following: when comparing any two input lines, a
// diagnostic's annotations are sorted in the same position
// relative to all other diagnostics' annotations.
return A.DiagIndex < B.DiagIndex;
});
// Compute the width of the label column.
const unsigned char *InputFilePtr = InputFileText.bytes_begin(),
*InputFileEnd = InputFileText.bytes_end();
unsigned LineCount = InputFileText.count('\n');
if (InputFileEnd[-1] != '\n')
++LineCount;
unsigned LineNoWidth = std::log10(LineCount) + 1;
// +3 below adds spaces (1) to the left of the (right-aligned) line numbers
// on input lines and (2) to the right of the (left-aligned) labels on
// annotation lines so that input lines and annotation lines are more
// visually distinct. For example, the spaces on the annotation lines ensure
// that input line numbers and check directive line numbers never align
// horizontally. Those line numbers might not even be for the same file.
// One space would be enough to achieve that, but more makes it even easier
// to see.
LabelWidth = std::max(LabelWidth, LineNoWidth) + 3;
// Print annotated input lines.
unsigned PrevLineInFilter = 0; // 0 means none so far
unsigned NextLineInFilter = 0; // 0 means uncomputed, UINT_MAX means none
std::string ElidedLines;
raw_string_ostream ElidedLinesOS(ElidedLines);
ColorMode TheColorMode =
WithColor(OS).colorsEnabled() ? ColorMode::Enable : ColorMode::Disable;
if (TheColorMode == ColorMode::Enable)
ElidedLinesOS.enable_colors(true);
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
for (unsigned Line = 1;
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
const unsigned char *InputFileLine = InputFilePtr;
// Compute the previous and next line included by the filter.
if (NextLineInFilter < Line)
NextLineInFilter = FindInputLineInFilter(DumpInputFilter, Line,
AnnotationItr, AnnotationEnd);
assert(NextLineInFilter && "expected NextLineInFilter to be computed");
if (NextLineInFilter == Line)
PrevLineInFilter = Line;
// Elide this input line and its annotations if it's not within the
// context specified by -dump-input-context of an input line included by
// -dump-input-filter. However, in case the resulting ellipsis would occupy
// more lines than the input lines and annotations it elides, buffer the
// elided lines and annotations so we can print them instead.
raw_ostream *LineOS = &OS;
if ((!PrevLineInFilter || PrevLineInFilter + DumpInputContext < Line) &&
(NextLineInFilter == UINT_MAX ||
Line + DumpInputContext < NextLineInFilter))
LineOS = &ElidedLinesOS;
else {
LineOS = &OS;
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
}
// Print right-aligned line number.
WithColor(*LineOS, raw_ostream::BLACK, /*Bold=*/true, /*BF=*/false,
TheColorMode)
<< format_decimal(Line, LabelWidth) << ": ";
// For the case where -v and colors are enabled, find the annotations for
// good matches for expected patterns in order to highlight everything
// else in the line. There are no such annotations if -v is disabled.
std::vector<InputAnnotation> FoundAndExpectedMatches;
if (Req.Verbose && TheColorMode == ColorMode::Enable) {
for (auto I = AnnotationItr; I != AnnotationEnd && I->InputLine == Line;
++I) {
if (I->FoundAndExpectedMatch)
FoundAndExpectedMatches.push_back(*I);
}
}
// Print numbered line with highlighting where there are no matches for
// expected patterns.
bool Newline = false;
{
WithColor COS(*LineOS, raw_ostream::SAVEDCOLOR, /*Bold=*/false,
/*BG=*/false, TheColorMode);
bool InMatch = false;
if (Req.Verbose)
COS.changeColor(raw_ostream::CYAN, true, true);
for (unsigned Col = 1; InputFilePtr != InputFileEnd && !Newline; ++Col) {
bool WasInMatch = InMatch;
InMatch = false;
for (auto M : FoundAndExpectedMatches) {
if (M.InputStartCol <= Col && Col < M.InputEndCol) {
InMatch = true;
break;
}
}
if (!WasInMatch && InMatch)
COS.resetColor();
else if (WasInMatch && !InMatch)
COS.changeColor(raw_ostream::CYAN, true, true);
if (*InputFilePtr == '\n') {
Newline = true;
COS << ' ';
} else
COS << *InputFilePtr;
++InputFilePtr;
}
}
*LineOS << '\n';
unsigned InputLineWidth = InputFilePtr - InputFileLine;
// Print any annotations.
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
/*BG=*/false, TheColorMode);
// The two spaces below are where the ": " appears on input lines.
COS << left_justify(AnnotationItr->Label, LabelWidth) << " ";
unsigned Col;
for (Col = 1; Col < AnnotationItr->InputStartCol; ++Col)
COS << ' ';
COS << AnnotationItr->Marker.Lead;
// If InputEndCol=UINT_MAX, stop at InputLineWidth.
for (++Col; Col < AnnotationItr->InputEndCol && Col <= InputLineWidth;
++Col)
COS << '~';
const std::string &Note = AnnotationItr->Marker.Note;
if (!Note.empty()) {
// Put the note at the end of the input line. If we were to instead
// put the note right after the marker, subsequent annotations for the
// same input line might appear to mark this note instead of the input
// line.
for (; Col <= InputLineWidth; ++Col)
COS << ' ';
COS << ' ' << Note;
}
COS << '\n';
++AnnotationItr;
}
}
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
OS << ">>>>>>\n";
}
int main(int argc, char **argv) {
// Enable use of ANSI color codes because FileCheck is using them to
// highlight text.
llvm::sys::Process::UseANSIEscapeCodes(true);
InitLLVM X(argc, argv);
cl::ParseCommandLineOptions(argc, argv, /*Overview*/ "", /*Errs*/ nullptr,
"FILECHECK_OPTS");
// Select -dump-input* values. The -help documentation specifies the default
// value and which value to choose if an option is specified multiple times.
// In the latter case, the general rule of thumb is to choose the value that
// provides the most information.
DumpInputValue DumpInput =
DumpInputs.empty()
? DumpInputFail
: *std::max_element(DumpInputs.begin(), DumpInputs.end());
DumpInputFilterValue DumpInputFilter;
if (DumpInputFilters.empty())
DumpInputFilter = DumpInput == DumpInputAlways ? DumpInputFilterAll
: DumpInputFilterError;
else
DumpInputFilter =
*std::max_element(DumpInputFilters.begin(), DumpInputFilters.end());
unsigned DumpInputContext = DumpInputContexts.empty()
? 5
: *std::max_element(DumpInputContexts.begin(),
DumpInputContexts.end());
if (DumpInput == DumpInputHelp) {
DumpInputAnnotationHelp(outs());
return 0;
}
if (CheckFilename.empty()) {
errs() << "<check-file> not specified\n";
return 2;
}
FileCheckRequest Req;
append_range(Req.CheckPrefixes, CheckPrefixes);
append_range(Req.CommentPrefixes, CommentPrefixes);
append_range(Req.ImplicitCheckNot, ImplicitCheckNot);
bool GlobalDefineError = false;
for (StringRef G : GlobalDefines) {
size_t EqIdx = G.find('=');
if (EqIdx == std::string::npos) {
errs() << "Missing equal sign in command-line definition '-D" << G
<< "'\n";
GlobalDefineError = true;
continue;
}
if (EqIdx == 0) {
errs() << "Missing variable name in command-line definition '-D" << G
<< "'\n";
GlobalDefineError = true;
continue;
}
Req.GlobalDefines.push_back(G);
}
if (GlobalDefineError)
return 2;
Req.AllowEmptyInput = AllowEmptyInput;
Req.AllowUnusedPrefixes = AllowUnusedPrefixes;
Req.EnableVarScope = EnableVarScope;
Req.AllowDeprecatedDagOverlap = AllowDeprecatedDagOverlap;
Req.Verbose = Verbose;
Req.VerboseVerbose = VerboseVerbose;
Req.NoCanonicalizeWhiteSpace = NoCanonicalizeWhiteSpace;
Req.MatchFullLines = MatchFullLines;
Req.IgnoreCase = IgnoreCase;
if (VerboseVerbose)
Req.Verbose = true;
FileCheck FC(Req);
if (!FC.ValidateCheckPrefixes())
return 2;
Regex PrefixRE = FC.buildCheckPrefixRegex();
std::string REError;
if (!PrefixRE.isValid(REError)) {
errs() << "Unable to combine check-prefix strings into a prefix regular "
"expression! This is likely a bug in FileCheck's verification of "
"the check-prefix strings. Regular expression parsing failed "
"with the following error: "
<< REError << "\n";
return 2;
}
SourceMgr SM;
// Read the expected strings from the check file.
ErrorOr<std::unique_ptr<MemoryBuffer>> CheckFileOrErr =
MemoryBuffer::getFileOrSTDIN(CheckFilename, /*IsText=*/true);
if (std::error_code EC = CheckFileOrErr.getError()) {
errs() << "Could not open check file '" << CheckFilename
<< "': " << EC.message() << '\n';
return 2;
}
MemoryBuffer &CheckFile = *CheckFileOrErr.get();
SmallString<4096> CheckFileBuffer;
StringRef CheckFileText = FC.CanonicalizeFile(CheckFile, CheckFileBuffer);
unsigned CheckFileBufferID =
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
CheckFileText, CheckFile.getBufferIdentifier()),
SMLoc());
std::pair<unsigned, unsigned> ImpPatBufferIDRange;
if (FC.readCheckFile(SM, CheckFileText, PrefixRE, &ImpPatBufferIDRange))
return 2;
// Open the file to check and add it to SourceMgr.
ErrorOr<std::unique_ptr<MemoryBuffer>> InputFileOrErr =
MemoryBuffer::getFileOrSTDIN(InputFilename, /*IsText=*/true);
if (InputFilename == "-")
InputFilename = "<stdin>"; // Overwrite for improved diagnostic messages
if (std::error_code EC = InputFileOrErr.getError()) {
errs() << "Could not open input file '" << InputFilename
<< "': " << EC.message() << '\n';
return 2;
}
MemoryBuffer &InputFile = *InputFileOrErr.get();
if (InputFile.getBufferSize() == 0 && !AllowEmptyInput) {
errs() << "FileCheck error: '" << InputFilename << "' is empty.\n";
DumpCommandLine(argc, argv);
return 2;
}
SmallString<4096> InputFileBuffer;
StringRef InputFileText = FC.CanonicalizeFile(InputFile, InputFileBuffer);
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
InputFileText, InputFile.getBufferIdentifier()),
SMLoc());
std::vector<FileCheckDiag> Diags;
int ExitCode = FC.checkInput(SM, InputFileText,
DumpInput == DumpInputNever ? nullptr : &Diags)
? EXIT_SUCCESS
: 1;
if (DumpInput == DumpInputAlways ||
(ExitCode == 1 && DumpInput == DumpInputFail)) {
errs() << "\n"
<< "Input file: " << InputFilename << "\n"
<< "Check file: " << CheckFilename << "\n"
<< "\n"
<< "-dump-input=help explains the following input dump.\n"
<< "\n";
std::vector<InputAnnotation> Annotations;
unsigned LabelWidth;
BuildInputAnnotations(SM, CheckFileBufferID, ImpPatBufferIDRange, Diags,
Annotations, LabelWidth);
DumpAnnotatedInput(errs(), Req, DumpInputFilter, DumpInputContext,
InputFileText, Annotations, LabelWidth);
}
return ExitCode;
}

View File

@@ -2,6 +2,7 @@
#define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"

View File

@@ -50,4 +50,18 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
];
}
def TritonConvertArithToIndex : Pass<"triton-convert-arith-to-index", "mlir::ModuleOp"> {
let summary = "Convert arith to index";
let constructor = "mlir::triton::createTritonConvertArithToIndexPass()";
let description = [{
Convert arith operation on index values to corresponding ops in the index dialect.
We need this because SCFToCF conversion currently generates arith ops on indices.
}];
let dependentDialects = ["mlir::index::IndexDialect"];
}
#endif

View File

@@ -0,0 +1,20 @@
#ifndef TRITON_CONVERSION_ARITH_TO_INDEX_H
#define TRITON_CONVERSION_ARITH_TO_INDEX_H
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createTritonConvertArithToIndexPass();
}
} // namespace mlir
#endif

View File

@@ -72,7 +72,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
DeclareOpInterfaceMethods<CastOpInterface>]> {
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Floating point casting for custom types";
let description = [{
@@ -405,19 +405,31 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
}
//
// Make PrintfOp
// Make PrintOp
//
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix,
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging";
def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side print, as in CUDA for debugging";
let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
`tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
format are generated automatically from the arguments.
}];
let assemblyFormat = [{
$prefix attr-dict ($args^ `:` type($args))?
$prefix attr-dict `:` ($args^ `:` type($args))?
}];
}
//
// Make AssertOp
//
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
let summary = "Device-side assert, as in CUDA for correctness checking";
let description = [{
`tt.assert` takes a condition tensor, a message string, a file string, a function string, and a line number.
If the condition is false, the message is printed, and the program is aborted.
}];
let arguments = (ins TT_Tensor:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line);
let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)";
}
#endif // Triton_OPS

View File

@@ -14,9 +14,7 @@ class TritonTypeDef<string name, string _mnemonic>
}
// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E5M2, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

View File

@@ -35,7 +35,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
}];
code extraBaseClassDeclaration = [{
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
unsigned getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}
@@ -382,44 +382,63 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
);
let builders = [
// Specially for MMAV1(Volta)
// Specially for MMAV1(Volta)
AttrBuilder<(ins "int":$versionMajor,
"int":$numWarps,
"ArrayRef<int64_t>":$shapeC,
"bool":$isARow,
"bool":$isBRow,
"bool":$isAVec4,
"bool":$isBVec4,
"int":$id), [{
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
SmallVector<unsigned> wpt({static_cast<unsigned>(numWarps), 1});
int versionMinor = 0;
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
int versionMinor = (isARow * (1<<0)) |\
(isBRow * (1<<1)) |\
(isAVec4 * (1<<2)) |\
(isBVec4 * (1<<3));
assert(id < (1<<numBitsToHoldMmaV1ID) && "MMAv1 ID exceeds the maximum");
for (int i = 0; i < numBitsToHoldMmaV1ID; ++i)
versionMinor |= static_cast<bool>((1<<i) & id) * (1<<(4+i));
// TODO: Share code with
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
// rep,spw and fpw.
SmallVector<unsigned> wpt({1, 1});
SmallVector<unsigned> wpt_nm1;
SmallVector<int, 2> rep(2), spw(2);
std::array<int, 3> fpw{{2, 2, 1}};
int packSize0 = (isARow || isAVec4) ? 1 : 2;
rep[0] = 2 * packSize0;
spw[0] = fpw[0] * 4 * rep[0];
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
rep[1] = 2 * packSize1;
spw[1] = fpw[1] * 4 * rep[1];
do {
wpt_nm1 = wpt;
if (wpt[0] * wpt[1] < numWarps)
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shapeC[0] / spw[0]);
if (wpt[0] * wpt[1] < numWarps)
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shapeC[1] / spw[1]);
} while (wpt_nm1 != wpt);
return $_get(context, versionMajor, versionMinor, wpt);
}]>,
// Specially for MMAV1(Volta)
AttrBuilder<(ins "int":$versionMajor,
"ArrayRef<unsigned>":$warpsPerCTA,
"int":$numWarps,
"ArrayRef<int64_t>":$shapeA,
"ArrayRef<int64_t>":$shapeB,
"ArrayRef<int64_t>":$shapeC,
"bool":$isARow,
"bool":$isBRow,
"int":$id), [{
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
// 3-bits to encode the MMA ID to make each unique
int versionMinor = (isARow * (1<<0)) |\
(isBRow * (1<<1)) |\
(isAVec4 * (1<<2)) |\
(isBVec4 * (1<<3));
assert(id < (1<<numBitsToHoldMmaV1ID) && "MMAv1 ID exceeds the maximum");
for (int i = 0; i < numBitsToHoldMmaV1ID; ++i)
versionMinor |= static_cast<bool>((1<<i) & id) * (1<<(4+i));
return $_get(context, versionMajor, versionMinor, warpsPerCTA);
return get(context, versionMajor, numWarps, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
}]>
];
@@ -489,25 +508,21 @@ section 9.7.13.4.1 for more details.
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent,
"Attribute":$isMMAv1Row
"Attribute":$parent
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().isVolta()){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);
}]>
];
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = extraBaseClassDeclaration;
let extraClassDeclaration = extraBaseClassDeclaration # [{
bool getMMAv1IsRow() const;
bool getMMAv1IsVec4() const;
SmallVector<int> getMMAv1Rep() const;
SmallVector<int> getMMAv1ShapePerWarp() const;
int getMMAv1Vec() const;
int getMMAv1NumOuter(ArrayRef<int64_t> shape) const;
}];
}

View File

@@ -9,6 +9,8 @@ include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
@@ -105,6 +107,69 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
}
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
Pure,
OffsetSizeAndStrideOpInterface
]> {
let summary = "extract slice operation";
let description = [{
same as tensor.extract_slice, but with int32 index. The motivations for re-implementing it are:
We reimplement ExtractSliceOp with int32 index, because:
- we want to enforce int32 indexing on GPUs since Triton tensors fit in SRAM
- we still want to use indexWidth = 64 when lowering to LLVM because our loops can have
64-bit induction variables and scf.for uses indexType for bounds/ivs
}];
let arguments = (ins
AnyRankedTensor:$source,
Variadic<I32>:$offsets,
Variadic<I32>:$sizes,
Variadic<I32>:$strides,
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);
let builders = [
// Build an ExtractSliceOp with mixed static and dynamic entries and custom
// result type. If the type passed is nullptr, it is inferred.
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
let extraClassDeclaration = [{
/// Return the number of leading operands before the `offsets`, `sizes` and
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
/// Returns the type of the base tensor operand.
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>();
}
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned rank = getSourceType().getRank();
return {rank, rank, rank};
}
}];
let assemblyFormat = [{
$source ``
custom<DynamicIndexList>($offsets, $static_offsets)
custom<DynamicIndexList>($sizes, $static_sizes)
custom<DynamicIndexList>($strides, $static_strides)
attr-dict `:` type($source) `to` type($result)
}];
}
//
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,

View File

@@ -23,7 +23,7 @@ std::unique_ptr<Pass> createTritonGPURemoveLayoutConversionsPass();
std::unique_ptr<Pass> createTritonGPUVerifier();
std::unique_ptr<Pass> createTritonGPUFuseTranspositionsPass();
std::unique_ptr<Pass> createTritonGPUOptimizeDotOperandsPass();
std::unique_ptr<Pass> createTritonGPUUpdateMmaForVoltaPass();

View File

@@ -60,7 +60,7 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
}
def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::ModuleOp"> {
def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> {
let summary = "fuse transpositions";
let description = [{
@@ -68,7 +68,7 @@ def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::Mo
hardware-accelerated transpositions.
}];
let constructor = "mlir::createTritonGPUFuseTranspositionsPass()";
let constructor = "mlir::createTritonGPUOptimizeDotOperandsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
@@ -86,6 +86,7 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> {
let summary = "remove superfluous layout conversions";
@@ -122,16 +123,4 @@ def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir
"mlir::triton::TritonDialect"];
}
def UpdateMmaForVolta : Pass<"tritongpu-update-mma-for-volta", "mlir::ModuleOp"> {
let summary = "Update mma encodings for Volta";
let description = [{
This helps to update the mma encodings for Volta.
}];
let constructor = "mlir::createTritonGPUUpdateMmaForVoltaPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
#endif

View File

@@ -27,7 +27,7 @@ void SharedMemoryAliasAnalysis::visitOperation(
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// XXX(Keren): the following ops are always aliasing for now
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp>(op)) {
// extract_slice %src
// trans %src
aliasInfo = AliasInfo(operands[0]->getValue());

View File

@@ -846,10 +846,14 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
void AxisInfoAnalysis::visitOperation(
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
// TODO: For sure not the right way to do this
// but why is scf.if not initialized otherwise?
for (auto op : operands)
if (op->getValue().getRank() == 0)
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
AxisInfo curr = visitors.apply(op, operands);
if (curr.getRank() == 0) {
if (curr.getRank() == 0)
return setAllToEntryStates(results);
}
// override with hint
auto newContiguity = curr.getContiguity();
auto newDivisibility = curr.getDivisibility();

View File

@@ -78,8 +78,8 @@ void MembarAnalysis::visitTerminator(Operation *op,
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
OpBuilder *builder) {
if (isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op) ||
isa<triton::TransOp>(op)) {
if (isa<triton::gpu::ExtractSliceOp>(op) ||
isa<triton::gpu::AllocTensorOp>(op) || isa<triton::TransOp>(op)) {
// alloc is an allocation op without memory write.
// FIXME(Keren): extract_slice is always alias for now
return;

View File

@@ -117,7 +117,7 @@ bool maybeSharedAllocationOp(Operation *op) {
}
bool maybeAliasOp(Operation *op) {
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
return isa<triton::gpu::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op);
}

View File

@@ -0,0 +1,92 @@
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
namespace {
class TritonArithToIndexConversionTarget : public mlir::ConversionTarget {
public:
static bool hasIndexResultOrOperand(Operation *op) {
if (!op)
return false;
bool hasRetIndex = llvm::find_if(op->getResultTypes(), [](Type type) {
return type.isIndex();
}) != op->getResultTypes().end();
bool hasArgIndex = llvm::find_if(op->getOperandTypes(), [](Type type) {
return type.isIndex();
}) != op->getOperandTypes().end();
return !hasRetIndex && !hasArgIndex;
}
explicit TritonArithToIndexConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<index::IndexDialect>();
addDynamicallyLegalDialect<arith::ArithDialect>(hasIndexResultOrOperand);
}
};
template <class SrcOp, class DstOp>
LogicalResult replaceArithWithIndex(SrcOp op, PatternRewriter &rewriter) {
// if (!hasIndexResultOrOperand(&*op))
// return failure();
rewriter.replaceOpWithNewOp<DstOp>(op, op->getResultTypes(),
op->getOperands(), op->getAttrs());
return success();
}
LogicalResult replaceArithCmpWithIndexCmp(arith::CmpIOp op,
PatternRewriter &rewriter) {
// if (!hasIndexResultOrOperand(&*op))
// return failure();
rewriter.replaceOpWithNewOp<index::CmpOp>(
op, op.getResult().getType(), (index::IndexCmpPredicate)op.getPredicate(),
op.getOperand(0), op.getOperand(1));
return success();
}
class ArithToIndex : public TritonConvertArithToIndexBase<ArithToIndex> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
TritonArithToIndexConversionTarget target(*context);
RewritePatternSet patterns(context);
patterns.add(replaceArithWithIndex<arith::IndexCastOp, index::CastSOp>);
patterns.add(replaceArithWithIndex<arith::ConstantOp, index::ConstantOp>);
patterns.add(replaceArithWithIndex<arith::AddIOp, index::AddOp>);
patterns.add(replaceArithCmpWithIndexCmp);
if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createTritonConvertArithToIndexPass() {
return std::make_unique<::ArithToIndex>();
}
} // namespace triton
} // namespace mlir

View File

@@ -1,17 +1,16 @@
add_mlir_conversion_library(TritonGPUToLLVM
TritonGPUToLLVM.cpp
GCNAsmFormat.cpp
PTXAsmFormat.cpp
TritonGPUToLLVMPass.cpp
ArithToIndexPass.cpp
ConvertLayoutOpToLLVM.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
TritonGPUToLLVM.cpp
TritonGPUToLLVMPass.cpp
GCNAsmFormat.cpp
PTXAsmFormat.cpp
ReduceOpToLLVM.cpp
Utility.cpp
TypeConverter.cpp
ViewOpToLLVM.cpp
DotOpHelpers.cpp

View File

@@ -4,10 +4,8 @@
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
@@ -56,28 +54,32 @@ public:
private:
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
unsigned elemId, ArrayRef<int64_t> shape,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTA) const {
auto shape = type.getShape();
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout));
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
i32_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
return multiDimOffset;
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parentEncoding = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
parentEncoding);
auto multiDimOffsetParent =
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
sliceLayout.paddedShape(shape),
getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
@@ -94,24 +96,24 @@ private:
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
#ifdef USE_ROCM
Value warpSize = idx_val(64);
Value warpSize = i32_val(64);
#else
Value warpSize = idx_val(32);
Value warpSize = i32_val(32);
#endif
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
Value _1 = idx_val(1);
Value _2 = idx_val(2);
Value _4 = idx_val(4);
Value _8 = idx_val(8);
Value _16 = idx_val(16);
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Value _8 = i32_val(8);
Value _16 = i32_val(16);
if (mmaLayout.isAmpere()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
Value mmaGrpId = udiv(laneId, _4);
Value mmaGrpIdP8 = add(mmaGrpId, _8);
Value mmaThreadIdInGrp = urem(laneId, _4);
@@ -135,15 +137,15 @@ private:
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else if (mmaLayout.isVolta()) {
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
auto coords = DotOpMmaV1ConversionHelper::getMNCoords(
threadId, rewriter, mmaLayout.getWarpsPerCTA(), shape, isARow,
isBRow, isAVec4, isBVec4);
threadId, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape,
isARow, isBRow, isAVec4, isBVec4);
return DotOpMmaV1ConversionHelper::getCoord(elemId, coords);
} else {
llvm_unreachable("Unexpected MMALayout version");
@@ -199,7 +201,7 @@ private:
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
@@ -216,13 +218,13 @@ private:
currVal = zext(llvmElemTy, currVal);
else if (isPtr)
currVal = ptrtoint(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
Value currVal = extract_element(llvmElemTy, valVec, i32_val(v));
if (isInt1)
currVal = icmp_ne(currVal,
rewriter.create<LLVM::ConstantOp>(
@@ -289,18 +291,15 @@ private:
// TODO[Superjomn]: Move the coordinate computation out of loop, it is
// duplicate in Volta.
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
}
if (needTrans) {
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
mma.decodeVoltaLayoutStates();
DotOpMmaV1ConversionHelper helper(mma);
// do transpose
int numM = helper.getElemsM(mma.getWarpsPerCTA()[0], shape[0], isARow,
isAVec4);
auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma);
int numM = aEncoding.getMMAv1NumOuter(shape);
int numN = accumSizePerThread / numM;
for (int r = 0; r < numM; r++) {
@@ -326,13 +325,13 @@ private:
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
auto currVal = coord2valT[elemId + v].second;
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(elemTy, valVec, idx_val(v));
Value currVal = extract_element(elemTy, valVec, i32_val(v));
vals[elemId + v] = currVal;
}
}
@@ -397,7 +396,8 @@ private:
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.getSrc(), rewriter);
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
@@ -448,7 +448,8 @@ private:
SmallVector<Type> types(outElems, llvmElemTy);
auto *ctx = llvmElemTy.getContext();
Type structTy = struct_ty(types);
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);
return success();
@@ -480,7 +481,7 @@ private:
auto dstStrides =
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter);
auto smemObj =
@@ -525,10 +526,10 @@ private:
auto thread = getThreadId(rewriter, loc);
if (dotOpLayout.getOpIdx() == 0) { // $a
res = helper.loadA(src, adaptor.getSrc(), blockedLayout, thread, loc,
rewriter);
getTypeConverter(), rewriter);
} else { // $b
res = helper.loadB(src, adaptor.getSrc(), blockedLayout, thread, loc,
rewriter);
getTypeConverter(), rewriter);
}
} else {
assert(false && "Unsupported dot operand layout found");
@@ -551,22 +552,42 @@ private:
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
// get source values
auto vals = getElementsFromStruct(loc, adaptor.getSrc(), rewriter);
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned elems = getElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
unsigned vecSize =
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
Type vecTy = vec_ty(elemTy, vecSize);
SmallVector<Type> types(elems / vecSize, vecTy);
SmallVector<Value> vecVals;
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
SmallVector<Type> types;
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of ldmatrix in i32
auto elemSize = elemTy.getIntOrFloatBitWidth();
if (auto intTy = elemTy.dyn_cast<IntegerType>() && elemSize <= 16) {
auto fold = 32 / elemSize;
for (unsigned i = 0; i < elems; i += fold) {
Value val = i32_val(0);
for (unsigned j = 0; j < fold; j++) {
auto ext =
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
val = or_(i32_ty, val, ext);
}
vecVals.push_back(val);
}
elems = elems / (32 / elemSize);
types = SmallVector<Type>(elems, i32_ty);
} else {
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
Type vecTy = vec_ty(elemTy, vecSize);
types = SmallVector<Type>(elems / vecSize, vecTy);
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
}
// This needs to be ordered the same way that
@@ -582,12 +603,8 @@ private:
reorderedVals.push_back(vecVals[i + 3]);
}
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view =
getStructFromElements(loc, reorderedVals, rewriter, structTy);
Value view = getTypeConverter()->packLLElements(loc, reorderedVals,
rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
@@ -622,8 +639,7 @@ private:
}
} else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
@@ -640,12 +656,12 @@ private:
// TODO[Superjomn]: transA is not available here.
bool transA = false;
res = helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc,
rewriter);
getTypeConverter(), rewriter, dst.getType());
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
// TODO[Superjomn]: transB is not available here.
bool transB = false;
res = helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc,
rewriter);
getTypeConverter(), rewriter, dst.getType());
}
} else {
assert(false && "Unsupported mma layout found");
@@ -655,7 +671,7 @@ private:
};
void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,

View File

@@ -9,7 +9,7 @@ using namespace mlir::triton;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,

View File

@@ -1,65 +1,13 @@
#include "DotOpHelpers.h"
#include "TypeConverter.h"
namespace mlir {
namespace LLVM {
int DotOpMmaV1ConversionHelper::numElemsPerThreadA(ArrayRef<int64_t> shape,
bool isARow, bool isAVec4,
int vec) const {
int numM = getNumM(shape[0], isARow, isAVec4);
int NK = shape[1];
// Here we mimic the logic in loadA, the result cannot be calculated
// directly.
llvm::DenseSet<std::pair<int, int>> visited;
auto ld = [&](int m, int k) {
visited.insert({m, k});
if (vec > 4) {
if (isARow)
visited.insert({m, k + 4});
else
visited.insert({m + 1, k});
}
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
if (!visited.count({m, k}))
ld(m, k);
return visited.size() * 2;
}
int DotOpMmaV1ConversionHelper::numElemsPerThreadB(ArrayRef<int64_t> shape,
bool isBRow, bool isBVec4,
int vec) const {
unsigned numN = getNumN(shape[1], isBRow, isBVec4);
int NK = shape[0];
// Here we mimic the logic in loadA, the result cannot be calculated
// directly.
llvm::DenseSet<std::pair<int, int>> visited;
int elemsPerLd = vec > 4 ? 4 : 2;
auto ld = [&](int n, int k) {
visited.insert({n, k});
if (vec > 4) {
if (isBRow)
visited.insert({n + 1, k});
else
visited.insert({n, k + 4});
}
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!visited.count({n, k}))
ld(n, k);
}
return visited.size() * 2;
}
Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Type resultTy) const {
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
@@ -70,12 +18,12 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isARow = order[0] != 0;
auto [isARow_, _0, isAVec4, _1, _2] = mmaLayout.decodeVoltaLayoutStates();
AParam param(isARow_, isAVec4);
auto resultEncoding = resultTy.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto [offsetAM, offsetAK, _3, _4] = computeOffsets(
thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc);
thread, isARow, false, fpw, resultEncoding.getMMAv1ShapePerWarp(),
resultEncoding.getMMAv1Rep(), rewriter, loc);
int vecA = sharedLayout.getVec();
@@ -152,7 +100,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
};
unsigned numM = getNumM(shape[0], isARow, isAVec4);
bool isARow_ = resultEncoding.getMMAv1IsRow();
bool isAVec4 = resultEncoding.getMMAv1IsVec4();
unsigned numM = resultEncoding.getMMAv1NumOuter(shape);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
if (!has.count({m, k}))
@@ -165,14 +115,14 @@ Value DotOpMmaV1ConversionHelper::loadA(
elems.push_back(item.second.second);
}
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy);
return res;
}
Value DotOpMmaV1ConversionHelper::loadB(
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Type resultTy) const {
// smem
auto strides = smemObj.strides;
@@ -186,10 +136,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0; // is row-major in shared memory layout
// isBRow_ indicates whether B is row-major in DotOperand layout
auto [_0, isBRow_, _1, isBVec4, _2] = mmaLayout.decodeVoltaLayoutStates();
assert(isBRow == isBRow_ && "B need smem isRow");
BParam param(isBRow_, isBVec4);
auto resultEncoding = resultTy.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
int vecB = sharedLayout.getVec();
Value strideBN = isBRow ? i32_val(1) : strides[1];
@@ -200,7 +149,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
int strideRepK = 1;
auto [_3, _4, offsetBN, offsetBK] = computeOffsets(
thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc);
thread, false, isBRow, fpw, resultEncoding.getMMAv1ShapePerWarp(),
resultEncoding.getMMAv1Rep(), rewriter, loc);
// swizzling
int perPhaseB = sharedLayout.getPerPhase();
@@ -266,7 +216,10 @@ Value DotOpMmaV1ConversionHelper::loadB(
}
};
unsigned numN = getNumN(shape[1], isBRow, isBVec4);
bool isBRow_ = resultEncoding.getMMAv1IsRow();
assert(isBRow == isBRow_ && "B need smem isRow");
bool isBVec4 = resultEncoding.getMMAv1IsVec4();
unsigned numN = resultEncoding.getMMAv1NumOuter(shape);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k}))
@@ -279,8 +232,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
elems.push_back(item.second.second);
}
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy);
return res;
}
@@ -350,10 +302,11 @@ DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
DotOpMmaV1ConversionHelper::ValueTable
DotOpMmaV1ConversionHelper::extractLoadedOperand(
Value llStruct, int NK, ConversionPatternRewriter &rewriter) const {
Value llStruct, int NK, ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter, Type type) const {
ValueTable rcds;
SmallVector<Value> elems =
getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter);
SmallVector<Value> elems = typeConverter->unpackLLElements(
llStruct.getLoc(), llStruct, rewriter, type);
int offset = 0;
for (int i = 0; offset < elems.size(); ++i) {
@@ -366,10 +319,12 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand(
return rcds;
}
// TODO: Mostly a duplicate of TritonGPUToLLVMBase::emitBaseIndexforMMaLayoutV1
SmallVector<DotOpMmaV1ConversionHelper::CoordTy>
DotOpMmaV1ConversionHelper::getMNCoords(Value thread,
ConversionPatternRewriter &rewriter,
ArrayRef<unsigned int> wpt,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4,
bool isBVec4) {
@@ -388,11 +343,17 @@ DotOpMmaV1ConversionHelper::getMNCoords(Value thread,
Value _fpw0 = i32_val(fpw[0]);
Value _fpw1 = i32_val(fpw[1]);
DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
SmallVector<int, 2> rep({aRep[0], bRep[1]});
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
Value lane = urem(thread, warpSize);
@@ -468,23 +429,6 @@ DotOpMmaV1ConversionHelper::getMNCoords(Value thread,
return coords; // {M,N} in row-major
}
void DotOpMmaV1ConversionHelper::AParam::build(bool isARow) {
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
rep.assign({repM, 0, repK});
spw.assign({spwM, 0, 1});
vec = 2 * rep[0];
}
void DotOpMmaV1ConversionHelper::BParam::build(bool isBRow) {
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
rep.assign({0, 2 * packSize1, 1});
spw.assign({0, fpw[1] * 4 * rep[1], 1});
vec = 2 * rep[1];
}
std::tuple<int, int>
DotOpMmaV2ConversionHelper::getRepMN(const RankedTensorType &tensorTy) {
auto mmaLayout = tensorTy.getEncoding().cast<MmaEncodingAttr>();
@@ -1077,6 +1021,7 @@ LogicalResult MMA16816ConversionHelper::convertDot(Value a, Value b, Value c,
helper.deduceMmaType(op);
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
@@ -1090,10 +1035,10 @@ LogicalResult MMA16816ConversionHelper::convertDot(Value a, Value b, Value c,
int numRepK = getNumRepK(aTensorTy, aShape[1]);
ValueTable ha =
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK);
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK, aTensorTy);
ValueTable hb = getValuesFromDotOperandLayoutStruct(
loadedB, std::max(numRepN / 2, 1), numRepK);
auto fc = getElementsFromStruct(loc, loadedC, rewriter);
loadedB, std::max(numRepN / 2, 1), numRepK, bTensorTy);
auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy);
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
unsigned colsPerThread = numRepN * 2;
@@ -1139,7 +1084,7 @@ LogicalResult MMA16816ConversionHelper::convertDot(Value a, Value b, Value c,
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), resElemTy));
Value res = getStructFromElements(loc, fc, rewriter, structTy);
Value res = typeConverter->packLLElements(loc, fc, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
@@ -1218,14 +1163,14 @@ Value MMA16816ConversionHelper::composeValuesToDotOperandLayoutStruct(
Type elemTy = elems[0].getType();
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems.size(), elemTy));
auto result = getStructFromElements(loc, elems, rewriter, structTy);
auto result = typeConverter->packLLElements(loc, elems, rewriter, structTy);
return result;
}
MMA16816ConversionHelper::ValueTable
MMA16816ConversionHelper::getValuesFromDotOperandLayoutStruct(Value value,
int n0,
int n1) const {
auto elems = getElementsFromStruct(loc, value, rewriter);
int n0, int n1,
Type type) const {
auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type);
int offset{};
ValueTable vals;
@@ -1257,6 +1202,7 @@ SmallVector<Value> DotOpFMAConversionHelper::getThreadIds(
}
Value DotOpFMAConversionHelper::loadA(
Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) const {
auto aTensorTy = A.getType().cast<RankedTensorType>();
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
@@ -1316,10 +1262,11 @@ Value DotOpFMAConversionHelper::loadA(
vas.emplace_back(va);
}
return getStructFromValueTable(vas, rewriter, loc, elemTy);
return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy);
}
Value DotOpFMAConversionHelper::loadB(
Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Location loc,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) const {
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
@@ -1379,14 +1326,15 @@ Value DotOpFMAConversionHelper::loadB(
vbs.emplace_back(vb);
}
return getStructFromValueTable(vbs, rewriter, loc, elemTy);
return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy);
}
DotOpFMAConversionHelper::ValueTable
DotOpFMAConversionHelper::getValueTableFromStruct(
Value val, int K, int n0, int shapePerCTA, int sizePerThread,
ConversionPatternRewriter &rewriter, Location loc) const {
ConversionPatternRewriter &rewriter, Location loc,
TritonGPUToLLVMTypeConverter *typeConverter, Type type) const {
ValueTable res;
auto elems = getElementsFromStruct(loc, val, rewriter);
auto elems = typeConverter->unpackLLElements(loc, val, rewriter, type);
int index = 0;
for (unsigned k = 0; k < K; ++k) {
for (unsigned m = 0; m < n0; m += shapePerCTA)
@@ -1398,7 +1346,7 @@ DotOpFMAConversionHelper::getValueTableFromStruct(
}
Value DotOpFMAConversionHelper::getStructFromValueTable(
ArrayRef<Value> vals, ConversionPatternRewriter &rewriter, Location loc,
Type elemTy) const {
TritonGPUToLLVMTypeConverter *typeConverter, Type elemTy) const {
SmallVector<Type> elemTypes(vals.size(), elemTy);
SmallVector<Value> elems;
elems.reserve(vals.size());
@@ -1407,7 +1355,7 @@ Value DotOpFMAConversionHelper::getStructFromValueTable(
}
Type structTy = struct_ty(elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy);
return typeConverter->packLLElements(loc, elems, rewriter, structTy);
}
int DotOpFMAConversionHelper::getNumElemsPerThread(
ArrayRef<int64_t> shape, DotOperandEncodingAttr dotOpLayout) {

View File

@@ -26,6 +26,8 @@
#include "Utility.h"
class TritonGPUToLLVMTypeConverter;
namespace mlir {
namespace LLVM {
using namespace mlir::triton;
@@ -45,45 +47,6 @@ struct DotOpMmaV1ConversionHelper {
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
// Help to share some variables across multiple functions for A.
// TODO[Superjomn]: refactor and restrict this to only use in DotOp
// conversion.
struct AParam {
SmallVector<int> rep;
SmallVector<int> spw;
bool isAVec4{};
int vec{}; // This could only used in DotOp, not in
// loadA/loadB/TypeConverter
AParam(bool isARow, bool isAVec4) : isAVec4(isAVec4) { build(isARow); }
private:
void build(bool isARow);
};
// Help to share some variables across multiple functions for A.
// TODO[Superjomn]: refactor and restrict this to only use in DotOp
// conversion.
struct BParam {
SmallVector<int> rep;
SmallVector<int> spw;
bool isBVec4{};
int vec{}; // This could only used in DotOp, not in
// loadA/loadB/TypeConverter
BParam(bool isBRow, bool isBVec4) : isBVec4(isBVec4) { build(isBRow); }
private:
void build(bool isBRow);
};
int getRepM(int M) const {
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
}
int getRepN(int N) const {
return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
}
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
static Type getMmaRetType(TensorType operand) {
@@ -100,35 +63,15 @@ struct DotOpMmaV1ConversionHelper {
return struct_ty(SmallVector<Type>{vecTy});
}
// Get the number of fp16x2 elements for $a.
unsigned getNumM(int M, bool isARow, bool isAVec4) const {
AParam param(isARow, isAVec4);
unsigned numM = param.rep[0] * M / (param.spw[0] * wpt[0]);
return numM;
}
// Get the number of fp16x2 elements for $b.
unsigned getNumN(int N, bool isBRow, bool isBVec4) const {
BParam param(isBRow, isBVec4);
unsigned numN = param.rep[1] * N / (param.spw[1] * wpt[1]);
return numN;
}
int numElemsPerThreadA(ArrayRef<int64_t> shape, bool isARow, bool isAVec4,
int vec) const;
int numElemsPerThreadB(ArrayRef<int64_t> shape, bool isBRow, bool isBVec4,
int vec) const;
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value tensor, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
Location loc, TritonGPUToLLVMTypeConverter *converter,
ConversionPatternRewriter &rewriter, Type resultTy) const;
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value tensor, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
Location loc, TritonGPUToLLVMTypeConverter *converter,
ConversionPatternRewriter &rewriter, Type resultTy) const;
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
@@ -145,25 +88,17 @@ struct DotOpMmaV1ConversionHelper {
ConversionPatternRewriter &rewriter, Location loc) const;
// Extract values belong to $a or $b from a LLVMStruct, the shape is n0xn1.
DotOpMmaV1ConversionHelper::ValueTable
extractLoadedOperand(Value llStruct, int NK,
ConversionPatternRewriter &rewriter) const;
// Get the number of elements of this thread in M axis. The N axis could be
// further deduced with the accSize / elemsM. \param wpt: the wpt in M axis
// \param M: the shape in M axis
int getElemsM(int wpt, int M, bool isARow, bool isAVec4) {
DotOpMmaV1ConversionHelper::AParam param(isARow, isAVec4);
int shapePerCTAM = param.spw[0] * wpt;
return M / shapePerCTAM * param.rep[0];
}
DotOpMmaV1ConversionHelper::ValueTable extractLoadedOperand(
Value llStruct, int NK, ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter, Type type) const;
using CoordTy = SmallVector<Value>;
// Get the coordinates(m,n) of the elements emit by a thread in accumulator.
static SmallVector<CoordTy>
getMNCoords(Value thread, ConversionPatternRewriter &rewriter,
ArrayRef<unsigned> wpt, ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4, bool isBVec4);
ArrayRef<unsigned> wpt, const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape, bool isARow, bool isBRow, bool isAVec4,
bool isBVec4);
// \param elemId the offset of the element in a thread
static CoordTy getCoord(int elemId, ArrayRef<CoordTy> coords) {
@@ -411,7 +346,7 @@ struct MMA16816ConversionHelper {
DotOpMmaV2ConversionHelper helper;
ConversionPatternRewriter &rewriter;
TypeConverter *typeConverter;
TritonGPUToLLVMTypeConverter *typeConverter;
Location loc;
MLIRContext *ctx{};
@@ -420,7 +355,8 @@ struct MMA16816ConversionHelper {
// dotOperand: type of either one operand of dotOp.
MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout,
Value thread, ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, Location loc)
TritonGPUToLLVMTypeConverter *typeConverter,
Location loc)
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()), thread(thread),
helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter),
loc(loc), ctx(mmaLayout.getContext()) {
@@ -564,8 +500,8 @@ private:
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
int n1) const;
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0,
int n1) const;
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1,
Type type) const;
};
// Helper for conversion of FMA DotOp.
@@ -584,19 +520,23 @@ struct DotOpFMAConversionHelper {
ConversionPatternRewriter &rewriter, Location loc) const;
Value loadA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
Location loc, TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) const;
Value loadB(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
Location loc, TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) const;
ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA,
int sizePerThread,
ConversionPatternRewriter &rewriter,
Location loc) const;
ValueTable getValueTableFromStruct(
Value val, int K, int n0, int shapePerCTA, int sizePerThread,
ConversionPatternRewriter &rewriter, Location loc,
TritonGPUToLLVMTypeConverter *typeConverter, Type type) const;
Value getStructFromValueTable(ArrayRef<Value> vals,
ConversionPatternRewriter &rewriter,
Location loc, Type elemTy) const;
Location loc,
TritonGPUToLLVMTypeConverter *typeConverter,
Type elemTy) const;
// get number of elements per thread for $a or $b.
static int getNumElemsPerThread(ArrayRef<int64_t> shape,

View File

@@ -7,8 +7,6 @@ using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
@@ -116,28 +114,30 @@ private:
auto AShape = ATensorTy.getShape();
auto BShape = BTensorTy.getShape();
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
bool isARow = ALayout.getMMAv1IsRow();
bool isBRow = BLayout.getMMAv1IsRow();
auto [isARow_, isBRow_, isAVec4_, isBVec4_, _] =
mmaLayout.decodeVoltaLayoutStates();
assert(isARow == isARow_);
assert(isBRow == isBRow_);
DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = helper.getNumM(AShape[0], isARow, isAVec4_);
unsigned numN = helper.getNumN(BShape[1], isBRow, isBVec4_);
unsigned numM = ALayout.getMMAv1NumOuter(AShape);
unsigned numN = BLayout.getMMAv1NumOuter(BShape);
unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(adaptor.getA(), NK, rewriter);
auto hbs = helper.extractLoadedOperand(adaptor.getB(), NK, rewriter);
auto has = helper.extractLoadedOperand(adaptor.getA(), NK, rewriter,
getTypeConverter(), ATensorTy);
auto hbs = helper.extractLoadedOperand(adaptor.getB(), NK, rewriter,
getTypeConverter(), BTensorTy);
// Initialize accumulators with external values, the acc holds the
// accumulator value that is shared between the MMA instructions inside a
// DotOp, we can call the order of the values the accumulator-internal
// order.
SmallVector<Value> acc =
getElementsFromStruct(loc, adaptor.getC(), rewriter);
SmallVector<Value> acc = getTypeConverter()->unpackLLElements(
loc, adaptor.getC(), rewriter, DTensorTy);
size_t resSize = acc.size();
// The resVals holds the final result of the DotOp.
@@ -209,9 +209,8 @@ private:
resVals[i] = acc[i];
}
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(resSize, type::f32Ty(ctx)));
Value res = getStructFromElements(loc, resVals, rewriter, structTy);
Value res =
getTypeConverter()->packLLElements(loc, resVals, rewriter, DTensorTy);
rewriter.replaceOp(op, res);
return success();
}
@@ -236,7 +235,8 @@ private:
BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto order = dLayout.getOrder();
auto cc = getElementsFromStruct(loc, adaptor.getC(), rewriter);
auto cc = getTypeConverter()->unpackLLElements(loc, adaptor.getC(),
rewriter, dTensorTy);
DotOpFMAConversionHelper helper(dLayout);
Value llA = adaptor.getA();
@@ -259,9 +259,11 @@ private:
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
mSizePerThread, rewriter, loc);
mSizePerThread, rewriter, loc,
getTypeConverter(), aTensorTy);
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
nSizePerThread, rewriter, loc);
nSizePerThread, rewriter, loc,
getTypeConverter(), bTensorTy);
SmallVector<Value> ret = cc;
bool isCRow = order[0] == 1;
@@ -281,16 +283,15 @@ private:
}
}
auto res = getStructFromElements(
loc, ret, rewriter,
struct_ty(SmallVector<Type>(ret.size(), ret[0].getType())));
auto res =
getTypeConverter()->packLLElements(loc, ret, rewriter, dTensorTy);
rewriter.replaceOp(op, res);
return success();
}
};
void populateDotOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,

View File

@@ -6,7 +6,7 @@
using namespace mlir;
using namespace mlir::triton;
void populateDotOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,

View File

@@ -2,9 +2,6 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
struct FpToFpOpConversion
@@ -12,10 +9,14 @@ struct FpToFpOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
/* ------------------ */
// FP8 -> FP16
/* ------------------ */
static SmallVector<Value>
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const char *ptxAsm, const Value &v0, const Value &v1,
const Value &v2, const Value &v3) {
auto ctx = rewriter.getContext();
#ifdef USE_ROCM
auto fp8x4VecTy = vec_ty(i8_ty, 4);
@@ -63,23 +64,12 @@ struct FpToFpOpConversion
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto &ptxOp = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
ptxOp({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2x2StructTy =
@@ -96,72 +86,41 @@ struct FpToFpOpConversion
}
static SmallVector<Value>
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
#ifdef USE_ROCM
Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
Value a1 = shl(i32_ty, fp16x2Vec1, i32_val(1));
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
a0 = add(i32_ty, a0, i32_val(0x00800080));
a1 = add(i32_ty, a1, i32_val(0x00800080));
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
Value b1 = or_( i32_ty, and_(i32_ty, fp16x2Vec1, i32_val(0x80008000)), a1 );
auto fp8x4VecTy = vec_ty(i8_ty, 4);
b0 = bitcast(b0, fp8x4VecTy);
b1 = bitcast(b1, fp8x4VecTy);
return {extract_element(i8_ty, b0, i32_val(1)),
extract_element(i8_ty, b0, i32_val(3)),
extract_element(i8_ty, b1, i32_val(1)),
extract_element(i8_ty, b1, i32_val(3))
};
#else
PTXBuilder builder;
convertFp8E4M3x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
#endif
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
}
static SmallVector<Value>
convertFp8E5M2x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
}
/* ------------------ */
// FP8 -> BF16
/* ------------------ */
static SmallVector<Value>
convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const char *ptxAsm, const Value &v0, const Value &v1,
const Value &v2, const Value &v3) {
auto ctx = rewriter.getContext();
#ifdef USE_ROCM
auto fp8x4VecTy = vec_ty(i8_ty, 4);
@@ -178,7 +137,7 @@ struct FpToFpOpConversion
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = bitcast(a1, i32_ty);
Value sign0 = and_(i32_ty, a0, i32_val(0x80008000));
Value sign1 = and_(i32_ty, a1, i32_val(0x80008000));
Value nosign0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
@@ -211,27 +170,12 @@ struct FpToFpOpConversion
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"and.b32 sign0, a0, 0x80008000; \n"
"and.b32 sign1, a1, 0x80008000; \n"
"and.b32 nosign0, a0, 0x7fff7fff; \n"
"and.b32 nosign1, a1, 0x7fff7fff; \n"
"shr.b32 nosign0, nosign0, 4; \n"
"shr.b32 nosign1, nosign1, 4; \n"
"add.u32 nosign0, nosign0, 0x38003800; \n"
"add.u32 nosign1, nosign1, 0x38003800; \n"
"or.b32 $0, sign0, nosign0; \n"
"or.b32 $1, sign1, nosign1; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto &ptxOp = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
ptxOp({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2x2StructTy =
@@ -247,10 +191,166 @@ struct FpToFpOpConversion
#endif
}
static SmallVector<Value>
convertFp8E4M3x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"and.b32 sign0, a0, 0x80008000; \n"
"and.b32 sign1, a1, 0x80008000; \n"
"and.b32 nosign0, a0, 0x7fff7fff; \n"
"and.b32 nosign1, a1, 0x7fff7fff; \n"
"shr.b32 nosign0, nosign0, 4; \n"
"shr.b32 nosign1, nosign1, 4; \n"
"add.u32 nosign0, nosign0, 0x38003800; \n"
"add.u32 nosign1, nosign1, 0x38003800; \n"
"or.b32 $0, sign0, nosign0; \n"
"or.b32 $1, sign1, nosign1; \n"
"}";
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
};
static SmallVector<Value>
convertFp8E5M2x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5140; \n"
"prmt.b32 a1, 0, $2, 0x7362; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 3; \n"
"shr.b32 b1, b1, 3; \n"
"add.u32 b0, b0, 0x30003000; \n"
"add.u32 b1, b1, 0x30003000; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
};
/* ------------------ */
// FP16 -> FP8
/* ------------------ */
static SmallVector<Value>
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const char *ptxAsm, const Value &v0, const Value &v1,
const Value &v2, const Value &v3) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
#ifdef USE_ROCM
Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
Value a1 = shl(i32_ty, fp16x2Vec1, i32_val(1));
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
a0 = add(i32_ty, a0, i32_val(0x00800080));
a1 = add(i32_ty, a1, i32_val(0x00800080));
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
Value b1 = or_( i32_ty, and_(i32_ty, fp16x2Vec1, i32_val(0x80008000)), a1 );
auto fp8x4VecTy = vec_ty(i8_ty, 4);
b0 = bitcast(b0, fp8x4VecTy);
b1 = bitcast(b1, fp8x4VecTy);
return {extract_element(i8_ty, b0, i32_val(1)),
extract_element(i8_ty, b0, i32_val(3)),
extract_element(i8_ty, b1, i32_val(1)),
extract_element(i8_ty, b1, i32_val(3))
};
#else
PTXBuilder builder;
auto &ptxOp = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
ptxOp({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
#endif
}
static SmallVector<Value>
convertFp16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
}
static SmallVector<Value>
convertFp16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
"prmt.b32 $0, $1, $2, 0x7531; \n\t"
"}";
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
}
/* ------------------ */
// FP32 -> FP8
/* ------------------ */
static SmallVector<Value>
convertFp32x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8E4M3x4(loc, rewriter, c0, c1, c2, c3);
}
static SmallVector<Value>
convertFp32x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8E5M2x4(loc, rewriter, c0, c1, c2, c3);
}
/* ------------------ */
// BF16 -> FP8
/* ------------------ */
static SmallVector<Value>
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const char *ptxAsm, const Value &v0, const Value &v1,
const Value &v2, const Value &v3) {
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
@@ -318,6 +418,26 @@ struct FpToFpOpConversion
#else
PTXBuilder builder;
auto &ptxOp = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
ptxOp({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
#endif
}
static SmallVector<Value>
convertBf16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
@@ -354,49 +474,49 @@ struct FpToFpOpConversion
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
"or.b32 $0, nosign, sign; \n"
"}";
auto &call = *builder.create(ptxAsm);
return convertBf16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
};
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
// TODO:
// static SmallVector<Value>
// convertBf16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
// const Value &v0, const Value &v1, const Value &v2,
// const Value &v3) {
// }
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
#endif
/* ------------------ */
// FP8 -> FP32
/* ------------------ */
static SmallVector<Value>
convertFp8E4M3x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8E4M3x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {convertFp16ToFp32(loc, rewriter, fp16Values[0]),
convertFp16ToFp32(loc, rewriter, fp16Values[1]),
convertFp16ToFp32(loc, rewriter, fp16Values[2]),
convertFp16ToFp32(loc, rewriter, fp16Values[3])};
}
static SmallVector<Value>
convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
convertFp8E5M2x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8E5M2x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])};
}
static SmallVector<Value>
convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
//
static SmallVector<Value>
convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
convertFp8E4M3x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8E4M3x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
@@ -404,14 +524,14 @@ struct FpToFpOpConversion
}
static SmallVector<Value>
convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
convertFp64x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
return convertFp16x4ToFp8E4M3x4(loc, rewriter, c0, c1, c2, c3);
}
static Value convertBf16ToFp32(Location loc,
@@ -424,7 +544,7 @@ struct FpToFpOpConversion
return(bitcast(shifted, f32_ty));
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f32.bf16");
auto &cvt = *builder.create("cvt.f32.bf16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);
@@ -432,6 +552,17 @@ struct FpToFpOpConversion
#endif
}
static Value convertFp16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.f32.f16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
}
static Value convertFp32ToBf16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
@@ -452,6 +583,17 @@ struct FpToFpOpConversion
#endif
}
static Value convertFp32ToFp16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f16.f32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(v, "r");
cvt(res, operand);
return builder.launch(rewriter, loc, f16_ty, false);
}
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -463,58 +605,68 @@ struct FpToFpOpConversion
auto loc = op->getLoc();
auto elems = getElemsPerThread(dstTensorType);
SmallVector<Value> resultVals;
bool isSrcFP8 =
srcEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
bool isDstFP8 =
dstEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
// Select convertor
if (srcEltType.isa<triton::Float8Type>() ||
dstEltType.isa<triton::Float8Type>()) {
std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const Value &, const Value &,
const Value &, const Value &)>
convertor;
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
convertor = convertFp8x4ToFp16x4;
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
convertor = convertFp8x4ToBf16x4;
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertBf16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
convertor = convertFp8x4ToFp32x4;
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp32x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
convertor = convertFp8x4ToFp64x4;
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp64x4ToFp8x4;
} else {
assert(false && "unsupported fp8 casting");
}
typedef std::function<SmallVector<Value>(
Location, ConversionPatternRewriter &, const Value &, const Value &,
const Value &, const Value &)>
ConvertorT;
// Vectorized casting
assert(elems % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto elements = getElementsFromStruct(loc, adaptor.getFrom(), rewriter);
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
} else if (srcEltType.isBF16() && dstEltType.isF32()) {
resultVals.emplace_back(
convertBf16ToFp32(loc, rewriter, adaptor.getFrom()));
} else if (srcEltType.isF32() && dstEltType.isBF16()) {
resultVals.emplace_back(
convertFp32ToBf16(loc, rewriter, adaptor.getFrom()));
} else {
assert(false && "unsupported type casting");
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNType>();
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
auto F16TyID = TypeID::get<mlir::Float16Type>();
auto BF16TyID = TypeID::get<mlir::BFloat16Type>();
auto F32TyID = TypeID::get<mlir::Float32Type>();
auto F64TyID = TypeID::get<mlir::Float64Type>();
DenseMap<std::pair<TypeID, TypeID>, ConvertorT> convertorMap = {
// F8 -> F16
{{F8E4M3TyID, F16TyID}, convertFp8E4M3x4ToFp16x4},
{{F8E5M2TyID, F16TyID}, convertFp8E5M2x4ToFp16x4},
// F16 -> F8
{{F16TyID, F8E4M3TyID}, convertFp16x4ToFp8E4M3x4},
{{F16TyID, F8E5M2TyID}, convertFp16x4ToFp8E5M2x4},
// F8 -> BF16
{{F8E4M3TyID, BF16TyID}, convertFp8E4M3x4ToBf16x4},
{{F8E5M2TyID, BF16TyID}, convertFp8E5M2x4ToBf16x4},
// BF16 -> F8
{{BF16TyID, F8E4M3TyID}, convertBf16x4ToFp8E4M3x4},
// TODO:
// {{BF16TyID, F8E5M2TyID}, convertBf16x4ToFp8E5M2x4},
// F8 -> F32
{{F8E4M3TyID, F32TyID}, convertFp8E4M3x4ToFp32x4},
{{F8E5M2TyID, F32TyID}, convertFp8E5M2x4ToFp32x4},
// F32 -> F8
{{F32TyID, F8E4M3TyID}, convertFp32x4ToFp8E4M3x4},
{{F32TyID, F8E5M2TyID}, convertFp32x4ToFp8E5M2x4},
};
std::pair<TypeID, TypeID> key = {srcEltType.getTypeID(),
dstEltType.getTypeID()};
if (convertorMap.count(key) == 0) {
llvm::errs() << "Unsupported conversion from " << srcEltType << " to "
<< dstEltType << "\n";
llvm_unreachable("");
}
auto convertor = convertorMap.lookup(key);
// Vectorized casting
assert(elems % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto elements = getTypeConverter()->unpackLLElements(
loc, adaptor.getFrom(), rewriter, srcTensorType);
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
assert(resultVals.size() == elems);
auto convertedDstTensorType =
this->getTypeConverter()->convertType(dstTensorType);
auto result = getStructFromElements(loc, resultVals, rewriter,
convertedDstTensorType);
auto result = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
dstTensorType);
rewriter.replaceOp(op, result);
return success();
}
@@ -536,8 +688,8 @@ class ElementwiseOpConversionBase
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
explicit ElementwiseOpConversionBase(
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
@@ -553,7 +705,7 @@ public:
Type structTy = this->getTypeConverter()->convertType(resultTy);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto operands = getOperands(rewriter, adaptor, elems, loc);
auto operands = getOperands(rewriter, adaptor, resultTy, elems, loc);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
@@ -561,7 +713,8 @@ public:
if (!bool(resultVals[i]))
return failure();
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
return success();
@@ -570,10 +723,11 @@ public:
protected:
SmallVector<SmallVector<Value>>
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
const unsigned elems, Location loc) const {
Type operandTy, const unsigned elems, Location loc) const {
SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) {
auto sub_operands = getElementsFromStruct(loc, operand, rewriter);
auto sub_operands = this->getTypeConverter()->unpackLLElements(
loc, operand, rewriter, operandTy);
for (size_t i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]);
}
@@ -994,15 +1148,14 @@ struct ExpOpConversionApprox
}
};
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation,
Value smem, PatternBenefit benefit) {
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp)
#undef POPULATE_TERNARY_OP
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
@@ -1056,9 +1209,160 @@ void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
// For FP64 input type, ExpOpConversionApprox will return failure and
// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For FP64 input type, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
}
struct FPExtOpConversion
: ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::FPExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF32() && srcTy.isF16()) {
return false;
}
return true;
}
Value createDestOp(LLVM::FPExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0]);
}
};
struct FPTruncOpConversion
: ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion> {
using Base =
ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::FPTruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF16() && srcTy.isF32()) {
return false;
}
return true;
}
Value createDestOp(LLVM::FPTruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0]);
}
};
struct TruncOpConversion
: ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::TruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(16) && srcTy.isInteger(32)) {
return false;
}
return true;
}
Value createDestOp(LLVM::TruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u16.u32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(operands[0], "r");
cvt(res, operand);
return builder.launch(rewriter, loc, i16_ty, false);
}
};
struct SExtOpConversion
: ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::SExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}
Value createDestOp(LLVM::SExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.s32.s16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0], "h");
cvt(res, operand);
return builder.launch(rewriter, loc, i32_ty, false);
}
};
struct ZExtOpConversion
: ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::ZExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}
Value createDestOp(LLVM::ZExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u32.u16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0], "h");
cvt(res, operand);
return builder.launch(rewriter, loc, i32_ty, false);
}
};
bool isLegalElementwiseOp(Operation *op) {
if (isa<LLVM::FPExtOp>(op)) {
return FPExtOpConversion::isLegalOp(cast<LLVM::FPExtOp>(op));
} else if (isa<LLVM::FPTruncOp>(op)) {
return FPTruncOpConversion::isLegalOp(cast<LLVM::FPTruncOp>(op));
} else if (isa<LLVM::TruncOp>(op)) {
return TruncOpConversion::isLegalOp(cast<LLVM::TruncOp>(op));
} else if (isa<LLVM::SExtOp>(op)) {
return SExtOpConversion::isLegalOp(cast<LLVM::SExtOp>(op));
} else if (isa<LLVM::ZExtOp>(op)) {
return ZExtOpConversion::isLegalOp(cast<LLVM::ZExtOp>(op));
}
return true;
}
void populateElementwiseOpToPTXPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<FPExtOpConversion>(typeConverter, benefit);
patterns.add<FPTruncOpConversion>(typeConverter, benefit);
patterns.add<TruncOpConversion>(typeConverter, benefit);
patterns.add<SExtOpConversion>(typeConverter, benefit);
patterns.add<ZExtOpConversion>(typeConverter, benefit);
}

View File

@@ -6,11 +6,15 @@
using namespace mlir;
using namespace mlir::triton;
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation,
Value smem, PatternBenefit benefit);
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit);
bool isLegalElementwiseOp(Operation *op);
void populateElementwiseOpToPTXPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit);
#endif

View File

@@ -7,9 +7,7 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -18,19 +16,6 @@ struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: axisAnalysisPass(axisAnalysisPass) {}
// Get corresponding LLVM element values of \param value.
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
ConversionPatternRewriter &rewriter,
Location loc) {
if (!value)
return {};
if (!llValue.getType().isa<LLVM::LLVMStructType>())
return {llValue};
// Here, we assume that all inputs should have a blockedLayout
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
return valueVals;
}
unsigned getContiguity(Value ptr) const {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
@@ -62,7 +47,7 @@ struct LoadOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LoadOpConversion(LLVMTypeConverter &converter,
LoadOpConversion(TritonGPUToLLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
@@ -92,13 +77,15 @@ struct LoadOpConversion
vec = std::min<size_t>(vec, getMaskAlignment(mask));
// Get the LLVM values for pointers
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
ptr.getType());
assert(ptrElems.size() == numElems);
// Get the LLVM values for mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
mask.getType());
assert(maskElems.size() == numElems);
}
@@ -114,7 +101,11 @@ struct LoadOpConversion
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
auto otherElems = getLLVMElems(other, llOther, rewriter, loc);
SmallVector<Value> otherElems;
if (other) {
otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter,
other.getType());
}
// vectorized iteration through all the pointer/mask/other elements
const int valueElemNbits =
@@ -283,8 +274,8 @@ struct LoadOpConversion
} // end vec
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, loadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
@@ -296,7 +287,7 @@ struct StoreOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(LLVMTypeConverter &converter,
StoreOpConversion(TritonGPUToLLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
@@ -305,7 +296,6 @@ struct StoreOpConversion
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.getPtr();
Value mask = op.getMask();
Value value = op.getValue();
Value llPtr = adaptor.getPtr();
@@ -320,28 +310,40 @@ struct StoreOpConversion
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
unsigned elemsPerThread = getElemsPerThread(ptr.getType());
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
auto valueElems = getLLVMElems(value, llValue, rewriter, loc);
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
ptr.getType());
auto valueElems = getTypeConverter()->unpackLLElements(
loc, llValue, rewriter, value.getType());
assert(ptrElems.size() == valueElems.size());
// Determine the vectorization size
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
Value mask = op.getMask();
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
mask.getType());
assert(valueElems.size() == maskElems.size());
unsigned maskAlign = getMaskAlignment(mask);
vec = std::min(vec, maskAlign);
}
// numElements = 1 for scalar
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNbits = dtsize * 8;
const int numVecs = numElems / vec;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
const int numVecs = elemsPerThread / vec;
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;
@@ -350,7 +352,7 @@ struct StoreOpConversion
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
assert(wordNElems * nWords * numVecs == elemsPerThread);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
@@ -368,13 +370,13 @@ struct StoreOpConversion
assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = sext(i8_ty, elem);
elem = bitcast(elem, valueElemTy);
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);
#ifdef USE_ROCM
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
rewriter.create<scf::IfOp>(loc, maskVal,
[&](OpBuilder &builder, Location loc){
auto storeOp = builder.create<LLVM::StoreOp>(loc, llWord, ptrElems[vecStart + wordIdx * wordNElems]);
@@ -393,7 +395,7 @@ struct StoreOpConversion
PTXBuilder ptxBuilder;
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
auto *asmAddr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
@@ -422,7 +424,7 @@ struct AtomicCASOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicCASOpConversion(LLVMTypeConverter &converter,
AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
@@ -442,13 +444,16 @@ struct AtomicCASOpConversion
Value llCmp = adaptor.getCmp();
Value llVal = adaptor.getVal();
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, op.getPtr().getType());
auto cmpElements = getTypeConverter()->unpackLLElements(
loc, llCmp, rewriter, op.getCmp().getType());
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
@@ -508,14 +513,17 @@ struct AtomicCASOpConversion
Value llCmp = adaptor.getCmp();
Value llVal = adaptor.getVal();
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, op.getPtr().getType());
auto cmpElements = getTypeConverter()->unpackLLElements(
loc, llCmp, rewriter, op.getCmp().getType());
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
PTXBuilder ptxBuilderMemfence;
@@ -565,7 +573,7 @@ struct AtomicRMWOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicRMWOpConversion(LLVMTypeConverter &converter,
AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
@@ -617,31 +625,36 @@ struct AtomicRMWOpConversion
Value llVal = adaptor.getVal();
Value llMask = adaptor.getMask();
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, val.getType());
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, ptr.getType());
SmallVector<Value> maskElements;
if (llMask)
maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, op.getMask().getType());
Value opResult = op.getResult();
auto valueTy = opResult.getType().dyn_cast<RankedTensorType>();
auto tensorTy = opResult.getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: opResult.getType();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
// vec = 1 for scalar
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
Value mask = int_val(1, 1);
auto tid = tid_val();
int numElems = 1;
// tensor
if (valueTy) {
if (tensorTy) {
auto valTy = val.getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
// mask
auto shape = valueTy.getShape();
auto numElements = product(shape);
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
@@ -654,11 +667,7 @@ struct AtomicRMWOpConversion
}
Value rmwPtr = ptrElements[i];
Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask);
if (!valueTy) {
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
}
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
Value undefVal = undef(valueElemTy);
// Build blocks to bypass the atomic instruction for ~rmwMask.
@@ -683,11 +692,11 @@ struct AtomicRMWOpConversion
rewriter.setInsertionPointToStart(endBlock);
Value retVal = endBlock->getArgument(0);
if (valueTy) {
if (tensorTy) {
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? retVal
: extract_element(valueElemTy, retVal, idx_val(ii));
: extract_element(valueElemTy, retVal, i32_val(ii));
}
} else {
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
@@ -697,10 +706,10 @@ struct AtomicRMWOpConversion
rewriter.replaceOp(op, {ret});
}
}
if (valueTy) {
Type structTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, structTy);
if (tensorTy) {
Type structTy = getTypeConverter()->convertType(tensorTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
@@ -715,37 +724,43 @@ struct AtomicRMWOpConversion
MLIRContext *ctx = rewriter.getContext();
auto atomicRmwAttr = op.getAtomicRmwOp();
Value ptr = op.getPtr();
Value val = op.getVal();
Value ptr = op.getPtr();
Value llPtr = adaptor.getPtr();
Value llVal = adaptor.getVal();
Value llMask = adaptor.getMask();
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, val.getType());
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, ptr.getType());
SmallVector<Value> maskElements;
if (llMask)
maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, op.getMask().getType());
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
// vec = 1 for scalar
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
Value mask = int_val(1, 1);
auto tid = tid_val();
int numElems = 1;
// tensor
if (valueTy) {
if (tensorTy) {
auto valTy = val.getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
// mask
auto shape = valueTy.getShape();
auto numElements = product(shape);
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
@@ -758,8 +773,7 @@ struct AtomicRMWOpConversion
}
Value rmwPtr = ptrElements[i];
Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask);
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
std::string tyId = valueElemNbits * vec == 64
@@ -812,13 +826,13 @@ struct AtomicRMWOpConversion
return failure();
}
atom.o(rmwOp).o(sTy);
if (valueTy) {
if (tensorTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto retType = vec == 1 ? valueElemTy : vecTy;
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
}
} else {
PTXBuilder ptxBuilderMemfence;
@@ -838,10 +852,10 @@ struct AtomicRMWOpConversion
rewriter.replaceOp(op, {ret});
}
}
if (valueTy) {
Type structTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, structTy);
if (tensorTy) {
Type structTy = getTypeConverter()->convertType(tensorTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
@@ -904,7 +918,7 @@ struct InsertSliceOpConversion
auto smemBase = gep(elemPtrTy, smemObj.base, offset);
auto llSrc = adaptor.getSource();
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
elemTy, loc, rewriter);
// Barrier is not necessary.
@@ -922,7 +936,8 @@ struct InsertSliceAsyncOpConversion
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(
LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation,
Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
@@ -958,7 +973,8 @@ struct InsertSliceAsyncOpConversion
Value llIndex = adaptor.getIndex();
// %src
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
auto srcElems = getTypeConverter()->unpackLLElements(loc, llSrc, rewriter,
src.getType());
// %dst
auto dstTy = dst.getType().cast<RankedTensorType>();
@@ -984,7 +1000,8 @@ struct InsertSliceAsyncOpConversion
// %mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
mask.getType());
assert(srcElems.size() == maskElems.size());
}
@@ -995,7 +1012,8 @@ struct InsertSliceAsyncOpConversion
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
// assert(false && "insert_slice_async: Other value not supported yet");
otherElems = getLLVMElems(other, llOther, rewriter, loc);
otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter,
other.getType());
assert(srcElems.size() == otherElems.size());
}
@@ -1025,7 +1043,7 @@ struct InsertSliceAsyncOpConversion
// single vector read into multiple ones
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// 16 * 8 = 128bits
@@ -1075,7 +1093,7 @@ struct InsertSliceAsyncOpConversion
};
void populateLoadStoreOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,

View File

@@ -7,7 +7,7 @@ using namespace mlir;
using namespace mlir::triton;
void populateLoadStoreOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,

View File

@@ -3,8 +3,6 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::shflSync;
using ::mlir::LLVM::storeShared;
using ::mlir::triton::gpu::getElemsPerThread;
@@ -160,11 +158,12 @@ private:
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.getOperand(), rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
auto srcValues = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperand(), rewriter, srcTy);
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcShape);
emitOffsetForLayout(srcLayout, srcTy);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
@@ -247,8 +246,7 @@ private:
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices =
emitIndices(loc, rewriter, resultLayout, resultShape);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
@@ -260,12 +258,8 @@ private:
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
}
SmallVector<Type> resultTypes(resultElems,
withIndex ? llvmIndexTy : llvmElemTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
resultTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
@@ -310,11 +304,12 @@ private:
unsigned sizeInterWarps = helper.getInterWarpSize();
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.getOperand(), rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
auto srcValues = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperand(), rewriter, srcTy);
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcShape);
emitOffsetForLayout(srcLayout, srcTy);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
@@ -457,8 +452,7 @@ private:
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices =
emitIndices(loc, rewriter, resultLayout, resultShape);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
@@ -471,12 +465,8 @@ private:
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
}
SmallVector<Type> resultTypes(resultElems,
withIndex ? llvmIndexTy : llvmElemTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
resultTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
@@ -489,7 +479,7 @@ private:
};
void populateReduceOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,

View File

@@ -7,7 +7,7 @@ using namespace mlir;
using namespace mlir::triton;
void populateReduceOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,

View File

@@ -1,13 +1,12 @@
#include "TritonGPUToLLVM.h"
#include "DotOpHelpers.h"
#include "Utility.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -66,9 +65,10 @@ struct BroadcastOpConversion
assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> srcVals =
getTypeConverter()->unpackLLElements(loc, src, rewriter, srcTy);
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for (size_t i = 0; i < srcOffsets.size(); i++) {
@@ -84,27 +84,26 @@ struct BroadcastOpConversion
resultVals.push_back(srcValues.lookup(offset));
}
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
getTypeConverter()->packLLElements(loc, resultVals, rewriter, resultTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
};
struct PrintfOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
struct PrintOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
triton::PrintOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
SmallVector<Value, 16> operands;
for (auto operand : adaptor.getOperands()) {
auto sub_operands = getElementsFromStruct(loc, operand, rewriter);
for (size_t i = 0; i < op.getNumOperands(); i++) {
auto sub_operands = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
for (auto elem : sub_operands) {
operands.push_back(elem);
}
@@ -170,22 +169,20 @@ struct PrintfOpConversion
auto type = value.getType();
Value newOp = value;
Type newType = type;
auto loc = UnknownLoc::get(context);
bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) {
if (bUnsigned) {
newType = ui32_ty;
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
value);
newOp = zext(newType, value);
} else {
newType = i32_ty;
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
value);
newOp = sext(newType, value);
}
} else if (type.isBF16() || type.isF16() || type.isF32()) {
newType = f64_ty;
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
value);
newOp = fpext(newType, value);
}
return {newType, newOp};
@@ -193,51 +190,24 @@ struct PrintfOpConversion
static void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
static const char formatStringPrefix[] = "printfFormat_";
assert(!msg.empty() && "printf with empty string not support");
Type int8Ptr = ptr_ty(i8_ty);
auto *context = rewriter.getContext();
auto *ctx = rewriter.getContext();
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto funcOp = getVprintfDeclaration(rewriter);
auto loc = UnknownLoc::get(ctx);
Value one = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
Value zero = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
Value one = i32_val(1);
Value zero = i32_val(0);
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
llvm::SmallString<64> formatString(msg);
formatString.push_back('\n');
formatString.push_back('\0');
size_t formatStringSize = formatString.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(context), globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
Value globalPtr =
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
Value stringStart = rewriter.create<LLVM::GEPOp>(
UnknownLoc::get(context), int8Ptr, globalPtr,
SmallVector<Value>({zero, zero}));
Value bufferPtr =
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value prefixString =
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
Value bufferPtr = null(int8Ptr);
SmallVector<Value, 16> newArgs;
if (args.size() >= 1) {
@@ -250,27 +220,121 @@ struct PrintfOpConversion
newArgs.push_back(newArg);
}
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
ptr_ty(structTy), one,
/*alignment=*/0);
Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes);
auto allocated =
rewriter.create<LLVM::AllocaOp>(loc, ptr_ty(structTy), one,
/*alignment=*/0);
for (const auto &entry : llvm::enumerate(newArgs)) {
auto index = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty,
rewriter.getI32IntegerAttr(entry.index()));
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
allocated, ArrayRef<Value>{zero, index});
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
fieldPtr);
auto index = i32_val(entry.index());
auto fieldPtr = gep(ptr_ty(argTypes[entry.index()]), allocated,
ArrayRef<Value>{zero, index});
store(entry.value(), fieldPtr);
}
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
int8Ptr, allocated);
bufferPtr = bitcast(allocated, int8Ptr);
}
SmallVector<Value> operands{stringStart, bufferPtr};
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
SmallVector<Value> operands{prefixString, bufferPtr};
call(funcOp, operands);
}
};
struct AssertOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctx = rewriter.getContext();
auto elems = getTypeConverter()->unpackLLElements(
loc, adaptor.getCondition(), rewriter, op.getCondition().getType());
auto elemTy = elems[0].getType();
Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0);
for (auto elem : elems) {
if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) {
condition =
or_(condition,
icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, elemTy, rewriter.getZeroAttr(elemTy))));
} else {
assert(false && "Unsupported type for assert");
return failure();
}
}
llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(),
adaptor.getFunc(), adaptor.getLine(), rewriter);
rewriter.eraseOp(op);
return success();
}
// op: the op at which the assert is inserted. Unlike printf, we need to
// know about the op to split the block.
static void llAssert(Operation *op, Value condition, StringRef message,
StringRef file, StringRef func, int line,
ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
auto ctx = rewriter.getContext();
auto loc = op->getLoc();
// #block1
// if (condition) {
// #block2
// __assertfail(message);
// }
// #block3
Block *prevBlock = op->getBlock();
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
rewriter.setInsertionPointToStart(ifBlock);
auto funcOp = getAssertfailDeclaration(rewriter);
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
Value messageString =
LLVM::addStringToModule(loc, rewriter, "assertMessage_", message);
Value fileString =
LLVM::addStringToModule(loc, rewriter, "assertFile_", file);
Value funcString =
LLVM::addStringToModule(loc, rewriter, "assertFunc_", func);
Value lineNumber = i32_val(line);
Value charSize = int_val(sizeof(size_t) * 8, sizeof(char));
SmallVector<Value> operands = {messageString, fileString, lineNumber,
funcString, charSize};
auto ret = call(funcOp, operands);
// Split a block after the call.
Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator());
rewriter.setInsertionPointToEnd(ifBlock);
rewriter.create<cf::BranchOp>(loc, thenBlock);
rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
}
static LLVM::LLVMFuncOp
getAssertfailDeclaration(ConversionPatternRewriter &rewriter) {
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("__assertfail");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);
// void __assert_fail(const char * assertion, const char * file, unsigned
// int line, const char * function);
auto *ctx = rewriter.getContext();
SmallVector<Type> argsType{ptr_ty(i8_ty), ptr_ty(i8_ty), i32_ty,
ptr_ty(i8_ty),
rewriter.getIntegerType(sizeof(size_t) * 8)};
auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(ctx), funcName,
funcType);
}
};
@@ -278,7 +342,7 @@ struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
MakeRangeOpConversion(
LLVMTypeConverter &converter,
TritonGPUToLLVMTypeConverter &converter,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
@@ -289,14 +353,14 @@ struct MakeRangeOpConversion
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto rankedTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto rankedTy = op.getResult().getType().cast<RankedTensorType>();
auto shape = rankedTy.getShape();
auto layout = rankedTy.getEncoding();
auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart());
auto idxs = emitIndices(loc, rewriter, layout, shape);
auto idxs = emitIndices(loc, rewriter, layout, rankedTy);
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
// TODO: slice layout has more elements than expected.
@@ -306,9 +370,8 @@ struct MakeRangeOpConversion
assert(multiDim.value().size() == 1);
retVals[multiDim.index()] = add(multiDim.value()[0], start);
}
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
Value result = getStructFromElements(loc, retVals, rewriter, structTy);
Value result =
getTypeConverter()->packLLElements(loc, retVals, rewriter, rankedTy);
rewriter.replaceOp(op, result);
return success();
}
@@ -325,11 +388,9 @@ struct GetProgramIdOpConversion
Location loc = op->getLoc();
assert(op.getAxis() < 3);
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(
loc, rewriter.getIndexType(), dims[op.getAxis()]);
auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
Value blockId =
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxis()]);
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
return success();
}
@@ -349,11 +410,10 @@ struct GetNumProgramsOpConversion
Location loc = op->getLoc();
assert(op.getAxis() < 3);
Value blockId = rewriter.create<::mlir::gpu::GridDimOp>(
loc, rewriter.getIndexType(), dims[op.getAxis()]);
auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
Value blockId =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
return success();
}
@@ -372,20 +432,23 @@ struct AddPtrOpConversion
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType();
auto offsetTy = op.getOffset().getType();
auto ptrTy = op.getPtr().getType();
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
if (resultTensorTy) {
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto ptrs = getElementsFromStruct(loc, adaptor.getPtr(), rewriter);
auto offsets = getElementsFromStruct(loc, adaptor.getOffset(), rewriter);
auto ptrs = getTypeConverter()->unpackLLElements(loc, adaptor.getPtr(),
rewriter, ptrTy);
auto offsets = getTypeConverter()->unpackLLElements(
loc, adaptor.getOffset(), rewriter, offsetTy);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
Value view = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
resultTy);
rewriter.replaceOp(op, view);
} else {
assert(resultTy.isa<triton::PointerType>());
@@ -431,12 +494,12 @@ struct AllocTensorOpConversion
};
struct ExtractSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<tensor::ExtractSliceOp> {
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
using ConvertTritonGPUOpToLLVMPattern<
tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// %dst = extract_slice %src[%offsets]
Location loc = op->getLoc();
@@ -531,7 +594,7 @@ namespace LLVM {
void vprintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
PrintfOpConversion::llPrintf(msg, args, rewriter);
PrintOpConversion::llPrintf(msg, args, rewriter);
}
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
@@ -550,7 +613,7 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
} // namespace mlir
void populateTritonGPUToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
@@ -568,5 +631,6 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
patterns.add<PrintOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
}

View File

@@ -7,7 +7,7 @@ using namespace mlir;
using namespace mlir::triton;
void populateTritonGPUToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,

View File

@@ -7,6 +7,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "triton/Analysis/Allocation.h"
#include "TypeConverter.h"
//
#include "DotOpHelpers.h"
#include "Utility.h"
@@ -18,6 +19,7 @@ using namespace mlir::triton;
using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -141,25 +143,26 @@ protected:
}
};
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
struct CacheKeyDenseMapInfo {
static IndexCacheKeyT getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
SmallVector<int64_t>{});
RankedTensorType{});
}
static IndexCacheKeyT getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
auto tombstone = llvm::DenseMapInfo<RankedTensorType>::getTombstoneKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
SmallVector<int64_t>{std::numeric_limits<int64_t>::max()});
tombstone);
}
static unsigned getHashValue(IndexCacheKeyT key) {
return llvm::hash_combine(
mlir::hash_value(key.first),
llvm::hash_combine_range(key.second.begin(), key.second.end()));
auto shape = key.second.getShape();
return llvm::hash_combine(mlir::hash_value(key.first),
mlir::hash_value(key.second));
}
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
return LHS == RHS;
@@ -178,22 +181,22 @@ public:
OpBuilder::InsertPoint *indexInsertPoint;
};
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter)
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter)
: converter(&typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem)
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
Value smem)
: converter(&typeConverter), allocation(allocation), smem(smem) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo)
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
Value smem, IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), allocation(allocation), smem(smem),
indexCacheInfo(indexCacheInfo) {}
LLVMTypeConverter *getTypeConverter() const { return converter; }
TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; }
static Value
getStructFromSharedMemoryObject(Location loc,
@@ -203,18 +206,20 @@ public:
auto types = smemObj.getTypes();
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
return getStructFromElements(loc, elems, rewriter, structTy);
// pack into struct
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
for (const auto &v : llvm::enumerate(elems)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index());
}
return llvmStruct;
}
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto cast = rewriter.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
Value threadId = cast.getResult(0);
return threadId;
auto tid = rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
}
// -----------------------------------------------------------------------
@@ -223,13 +228,12 @@ public:
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
Value offVal = idx_val(offset);
Value offVal = i32_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}
@@ -244,8 +248,8 @@ public:
// This utililty computes the pointers for accessing the provided swizzled
// shared memory layout `resSharedLayout`. More specifically, it computes,
// for all indices (row, col) of `srcEncoding` such that idx % inVec = 0,
// the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] + colOff)
// where :
// the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] +
// colOff) where :
// compute phase = (row // perPhase) % maxPhase
// rowOff = row
// colOff = colOffSwizzled + colOffOrdered
@@ -255,8 +259,8 @@ public:
// Note 1:
// -------
// Because swizzling happens at a granularity of outVec, we need to
// decompose the offset into a swizzled factor and a non-swizzled (ordered)
// factor
// decompose the offset into a swizzled factor and a non-swizzled
// (ordered) factor
//
// Note 2:
// -------
@@ -282,7 +286,7 @@ public:
auto inOrder = triton::gpu::getOrder(srcEncoding);
auto outOrder = triton::gpu::getOrder(resSharedLayout);
// tensor indices held by the current thread, as LLVM values
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcShape);
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy);
// return values
DenseMap<unsigned, Value> ret;
// cache for non-immediate offsets
@@ -376,7 +380,8 @@ public:
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = triton::gpu::getElemsPerThread(srcTy);
assert(numElems == srcIndices.size());
auto inVals = LLVM::getElementsFromStruct(loc, llSrc, rewriter);
auto inVals =
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy);
Value outVecVal = i32_val(outVec);
@@ -435,7 +440,7 @@ public:
} else {
Value remained = linear;
for (auto &&en : llvm::enumerate(shape.drop_back())) {
Value dimSize = idx_val(en.value());
Value dimSize = i32_val(en.value());
multiDim[en.index()] = urem(remained, dimSize);
remained = udiv(remained, dimSize);
}
@@ -454,12 +459,12 @@ public:
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
auto rank = multiDim.size();
Value linear = idx_val(0);
Value linear = i32_val(0);
if (rank > 0) {
linear = multiDim.back();
for (auto [dim, dimShape] :
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
Value dimSize = idx_val(dimShape);
Value dimSize = i32_val(dimShape);
linear = add(mul(linear, dimSize), dim);
}
}
@@ -469,7 +474,7 @@ public:
Value dot(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
assert(offsets.size() == strides.size());
Value ret = idx_val(0);
Value ret = i32_val(0);
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
ret = add(ret, mul(offset, stride));
}
@@ -499,8 +504,8 @@ public:
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape));
RankedTensorType type) const {
IndexCacheKeyT key = std::make_pair(layout, type);
auto cache = indexCacheInfo.baseIndexCache;
assert(cache && "baseIndexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
@@ -512,12 +517,12 @@ public:
SmallVector<Value> result;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
result =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, type);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
if (mmaLayout.isAmpere())
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
@@ -528,14 +533,14 @@ public:
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
emitOffsetForLayout(const Attribute &layout, RankedTensorType type) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, shape);
return emitOffsetForBlockedLayout(blockedLayout, type);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
return emitOffsetForMmaLayoutV1(mmaLayout, type);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
return emitOffsetForMmaLayoutV2(mmaLayout, type);
}
llvm_unreachable("unsupported emitOffsetForLayout");
}
@@ -546,8 +551,8 @@ public:
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
IndexCacheKeyT key(layout, llvm::to_vector(shape));
RankedTensorType type) const {
IndexCacheKeyT key(layout, type);
auto cache = indexCacheInfo.indexCache;
assert(cache && "indexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
@@ -558,11 +563,11 @@ public:
restoreInsertionPointIfSet(insertPt, b);
SmallVector<SmallVector<Value>> result;
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, blocked, shape);
result = emitIndicesForDistributedLayout(loc, b, blocked, type);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mma, shape);
result = emitIndicesForDistributedLayout(loc, b, mma, type);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForSliceLayout(loc, b, slice, shape);
result = emitIndicesForSliceLayout(loc, b, slice, type);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
@@ -591,16 +596,15 @@ private:
// -----------------------------------------------------------------------
// Get an index-base for each dimension for a \param blocked_layout.
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc,
ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
SmallVector<Value> emitBaseIndexForBlockedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blocked_layout, RankedTensorType type) const {
auto shape = type.getShape();
Value threadId = getThreadId(rewriter, loc);
#ifdef USE_ROCM
Value warpSize = idx_val(64);
Value warpSize = i32_val(64);
#else
Value warpSize = idx_val(32);
Value warpSize = i32_val(32);
#endif
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
@@ -623,13 +627,13 @@ private:
auto maxWarps =
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
auto maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps));
multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads));
multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps));
multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads));
// multiDimBase[k] = (multiDimThreadId[k] +
// multiDimWarpId[k] * threadsPerWarp[k]) *
// sizePerThread[k];
Value threadsPerWarpK = idx_val(threadsPerWarp[k]);
Value sizePerThreadK = idx_val(sizePerThread[k]);
Value threadsPerWarpK = i32_val(threadsPerWarp[k]);
Value sizePerThreadK = i32_val(sizePerThread[k]);
multiDimBase[k] =
mul(sizePerThreadK, add(multiDimThreadId[k],
mul(multiDimWarpId[k], threadsPerWarpK)));
@@ -639,7 +643,8 @@ private:
SmallVector<SmallVector<unsigned>>
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
ArrayRef<int64_t> shape) const {
RankedTensorType type) const {
auto shape = type.getShape();
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
@@ -668,7 +673,7 @@ private:
threadOffset * sizePerThread[k] + elemOffset);
}
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
unsigned elemsPerThread = triton::gpu::getElemsPerThread(type);
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
for (unsigned n = 0; n < elemsPerThread; ++n) {
@@ -696,11 +701,12 @@ private:
SmallVector<Value>
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
RankedTensorType type) const {
auto shape = type.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
auto fpw = LLVM::DotOpMmaV1ConversionHelper::fpw;
auto [isARow, isBRow, isAVec4, isBVec4, id] =
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
Value thread = getThreadId(rewriter, loc);
@@ -717,11 +723,17 @@ private:
Value _fpw0 = i32_val(fpw[0]);
Value _fpw1 = i32_val(fpw[1]);
LLVM::DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
LLVM::DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
SmallVector<int, 2> rep({aRep[0], bRep[1]});
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
Value lane = urem(thread, warpSize);
@@ -764,16 +776,28 @@ private:
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
RankedTensorType type) const {
auto shape = type.getShape();
auto [isARow, isBRow, isAVec4, isBVec4, id] =
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
LLVM::DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
LLVM::DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
// A info
auto aEncoding =
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding =
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();
auto wpt = mmaLayout.getWarpsPerCTA();
auto fpw = LLVM::DotOpMmaV1ConversionHelper::fpw;
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
SmallVector<int, 2> rep({aRep[0], bRep[1]});
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
SmallVector<unsigned> idxM;
@@ -804,34 +828,36 @@ private:
SmallVector<Value>
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
RankedTensorType type) const {
auto shape = type.getShape();
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
idx_val(_warpsPerCTA[1])};
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
Value threadId = getThreadId(rewriter, loc);
#ifdef USE_ROCM
Value warpSize = idx_val(64);
Value warpSize = i32_val(64);
#else
Value warpSize = idx_val(32);
Value warpSize = i32_val(32);
#endif
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), idx_val(shape[0] / 16));
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / 16));
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
idx_val(shape[1] / 8));
Value offWarp0 = mul(warpId0, idx_val(16));
Value offWarp1 = mul(warpId1, idx_val(8));
i32_val(shape[1] / 8));
Value offWarp0 = mul(warpId0, i32_val(16));
Value offWarp1 = mul(warpId1, i32_val(8));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);
multiDimBase[1] = add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarp1);
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
RankedTensorType type) const {
auto shape = type.getShape();
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
@@ -849,31 +875,34 @@ private:
// [elemsPerThread X rank] index matrix.
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const Attribute &layout, ArrayRef<int64_t> shape) const {
const Attribute &layout, RankedTensorType type) const {
// step 1, delinearize threadId to get the base index
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
// step 2, get offset of each element
auto offset = emitOffsetForLayout(layout, shape);
auto offset = emitOffsetForLayout(layout, type);
// step 3, add offset to base, and reorder the sequence of indices to
// guarantee that elems in the same sizePerThread are adjacent in order
auto shape = type.getShape();
unsigned rank = shape.size();
unsigned elemsPerThread = offset.size();
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
SmallVector<Value>(rank));
for (unsigned n = 0; n < elemsPerThread; ++n)
for (unsigned k = 0; k < rank; ++k)
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
multiDimIdx[n][k] = add(multiDimBase[k], i32_val(offset[n][k]));
return multiDimIdx;
}
SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
ArrayRef<int64_t> shape) const {
auto parent = sliceLayout.getParent();
RankedTensorType type) const {
auto parentEncoding = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
auto parentIndices =
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentEncoding);
auto parentIndices = emitIndices(loc, rewriter, parentEncoding, parentTy);
unsigned numIndices = parentIndices.size();
SmallVector<SmallVector<Value>> resultIndices;
for (unsigned i = 0; i < numIndices; ++i) {
@@ -885,7 +914,7 @@ private:
}
protected:
LLVMTypeConverter *converter;
TritonGPUToLLVMTypeConverter *converter;
const Allocation *allocation;
Value smem;
IndexCacheInfo indexCacheInfo;
@@ -898,30 +927,29 @@ class ConvertTritonGPUOpToLLVMPattern
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
PatternBenefit benefit = 1)
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
Value smem, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo,
PatternBenefit benefit = 1)
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
Value smem, IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
indexCacheInfo) {}
protected:
LLVMTypeConverter *getTypeConverter() const {
return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
TritonGPUToLLVMTypeConverter *getTypeConverter() const {
LLVMTypeConverter *ret =
((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
return (TritonGPUToLLVMTypeConverter *)ret;
}
};

View File

@@ -5,8 +5,11 @@
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -26,39 +29,25 @@
#include "TypeConverter.h"
#include "ViewOpToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
namespace mlir {
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
#ifdef USE_ROCM
addLegalDialect<ROCDL::ROCDLDialect>();
addLegalDialect<mlir::scf::SCFDialect>();
#else
addLegalDialect<NVVM::NVVMDialect>();
#endif
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
namespace {
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<index::IndexDialect>();
addLegalDialect<LLVM::LLVMDialect>();
#ifdef USE_ROCM
addLegalDialect<ROCDL::ROCDLDialect>();
addLegalDialect<mlir::scf::SCFDialect>();
#else
addLegalDialect<NVVM::NVVMDialect>();
#endif
@@ -67,9 +56,40 @@ public:
}
};
} // namespace mlir
class TritonPTXConversionTarget : public ConversionTarget {
public:
explicit TritonPTXConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
addDynamicallyLegalDialect<LLVM::LLVMDialect>(
[&](Operation *op) { return isLegalElementwiseOp(op); });
#ifdef USE_ROCM
addLegalDialect<ROCDL::ROCDLDialect>();
#else
addLegalDialect<NVVM::NVVMDialect>();
#endif
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
namespace {
struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op.getNumOperands();
// Currently, Triton kernel function always return nothing.
// TODO(Superjomn) add support for non-inline device function
if (numArguments > 0) {
return rewriter.notifyMatchFailure(
op, "Only kernel function with nothing returned is supported.");
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
return success();
}
};
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
@@ -83,8 +103,9 @@ struct FuncOpConversion : public FuncOpConversionBase {
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
if (!newFuncOp) {
return failure();
}
auto ctx = funcOp->getContext();
@@ -105,6 +126,24 @@ private:
int numWarps{0};
};
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
#ifdef USE_ROCM
addLegalDialect<ROCDL::ROCDLDialect>();
addLegalDialect<mlir::scf::SCFDialect>();
#else
addLegalDialect<NVVM::NVVMDialect>();
#endif
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
@@ -115,48 +154,39 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context);
TritonLLVMConversionTarget target(*context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// Step 1: Decompose unoptimized layout conversions to use shared memory
// Step 2: Decompose insert_slice_async to use load + insert_slice for
// pre-Ampere architectures or unsupported vectorized load sizes
// Step 3: Allocate shared memories and insert barriers
// Step 4: Convert FuncOp to LLVMFuncOp via partial conversion
// Step 5: Get axis and shared memory info
// Step 6: Convert the rest of ops via partial conversion
//
// The reason for a separation between 4/6 is that, step 5 is out of the
// scope of Dialect Conversion, thus we need to make sure the smem is not
// revised during the conversion of step 6.
// Step 1
/* preprocess */
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
// Step 2
if (failed(decomposeInsertSliceAsyncOp(mod)))
return signalPassFailure();
// Step 3
/* allocate shared memory and set barrier */
Allocation allocation(mod);
MembarAnalysis membarPass(&allocation);
membarPass.run();
// Step 4
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
/* lower functions */
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context);
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
/*benefit=*/1);
funcPatterns.add<ReturnOpConversion>(typeConverter);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
funcPatterns);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
}
// Step 5 - get axis and shared memory info
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(mod)))
@@ -166,64 +196,58 @@ public:
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
allocation.getSharedMemorySize()));
// Step 6 - rewrite rest of ops
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community
// patterns.
/* rewrite ops */
RewritePatternSet patterns(context);
// TritonGPU lowering patterns
OpBuilder::InsertPoint indexInsertPoint;
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
&baseIndexCache, &indexCache, &indexInsertPoint};
RewritePatternSet patterns(context);
// Normal conversions
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
*axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ConvertLayoutOp
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
*axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// DotOp
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
*axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// ElementwiseOp
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps,
*axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// LoadStoreOp
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
*axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ReduceOp
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
*axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ViewOp
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
*axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// Add arith/math's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
auto populatePatterns1 = [&](auto populateFunc) {
populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis,
&allocation, smem, indexCacheInfo, /*benefit*/ 1);
};
auto populatePatterns2 = [&](auto populateFunc) {
populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis,
&allocation, smem, /*benefit*/ 1);
};
populatePatterns1(populateTritonGPUToLLVMPatterns);
populatePatterns1(populateConvertLayoutOpToLLVMPatterns);
populatePatterns2(populateDotOpToLLVMPatterns);
populatePatterns2(populateElementwiseOpToLLVMPatterns);
populatePatterns1(populateLoadStoreOpToLLVMPatterns);
populatePatterns1(populateReduceOpToLLVMPatterns);
populatePatterns2(populateViewOpToLLVMPatterns);
// Native lowering patterns
#ifdef USE_ROCM
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::HIP);
#else
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
#endif
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
#ifndef USE_ROCM
// Use our custom converters to convert some operations to PTX to avoid
// using NVPTX for two reasons:
// 1. NVPTX backend is flaky on data types like float16 and bfloat16
// 2. In some cases, we may generate faster PTX code than NVPTX backend
TritonPTXConversionTarget ptxTarget(*context);
RewritePatternSet ptxPatterns(context);
// Add patterns to convert LLVM to PTX
populateElementwiseOpToPTXPatterns(typeConverter, ptxPatterns,
/*benefits=*/10);
if (failed(applyPartialConversion(mod, ptxTarget, std::move(ptxPatterns))))
return signalPassFailure();
#endif
}
private:
Value smem;
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,

View File

@@ -0,0 +1,163 @@
#include "TypeConverter.h"
#include "DotOpHelpers.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/MLIRTypes.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
return convertTritonPointerType(type);
});
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
return convertTritonTensorType(type);
});
// Internally store float8 as int8
addConversion([&](mlir::Float8E4M3FNType type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E5M2Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}
Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
triton::PointerType type) {
// Recursively translate pointee type
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
type.getAddressSpace());
}
Value TritonGPUToLLVMTypeConverter::packLLElements(
Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter,
Type type) {
auto structType = this->convertType(type);
if (!structType.isa<LLVM::LLVMStructType>()) {
return *resultVals.begin();
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
// llvm::outs() << structType << "\n";
for (const auto &v : llvm::enumerate(resultVals)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index());
}
return llvmStruct;
}
SmallVector<Value> TritonGPUToLLVMTypeConverter::unpackLLElements(
Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter,
Type type) {
assert(bool(llvmStruct) && "can not unpack null values");
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
llvmStruct.getType().isa<triton::PointerType>() ||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
return {llvmStruct};
ArrayRef<Type> types =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
SmallVector<Value> results(types.size());
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
results[i] = extract_val(type, llvmStruct, i);
}
return results;
}
llvm::Optional<Type>
TritonGPUToLLVMTypeConverter::convertTritonTensorType(RankedTensorType type) {
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
unsigned numElementsPerThread = getElemsPerThread(type);
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
SmallVector<Type, 4> types;
// base ptr
auto ptrType =
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
types.push_back(ptrType);
// shape dims
auto rank = type.getRank();
// offsets + strides
for (auto i = 0; i < rank * 2; i++) {
types.push_back(IntegerType::get(ctx, 32));
}
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto dotOpLayout =
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
if (dotOpLayout.getParent()
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
int numElemsPerThread =
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
} else { // for parent is MMA layout
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = convertType(type.getElementType());
if (mmaLayout.isAmpere()) {
const llvm::DenseMap<int, Type> targetTyMap = {
{32, vec_ty(elemTy, 1)},
{16, vec_ty(elemTy, 2)},
{8, vec_ty(elemTy, 4)},
};
Type targetTy;
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
// <2xi16>/<4xi8> => i32
// We are doing this because NVPTX inserts extra integer instrs to
// pack & unpack vectors of sub-word integers
// Note: this needs to be synced with
// DotOpMmaV2ConversionHelper::loadX4
if (elemTy.isa<IntegerType>() &&
(elemTy.getIntOrFloatBitWidth() == 8 ||
elemTy.getIntOrFloatBitWidth() == 16))
targetTy = IntegerType::get(ctx, 32);
} else {
assert(false && "Unsupported element type");
}
auto elems = getElemsPerThread(type);
return struct_ty(SmallVector<Type>(elems, targetTy));
}
if (mmaLayout.isVolta()) {
int elems = getElemsPerThread(type);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
}
llvm::errs() << "Unexpected dot operand layout detected in "
"TritonToLLVMTypeConverter";
return std::nullopt;
}
return std::nullopt;
}

View File

@@ -1,154 +1,30 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/MLIRTypes.h"
#include "DotOpHelpers.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis = nullptr)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
return convertTritonPointerType(type);
});
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
return convertTritonTensorType(type);
});
// Internally store float8 as int8
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}
const DataLayoutAnalysis *analysis = nullptr);
Type convertTritonPointerType(triton::PointerType type) {
// Recursively translate pointee type
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
type.getAddressSpace());
}
Type convertTritonPointerType(triton::PointerType type);
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
Value packLLElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter, Type type);
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
unsigned numElementsPerThread = getElemsPerThread(type);
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
SmallVector<Type, 4> types;
// base ptr
auto ptrType =
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
types.push_back(ptrType);
// shape dims
auto rank = type.getRank();
// offsets + strides
for (auto i = 0; i < rank * 2; i++) {
types.push_back(IntegerType::get(ctx, 32));
}
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto dotOpLayout =
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
if (dotOpLayout.getParent()
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
int numElemsPerThread =
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter,
Type type);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
} else { // for parent is MMA layout
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = convertType(type.getElementType());
if (mmaLayout.isAmpere()) {
const llvm::DenseMap<int, Type> targetTyMap = {
{32, vec_ty(elemTy, 1)},
{16, vec_ty(elemTy, 2)},
{8, vec_ty(elemTy, 4)},
};
Type targetTy;
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
// <2xi16>/<4xi8> => i32
// We are doing this because NVPTX inserts extra integer instrs to
// pack & unpack vectors of sub-word integers
// Note: this needs to be synced with
// DotOpMmaV2ConversionHelper::loadX4
if (elemTy.isa<IntegerType>() &&
(elemTy.getIntOrFloatBitWidth() == 8 ||
elemTy.getIntOrFloatBitWidth() == 16))
targetTy = IntegerType::get(ctx, 32);
} else {
assert(false && "Unsupported element type");
}
if (dotOpLayout.getOpIdx() == 0) { // $a
auto elems =
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
return struct_ty(SmallVector<Type>(elems, targetTy));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
auto elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
return struct_ty(SmallVector<Type>(elems, targetTy));
}
}
if (mmaLayout.isVolta()) {
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
mmaLayout.decodeVoltaLayoutStates();
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOpLayout.getOpIdx() == 0) { // $a
DotOpMmaV1ConversionHelper::AParam param(isARow, isAVec4);
int elems =
helper.numElemsPerThreadA(shape, isARow, isAVec4, param.vec);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
DotOpMmaV1ConversionHelper::BParam param(isBRow, isBVec4);
int elems =
helper.numElemsPerThreadB(shape, isBRow, isBVec4, param.vec);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
}
}
llvm::errs() << "Unexpected dot operand layout detected in "
"TritonToLLVMTypeConverter";
return std::nullopt;
}
return std::nullopt;
}
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type);
};
#endif

View File

@@ -1,41 +1,11 @@
#include "Utility.h"
#include "TypeConverter.h"
namespace mlir {
namespace LLVM {
using namespace mlir::triton;
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) {
if (!structType.isa<LLVM::LLVMStructType>()) {
return *resultVals.begin();
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (const auto &v : llvm::enumerate(resultVals)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index());
}
return llvmStruct;
}
SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
llvmStruct.getType().isa<triton::PointerType>() ||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
return {llvmStruct};
ArrayRef<Type> types =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
SmallVector<Value> results(types.size());
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
results[i] = extract_val(type, llvmStruct, i);
}
return results;
}
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
@@ -73,7 +43,14 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
ArrayRef<Type> types =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
SmallVector<Value> elems(types.size());
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
elems[i] = extract_val(type, llvmStruct, i);
}
auto rank = (elems.size() - 1) / 2;
return {/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
@@ -175,5 +152,39 @@ Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
return builder.launch(rewriter, loc, val.getType(), false);
}
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
StringRef key, StringRef content) {
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto ctx = moduleOp.getContext();
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(key + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
llvm::SmallString<64> contentStr(content);
size_t contentSize = contentStr.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(ctx), globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(contentStr));
}
Value zero = i32_val(0);
Value globalPtr =
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(ctx), global);
Value stringStart =
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
globalPtr, SmallVector<Value>({zero, zero}));
return stringStart;
}
} // namespace LLVM
} // namespace mlir

View File

@@ -14,6 +14,8 @@
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define trunc(...) rewriter.create<LLVM::TruncOp>(loc, __VA_ARGS__)
#define sext(...) rewriter.create<LLVM::SExtOp>(loc, __VA_ARGS__)
#define fpext(...) rewriter.create<LLVM::FPExtOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
@@ -30,9 +32,11 @@
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)
#define bitcast(val__, type__) \
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
#define addrspacecast(val__, type__) \
@@ -53,6 +57,9 @@
#define fcmp_olt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::olt, lhs, rhs)
#define fcmp_eq(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::oeq, lhs, rhs)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
@@ -77,30 +84,31 @@
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
// Types
#define i64_ty rewriter.getIntegerType(64)
#define i32_ty rewriter.getIntegerType(32)
#define i16_ty rewriter.getIntegerType(16)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define i1_ty rewriter.getI1Type()
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
// Constants
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
#define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
#define idx_val(...) \
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
__VA_ARGS__)
#define tid_val() getThreadId(rewriter, loc)
// Attributes
@@ -173,14 +181,6 @@ T getLinearIndex(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape,
namespace LLVM {
using namespace mlir::triton;
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType);
SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter);
/// Create a 32-bit integer constant.
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v);
/// Create a 32-bit float constant.
@@ -271,6 +271,9 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
int i);
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
StringRef key, StringRef content);
} // namespace LLVM
} // namespace mlir

View File

@@ -7,9 +7,7 @@ using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::getElemsPerThread;
@@ -25,7 +23,7 @@ struct SplatOpConversion
// @resType: the return type of the Splat-like op.
// @constVal: a LLVM::ConstantOp or other scalar value.
static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
@@ -33,11 +31,7 @@ struct SplatOpConversion
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy);
return typeConverter->packLLElements(loc, elems, rewriter, resType);
}
LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
@@ -95,7 +89,7 @@ struct ArithConstantSplatOpConversion
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
using OpAdaptor = typename CatOp::Adaptor;
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
explicit CatOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
@@ -109,8 +103,10 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
// unpack input values
auto lhsVals = getElementsFromStruct(loc, adaptor.getLhs(), rewriter);
auto rhsVals = getElementsFromStruct(loc, adaptor.getRhs(), rewriter);
auto lhsVals = getTypeConverter()->unpackLLElements(
loc, adaptor.getLhs(), rewriter, op.getOperand(0).getType());
auto rhsVals = getTypeConverter()->unpackLLElements(
loc, adaptor.getRhs(), rewriter, op.getOperand(1).getType());
// concatenate (and potentially reorder) values
SmallVector<Value> retVals;
for (Value v : lhsVals)
@@ -118,8 +114,8 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
for (Value v : rhsVals)
retVals.push_back(v);
// pack and replace
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value ret = getStructFromElements(loc, retVals, rewriter, structTy);
Value ret =
getTypeConverter()->packLLElements(loc, retVals, rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}
@@ -128,24 +124,19 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
explicit ViewLikeOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly run `rewriter.replaceOp(op, adaptor.getSrc())`
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto vals = getElementsFromStruct(loc, adaptor.getSrc(), rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
Value view =
this->getTypeConverter()->packLLElements(loc, vals, rewriter, resultTy);
rewriter.replaceOp(op, view);
return success();
}
@@ -174,7 +165,7 @@ struct TransOpConversion
}
};
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,

View File

@@ -6,7 +6,7 @@
using namespace mlir;
using namespace mlir::triton;
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,

View File

@@ -3,6 +3,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -22,6 +23,13 @@ using namespace mlir::triton;
namespace {
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
for (const NamedAttribute attr : dictAttrs.getValue())
if (!op->hasAttr(attr.getName()))
op->setAttr(attr.getName(), attr.getValue());
}
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
@@ -30,7 +38,9 @@ public:
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
addNamedAttrs(
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
@@ -44,8 +54,10 @@ public:
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());
addNamedAttrs(
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs()),
adaptor.getAttributes());
return success();
}
};
@@ -66,7 +78,9 @@ public:
else
// This is a hack. We just want to add encoding
value = value.reshape(retType);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value);
addNamedAttrs(
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value),
adaptor.getAttributes());
return success();
}
};
@@ -138,9 +152,10 @@ public:
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(),
adaptor.getTrueValue(), adaptor.getFalseValue()),
adaptor.getAttributes());
return success();
}
};
@@ -177,8 +192,9 @@ struct TritonMakeRangePattern
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op, retType, adaptor.getStart(), adaptor.getEnd());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op, retType, adaptor.getStart(), adaptor.getEnd()),
adaptor.getAttributes());
return success();
}
};
@@ -221,8 +237,9 @@ struct TritonExpandDimsPattern
// construct new op
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
op.getLoc(), newArgType, adaptor.getSrc());
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, newSrc,
adaptor.getAxis());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
op, newSrc, adaptor.getAxis()),
adaptor.getAttributes());
return success();
}
};
@@ -280,8 +297,9 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
}
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, c,
adaptor.getAllowTF32());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, c, adaptor.getAllowTF32()),
adaptor.getAttributes());
return success();
}
};
@@ -296,8 +314,9 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
// For now, this behaves like generic, but this will evolve when
// we add support for `can_reorder=False`
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
adaptor.getOperands());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
op, retType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
@@ -329,7 +348,8 @@ struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
src);
}
rewriter.replaceOpWithNewOp<triton::TransOp>(op, src);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::TransOp>(op, src),
adaptor.getAttributes());
return success();
}
};
@@ -340,10 +360,12 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, typeConverter->convertType(op.getType()), adaptor.getPtr(),
adaptor.getMask(), adaptor.getOther(), adaptor.getCache(),
adaptor.getEvict(), adaptor.getIsVolatile());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, typeConverter->convertType(op.getType()),
adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(),
adaptor.getCache(), adaptor.getEvict(),
adaptor.getIsVolatile()),
adaptor.getAttributes());
return success();
}
};
@@ -354,9 +376,11 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.getPtr(), adaptor.getValue(), adaptor.getMask(),
adaptor.getCache(), adaptor.getEvict());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.getPtr(), adaptor.getValue(),
adaptor.getMask(), adaptor.getCache(),
adaptor.getEvict()),
adaptor.getAttributes());
return success();
}
};
@@ -368,9 +392,10 @@ struct TritonAtomicCASPattern
LogicalResult
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
op, typeConverter->convertType(op.getType()), adaptor.getPtr(),
adaptor.getCmp(), adaptor.getVal());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
op, typeConverter->convertType(op.getType()),
adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal()),
adaptor.getAttributes());
return success();
}
};
@@ -382,9 +407,11 @@ struct TritonAtomicRMWPattern
LogicalResult
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
op, typeConverter->convertType(op.getType()), adaptor.getAtomicRmwOp(),
adaptor.getPtr(), adaptor.getVal(), adaptor.getMask());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
op, typeConverter->convertType(op.getType()),
adaptor.getAtomicRmwOp(), adaptor.getPtr(),
adaptor.getVal(), adaptor.getMask()),
adaptor.getAttributes());
return success();
}
};
@@ -396,9 +423,11 @@ struct TritonExtElemwisePattern
LogicalResult
matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
op, typeConverter->convertType(op.getType()), adaptor.getArgs(),
adaptor.getLibname(), adaptor.getLibpath(), adaptor.getSymbol());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
op, typeConverter->convertType(op.getType()),
adaptor.getArgs(), adaptor.getLibname(),
adaptor.getLibpath(), adaptor.getSymbol()),
adaptor.getAttributes());
return success();
}
};
@@ -411,7 +440,9 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
addNamedAttrs(
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
@@ -432,8 +463,9 @@ struct TritonBroadcastPattern
Type retType = RankedTensorType::get(opType.getShape(),
opType.getElementType(), srcEncoding);
// Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
adaptor.getOperands());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
op, retType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
@@ -444,20 +476,38 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, adaptor.getRedOp(), adaptor.getOperand(), adaptor.getAxis());
addNamedAttrs(
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, adaptor.getRedOp(), adaptor.getOperand(), adaptor.getAxis()),
adaptor.getAttributes());
return success();
}
};
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
using OpConversionPattern<PrintfOp>::OpConversionPattern;
struct TritonPrintPattern : public OpConversionPattern<triton::PrintOp> {
using OpConversionPattern<triton::PrintOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
matchAndRewrite(triton::PrintOp op, typename triton::PrintOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.getPrefixAttr(),
adaptor.getOperands());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::PrintOp>(
op, op.getPrefixAttr(), adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op,
typename triton::AssertOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AssertOp>(
op, adaptor.getCondition(), op.getMessageAttr(),
op.getFileAttr(), op.getFuncAttr(), op.getLineAttr()),
adaptor.getAttributes());
return success();
}
};
@@ -476,8 +526,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
TritonAtomicRMWPattern>(typeConverter, context);
TritonStorePattern, TritonExtElemwisePattern, TritonPrintPattern,
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
}
//
@@ -536,7 +586,9 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
// op.erase();
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
addNamedAttrs(
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
@@ -641,8 +693,9 @@ public:
LogicalResult
matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<cf::BranchOp>(op, op.getSuccessor(),
adaptor.getOperands());
addNamedAttrs(rewriter.replaceOpWithNewOp<cf::BranchOp>(
op, op.getSuccessor(), adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
@@ -659,6 +712,7 @@ public:
op, adaptor.getCondition(), op.getTrueDest(),
adaptor.getTrueDestOperands(), op.getFalseDest(),
adaptor.getFalseDestOperands());
addNamedAttrs(newOp, adaptor.getAttributes());
if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(),
*converter)))

View File

@@ -123,29 +123,6 @@ void StoreOp::print(OpAsmPrinter &printer) {
namespace mlir {
namespace triton {
//-- FpToFpOp --
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
::mlir::TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
auto srcEltType = inputs.front();
auto dstEltType = outputs.front();
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
if (srcTensorType && dstTensorType) {
srcEltType = srcTensorType.getElementType();
dstEltType = dstTensorType.getElementType();
}
// Check whether fp8 <=> fp16, bf16, f32, f64
// Make `srcEltType` always the fp8 side
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
std::swap(srcEltType, dstEltType);
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
return false;
return dstEltType.isF16() || dstEltType.isBF16() || dstEltType.isF32() ||
dstEltType.isF64();
}
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value,

View File

@@ -43,9 +43,7 @@ namespace mlir {
unsigned getPointeeBitWidth(RankedTensorType tensorTy) {
auto ptrTy = tensorTy.getElementType().cast<triton::PointerType>();
auto pointeeType = ptrTy.getPointeeType();
return pointeeType.isa<triton::Float8Type>()
? 8
: pointeeType.getIntOrFloatBitWidth();
return pointeeType.getIntOrFloatBitWidth();
}
} // namespace mlir

View File

@@ -46,17 +46,18 @@ namespace gpu {
// so that all distributed layouts implement
// these utilities
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
Type eltTy) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
return blockedLayout.getElemsPerThread(shape, eltTy);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return sliceLayout.getElemsPerThread(shape);
return sliceLayout.getElemsPerThread(shape, eltTy);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return mmaLayout.getElemsPerThread(shape);
return mmaLayout.getElemsPerThread(shape, eltTy);
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return sharedLayout.getElemsPerThread(shape);
return sharedLayout.getElemsPerThread(shape, eltTy);
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return dotLayout.getElemsPerThread(shape);
return dotLayout.getElemsPerThread(shape, eltTy);
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
@@ -64,11 +65,11 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
}
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::Float8Type>() ||
type.isa<triton::PointerType>())
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape(),
tensorType.getElementType());
}
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
@@ -330,7 +331,8 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
return SliceEncodingAttr::get(getContext(), axis, *this);
}
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
size_t rank = shape.size();
auto sizePerThread = getSizePerThread();
auto warpsPerCTA = getWarpsPerCTA();
@@ -365,12 +367,14 @@ SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
template SmallVector<int64_t>
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
auto parent = getParent();
return ::getElemsPerThread(parent, paddedShape(shape));
return ::getElemsPerThread(parent, paddedShape(shape), eltTy);
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
assert((isVolta() || isAmpere()) && "Only version 1 and 2 is supported");
@@ -401,18 +405,99 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
return res;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
// TODO:
assert(0 && "SharedEncodingAttr::getElemsPerThread not implemented");
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
llvm_unreachable("Unexpected shared layout");
return 0;
}
unsigned
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
if (auto blockedLayout = getParent().dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
unsigned DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
if (auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>()) {
int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0];
int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1];
// A100
if (mmaParent.isAmpere()) {
int bitwidth = eltTy.getIntOrFloatBitWidth();
int shapePerWarpM = 16;
int shapePerWarpN = 8;
int shapePerWarpK = 4 * 64 / bitwidth;
int shapePerCTAM = shapePerWarpM * warpsPerCTAM;
int shapePerCTAN = shapePerWarpN * warpsPerCTAN;
if (getOpIdx() == 0) {
int repM = std::max<int>(1, shape[0] / shapePerCTAM);
int repK = std::max<int>(1, shape[1] / shapePerWarpK);
return 4 * repM * repK;
}
if (getOpIdx() == 1) {
int repN = std::max<int>(1, shape[1] / shapePerCTAN);
int repK = std::max<int>(1, shape[0] / shapePerWarpK);
return 4 * std::max(repN / 2, 1) * repK;
}
}
// V100
if (mmaParent.isVolta()) {
bool isRow = getMMAv1IsRow();
bool isVec4 = getMMAv1IsVec4();
if (getOpIdx() == 0) {
int packSizeM = (isRow || isVec4) ? 1 : 2;
int repM = 2 * packSizeM;
int spwM = 2 * 4 * repM;
int numM = getMMAv1NumOuter(shape);
int NK = shape[1];
int vec = 2 * repM;
// Here we mimic the logic in loadA, the result cannot be calculated
// directly.
llvm::DenseSet<std::pair<int, int>> visited;
auto ld = [&](int m, int k) {
visited.insert({m, k});
if (vec > 4) {
if (isRow)
visited.insert({m, k + 4});
else
visited.insert({m + 1, k});
}
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
if (!visited.count({m, k}))
ld(m, k);
return visited.size() * 2;
}
if (getOpIdx() == 1) {
int packSizeN = (isRow && !isVec4) ? 2 : 1;
int repN = 2 * packSizeN;
int spwN = 2 * 4 * repN;
int numN = getMMAv1NumOuter(shape);
int vec = 2 * repN;
int NK = shape[0];
// Here we mimic the logic in loadA, the result cannot be calculated
// directly.
llvm::DenseSet<std::pair<int, int>> visited;
int elemsPerLd = vec > 4 ? 4 : 2;
auto ld = [&](int n, int k) {
visited.insert({n, k});
if (vec > 4) {
if (isRow)
visited.insert({n + 1, k});
else
visited.insert({n, k + 4});
}
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!visited.count({n, k}))
ld(n, k);
}
return visited.size() * 2;
}
}
}
assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented");
llvm_unreachable("unknown mma version");
return 0;
}
@@ -630,26 +715,69 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if (parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().isVolta()) {
isMMAv1Row = attrs.get("isMMAv1Row");
if (!isMMAv1Row)
llvm::report_fatal_error("isMMAv1Row attribute is missing");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent, isMMAv1Row);
parent);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent();
if (getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
bool DotOperandEncodingAttr::getMMAv1IsRow() const {
auto [isARow, isBRow, _0, _1, _2] =
getParent().cast<MmaEncodingAttr>().decodeVoltaLayoutStates();
return getOpIdx() == 0 ? isARow : isBRow;
}
bool DotOperandEncodingAttr::getMMAv1IsVec4() const {
auto [_0, _1, isAVec4, isBVec4, _2] =
getParent().cast<MmaEncodingAttr>().decodeVoltaLayoutStates();
return getOpIdx() == 0 ? isAVec4 : isBVec4;
}
SmallVector<int> DotOperandEncodingAttr::getMMAv1Rep() const {
auto [isARow, isBRow, isAVec4, isBVec4, _] =
getParent().cast<MmaEncodingAttr>().decodeVoltaLayoutStates();
// A
if (getOpIdx() == 0) {
int packSize = (isARow || isAVec4) ? 1 : 2;
return {2 * packSize, 0, 1};
}
// B
else {
int packSize = (isBRow && !isBVec4) ? 2 : 1;
return {0, 2 * packSize, 1};
}
}
SmallVector<int> DotOperandEncodingAttr::getMMAv1ShapePerWarp() const {
auto rep = getMMAv1Rep();
if (getOpIdx() == 0) {
return {8 * rep[0], 0, 1};
} else {
return {0, 8 * rep[1], 1};
}
}
int DotOperandEncodingAttr::getMMAv1Vec() const {
size_t opIdx = getOpIdx();
return 2 * getMMAv1Rep()[opIdx];
}
int DotOperandEncodingAttr::getMMAv1NumOuter(ArrayRef<int64_t> shape) const {
auto spw = getMMAv1ShapePerWarp();
auto rep = getMMAv1Rep();
auto warpsPerCTA = getParent().cast<MmaEncodingAttr>().getWarpsPerCTA();
if (getOpIdx() == 0) {
return rep[0] * shape[0] / (spw[0] * warpsPerCTA[0]);
} else {
return rep[1] * shape[1] / (spw[1] * warpsPerCTA[1]);
}
}
//===----------------------------------------------------------------------===//
// InsertSliceAsyncOp
//===----------------------------------------------------------------------===//
@@ -851,7 +979,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
return mlir::success();
}
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
@@ -872,7 +1000,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
rewriter.setInsertionPoint(extract_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.getSource());
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
op, resType, newArg.getResult(), extract_slice.offsets(),
extract_slice.sizes(), extract_slice.strides(),
extract_slice.static_offsets(), extract_slice.static_sizes(),
@@ -925,6 +1053,29 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
//===----------------------------------------------------------------------===//
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
/// result type. If the type passed is nullptr, it is inferred.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
RankedTensorType resultType, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
//===----------------------------------------------------------------------===//
void TritonGPUDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST

View File

@@ -8,6 +8,7 @@
using namespace mlir;
namespace {
using triton::DotOp;
using triton::gpu::BlockedEncodingAttr;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
@@ -43,12 +44,6 @@ SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
}
}
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
int numWarps) {
// Set a default value that ensures product of wpt equals numWarps
return {static_cast<unsigned>(numWarps), 1};
}
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
@@ -92,19 +87,6 @@ public:
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTileV1(shape, numWarps);
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
default:
llvm_unreachable("unsupported MMA version");
}
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
@@ -125,13 +107,46 @@ public:
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto warpsPerTile =
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
// operands
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
SetVector<Operation *> aBwdSlices, bBwdSlices;
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
getBackwardSlice(a, &aBwdSlices, isCvt);
getBackwardSlice(b, &bBwdSlices, isCvt);
// get the source of the first conversion found in slices
auto getCvtArgOrder = [](Operation *op) {
return cast<ConvertLayoutOp>(op)
.getOperand()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<BlockedEncodingAttr>()
.getOrder();
};
bool isARow = true;
bool isBRow = true;
Operation *aOp = a.getDefiningOp();
Operation *bOp = b.getDefiningOp();
if (!aBwdSlices.empty())
aOp = aBwdSlices[0];
if (!bBwdSlices.empty())
bOp = bBwdSlices[0];
if (aOp)
isARow = getCvtArgOrder(aOp)[0] == 1;
if (bOp)
isBRow = getCvtArgOrder(bOp)[0] == 1;
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
oldRetType.getContext(), versionMajor, numWarps, oldAType.getShape(),
oldBType.getShape(), retShape, isARow, isBRow, mmaV1Counter++);
} else if (versionMajor == 2) {
auto warpsPerTile = warpsPerTileV2(dotOp, retShape, numWarps);
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTile);
@@ -145,10 +160,6 @@ public:
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
@@ -159,21 +170,15 @@ public:
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if (versionMajor == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);

View File

@@ -2,13 +2,12 @@ add_mlir_dialect_library(TritonGPUTransforms
AccelerateMatmul.cpp
Coalesce.cpp
DecomposeConversions.cpp
FuseTranspositions.cpp
OptimizeDotOperands.cpp
Pipeline.cpp
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
TritonGPUConversion.cpp
UpdateMmaForVolta.cpp
Utility.cpp
DEPENDS

View File

@@ -16,54 +16,6 @@ using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SliceEncodingAttr;
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if (auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if (auto srcSharedLayout =
srcType.getEncoding()
.dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
if (!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row =
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
dstDotOperandLayout.getParent(), newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
// convert(trans(convert(arg)))
// x = convert_layout arg: #distributed -> #shared_x
// y = trans x: #shared_x -> #shared_y
@@ -125,10 +77,11 @@ public:
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUFuseTranspositionsPass
: public TritonGPUFuseTranspositionsBase<TritonGPUFuseTranspositionsPass> {
class TritonGPUOptimizeDotOperandsPass
: public TritonGPUOptimizeDotOperandsBase<
TritonGPUOptimizeDotOperandsPass> {
public:
TritonGPUFuseTranspositionsPass() = default;
TritonGPUOptimizeDotOperandsPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
@@ -139,7 +92,6 @@ public:
auto ret = pm.run(m);
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<ConvertTransConvert>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
@@ -148,6 +100,6 @@ public:
}
};
std::unique_ptr<Pass> mlir::createTritonGPUFuseTranspositionsPass() {
return std::make_unique<TritonGPUFuseTranspositionsPass>();
std::unique_ptr<Pass> mlir::createTritonGPUOptimizeDotOperandsPass() {
return std::make_unique<TritonGPUOptimizeDotOperandsPass>();
}

View File

@@ -1,6 +1,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -173,6 +175,8 @@ LogicalResult LoopPipeliner::initialize() {
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
auto ptr = loadOp.getPtr();
unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr);
if (auto mask = loadOp.getMask())
vec = std::min<unsigned>(vec, axisInfoAnalysis->getMaskAlignment(mask));
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
continue;
@@ -394,7 +398,7 @@ void LoopPipeliner::emitPrologue() {
sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]},
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
@@ -480,15 +484,25 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
OpBuilder::InsertionGuard guard(builder);
Value load = loads[idx];
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
// set insertion point
Value newLoad = mapping.lookup(load);
Value newLoadUse = mapping.lookup(loadUse);
builder.setInsertionPoint(newLoadUse.getDefiningOp());
// create conversion
auto cvt = builder.create<ttg::ConvertLayoutOp>(
loadUse.getLoc(), loadUse.getType(),
newForOp.getRegionIterArgs()[loadIdx + idx]);
// replace uses
newLoadUse.replaceAllUsesWith(cvt.getResult());
// delete old load and layout conversion
mapping.lookup(loadUse).getDefiningOp()->erase();
mapping.lookup(load).getDefiningOp()->erase();
newLoadUse.getDefiningOp()->erase();
newLoad.getDefiningOp()->erase();
}
// 4. prefetch the next iteration
@@ -532,8 +546,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
Value extractSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
extractSliceIndex = builder.create<arith::IndexCastOp>(
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
for (Operation *op : orderedDeps)
if (!loads.contains(op->getResult(0))) {
@@ -591,7 +603,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<tensor::ExtractSliceOp>(
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
int_attr(0)},
@@ -618,35 +630,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
}
{
OpBuilder::InsertionGuard guard(builder);
for (Operation &op : *newForOp.getBody()) {
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
builder.setInsertionPoint(&op);
auto dotType = dotOp.getType().cast<RankedTensorType>();
Value a = dotOp.getA();
Value b = dotOp.getB();
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
auto newEncoding = ttg::DotOperandEncodingAttr::get(
tensorType.getContext(), opIdx, dotType.getEncoding());
auto newType =
RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), newEncoding);
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
newType, dotOperand);
}
return dotOperand;
};
a = layoutCast(a, 0);
b = layoutCast(b, 1);
dotOp->setOperand(0, a);
dotOp->setOperand(1, b);
}
}
}
// async.wait & extract_slice
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
@@ -698,6 +681,17 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
if (numStages <= 1)
return;
// Pre-processing
// we make sure element-wise ops are done *after* the conversion
// to dot operands
// we can achieve this with simple recursive pattern matching
// MLIRContext *context = &getContext();
// mlir::RewritePatternSet patterns(context);
// patterns.add<MoveOpAfterLayoutConversion>(context);
// auto didPreprocess =
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// Do the pipelining
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, numStages);
@@ -707,7 +701,6 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
pipeliner.emitPrologue();
scf::ForOp newForOp = pipeliner.createNewForOp();
pipeliner.emitEpilogue();
// replace the original loop

View File

@@ -103,11 +103,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
if (offsetK)
offset[kIdx] = *offsetK;
Value newSmem = builder.create<tensor::ExtractSliceOp>(
v.getLoc(),
// TODO: encoding?
RankedTensorType::get(shape, elementType, type.getEncoding()), v,
SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
Value newSmem = builder.create<triton::gpu::ExtractSliceOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, type.getEncoding()),
v, SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
SmallVector<OpFoldResult>{intAttr(shape[0]), intAttr(shape[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});

View File

@@ -108,7 +108,7 @@ public:
auto newReduce = rewriter.create<triton::ReduceOp>(
op->getLoc(), reduce.getRedOp(), reduceArg.getOperand(),
reduce.getAxis());
if (isa<triton::gpu::ConvertLayoutOp>(
if (isa_and_nonnull<triton::gpu::ConvertLayoutOp>(
*reduceArg.getOperand().getDefiningOp()))
return mlir::failure();
Value newRet = newReduce.getResult();
@@ -146,157 +146,6 @@ public:
//
// -----------------------------------------------------------------------------
// TODO: Interface
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) {
ret = targetEncoding;
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
ret = triton::gpu::SliceEncodingAttr::get(
op->getContext(), expand_dims.getAxis(), targetEncoding);
}
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
auto sliceEncoding =
targetEncoding.dyn_cast<triton::gpu::SliceEncodingAttr>();
if (!sliceEncoding)
return failure();
if (sliceEncoding.getDim() != reduce.getAxis())
return failure();
ret = sliceEncoding.getParent();
}
if (auto view = dyn_cast<triton::ViewOp>(op)) {
return failure();
}
return success();
}
inline bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 1: A size 1 tensor is not expensive since all threads will load the
// same
if (isSingleValue(op->getOperand(0)))
return false;
auto ptr = op->getOperand(0);
// Case 2: We assume that `evict_last` loads/stores have high hit rate
if (auto load = dyn_cast<triton::LoadOp>(op))
if (load.getEvict() == triton::EvictionPolicy::EVICT_LAST)
return false;
if (auto store = dyn_cast<triton::StoreOp>(op))
if (store.getEvict() == triton::EvictionPolicy::EVICT_LAST)
return false;
if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
auto encoding = tensorTy.getEncoding();
// Case 3: Different type conversion is expensive (e.g., mma <-> block)
if (encoding.getTypeID() != targetEncoding.getTypeID())
return true;
auto sizePerThread = triton::gpu::getSizePerThread(encoding);
auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
auto order = triton::gpu::getOrder(encoding);
auto targetOrder = triton::gpu::getOrder(targetEncoding);
// Case 4: The targeEncoding may expose more vectorization opportunities
return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
}
return false;
}
inline bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return expensiveLoadOrStore(op, targetEncoding);
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
op))
return true;
return false;
}
LogicalResult simulateBackwardRematerialization(
Operation *initOp, SetVector<Operation *> &processed,
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
const Attribute &targetEncoding) {
// DFS
std::vector<std::pair<Operation *, Attribute>> queue;
queue.emplace_back(initOp, targetEncoding);
// We want to see the effect of converting `initOp` to a new layout
// so we initialize `numCvts = 1`.
int numCvts = 1;
while (!queue.empty()) {
Operation *currOp;
Attribute currLayout;
std::tie(currOp, currLayout) = queue.back();
queue.pop_back();
// If the current operation is expensive to rematerialize,
// we stop everything
if (expensiveToRemat(currOp, currLayout))
break;
// A conversion will be removed here (i.e. transferred to operands)
numCvts -= 1;
// Done processing
processed.insert(currOp);
layout.insert(currLayout);
// Add all operands to the queue
for (Value argI : currOp->getOperands()) {
Attribute newEncoding;
// Cannot invert the current encoding for this operand
// we stop everything
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
return mlir::failure();
if (toConvert.count(argI) && toConvert[argI] != newEncoding)
return mlir::failure();
Operation *opArgI = argI.getDefiningOp();
toConvert.insert({argI, newEncoding});
// 1. Only convert RankedTensorType
// 2. Skip if there's no defining op
// 3. Skip if the defining op has already been processed
// 4. Skip or the defining op is in a different block
if (!argI.getType().isa<RankedTensorType>() || !opArgI ||
processed.contains(opArgI) ||
opArgI->getBlock() != currOp->getBlock())
continue;
// If the conversion can be folded into opArgI then
// we don't count this conversion as expensive
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
continue;
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
continue;
// We add one expensive conversion for the current operand
numCvts += 1;
queue.emplace_back(opArgI, newEncoding);
}
}
// if rematerialization would add more conversions than it removes
// then we don't do it
if (numCvts > 0)
return mlir::failure();
return mlir::success();
}
//
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
IRMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(), argType.getEncoding());
newOp->getResult(0).setType(newType);
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
if (typeInfer) {
SmallVector<Type, 1> newTypes;
auto success = typeInfer.inferReturnTypes(
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
newOp->getAttrDictionary(), newOp->getRegions(), newTypes);
if (succeeded(success))
newOp->getResult(0).setType(newTypes.front());
}
return newOp;
}
// op(cvt(arg_0), arg_1, ..., arg_n)
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
@@ -368,11 +217,11 @@ public:
IRMapping mapping;
for (size_t i = 0; i < numOps; i++) {
auto thenCvt = dyn_cast<triton::gpu::ConvertLayoutOp>(
auto thenCvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
thenYield.getOperand(i).getDefiningOp());
if (hasElse) {
auto elseYield = ifOp.elseYield();
auto elseCvt = dyn_cast<triton::gpu::ConvertLayoutOp>(
auto elseCvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
elseYield.getOperand(i).getDefiningOp());
if (thenCvt && elseCvt &&
std::distance(elseCvt->user_begin(), elseCvt->user_end()) == 1 &&
@@ -492,8 +341,8 @@ public:
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 &&
failed(simulateBackwardRematerialization(argOp, processed, layout,
toConvert, srcEncoding))) {
simulateBackwardRematerialization(argOp, processed, layout,
toConvert, srcEncoding) > 0) {
return failure();
}
}
@@ -539,50 +388,13 @@ public:
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
std::vector<std::pair<Operation *, Attribute>> queue;
if (failed(simulateBackwardRematerialization(
cvt, processed, layout, toConvert, targetType.getEncoding())))
if (simulateBackwardRematerialization(cvt, processed, layout, toConvert,
targetType.getEncoding()) > 0)
return mlir::failure();
SmallVector<Value, 4> sortedValues;
SetVector<Operation *> tmp;
for (auto &item : toConvert) {
Value v = item.first;
if (v.getDefiningOp())
tmp.insert(v.getDefiningOp());
else
sortedValues.push_back(v);
}
tmp = mlir::multiRootTopologicalSort(tmp);
for (Operation *op : tmp)
sortedValues.push_back(op->getResult(0));
IRMapping mapping;
for (Value currOperand : sortedValues) {
// unpack information
Attribute targetLayout = toConvert.lookup(currOperand);
// rematerialize the operand if necessary
Operation *currOperation = currOperand.getDefiningOp();
if (processed.contains(currOperation)) {
Operation *newOperation =
cloneWithInferType(rewriter, currOperation, mapping);
newOperation->moveAfter(currOperation);
currOperation = newOperation;
currOperand = currOperation->getResult(0);
}
// compute target type for the layout cast
auto currType = currOperand.getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
currType.getShape(), currType.getElementType(), targetLayout);
auto newOperand = rewriter.create<triton::gpu::ConvertLayoutOp>(
currOperand.getLoc(), newType, currOperand);
if (currOperation)
newOperand->moveAfter(currOperation);
else {
Block *block = currOperand.cast<BlockArgument>().getOwner();
newOperand->moveAfter(block, block->begin());
}
mapping.map(currOperand, newOperand);
}
rematerializeConversionChain(toConvert, rewriter, processed, mapping);
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
return mlir::success();
}

View File

@@ -1,212 +0,0 @@
#include "Utility.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace mlir {
namespace {
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SharedEncodingAttr;
using triton::gpu::SliceEncodingAttr;
// Get the wpt for MMAv1 using more information.
// Reference the original logic here
// https://github.com/openai/triton/blob/0e4691e6dd91e001a8d33b71badf8b3314325459/lib/codegen/analysis/layout.cc#L223
SmallVector<unsigned> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4, bool isBVec4,
int numWarps) {
// TODO[Superjomn]: Share code with
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
// rep,spw and fpw.
SmallVector<unsigned> wpt({1, 1});
SmallVector<unsigned> wpt_nm1;
SmallVector<int, 2> rep(2), spw(2);
std::array<int, 3> fpw{{2, 2, 1}};
int packSize0 = (isARow || isAVec4) ? 1 : 2;
rep[0] = 2 * packSize0;
spw[0] = fpw[0] * 4 * rep[0];
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
rep[1] = 2 * packSize1;
spw[1] = fpw[1] * 4 * rep[1];
do {
wpt_nm1 = wpt;
if (wpt[0] * wpt[1] < numWarps)
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shape[0] / spw[0]);
if (wpt[0] * wpt[1] < numWarps)
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shape[1] / spw[1]);
} while (wpt_nm1 != wpt);
return wpt;
}
// Given a (potentially malformed) DotOp, determines the optimal
// MMAEncoding to use on V100
LogicalResult getOptimizedV100MMaLayout(triton::DotOp dotOp,
MmaEncodingAttr &old,
MmaEncodingAttr &ret) {
auto *ctx = dotOp->getContext();
auto AT = dotOp.getA().getType().cast<RankedTensorType>();
auto BT = dotOp.getB().getType().cast<RankedTensorType>();
auto DT = dotOp.getD().getType().cast<RankedTensorType>();
auto shapeA = AT.getShape();
auto shapeB = BT.getShape();
if (!DT.getEncoding())
return mlir::failure();
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
if (!(mmaLayout && mmaLayout.isVolta()))
return mlir::failure();
// We have an MmaEncodingAttr here. Find the correct layout for it.
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
mmaLayout.decodeVoltaLayoutStates();
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// The wpt of MMAv1 is also determined by isARow, isBRow and shape, and it
// could only be set here for those states might be updated by previous
// patterns in the Combine Pass.
auto tgtWpt = getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
product(mmaLayout.getWarpsPerCTA()));
if (isARow == isARow_ && isBRow == isBRow_ && isAVec4 == isAVec4_ &&
isBVec4 == isBVec4_) {
if (tgtWpt == mmaLayout.getWarpsPerCTA())
return mlir::failure();
}
// Recalculate the wpt, for here we could get the latest information, the
// wpt should be updated.
auto updatedWpt =
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
product(mmaLayout.getWarpsPerCTA()));
// return results
old = mmaLayout;
ret =
MmaEncodingAttr::get(ctx, mmaLayout.getVersionMajor(), updatedWpt,
AT.getShape(), BT.getShape(), isARow, isBRow, mmaId);
return mlir::success();
}
// Replace result op type
void setOpResultType(Operation *op, ArrayRef<Type> newTypes) {
if (op->getNumResults() != newTypes.size())
llvm_unreachable("number of types different from number of results");
// nothing to do
if (op->getNumResults() == 0)
return;
// replace types
for (unsigned i = 0; i < op->getNumResults(); i++) {
Type newType = newTypes[i];
op->getResult(i).setType(newType);
}
// special case: arith.constant: we need to change the value attr
if (isa<arith::ConstantOp>(op)) {
Type newType = newTypes[0];
auto attr = op->getAttrDictionary()
.get("value")
.dyn_cast<mlir::DenseElementsAttr>();
if (attr) {
auto newAttr =
mlir::DenseElementsAttr::getFromRawBuffer(newType, attr.getRawData());
op->setAttr("value", newAttr);
}
}
}
// update style type given the provided layoutMap
Type updateStaleType(
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &layoutMap,
RankedTensorType type) {
auto encoding = type.getEncoding();
// mma encoding
if (auto mma = encoding.dyn_cast<MmaEncodingAttr>()) {
auto newMma = layoutMap.lookup(mma);
if (!newMma)
return Type();
return RankedTensorType::get(type.getShape(), type.getElementType(),
newMma);
}
// slice encoding
else if (auto slice = encoding.dyn_cast<SliceEncodingAttr>()) {
if (auto mma = slice.getParent().dyn_cast<MmaEncodingAttr>()) {
auto newMma = layoutMap.lookup(mma);
if (!newMma)
return Type();
auto newSlice =
SliceEncodingAttr::get(slice.getContext(), slice.getDim(), newMma);
return RankedTensorType::get(type.getShape(), type.getElementType(),
newSlice);
}
}
// dot operand encoding
else if (auto dotOp = encoding.dyn_cast<DotOperandEncodingAttr>()) {
if (auto mma = dotOp.getParent().dyn_cast<MmaEncodingAttr>()) {
auto newMma = layoutMap.lookup(mma);
if (!newMma)
return Type();
auto newDotOp = DotOperandEncodingAttr::get(
dotOp.getContext(), dotOp.getOpIdx(), newMma, dotOp.getIsMMAv1Row());
return RankedTensorType::get(type.getShape(), type.getElementType(),
newDotOp);
}
}
return Type();
}
} // namespace
class UpdateMmaForVoltaPass
: public UpdateMmaForVoltaBase<UpdateMmaForVoltaPass> {
public:
UpdateMmaForVoltaPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
// Step 1:
// Build a map from old MMA encoding to new MMA encoding.
DenseMap<MmaEncodingAttr, MmaEncodingAttr> layoutMap;
m.walk([&layoutMap](triton::DotOp dotOp) {
MmaEncodingAttr newLayout;
MmaEncodingAttr oldLayout;
if (failed(getOptimizedV100MMaLayout(dotOp, oldLayout, newLayout)))
return;
layoutMap[oldLayout] = newLayout;
});
// Step 2:
// Replace all wrong layouts with the right one
m.walk([&layoutMap](Operation *op) {
if (op->getNumResults() != 1)
return;
auto type = op->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!type)
return;
Type newType = updateStaleType(layoutMap, type);
if (!newType)
return;
setOpResultType(op, {newType});
});
// Step 3:
// We may have messed up some loops in the process.
// Fix them up
if (fixupLoops(m).failed())
signalPassFailure();
}
};
std::unique_ptr<Pass> createTritonGPUUpdateMmaForVoltaPass() {
return std::make_unique<UpdateMmaForVoltaPass>();
}
} // namespace mlir

View File

@@ -1,7 +1,10 @@
#include "Utility.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
@@ -60,4 +63,199 @@ LogicalResult fixupLoops(ModuleOp mod) {
return success();
}
// -------------------------------------------------------------------------- //
// TODO: Interface
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) {
ret = targetEncoding;
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
ret = triton::gpu::SliceEncodingAttr::get(
op->getContext(), expand_dims.getAxis(), targetEncoding);
}
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
auto sliceEncoding =
targetEncoding.dyn_cast<triton::gpu::SliceEncodingAttr>();
if (!sliceEncoding)
return failure();
if (sliceEncoding.getDim() != reduce.getAxis())
return failure();
ret = sliceEncoding.getParent();
}
if (auto view = dyn_cast<triton::ViewOp>(op)) {
return failure();
}
return success();
}
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 1: A size 1 tensor is not expensive since all threads will load the
// same
if (isSingleValue(op->getOperand(0)))
return false;
// auto ptr = op->getOperand(0);
//// Case 2: We assume that `evict_last` loads/stores have high hit rate
// if (auto load = dyn_cast<triton::LoadOp>(op))
// if (load.getEvict() == triton::EvictionPolicy::EVICT_LAST)
// return false;
// if (auto store = dyn_cast<triton::StoreOp>(op))
// if (store.getEvict() == triton::EvictionPolicy::EVICT_LAST)
// return false;
// if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
// auto encoding = tensorTy.getEncoding();
// // Case 3: Different type conversion is expensive (e.g., mma <-> block)
// if (encoding.getTypeID() != targetEncoding.getTypeID())
// return true;
// auto sizePerThread = triton::gpu::getSizePerThread(encoding);
// auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
// auto order = triton::gpu::getOrder(encoding);
// auto targetOrder = triton::gpu::getOrder(targetEncoding);
// // Case 4: The targeEncoding may expose more vectorization opportunities
// return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
// }
return true;
}
bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return expensiveLoadOrStore(op, targetEncoding);
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
op))
return true;
return false;
}
int simulateBackwardRematerialization(
Operation *initOp, SetVector<Operation *> &processed,
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
const Attribute &targetEncoding) {
// DFS
std::vector<std::pair<Operation *, Attribute>> queue;
queue.emplace_back(initOp, targetEncoding);
// We want to see the effect of converting `initOp` to a new layout
// so we initialize `numCvts = 1`.
int numCvts = 1;
while (!queue.empty()) {
Operation *currOp;
Attribute currLayout;
std::tie(currOp, currLayout) = queue.back();
queue.pop_back();
// If the current operation is expensive to rematerialize,
// we stop everything
if (expensiveToRemat(currOp, currLayout))
break;
// A conversion will be removed here (i.e. transferred to operands)
numCvts -= 1;
// Done processing
processed.insert(currOp);
layout.insert(currLayout);
// Add all operands to the queue
for (Value argI : currOp->getOperands()) {
Attribute newEncoding;
// Cannot invert the current encoding for this operand
// we stop everything
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
return INT_MAX;
if (toConvert.count(argI) && toConvert[argI] != newEncoding)
return INT_MAX;
Operation *opArgI = argI.getDefiningOp();
toConvert.insert({argI, newEncoding});
// 1. Only convert RankedTensorType
// 2. Skip if there's no defining op
// 3. Skip if the defining op has already been processed
// 4. Skip or the defining op is in a different block
if (!argI.getType().isa<RankedTensorType>() || !opArgI ||
processed.contains(opArgI) ||
opArgI->getBlock() != currOp->getBlock())
continue;
// If the conversion can be folded into opArgI then
// we don't count this conversion as expensive
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
continue;
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
continue;
// We add one expensive conversion for the current operand
numCvts += 1;
queue.emplace_back(opArgI, newEncoding);
}
}
// return net number of conversions
return numCvts;
}
//
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
IRMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(), argType.getEncoding());
newOp->getResult(0).setType(newType);
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
if (typeInfer) {
SmallVector<Type, 1> newTypes;
auto success = typeInfer.inferReturnTypes(
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
newOp->getAttrDictionary(), newOp->getRegions(), newTypes);
if (succeeded(success))
newOp->getResult(0).setType(newTypes.front());
}
return newOp;
}
void rematerializeConversionChain(
const llvm::MapVector<Value, Attribute> &toConvert,
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
IRMapping &mapping) {
SmallVector<Value, 4> sortedValues;
SetVector<Operation *> tmp;
for (auto &item : toConvert) {
Value v = item.first;
if (v.getDefiningOp())
tmp.insert(v.getDefiningOp());
else
sortedValues.push_back(v);
}
tmp = mlir::multiRootTopologicalSort(tmp);
for (Operation *op : tmp)
sortedValues.push_back(op->getResult(0));
for (Value currOperand : sortedValues) {
// unpack information
Attribute targetLayout = toConvert.lookup(currOperand);
// rematerialize the operand if necessary
Operation *currOperation = currOperand.getDefiningOp();
if (processed.contains(currOperation)) {
Operation *newOperation =
cloneWithInferType(rewriter, currOperation, mapping);
newOperation->moveAfter(currOperation);
currOperation = newOperation;
currOperand = currOperation->getResult(0);
}
// compute target type for the layout cast
auto currType = currOperand.getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
currType.getShape(), currType.getElementType(), targetLayout);
auto newOperand = rewriter.create<triton::gpu::ConvertLayoutOp>(
currOperand.getLoc(), newType, currOperand);
if (currOperation)
newOperand->moveAfter(currOperation);
else {
Block *block = currOperand.cast<BlockArgument>().getOwner();
newOperand->moveAfter(block, block->begin());
}
mapping.map(currOperand, newOperand);
}
}
} // namespace mlir

View File

@@ -2,11 +2,32 @@
#define TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/MapVector.h"
namespace mlir {
LogicalResult fixupLoops(ModuleOp mod);
// TODO: Interface
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret);
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
bool expensiveToRemat(Operation *op, Attribute &targetEncoding);
int simulateBackwardRematerialization(
Operation *initOp, SetVector<Operation *> &processed,
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
const Attribute &targetEncoding);
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
IRMapping &mapping);
void rematerializeConversionChain(
const llvm::MapVector<Value, Attribute> &toConvert,
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
IRMapping &mapping);
} // namespace mlir
#endif // TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

View File

@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/CallingConv.h"
@@ -309,12 +310,14 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(createTritonConvertArithToIndexPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
// Simplify the IR
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());
#ifdef USE_ROCM
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(createConvertControlFlowToLLVMPass());
@@ -325,6 +328,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}
// llvm::outs() << module << "\n";
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmIR) {
llvm::errs() << "Translate to LLVM IR failed";

View File

@@ -50,7 +50,6 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
int ptxMajor = maxPTX / 10;
int ptxMinor = maxPTX % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(maxCC);
std::string layout = "";
@@ -82,17 +81,19 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
else
module.setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module.functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(module);
std::string result;
{
llvm::raw_string_ostream stream(result);
llvm::buffer_ostream pstream(stream);
for (llvm::Function &f : module.functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
// emit
machine->addPassesToEmitFile(pass, pstream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(module);
}
// post-process
std::string result(buffer.begin(), buffer.end());
findAndReplace(result, ".version", "\n",
".version " + std::to_string(ptxMajor) + "." +
std::to_string(ptxMinor) + "\n");

View File

@@ -9,6 +9,7 @@ import tarfile
import tempfile
import urllib.request
from distutils.version import LooseVersion
from pathlib import Path
from typing import NamedTuple
from setuptools import Extension, setup
@@ -38,7 +39,6 @@ class Package(NamedTuple):
package: str
name: str
url: str
test_file: str
include_flag: str
lib_flag: str
syspath_var_name: str
@@ -49,7 +49,7 @@ class Package(NamedTuple):
def get_pybind11_package_info():
name = "pybind11-2.10.0"
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
# llvm
@@ -65,12 +65,13 @@ def get_llvm_package_info():
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
system_suffix = f"linux-gnu-{linux_suffix}"
else:
return Package("llvm", "LLVM-C.lib", "", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
release_suffix = "assert" if use_assert_enabled_llvm else "release"
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-17.0.0-37b7a60cd74b/{name}.tar.xz"
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
version = "llvm-17.0.0-8e5a41e8271f"
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
def get_thirdparty_packages(triton_cache_path):
@@ -81,8 +82,9 @@ def get_thirdparty_packages(triton_cache_path):
package_dir = os.path.join(package_root_dir, p.name)
if p.syspath_var_name in os.environ:
package_dir = os.environ[p.syspath_var_name]
test_file_path = os.path.join(package_dir, p.test_file)
if not os.path.exists(test_file_path):
version_file_path = os.path.join(package_dir, "version.txt")
if not os.path.exists(version_file_path) or\
Path(version_file_path).read_text() != p.url:
try:
shutil.rmtree(package_root_dir)
except Exception:
@@ -92,6 +94,9 @@ def get_thirdparty_packages(triton_cache_path):
ftpstream = urllib.request.urlopen(p.url)
file = tarfile.open(fileobj=ftpstream, mode="r|*")
file.extractall(path=package_root_dir)
# write version url to package_dir
with open(os.path.join(package_dir, "version.txt"), "w") as f:
f.write(p.url)
if p.include_flag:
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
if p.lib_flag:
@@ -208,14 +213,14 @@ download_and_copy_ptxas()
setup(
name="triton",
version="2.0.0",
version="2.1.0",
author="Philippe Tillet",
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/impl", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
install_requires=[
"cmake",
"cmake>=3.20",
"filelock",
"torch",
"lit",

View File

@@ -13,6 +13,8 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
@@ -45,6 +47,7 @@
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <regex>
#include <signal.h>
#include <sstream>
#include <stdexcept>
#include <string>
@@ -119,6 +122,7 @@ void init_triton_ir(py::module &&m) {
.def(py::init<>())
.def("load_triton", [](mlir::MLIRContext &self) {
self.getOrLoadDialect<mlir::triton::TritonDialect>();
self.getOrLoadDialect<mlir::index::IndexDialect>();
// we load LLVM because the frontend uses LLVM.undef for
// some placeholders
self.getOrLoadDialect<mlir::triton::TritonDialect>();
@@ -394,8 +398,8 @@ void init_triton_ir(py::module &&m) {
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect, mlir::scf::SCFDialect,
mlir::cf::ControlFlowDialect>();
mlir::index::IndexDialect, mlir::func::FuncDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
@@ -513,6 +517,12 @@ void init_triton_ir(py::module &&m) {
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
loc, v, self.getI8Type()));
})
.def("get_int16",
[](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
auto loc = self.getUnknownLoc();
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
loc, v, self.getI16Type()));
})
.def("get_int32",
[](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -525,16 +535,15 @@ void init_triton_ir(py::module &&m) {
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
loc, v, self.getI64Type()));
})
// bfloat16 cannot be initialized as it is treated as int16 for now
//.def("get_bf16",
// [](mlir::OpBuilder &self, float v) -> mlir::Value {
// auto loc = self.getUnknownLoc();
// auto type = self.getBF16Type();
// return self.create<mlir::arith::ConstantFloatOp>(
// loc,
// mlir::APFloat(type.getFloatSemantics(), std::to_string(v)),
// type);
// })
.def("get_bf16",
[](mlir::OpBuilder &self, float v) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto type = self.getBF16Type();
return self.create<mlir::arith::ConstantFloatOp>(
loc,
mlir::APFloat(type.getFloatSemantics(), std::to_string(v)),
type);
})
.def("get_fp16",
[](mlir::OpBuilder &self, float v) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -547,6 +556,12 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::ConstantOp>(
loc, self.getF32FloatAttr(v));
})
.def("get_fp64",
[](mlir::OpBuilder &self, double v) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::ConstantOp>(
loc, self.getF64FloatAttr(v));
})
.def("get_null_value",
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -598,9 +613,13 @@ void init_triton_ir(py::module &&m) {
.def(
"get_int64_ty",
[](mlir::OpBuilder &self) -> mlir::Type { return self.getI64Type(); })
.def("get_fp8_ty",
.def("get_fp8e4_ty",
[](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::Float8Type>();
return self.getType<mlir::Float8E4M3FNType>();
})
.def("get_fp8e5_ty",
[](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::Float8E5M2Type>();
})
.def(
"get_half_ty",
@@ -817,7 +836,7 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::IndexCastOp>(
loc, self.getI32Type(), input);
loc, self.getI64Type(), input);
})
.def("create_fmul",
[](mlir::OpBuilder &self, mlir::Value &lhs,
@@ -1428,16 +1447,32 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::SelectOp>(loc, condition,
trueValue, falseValue);
})
.def("create_printf",
.def("create_print",
[](mlir::OpBuilder &self, const std::string &prefix,
const std::vector<mlir::Value> &values) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::PrintfOp>(
self.create<mlir::triton::PrintOp>(
loc,
mlir::StringAttr::get(self.getContext(),
llvm::StringRef(prefix)),
values);
})
.def("create_assert",
[](mlir::OpBuilder &self, mlir::Value &condition,
const std::string &message, const std::string &fileName,
const std::string &funcName, unsigned lineNo) -> void {
auto loc = self.getUnknownLoc();
auto messageAttr = mlir::StringAttr::get(self.getContext(),
llvm::StringRef(message));
auto fileNameAttr = mlir::StringAttr::get(
self.getContext(), llvm::StringRef(fileName));
auto funcNameAttr = mlir::StringAttr::get(
self.getContext(), llvm::StringRef(funcName));
auto lineNoAttr = self.getI32IntegerAttr(lineNo);
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr,
fileNameAttr, funcNameAttr,
lineNoAttr);
})
// Undef
.def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
@@ -1521,18 +1556,14 @@ void init_triton_ir(py::module &&m) {
self.addPass(
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
})
.def("add_tritongpu_fuse_transpositions_pass",
.def("add_tritongpu_optimize_dot_operands_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUFuseTranspositionsPass());
self.addPass(mlir::createTritonGPUOptimizeDotOperandsPass());
})
.def("add_tritongpu_remove_layout_conversions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
})
.def("add_tritongpu_update_mma_for_volta_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUUpdateMmaForVoltaPass());
})
.def("add_tritongpu_reorder_instructions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUReorderInstructionsPass());
@@ -1611,7 +1642,6 @@ void init_triton_translation(py::module &m) {
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
std::string fbin = std::string(fsrc) + ".o";
llvm::FileRemover srcRemover(fsrc);
llvm::FileRemover logRemover(flog);
llvm::FileRemover binRemover(fbin);
const char *_fsrc = fsrc.c_str();
@@ -1628,16 +1658,30 @@ void init_triton_translation(py::module &m) {
err = system(cmd.c_str());
if (err != 0) {
err >>= 8;
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
log);
if (err == 255) {
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
log);
} else if (err == 128 + SIGSEGV) {
throw std::runtime_error("Please run `ptxas " + fsrc.str().str() +
"` to confirm that this is a "
"bug in `ptxas`\n" +
log);
} else {
throw std::runtime_error("`ptxas` failed with error code " +
std::to_string(err) + ": \n" + log);
}
return {};
} else {
llvm::FileRemover srcRemover(fsrc);
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
py::bytes bytes(cubin);
return std::move(bytes);
}
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
py::bytes bytes(cubin);
return std::move(bytes);
});
m.def("add_external_libs",

View File

@@ -0,0 +1,45 @@
import sys
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@triton.jit
def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_assert(x == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
assert x == 0, "x != 0"
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_static_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_assert(BLOCK == 128, "BLOCK != 128")
tl.store(Y + tl.arange(0, BLOCK), x)
def test_assert(func: str):
shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_assert":
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
elif func == "assert":
kernel_assert[(1,)](x, y, BLOCK=shape[0])
elif func == "static_assert":
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
if __name__ == "__main__":
test_assert(sys.argv[1])

View File

@@ -0,0 +1,46 @@
import sys
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@triton.jit
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_print("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
print("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_print(x)
tl.store(Y + tl.arange(0, BLOCK), x)
def test_print(func: str, data_type: str):
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_print":
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
elif func == "print":
kernel_print[(1,)](x, y, BLOCK=shape[0])
elif func == "static_print":
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
if __name__ == "__main__":
test_print(sys.argv[1], sys.argv[2])

View File

@@ -1,56 +0,0 @@
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
torch_type = {
"bool": torch.bool,
'int8': torch.int8,
'uint8': torch.uint8,
'int16': torch.int16,
"int32": torch.int32,
'int64': torch.long,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64
}
def get_tensor(shape, data_type, b_positive=False):
x = None
if data_type.startswith('int'):
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
else:
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
return x
# @pytest.mark.parametrize('data_type',
# [("int8"),
# ('int16'),
# ('int32'),
# ("int64"),
# ('float16'),
# ("float32"),
# ("float64")])
def printf(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.printf("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
printf("float16")
printf("int8")

View File

@@ -385,17 +385,22 @@ def test_where(dtype):
@triton.jit
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr,
TEST_POINTERS: tl.constexpr):
TEST_POINTERS: tl.constexpr,
TEST_SCALAR_POINTERS: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
decide = tl.load(cond_ptr + offsets, mask=mask)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
if TEST_SCALAR_POINTERS:
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
output = tl.load(ptr + offsets, mask=mask)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
tl.store(output_ptr + offsets, output, mask=mask)
SIZE = 1_000
@@ -411,8 +416,12 @@ def test_where(dtype):
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
assert (z == to_numpy(z_tri)).all()
if select_ptrs:
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
z = np.where(cond[0], x, y)
assert (z == to_numpy(z_tri)).all()
def test_where_broadcast():
@@ -683,6 +692,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_tensor_atomic_rmw_block(device="cuda"):
shape = (8, 8)
@triton.jit
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
offs = off0[:, None] * SHAPE1 + off1[None, :]
val = offs.to(tl.float32)
x = X + offs
tl.atomic_min(x, val)
x = torch.ones((8, 8), device=device, dtype=torch.float32)
kernel[(2,)](x, shape[0], shape[1])
assert torch.min(x).item() == 0.0
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
@@ -798,10 +823,25 @@ def test_store_constant(dtype_str):
assert torch.all(output == ref)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
def test_load_store_same_ptr():
@triton.jit()
def kernel(in_out_ptr):
pid = tl.program_id(axis=0)
x = tl.load(in_out_ptr + pid)
out = x * 2
tl.store(in_out_ptr + pid, out)
for _ in range(1000):
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
kernel[(65536,)](x, num_warps=32)
assert torch.all(x == 2)
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
check_type_supported(out_dtype)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
@@ -812,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
tl.store(output_ptr + offsets, output, mask=mask)
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
f8 = triton.reinterpret(f8_tensor, tl.float8)
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
f8_tensor[all_exp_ones] = 0
f8 = triton.reinterpret(f8_tensor, in_dtype)
n_elements = f8_tensor.numel()
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor)
def test_f16_to_f8_rounding():
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
def test_f16_to_f8_rounding(in_dtype):
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
@@ -848,7 +892,7 @@ def test_f16_to_f8_rounding():
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
n_elements = f16_input.numel()
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
@@ -858,7 +902,7 @@ def test_f16_to_f8_rounding():
abs_error = torch.abs(f16_input - f16_output)
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
@@ -1240,6 +1284,32 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@triton.jit
def kernel_static(out):
a = GENERATE_TEST_HERE
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
@triton.jit
def kernel_dynamic(out, val, dtype: tl.constexpr):
a = tl.full((128,), val, dtype)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
out_static = torch.zeros((128), dtype=dtype, device="cuda")
kernel_static_patched[(1,)](out_static)
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_static == 2)
assert torch.all(out_dynamic == 2)
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):
@@ -1409,6 +1479,28 @@ def test_vectorization(N):
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
@pytest.mark.parametrize("has_hints", [False, True])
def test_vectorization_hints(has_hints):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
off = torch.zeros(1, device='cuda', dtype=torch.int32)
@triton.jit
def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offsets = offsets + tl.load(off)
if HINT:
tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints)
ptx = pgm.asm["ptx"]
if has_hints:
assert "ld.global.v4.b32" in ptx
else:
assert "ld.global.v4.b32" not in ptx
# ---------------
# test store
# ---------------
@@ -1479,7 +1571,7 @@ def test_pointer_arguments(device):
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@@ -1739,6 +1831,23 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
# -----------------------
def test_for_iv_int64():
@triton.jit
def kernel(Out, lo, hi):
acc = 0
acc = acc.to(tl.int64)
for i in range(lo, hi):
acc += i
tl.store(Out, acc)
lo = 2**35
hi = 2**35 + 20
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
kernel[(1,)](out, lo, hi)
assert out[0] == sum(range(lo, hi))
def test_if_else():
@triton.jit
@@ -1972,3 +2081,42 @@ def test_load_scalar_with_mask():
Out = torch.empty_like(Index, device='cuda')
kernel[(1,)](Input, Index, Out, Index.numel())
assert Out.data[0] == 0
# This test is used to test our own PTX codegen for float16 and int16 conversions
# maybe delete it later after ptxas has been fixed
@pytest.mark.parametrize("dtype_str", ['float16', 'int16'])
def test_ptx_cast(dtype_str):
@triton.jit
def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype)
tmp1 = 2
tmp2 = tmp0 * tmp1
tmp3 = tmp2.to(dtype)
tmp5 = _tmp4 < tmp3
_tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4)
tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
torch.manual_seed(123)
if dtype_str == 'int16':
torch_dtype = torch.int16
triton_dtype = tl.int32
else:
torch_dtype = torch.float16
triton_dtype = tl.float32
s0 = 4
buf11 = -torch.ones((6 * s0, 197, 197), device='cuda', dtype=torch_dtype)
buf14 = -torch.ones((s0, 6, 197, 197), device='cuda', dtype=torch_dtype)
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
assert buf14.to(torch.float32).mean() == -2.0

View File

@@ -385,17 +385,22 @@ def test_where(dtype):
@triton.jit
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr,
TEST_POINTERS: tl.constexpr):
TEST_POINTERS: tl.constexpr,
TEST_SCALAR_POINTERS: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
decide = tl.load(cond_ptr + offsets, mask=mask)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
if TEST_SCALAR_POINTERS:
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
output = tl.load(ptr + offsets, mask=mask)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
tl.store(output_ptr + offsets, output, mask=mask)
SIZE = 1_000
@@ -411,8 +416,12 @@ def test_where(dtype):
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
assert (z == to_numpy(z_tri)).all()
if select_ptrs:
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
z = np.where(cond[0], x, y)
assert (z == to_numpy(z_tri)).all()
def test_where_broadcast():
@@ -683,6 +692,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_tensor_atomic_rmw_block(device="cuda"):
shape = (8, 8)
@triton.jit
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
offs = off0[:, None] * SHAPE1 + off1[None, :]
val = offs.to(tl.float32)
x = X + offs
tl.atomic_min(x, val)
x = torch.ones((8, 8), device=device, dtype=torch.float32)
kernel[(2,)](x, shape[0], shape[1])
assert torch.min(x).item() == 0.0
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
@@ -798,10 +823,25 @@ def test_store_constant(dtype_str):
assert torch.all(output == ref)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
def test_load_store_same_ptr():
@triton.jit()
def kernel(in_out_ptr):
pid = tl.program_id(axis=0)
x = tl.load(in_out_ptr + pid)
out = x * 2
tl.store(in_out_ptr + pid, out)
for _ in range(1000):
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
kernel[(65536,)](x, num_warps=16)
assert torch.all(x == 2)
@pytest.mark.parametrize("in_dtype", [tl.float8e4]) # TODO: support tl.float8e5
@pytest.mark.parametrize("out_dtype", [torch.float16]) # TODO: support torch.float32
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
check_type_supported(out_dtype)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
@@ -812,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
tl.store(output_ptr + offsets, output, mask=mask)
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
f8 = triton.reinterpret(f8_tensor, tl.float8)
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
f8_tensor[all_exp_ones] = 0
f8 = triton.reinterpret(f8_tensor, in_dtype)
n_elements = f8_tensor.numel()
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor)
def test_f16_to_f8_rounding():
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
def test_f16_to_f8_rounding(in_dtype):
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
@@ -848,7 +892,7 @@ def test_f16_to_f8_rounding():
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
n_elements = f16_input.numel()
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
@@ -858,7 +902,7 @@ def test_f16_to_f8_rounding():
abs_error = torch.abs(f16_input - f16_output)
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
@@ -1036,8 +1080,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
[(dtype, shape, perm)
# TODO: bfloat16
for dtype in ['float16', 'float32']
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
@@ -1248,6 +1292,51 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@triton.jit
def kernel_static(out):
a = GENERATE_TEST_HERE
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
@triton.jit
def kernel_dynamic(out, val, dtype: tl.constexpr):
a = tl.full((128,), val, dtype)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
out_static = torch.zeros((128), dtype=dtype, device="cuda")
kernel_static_patched[(1,)](out_static)
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_static == 2)
assert torch.all(out_dynamic == 2)
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):
# @triton.jit
# def _kernel(out):
# a = GENERATE_TEST_HERE
# b = GENERATE_TEST_HERE
# c = tl.dot(a, b)
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
# tl.store(out_ptr, c)
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# out_ref = torch.matmul(a, b)
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# kernel[(1,)](out)
# assert torch.all(out == out_ref)
# ---------------
# test arange
# ---------------
@@ -1426,7 +1515,7 @@ def test_pointer_arguments(device):
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@@ -1686,6 +1775,23 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
# -----------------------
def test_for_iv_int64():
@triton.jit
def kernel(Out, lo, hi):
acc = 0
acc = acc.to(tl.int64)
for i in range(lo, hi):
acc += i
tl.store(Out, acc)
lo = 2**35
hi = 2**35 + 20
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
kernel[(1,)](out, lo, hi)
assert out[0] == sum(range(lo, hi))
def test_if_else():
@triton.jit
@@ -1863,7 +1969,7 @@ else:
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
# BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])

View File

@@ -1,22 +0,0 @@
import os
import subprocess
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
printf_path = os.path.join(dir_path, "printf_helper.py")
def test_printf():
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
(outs, err) = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128

View File

@@ -0,0 +1,53 @@
import os
import subprocess
import sys
import pytest
dir_path = os.path.dirname(os.path.realpath(__file__))
print_path = os.path.join(dir_path, "print_helper.py")
assert_path = os.path.join(dir_path, "assert_helper.py")
# TODO: bfloat16 after LLVM-15
func_types = ["device_assert", "assert", "static_assert"]
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
@pytest.mark.parametrize("func_type, data_type",
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32")])
def test_print(func_type: str, data_type: str):
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
outs, _ = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = line
if func_type != "static_print":
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
if func_type != "static_print":
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128
else:
assert len(new_lines) == 1
@pytest.mark.parametrize("func_type", func_types)
def test_assert(func_type: str):
os.environ["TRITON_DEBUG"] = "1"
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
_, errs = proc.communicate()
errs = errs.splitlines()
num_errs = 0
for err in errs:
if "x != 0" in err.decode("utf-8"):
num_errs += 1
os.environ["TRITON_DEBUG"] = "0"
if func_type != "static_assert":
assert num_errs == 127
else:
assert num_errs == 0

View File

@@ -5,7 +5,8 @@ import triton
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
@@ -21,7 +22,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
p = torch.softmax(p.float(), dim=-1).to(dtype)
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
@@ -38,6 +39,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
decimal = 1 if dtype == torch.bfloat16 else 2
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)

View File

@@ -1,6 +1,5 @@
import multiprocessing
import os
import re
import shutil
from collections import namedtuple
@@ -107,33 +106,6 @@ def test_specialize(mode):
assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
cache_str = None
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs["repr"]
triton.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.JITFunction.cache_hook = None
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type
def test_constexpr_not_callable() -> None:
@triton.jit
def kernel(X, c: tl.constexpr):
@@ -176,6 +148,26 @@ def test_jit_warmup_cache() -> None:
assert len(kernel_add.cache) == 1
def test_jit_debug() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
device = torch.cuda.current_device()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 1
kernel_add.debug = False
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 1
kernel_add.debug = True
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 2
def test_compile_in_subproc() -> None:
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):

View File

@@ -39,7 +39,8 @@ def str_to_ty(name):
ty = str_to_ty(name[1:])
return triton.language.pointer_type(ty)
tys = {
"fp8": triton.language.float8,
"fp8e5": triton.language.float8e5,
"fp8e4": triton.language.float8e4,
"fp16": triton.language.float16,
"bf16": triton.language.bfloat16,
"fp32": triton.language.float32,
@@ -111,7 +112,7 @@ class enter_sub_region:
self.generator.local_defs = self.prev_defs
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict()):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict(), debug=False):
self.builder = _triton.ir.builder(context)
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = function_types
@@ -123,15 +124,19 @@ class CodeGenerator(ast.NodeVisitor):
self.function_name = function_name
self.is_kernel = is_kernel
self.last_node = None
self.debug = debug
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'print': triton.language.core.device_print,
'isinstance': isinstance,
'getattr': getattr,
}
self.static_functions = [
'static_print', 'static_assert'
]
self.scf_stack = []
# SSA-construction
# name => triton.language.tensor
@@ -183,9 +188,23 @@ class CodeGenerator(ast.NodeVisitor):
break
return stmts and isinstance(stmt, ast.Return)
# TODO: should be its own AST visitor
def contains_return_op(self, node):
if isinstance(node, ast.Return):
return True
elif isinstance(node, ast.Assign):
return self.contains_return_op(node.value)
elif isinstance(node, ast.Module):
pred = lambda s: self.contains_return_op(s)
return any(pred(s) for s in node.body)
elif isinstance(node, ast.FunctionDef):
pred = lambda s: self.contains_return_op(s)
return any(pred(s) for s in node.body)
elif isinstance(node, ast.Call):
fn = self.visit(node.func)
if isinstance(fn, triton.JITFunction):
return self.contains_return_op(fn.parse())
return False
elif isinstance(node, ast.If):
pred = lambda s: self.contains_return_op(s)
ret = any(pred(s) for s in node.body)
@@ -670,17 +689,24 @@ class CodeGenerator(ast.NodeVisitor):
step = triton.language.constexpr(-step.value)
negative_step = True
lb, ub = ub, lb
lb = triton.language.core._to_tensor(lb, self.builder)
ub = triton.language.core._to_tensor(ub, self.builder)
step = triton.language.core._to_tensor(step, self.builder)
# induction variable type
iv_type = triton.language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
iv_type = triton.language.semantic.integer_promote_impl(iv_type, step.dtype)
iv_ir_type = iv_type.to_ir(self.builder)
# lb/ub/step might be constexpr, we need to cast them to tensor
lb = triton.language.core._to_tensor(lb, self.builder).handle
ub = triton.language.core._to_tensor(ub, self.builder).handle
step = triton.language.core._to_tensor(step, self.builder).handle
lb = lb.handle
ub = ub.handle
step = step.handle
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
lb = self.builder.create_to_index(lb)
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
# Create placeholder for the loop induction variable
iv = self.builder.create_undef(self.builder.get_int32_ty())
self.set_value(node.target.id, triton.language.core.tensor(iv, triton.language.core.int32))
iv = self.builder.create_undef(iv_ir_type)
self.set_value(node.target.id, triton.language.core.tensor(iv, iv_type))
with enter_sub_region(self) as sr:
liveins, insert_block = sr
@@ -737,11 +763,13 @@ class CodeGenerator(ast.NodeVisitor):
# update induction variable with actual value, and replace all uses
self.builder.set_insertion_point_to_start(for_op.get_body(0))
iv = self.builder.create_index_to_si(for_op.get_induction_var())
iv = self.builder.create_int_cast(iv, iv_ir_type, True)
if negative_step:
ub_si = self.builder.create_index_to_si(ub)
ub_si = self.builder.create_int_cast(ub_si, iv_ir_type, True)
iv = self.builder.create_sub(ub_si, iv)
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
self.set_value(node.target.id, triton.language.core.tensor(iv, triton.language.core.int32))
self.set_value(node.target.id, triton.language.core.tensor(iv, iv_type))
# update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names):
@@ -763,6 +791,14 @@ class CodeGenerator(ast.NodeVisitor):
def visit_keyword(self, node):
return {node.arg: self.visit(node.value)}
def visit_Assert(self, node) -> Any:
if not self.debug:
return
test = self.visit(node.test)
msg = self.visit(node.msg)
# Convert assert to triton's device_assert which happens on the device
return triton.language.core.device_assert(test, msg, _builder=self.builder)
def visit_Call(self, node):
fn = self.visit(node.func)
if isinstance(fn, triton.language.constexpr):
@@ -771,6 +807,18 @@ class CodeGenerator(ast.NodeVisitor):
for keyword in node.keywords:
kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args]
if fn.__name__ == "print":
fn = self.builtins["print"]
elif fn.__name__ == "device_assert":
if not self.debug:
return
elif fn.__name__ in self.static_functions:
if fn.__name__ == "static_print":
print(*args, **kws)
return
elif fn.__name__ == "static_assert":
assert args[0], args[1]
return
if isinstance(fn, triton.runtime.JITFunction):
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
@@ -790,7 +838,7 @@ class CodeGenerator(ast.NodeVisitor):
if not self.module.has_function(fn_name):
prototype = triton.language.function_type([], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
@@ -928,7 +976,7 @@ def parse_mlir_module(path, context):
return module
def build_triton_ir(fn, signature, specialization, constants):
def build_triton_ir(fn, signature, specialization, constants, debug=False):
# canonicalize signature
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
@@ -948,7 +996,7 @@ def build_triton_ir(fn, signature, specialization, constants):
arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
prototype = triton.language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True, debug=debug)
try:
generator.visit(fn.parse())
except Exception as e:
@@ -975,8 +1023,8 @@ def optimize_triton_ir(mod):
return mod
def ast_to_ttir(fn, signature, specialization, constants):
mod, _ = build_triton_ir(fn, signature, specialization, constants)
def ast_to_ttir(fn, signature, specialization, constants, debug=False):
mod, _ = build_triton_ir(fn, signature, specialization, constants, debug)
return optimize_triton_ir(mod)
@@ -991,18 +1039,15 @@ def optimize_ttgir(mod, num_stages, compute_capability):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_coalesce_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_accelerate_matmul_pass(compute_capability)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_fuse_transpositions_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_tritongpu_prefetch_pass()
pm.add_tritongpu_fuse_transpositions_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
if compute_capability // 10 == 7:
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
# NOTE this pass should be placed after all the passes those modifies mma layout
pm.add_tritongpu_update_mma_for_volta_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
@@ -1718,6 +1763,7 @@ def compile(fn, **kwargs):
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3 if capability >= 75 else 2)
extern_libs = kwargs.get("extern_libs", dict())
debug = kwargs.get("debug", False)
# build compilation stages
if torch.version.hip is not None:
if extern_libs is None:
@@ -1736,7 +1782,7 @@ def compile(fn, **kwargs):
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
@@ -1750,7 +1796,7 @@ def compile(fn, **kwargs):
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
@@ -1804,7 +1850,8 @@ def compile(fn, **kwargs):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
else:
metadata = {"num_warps": num_warps, "num_stages": num_stages, "constants": _get_jsonable_constants(constants), "ctime": dict()}
metadata = {"num_warps": num_warps, "num_stages": num_stages,
"constants": _get_jsonable_constants(constants), "ctime": dict(), "debug": debug}
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
metadata["shared"] = kwargs["shared"]
@@ -1935,7 +1982,8 @@ class CompiledKernel:
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
# print(self.shared, n_regs, n_spills)
self.n_spills = n_spills
self.n_regs = n_regs
self.cu_module = mod
self.cu_function = func

View File

@@ -28,6 +28,8 @@ from .core import (
constexpr,
cos,
debug_barrier,
device_assert,
device_print,
dot,
dtype,
exp,
@@ -36,7 +38,8 @@ from .core import (
float16,
float32,
float64,
float8,
float8e4,
float8e5,
function_type,
int1,
int16,
@@ -54,7 +57,6 @@ from .core import (
num_programs,
pi32_t,
pointer_type,
printf,
program_id,
ravel,
reshape,
@@ -62,6 +64,8 @@ from .core import (
sin,
softmax,
sqrt,
static_assert,
static_print,
store,
sum,
swizzle2d,
@@ -118,6 +122,8 @@ __all__ = [
"constexpr",
"cos",
"debug_barrier",
"device_assert",
"device_print",
"dot",
"dtype",
"exp",
@@ -125,7 +131,8 @@ __all__ = [
"float16",
"float32",
"float64",
"float8",
"float8e4",
"float8e5",
"full",
"function_type",
"int1",
@@ -149,7 +156,6 @@ __all__ = [
"philox_impl",
"pi32_t",
"pointer_type",
"printf",
"program_id",
"rand",
"rand4x",
@@ -164,6 +170,8 @@ __all__ = [
"softmax",
"sqrt",
"static_range",
"static_assert",
"static_print",
"store",
"sum",
"swizzle2d",

View File

@@ -39,8 +39,7 @@ def _to_tensor(x, builder):
class dtype:
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
CUSTOMIZED_FP_TYPES = ['fp8']
FP_TYPES = ['fp8e4', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64']
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
OTHER_TYPES = ['void']
@@ -60,9 +59,12 @@ class dtype:
self.int_bitwidth = int(name.split('int')[-1])
self.primitive_bitwidth = self.int_bitwidth
elif name in dtype.FP_TYPES:
if name == 'fp8':
if name == 'fp8e4':
self.fp_mantissa_width = 3
self.primitive_bitwidth = 8
elif name == 'fp8e5':
self.fp_mantissa_width = 2
self.primitive_bitwidth = 8
elif name == 'fp16':
self.fp_mantissa_width = 10
self.primitive_bitwidth = 16
@@ -75,11 +77,13 @@ class dtype:
elif name == 'fp64':
self.fp_mantissa_width = 53
self.primitive_bitwidth = 64
else:
raise RuntimeError(f'Unsupported floating-point type {name}')
elif name == 'void':
self.primitive_bitwidth = 0
def is_fp8(self):
return self.name == 'fp8'
return 'fp8' in self.name
def is_fp16(self):
return self.name == 'fp16'
@@ -123,9 +127,6 @@ class dtype:
def is_floating(self):
return self.name in dtype.FP_TYPES
def is_customized_floating(self):
return self.name in dtype.CUSTOMIZED_FP_TYPES
def is_standard_floating(self):
return self.name in dtype.STANDARD_FP_TYPES
@@ -181,8 +182,10 @@ class dtype:
return builder.get_int32_ty()
elif self.name in ('int64', 'uint64'):
return builder.get_int64_ty()
elif self.name == 'fp8':
return builder.get_fp8_ty()
elif self.name == 'fp8e5':
return builder.get_fp8e5_ty()
elif self.name == 'fp8e4':
return builder.get_fp8e4_ty()
elif self.name == 'fp16':
return builder.get_half_ty()
elif self.name == 'bf16':
@@ -314,7 +317,8 @@ uint8 = dtype('uint8')
uint16 = dtype('uint16')
uint32 = dtype('uint32')
uint64 = dtype('uint64')
float8 = dtype('fp8')
float8e5 = dtype('fp8e5')
float8e4 = dtype('fp8e4')
float16 = dtype('fp16')
bfloat16 = dtype('bf16')
float32 = dtype('fp32')
@@ -1047,7 +1051,7 @@ def where(condition, x, y, _builder=None):
If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
:code:`x` and :code:`y` must have the data type.
:code:`x` and :code:`y` must have the same data type.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
@@ -1349,24 +1353,56 @@ def zeros(shape, dtype):
def zeros_like(input):
return zeros(input.shape, input.dtype)
# -----------------------
# Debugging functions
# -----------------------
@builtin
def printf(prefix, *args, _builder=None):
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
pass
@builtin
def static_assert(cond, msg="", _builder=None):
pass
@builtin
def device_print(prefix, *args, _builder=None):
import string
new_prefix = prefix
if isinstance(prefix, constexpr):
new_prefix = prefix.value
assert isinstance(new_prefix, str), f"{new_prefix} is not string"
prefix = _constexpr_to_value(prefix)
assert isinstance(prefix, str), f"{prefix} is not string"
b_ascii = True
for ch in new_prefix:
for ch in prefix:
if ch not in string.printable:
b_ascii = False
break
assert b_ascii, f"{new_prefix} is not an ascii string"
assert b_ascii, f"{prefix} is not an ascii string"
new_args = []
for arg in args:
new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder)
return semantic.device_print(prefix, new_args, _builder)
@builtin
def device_assert(cond, msg="", _builder=None):
msg = _constexpr_to_value(msg)
import inspect
frame = inspect.currentframe()
module = inspect.getmodule(frame)
# The triton function module doesn't have the name attribute.
# We use this trick to find the caller.
while hasattr(module, "__name__"):
frame = frame.f_back
module = inspect.getmodule(frame)
func_name = frame.f_code.co_name
file_name = frame.f_back.f_code.co_filename
# TODO: The line number currently indicates the line
# where the triton function is called but not where the
# device_assert is called. Need to enhance this.
lineno = frame.f_back.f_lineno
return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
# -----------------------
# Iterators

View File

@@ -493,13 +493,22 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
if value == 0:
_value = builder.get_null_value(dtype.to_ir(builder))
if isinstance(value, tl.tensor):
assert value.numel.value == 1, "only accepts size-1 tensor"
value = cast(value, dtype, builder)
ret_ty = tl.block_type(value.dtype, shape)
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
else:
get_value_fn = getattr(builder, f"get_{dtype.name}")
_value = get_value_fn(value)
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(_value, shape), ret_ty)
# scalar
if value == 0:
value = builder.get_null_value(dtype.to_ir(builder))
else:
get_value_fn = getattr(builder, f"get_{dtype.name}")
value = get_value_fn(value)
if dtype is None:
raise ValueError("dtype must be specified when value is not a tensor")
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(value, shape), ret_ty)
@@ -640,7 +649,7 @@ def bitcast(input: tl.tensor,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
src_sca_ty = src_ty.scalar
@@ -664,7 +673,7 @@ def cast(input: tl.tensor,
if isinstance(dst_ty, tl.constexpr):
dst_ty = dst_ty.value
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
@@ -672,8 +681,8 @@ def cast(input: tl.tensor,
dst_sca_ty = dst_ty.scalar
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
(src_sca_ty.is_floating() and dst_sca_ty.is_fp8()):
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
dst_ty)
@@ -1147,6 +1156,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
# get result type
shape = input.type.shape
rank = len(shape)
assert 0 <= axis < rank, f"axis (v={axis}) is out of range, should be within [0, {rank})"
ret_shape = []
for i, s in enumerate(shape):
if i != axis:
@@ -1248,8 +1261,12 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_barrier(), tl.void)
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
new_args = []
for arg in args:
new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
return tl.tensor(builder.create_print(prefix, new_args), tl.void)
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)

View File

@@ -63,7 +63,7 @@ def _fwd_kernel(
p *= l_rcp
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(tl.float16)
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
# update m_i and l_i
@@ -167,7 +167,7 @@ def _bwd_kernel(
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
@@ -175,10 +175,10 @@ def _bwd_kernel(
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(tl.float16), k)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
@@ -198,7 +198,7 @@ class _attention(torch.autograd.Function):
# only support for Ampere now
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported for compute capability < 80")
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]

View File

@@ -15,7 +15,7 @@ class Autotuner(KernelInterface):
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
'''
if not configs:
self.configs = [Config({}, num_warps=4, num_stages=2)]
@@ -168,7 +168,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
resets the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
@@ -176,7 +176,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""

View File

@@ -125,8 +125,6 @@ class JITFunction(KernelInterface[T]):
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**31 <= arg and arg <= 2**32 - 1:
return "u32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
@@ -179,7 +177,8 @@ class JITFunction(KernelInterface[T]):
triton.language.uint16: 'u16',
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
triton.language.float8: 'fp8',
triton.language.float8e5: 'fp8e5',
triton.language.float8e4: 'fp8e4',
triton.language.float16: 'fp16',
triton.language.bfloat16: 'bf16',
triton.language.float32: 'fp32',
@@ -243,7 +242,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
@@ -278,7 +277,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
@@ -291,7 +290,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
exec(src, scope)
return scope[self.fn.__name__]
def __init__(self, fn, version=None, do_not_specialize=None):
def __init__(self, fn, version=None, do_not_specialize=None, debug=None):
self.fn = fn
self.module = fn.__module__
self.version = version
@@ -312,6 +311,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug
# annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
@@ -380,6 +380,7 @@ def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
) -> Callable[[T], JITFunction[T]]:
...
@@ -389,6 +390,7 @@ def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
@@ -413,6 +415,7 @@ def jit(
fn,
version=version,
do_not_specialize=do_not_specialize,
debug=debug,
)
if fn is not None:

View File

@@ -454,10 +454,12 @@ def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
cc = _triton.runtime.cc(backend, device)
if cc < 80:
triton.compiler.init_cuda_utils()
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
if dtype == torch.float32:
ops_per_sub_core = 32 # 2*16
elif dtype == torch.float16:

View File

@@ -93,9 +93,9 @@ func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
func.func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : index
// CHECK-NEXT: %extracted_slice -> %cst
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
%index = arith.constant 0 : i32
// CHECK-NEXT: %0 -> %cst
%cst1 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
return
}
@@ -169,9 +169,9 @@ func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
// CHECK-NEXT: %0#2 -> %cst,%cst_0
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
scf.if %i1 {
%index = arith.constant 8 : index
// CHECK-NEXT: %extracted_slice -> %cst,%cst_0
%cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED>
%index = arith.constant 8 : i32
// CHECK-NEXT: %1 -> %cst,%cst_0
%cst0 = triton_gpu.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED>
scf.yield
}
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>

View File

@@ -200,8 +200,8 @@ func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
func.func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : index
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
%index = arith.constant 0 : i32
%cst1 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
return
// CHECK-NEXT: size = 512
}
@@ -284,8 +284,8 @@ func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f1
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
scf.if %i1 {
%index = arith.constant 8 : index
%cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED>
%index = arith.constant 8 : i32
%cst0 = triton_gpu.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED>
scf.yield
}
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>

View File

@@ -109,8 +109,8 @@ func.func @alloc() {
// CHECK-LABEL: extract_slice
func.func @extract_slice() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : index
%0 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
%index = arith.constant 0 : i32
%0 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>

View File

@@ -13,11 +13,13 @@ configure_lit_site_cfg(
set(TRITON_TEST_DEPENDS
triton-opt
FileCheck
)
set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck")
set(LIT_ARGS "-Dfilecheck=${FILECHECK_PATH}")
add_lit_testsuite(check-triton-lit-tests "Running the triton regression tests"
${CMAKE_CURRENT_BINARY_DIR}
ARGS ${LIT_ARGS}
DEPENDS ${TRITON_TEST_DEPENDS}
)

View File

@@ -1,7 +1,7 @@
// RUN: (triton-opt %s -split-input-file --convert-triton-gpu-to-llvm --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null; true) | FileCheck --check-prefixes=CHECK,GCN %s
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// PTX: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]} {{.*}}
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
@@ -700,9 +700,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-NEXT: llvm.mul
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.getelementptr
%index = arith.constant 1 : index
%index = arith.constant 1 : i32
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
%1 = tensor.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
%1 = triton_gpu.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
return
}
}
@@ -1394,14 +1394,67 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// GCN-NOT: llvm.inline_asm
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
// PTX: llvm.icmp "slt"
// PTX: llvm.inline_asm
// PTX-SAME: atom.global.gpu.add.f32
// PTX-SAME: @$3 atom.global.gpu.add.f32
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
return
}
}
// -----
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32_scalar
func.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
// GCN-NOT: llvm.inline_asm
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
// PTX: llvm.icmp "eq"
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32
func.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
// GCN-NOT: llvm.inline_asm
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
// PTX: llvm.icmp "slt"
// PTX: llvm.inline_asm
// PTX-SAME: @$2 st.global.b32
// PTX: llvm.inline_asm
// PTX-SAME: @$2 st.global.b32
tt.store %arg0, %arg1 : tensor<256xf32, #blocked0>
return
}
}
// -----
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32_scalar
func.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
// GCN-NOT: llvm.inline_asm
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
// PTX: llvm.icmp "slt"
// PTX: llvm.inline_asm
// PTX-SAME: @$2 st.global.b32
tt.store %arg0, %arg1 : f32
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {

View File

@@ -55,49 +55,6 @@ func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK: return %6 : tensor<1024xi32, [[$target_layout]]>
}
// CHECK-LABEL: remat_load_store
func.func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout0>
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout0>) -> tensor<64xi32, #layout1>
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout0>) -> tensor<64x!tt.ptr<i32>, #layout1>
tt.store %5, %4 : tensor<64xi32, #layout1>
return
}
// Don't rematerialize vectorized loads
// CHECK-LABEL: remat_expensive
func.func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout1>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout1>, tensor<64xi32, #layout1>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout1>
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout1>) -> tensor<64xi32, #layout0>
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout1>) -> tensor<64x!tt.ptr<i32>, #layout0>
tt.store %5, %4 : tensor<64xi32, #layout0>
return
}
// Don't rematerialize loads when original and target layouts are different
// CHECK-LABEL: remat_multi_layout
func.func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout0>
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout0>) -> tensor<64xi32, #layout2>
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout0>) -> tensor<64x!tt.ptr<i32>, #layout2>
tt.store %5, %4 : tensor<64xi32, #layout2>
return
}
// Always rematerialize single value loads
// CHECK-LABEL: remat_single_value
func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
@@ -998,3 +955,30 @@ func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !
}
return
}
// Just make sure it doesn't crash on non-tensor types.
// CHECK-LABEL: if_no_tensor
func.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%c-1_i64 = arith.constant -1 : i64
%cst = arith.constant 0.000000e+00 : f32
%c-1_i32 = arith.constant -1 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.addptr %arg3, %0 : !tt.ptr<i64>, i32
%2 = tt.load %1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
%3 = arith.cmpi eq, %2, %c-1_i64 : i64
%4 = arith.select %3, %c-1_i32, %arg2 : i32
%5 = scf.if %3 -> (!tt.ptr<f32>) {
scf.yield %arg0 : !tt.ptr<f32>
} else {
%10 = tt.addptr %arg0, %2 : !tt.ptr<f32>, i64
scf.yield %10 : !tt.ptr<f32>
}
%6 = arith.extsi %4 : i32 to i64
%7 = arith.cmpi slt, %2, %6 : i64
%8 = tt.load %5, %7, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32
%9 = tt.addptr %arg1, %0 : !tt.ptr<f32>, i32
tt.store %9, %8 {cache = 1 : i32, evict = 1 : i32} : f32
return
}

View File

@@ -29,26 +29,26 @@
// CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]]
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
func.func @matmul_loop(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
// A ptrs
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0>
@@ -71,12 +71,15 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
%b_scale = arith.constant dense<4.> : tensor<32x128xf16, #B>
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%a_ = tt.load %a_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%b__ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b_ = triton_gpu.convert_layout %b__ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
@@ -84,10 +87,9 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return
return %loop#2: tensor<128x128xf32, #C>
}
// CHECK: func.func @matmul_loop_nested
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
@@ -101,27 +103,28 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
scf.for %iv0 = %lb to %ub step %step {
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{
%c_start = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
%loop1:1 = scf.for %iv0 = %lb to %ub step %step iter_args(%c_init = %c_start) -> (tensor<128x128xf32, #C>) {
// A ptrs
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0>
@@ -139,12 +142,11 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
%b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL>
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%loop2:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
@@ -156,9 +158,11 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
scf.yield %loop2#2 : tensor<128x128xf32, #C>
}
return
}
return %loop1#0 : tensor<128x128xf32, #C>
}
// CHECK: func.func @matmul_loop_single_pipeline
@@ -170,22 +174,21 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
// A ptrs
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0>
@@ -211,12 +214,12 @@ func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return
}
return %loop#1 : tensor<128x128xf32, #C>
}

View File

@@ -12,20 +12,20 @@
// CHECK: func.func @matmul_loop
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16]
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[A0:.*]][0, 0] [128, 16]
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128]
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[B0:.*]][0, 0] [16, 128]
// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.convert_layout %[[B0_PREFETCH_SMEM]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_PREFETCH]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]]
// CHECK-DAG: %[[A_REM_SMEM:.*]] = tensor.extract_slice %[[arg_a0]][0, 16] [128, 16]
// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_a0]][0, 16] [128, 16]
// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.convert_layout %[[A_REM_SMEM]]
// CHECK-DAG: %[[B_REM_SMEM:.*]] = tensor.extract_slice %[[arg_b0]][16, 0] [16, 128]
// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_b0]][16, 0] [16, 128]
// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.convert_layout %[[B_REM_SMEM]]
// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}}
// CHECK: tt.dot %[[A_REM]], %[[B_REM]], %[[D_FIRST:.*]]
// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [128, 16]
// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [128, 16]
// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_A_PREFETCH_SMEM]]
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128]
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [16, 128]
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {

View File

@@ -1,73 +0,0 @@
// RUN: triton-opt %s -split-input-file -tritongpu-fuse-transposition -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
// -----
// check the UpdateMMAVersionMinorForVolta pattern
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 4], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[4,4]}>
// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor,
// and the pattern should update the versionMinor.
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
// It creates a new MMA layout to fit with $a and $b's dot_operand, and get the right warpsPerCTA
// The ID of this MMA instance should be 0.
// CHECK: [[$new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
module attributes {"triton_gpu.num-warps" = 16 : i32} {
// CHECK-LABEL: dot_mmav1
func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
%CC = triton_gpu.convert_layout %C : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #mma0>
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[$new_mma]], isMMAv1Row = true}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[$new_mma]], isMMAv1Row = true}>> -> tensor<64x64xf32, [[$new_mma]]>
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<64x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<64x64xf32, #mma0>
%res = triton_gpu.convert_layout %D : (tensor<64x64xf32, #mma0>) -> tensor<64x64xf32, #blocked0>
return %res : tensor<64x64xf32, #blocked0>
}
}
// -----
// Check id in multiple MMA layout instances
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 4], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[4,4]}>
// mma id=1, with all other boolean flags be false, should get a versionMinor of 16(= 1 * 1<<4)
#mma1 = #triton_gpu.mma<{versionMajor=1, versionMinor=16, warpsPerCTA=[4,4]}>
// Will still get two MMA layouts
// CHECK-DAG: [[$new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
// CHECK-DAG: [[$new_mma1:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 19, warpsPerCTA = [4, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
#dot_operand_a1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma1, isMMAv1Row=true}>
#dot_operand_b1 = #triton_gpu.dot_op<{opIdx=1, parent=#mma1, isMMAv1Row=false}>
module attributes {"triton_gpu.num-warps" = 16 : i32} {
// CHECK-LABEL: dot_mmav1
func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
%CC = triton_gpu.convert_layout %C : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #mma0>
%AA1 = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a1>
%BB1 = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b1>
%CC1 = triton_gpu.convert_layout %C : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #mma1>
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, {{.*}} {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[$new_mma]], isMMAv1Row = true}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[$new_mma]], isMMAv1Row = true}>> -> tensor<64x64xf32, [[$new_mma]]>
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, {{.*}} {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[$new_mma1]], isMMAv1Row = true}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[$new_mma1]], isMMAv1Row = true}>> -> tensor<64x64xf32, [[$new_mma1]]>
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<64x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<64x64xf32, #mma0>
%D1 = tt.dot %AA1, %BB1, %CC1 {allowTF32 = true} : tensor<64x64xf16, #dot_operand_a1> * tensor<64x64xf16, #dot_operand_b1> -> tensor<64x64xf32, #mma1>
%res = triton_gpu.convert_layout %D : (tensor<64x64xf32, #mma0>) -> tensor<64x64xf32, #blocked0>
%res1 = triton_gpu.convert_layout %D1 : (tensor<64x64xf32, #mma1>) -> tensor<64x64xf32, #blocked0>
%sum = arith.addf %res, %res1 : tensor<64x64xf32, #blocked0>
return %sum : tensor<64x64xf32, #blocked0>
}
}