33#include "llvm/ADT/SmallVectorExtras.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/Support/Debug.h"
39#define DEBUG_TYPE "linalg-utils"
67 assert(cast<AffineConstantExpr>(expr.
getRHS()).getValue() > 0 &&
68 "nonpositive multiplying coefficient");
79 TileCheck t(tileSizes);
94std::optional<RegionMatcher::BinaryOpKind>
96 auto ®ion = op.getRegion();
97 if (!region.hasOneBlock())
115 if (addPattern.match(&ops.back()))
132 for (
Range range : ranges) {
151static SmallVector<int64_t>
154 PackingMetadata &packingMetadata) {
155 int64_t numPackedDims = innerDimsPos.size();
157 llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
158 packingMetadata = computePackingMetadata(rank, innerDimsPos);
163 if (!outerPerm.empty())
170 return packInverseDestPermutation;
177 PackingMetadata &metadata) {
179 int64_t packedRank = packOp.getDestType().getRank();
184 return packInvDestPerm;
188 PackingMetadata &metadata) {
189 int64_t packedRank = unpackOp.getSourceType().getRank();
194 return unpackInvSrcPerm;
198 return llvm::all_of(op.getIndexingMapsArray(), [](
AffineMap m) {
199 return m.isProjectedPermutation(true);
207 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
208 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
210 llvm::any_of(op.getResultTypes(),
211 [](
Type type) { return !type.isIntOrIndexOrFloat(); }))
218 if (op.getNumLoops() != op.getNumParallelLoops())
225 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
226 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
233 return iteratorType == utils::IteratorType::parallel;
237 return iteratorType == utils::IteratorType::reduction;
252 if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
253 !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
254 !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
255 !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
258 return dyn_cast<BlockArgument>(defOp->
getOperand(0));
280 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(inputSubOp))
284 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(filterSubOp))
300 if (!inputBlockArg || !inputZpBlockArg || !filterBlockArg ||
301 !filterZpBlockArg || !outBlockArg)
305 if (inputBlockArg.
getOwner() != body || inputZpBlockArg.
getOwner() != body ||
306 filterBlockArg.
getOwner() != body ||
332 bool containsZeroPointOffset =
false) {
335 if (!isa_and_present<arith::AddIOp, arith::AddFOp>(accOp)) {
336 if (!isa_and_present<arith::OrIOp>(accOp))
342 if (!isOrOp && !isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
344 if (isOrOp && !isa_and_present<arith::AndIOp>(mulOp))
347 if (containsZeroPointOffset) {
356 if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
365template <
typename... OpTypes>
368 if (!(isa_and_present<OpTypes>(defOp) || ...))
375 if (!lhsArg || !rhsArg || lhsArg.
getOwner() != body ||
409 auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
410 if (dimIndex < affineMap.getNumResults())
411 return affineMap.getResult(dimIndex);
421 if ((dim = dyn_cast<AffineDimExpr>(expr)))
424 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
432 if (((dim = dyn_cast<AffineDimExpr>(
lhs)) &&
433 (cst = dyn_cast<AffineConstantExpr>(
rhs))) ||
434 ((dim = dyn_cast<AffineDimExpr>(
rhs)) &&
435 (cst = dyn_cast<AffineConstantExpr>(
lhs))))
463 unsigned fDim,
unsigned oDim,
465 unsigned inputMapIdx = 0, filterMapIdx = 1,
466 outputMapIdx = indexingMaps.size() - 1;
468 auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
476 if (c0 == -1 || c1 == -1)
481 if (dim0 == fExpr && dim1 == oExpr) {
486 if (dim1 == fExpr && dim0 == oExpr) {
500 return indexingMaps ==
501 ArrayAttr::get(context,
502 llvm::map_to_vector<4>(expectedIndexingMaps,
504 return AffineMapAttr::get(m);
539 ArrayAttr indexingMaps;
547 : op(op), ctx(op->
getContext()), dilations(d), strides(s),
548 indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
558 return base * (*strides)[idx] + kernel * (*dilations)[idx];
567 (*dilations)[idx], (*strides)[idx]);
583 Block *body = op.getBlock();
585 switch (poolingType) {
588 containsZeroPointOffset);
609std::optional<DilationsAndStrides>
612 if (isa<linalg::Conv1DOp>(op)) {
627 if (m.matchStride(0, 0, 0, 0)
628 .matchMaps({{m.strided(W, w, 0)},
637std::optional<DilationsAndStrides>
640 if (
auto convOp = dyn_cast<linalg::Conv1DNwcWcfOp>(op.getOperation())) {
642 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
643 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
652 AffineExpr N = m.dim(0);
653 AffineExpr
W = m.dim(1);
654 AffineExpr F = m.dim(2);
655 AffineExpr w = m.dim(3);
656 AffineExpr c = m.dim(4);
658 if (m.matchStride(1, 0, 1, 0)
659 .matchMaps({{N, m.strided(W, w, 0), c},
668std::optional<DilationsAndStrides>
671 if (
auto convOp = dyn_cast<linalg::Conv1DNcwFcwOp>(op.getOperation())) {
673 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
674 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
683 AffineExpr N = m.dim(0);
684 AffineExpr F = m.dim(1);
685 AffineExpr
W = m.dim(2);
686 AffineExpr c = m.dim(3);
687 AffineExpr w = m.dim(4);
689 if (m.matchStride(2, 2, 2, 0)
690 .matchMaps({{N, c, m.strided(W, w, 0)},
699std::optional<DilationsAndStrides>
702 if (isa<linalg::Conv2DOp>(op)) {
704 result.dilations = SmallVector<int64_t>(2, 1);
705 result.strides = SmallVector<int64_t>(2, 1);
714 AffineExpr H = m.dim(0);
715 AffineExpr
W = m.dim(1);
716 AffineExpr h = m.dim(2);
717 AffineExpr w = m.dim(3);
719 if (m.matchStride(0, 0, 0, 0)
720 .matchStride(1, 1, 1, 1)
721 .matchMaps({{m.strided(H, h, 0), m.strided(W, w, 1)},
730std::optional<DilationsAndStrides>
733 if (
auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(op.getOperation())) {
735 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
736 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
745 AffineExpr N = m.dim(0);
746 AffineExpr H = m.dim(1);
747 AffineExpr
W = m.dim(2);
748 AffineExpr F = m.dim(3);
749 AffineExpr h = m.dim(4);
750 AffineExpr w = m.dim(5);
751 AffineExpr c = m.dim(6);
753 if (m.matchStride(1, 0, 1, 0)
754 .matchStride(2, 1, 2, 1)
756 {{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
765std::optional<DilationsAndStrides>
768 if (
auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfQOp>(op.getOperation())) {
770 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
771 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
780 AffineExpr N = m.dim(0);
781 AffineExpr H = m.dim(1);
782 AffineExpr
W = m.dim(2);
783 AffineExpr F = m.dim(3);
784 AffineExpr h = m.dim(4);
785 AffineExpr w = m.dim(5);
786 AffineExpr c = m.dim(6);
788 if (m.matchStride(1, 0, 1, 0)
789 .matchStride(2, 1, 2, 1)
791 {{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
802std::optional<DilationsAndStrides>
805 if (
auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcOp>(op.getOperation())) {
807 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
808 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
817 AffineExpr N = m.dim(0);
818 AffineExpr H = m.dim(1);
819 AffineExpr
W = m.dim(2);
820 AffineExpr F = m.dim(3);
821 AffineExpr h = m.dim(4);
822 AffineExpr w = m.dim(5);
823 AffineExpr c = m.dim(6);
825 if (m.matchStride(1, 1, 1, 0)
826 .matchStride(2, 2, 2, 1)
828 {{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
837std::optional<DilationsAndStrides>
840 if (
auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcQOp>(op.getOperation())) {
842 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
843 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
852 AffineExpr N = m.dim(0);
853 AffineExpr H = m.dim(1);
854 AffineExpr
W = m.dim(2);
855 AffineExpr F = m.dim(3);
856 AffineExpr h = m.dim(4);
857 AffineExpr w = m.dim(5);
858 AffineExpr c = m.dim(6);
860 if (m.matchStride(1, 1, 1, 0)
861 .matchStride(2, 2, 2, 1)
863 {{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
874std::optional<DilationsAndStrides>
877 if (
auto convOp = dyn_cast<linalg::Conv2DNchwFchwOp>(op.getOperation())) {
879 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
880 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
889 AffineExpr N = m.dim(0);
890 AffineExpr F = m.dim(1);
891 AffineExpr H = m.dim(2);
892 AffineExpr
W = m.dim(3);
893 AffineExpr c = m.dim(4);
894 AffineExpr h = m.dim(5);
895 AffineExpr w = m.dim(6);
897 if (m.matchStride(2, 2, 2, 0)
898 .matchStride(3, 3, 3, 1)
900 {{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
909std::optional<DilationsAndStrides>
912 if (
auto convOp = dyn_cast<linalg::Conv2DNchwFchwQOp>(op.getOperation())) {
914 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
915 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
924 AffineExpr N = m.dim(0);
925 AffineExpr F = m.dim(1);
926 AffineExpr H = m.dim(2);
927 AffineExpr
W = m.dim(3);
928 AffineExpr c = m.dim(4);
929 AffineExpr h = m.dim(5);
930 AffineExpr w = m.dim(6);
932 if (m.matchStride(2, 2, 2, 0)
933 .matchStride(3, 3, 3, 1)
935 {{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
946std::optional<DilationsAndStrides>
949 if (
auto convOp = dyn_cast<linalg::Conv2DNgchwFgchwOp>(op.getOperation())) {
951 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
952 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
961 AffineExpr N = m.dim(0);
962 AffineExpr G = m.dim(1);
963 AffineExpr F = m.dim(2);
964 AffineExpr H = m.dim(3);
965 AffineExpr
W = m.dim(4);
966 AffineExpr c = m.dim(5);
967 AffineExpr h = m.dim(6);
968 AffineExpr w = m.dim(7);
970 if (m.matchStride(3, 3, 3, 0)
971 .matchStride(4, 4, 4, 1)
973 {{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
982std::optional<DilationsAndStrides>
985 if (
auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwOp>(op.getOperation())) {
987 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
988 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
997 AffineExpr N = m.dim(0);
998 AffineExpr G = m.dim(1);
999 AffineExpr F = m.dim(2);
1000 AffineExpr H = m.dim(3);
1001 AffineExpr
W = m.dim(4);
1002 AffineExpr c = m.dim(5);
1003 AffineExpr h = m.dim(6);
1004 AffineExpr w = m.dim(7);
1006 if (m.matchStride(3, 3, 3, 0)
1007 .matchStride(4, 4, 4, 1)
1009 {{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
1014 return std::nullopt;
1018std::optional<DilationsAndStrides>
1021 if (
auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwQOp>(op.getOperation())) {
1023 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1024 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1029 return std::nullopt;
1033 AffineExpr N = m.dim(0);
1034 AffineExpr G = m.dim(1);
1035 AffineExpr F = m.dim(2);
1036 AffineExpr H = m.dim(3);
1037 AffineExpr
W = m.dim(4);
1038 AffineExpr c = m.dim(5);
1039 AffineExpr h = m.dim(6);
1040 AffineExpr w = m.dim(7);
1042 if (m.matchStride(3, 3, 3, 0)
1043 .matchStride(4, 4, 4, 1)
1045 {{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
1052 return std::nullopt;
1056std::optional<DilationsAndStrides>
1059 if (
auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcOp>(op.getOperation())) {
1061 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1062 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1067 return std::nullopt;
1071 AffineExpr N = m.dim(0);
1072 AffineExpr H = m.dim(1);
1073 AffineExpr
W = m.dim(2);
1074 AffineExpr G = m.dim(3);
1075 AffineExpr F = m.dim(4);
1076 AffineExpr h = m.dim(5);
1077 AffineExpr w = m.dim(6);
1078 AffineExpr c = m.dim(7);
1080 if (m.matchStride(1, 2, 1, 0)
1081 .matchStride(2, 3, 2, 1)
1083 {{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1088 return std::nullopt;
1092std::optional<DilationsAndStrides>
1095 if (
auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcQOp>(op.getOperation())) {
1097 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1098 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1103 return std::nullopt;
1107 AffineExpr N = m.dim(0);
1108 AffineExpr H = m.dim(1);
1109 AffineExpr
W = m.dim(2);
1110 AffineExpr G = m.dim(3);
1111 AffineExpr F = m.dim(4);
1112 AffineExpr h = m.dim(5);
1113 AffineExpr w = m.dim(6);
1114 AffineExpr c = m.dim(7);
1116 if (m.matchStride(1, 2, 1, 0)
1117 .matchStride(2, 3, 2, 1)
1119 {{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1126 return std::nullopt;
1130std::optional<DilationsAndStrides>
1133 if (isa<linalg::Conv3DOp>(op)) {
1135 result.dilations = SmallVector<int64_t>(3, 1);
1136 result.strides = SmallVector<int64_t>(3, 1);
1141 return std::nullopt;
1145 AffineExpr D = m.dim(0);
1146 AffineExpr H = m.dim(1);
1147 AffineExpr
W = m.dim(2);
1148 AffineExpr d = m.dim(3);
1149 AffineExpr h = m.dim(4);
1150 AffineExpr w = m.dim(5);
1152 if (m.matchStride(0, 0, 0, 0)
1153 .matchStride(1, 1, 1, 1)
1154 .matchStride(2, 2, 2, 2)
1155 .matchMaps({{m.strided(D, d, 0), m.strided(H, h, 1),
1156 m.strided(W, w, 2)},
1161 return std::nullopt;
1165std::optional<DilationsAndStrides>
1168 if (
auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfOp>(op.getOperation())) {
1170 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1171 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1176 return std::nullopt;
1180 AffineExpr N = m.dim(0);
1181 AffineExpr D = m.dim(1);
1182 AffineExpr H = m.dim(2);
1183 AffineExpr
W = m.dim(3);
1184 AffineExpr F = m.dim(4);
1185 AffineExpr d = m.dim(5);
1186 AffineExpr h = m.dim(6);
1187 AffineExpr w = m.dim(7);
1188 AffineExpr c = m.dim(8);
1190 if (m.matchStride(1, 0, 1, 0)
1191 .matchStride(2, 1, 2, 1)
1192 .matchStride(3, 2, 3, 2)
1193 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
1194 m.strided(W, w, 2), c},
1199 return std::nullopt;
1203std::optional<DilationsAndStrides>
1206 if (
auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfQOp>(op.getOperation())) {
1208 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1209 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1214 return std::nullopt;
1218 AffineExpr N = m.dim(0);
1219 AffineExpr D = m.dim(1);
1220 AffineExpr H = m.dim(2);
1221 AffineExpr
W = m.dim(3);
1222 AffineExpr F = m.dim(4);
1223 AffineExpr d = m.dim(5);
1224 AffineExpr h = m.dim(6);
1225 AffineExpr w = m.dim(7);
1226 AffineExpr c = m.dim(8);
1228 if (m.matchStride(1, 0, 1, 0)
1229 .matchStride(2, 1, 2, 1)
1230 .matchStride(3, 2, 3, 2)
1231 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
1232 m.strided(W, w, 2), c},
1239 return std::nullopt;
1243std::optional<DilationsAndStrides>
1246 if (
auto convOp = dyn_cast<linalg::Conv3DNcdhwFcdhwOp>(op.getOperation())) {
1248 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1249 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1254 return std::nullopt;
1258 AffineExpr N = m.dim(0);
1259 AffineExpr F = m.dim(1);
1260 AffineExpr D = m.dim(2);
1261 AffineExpr H = m.dim(3);
1262 AffineExpr
W = m.dim(4);
1263 AffineExpr c = m.dim(5);
1264 AffineExpr d = m.dim(6);
1265 AffineExpr h = m.dim(7);
1266 AffineExpr w = m.dim(8);
1268 if (m.matchStride(2, 2, 2, 0)
1269 .matchStride(3, 3, 3, 1)
1270 .matchStride(4, 4, 4, 2)
1271 .matchMaps({{N, c, m.strided(D, d, 0),
1272 m.strided(H, h, 1), m.strided(W, w, 2)},
1277 return std::nullopt;
1281std::optional<DilationsAndStrides>
1285 dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
1287 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1288 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1293 return std::nullopt;
1297 AffineExpr N = m.dim(0);
1298 AffineExpr
W = m.dim(1);
1299 AffineExpr
C = m.dim(2);
1300 AffineExpr w = m.dim(3);
1302 if (m.matchStride(2, 1, 2, 0)
1303 .matchMaps({{N, C, m.strided(W, w, 0)},
1308 return std::nullopt;
1312std::optional<DilationsAndStrides>
1316 dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
1318 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1319 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1324 return std::nullopt;
1328 AffineExpr N = m.dim(0);
1329 AffineExpr
W = m.dim(1);
1330 AffineExpr
C = m.dim(2);
1331 AffineExpr w = m.dim(3);
1333 if (m.matchStride(1, 0, 1, 0)
1334 .matchMaps({{N, m.strided(W, w, 0), C},
1339 return std::nullopt;
1343std::optional<DilationsAndStrides>
1347 dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
1349 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1350 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1355 return std::nullopt;
1359 AffineExpr N = m.dim(0);
1360 AffineExpr
W = m.dim(1);
1361 AffineExpr
C = m.dim(2);
1362 AffineExpr CM = m.dim(3);
1363 AffineExpr w = m.dim(4);
1365 if (m.matchStride(1, 0, 1, 0)
1366 .matchMaps({{N, m.strided(W, w, 0), C},
1371 return std::nullopt;
1375std::optional<DilationsAndStrides>
1379 dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
1381 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1382 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1387 return std::nullopt;
1391 AffineExpr N = m.dim(0);
1392 AffineExpr H = m.dim(1);
1393 AffineExpr
W = m.dim(2);
1394 AffineExpr
C = m.dim(3);
1395 AffineExpr h = m.dim(4);
1396 AffineExpr w = m.dim(5);
1398 if (m.matchStride(2, 1, 2, 0)
1399 .matchStride(3, 2, 3, 1)
1401 {{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1406 return std::nullopt;
1410std::optional<DilationsAndStrides>
1414 dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
1416 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1417 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1422 return std::nullopt;
1426 AffineExpr N = m.dim(0);
1427 AffineExpr H = m.dim(1);
1428 AffineExpr
W = m.dim(2);
1429 AffineExpr
C = m.dim(3);
1430 AffineExpr h = m.dim(4);
1431 AffineExpr w = m.dim(5);
1433 if (m.matchStride(1, 0, 1, 0)
1434 .matchStride(2, 1, 2, 1)
1436 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1441 return std::nullopt;
1445std::optional<DilationsAndStrides>
1449 dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
1451 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1452 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1457 return std::nullopt;
1461 AffineExpr N = m.dim(0);
1462 AffineExpr H = m.dim(1);
1463 AffineExpr
W = m.dim(2);
1464 AffineExpr
C = m.dim(3);
1465 AffineExpr h = m.dim(4);
1466 AffineExpr w = m.dim(5);
1468 if (m.matchStride(1, 0, 1, 0)
1469 .matchStride(2, 1, 2, 1)
1471 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1478 return std::nullopt;
1482std::optional<DilationsAndStrides>
1486 dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
1488 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1489 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1494 return std::nullopt;
1498 AffineExpr N = m.dim(0);
1499 AffineExpr H = m.dim(1);
1500 AffineExpr
W = m.dim(2);
1501 AffineExpr
C = m.dim(3);
1502 AffineExpr CM = m.dim(4);
1503 AffineExpr h = m.dim(5);
1504 AffineExpr w = m.dim(6);
1506 if (m.matchStride(1, 0, 1, 0)
1507 .matchStride(2, 1, 2, 1)
1509 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1514 return std::nullopt;
1518std::optional<DilationsAndStrides>
1522 dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
1524 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1525 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1530 return std::nullopt;
1534 AffineExpr N = m.dim(0);
1535 AffineExpr H = m.dim(1);
1536 AffineExpr
W = m.dim(2);
1537 AffineExpr
C = m.dim(3);
1538 AffineExpr CM = m.dim(4);
1539 AffineExpr h = m.dim(5);
1540 AffineExpr w = m.dim(6);
1542 if (m.matchStride(1, 0, 1, 0)
1543 .matchStride(2, 1, 2, 1)
1545 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1552 return std::nullopt;
1556std::optional<DilationsAndStrides>
1560 dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
1562 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1563 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1568 return std::nullopt;
1572 AffineExpr N = m.dim(0);
1573 AffineExpr D = m.dim(1);
1574 AffineExpr H = m.dim(2);
1575 AffineExpr
W = m.dim(3);
1576 AffineExpr d = m.dim(4);
1577 AffineExpr h = m.dim(5);
1578 AffineExpr w = m.dim(6);
1579 AffineExpr
C = m.dim(7);
1581 if (m.matchStride(1, 0, 1, 0)
1582 .matchStride(2, 1, 2, 1)
1583 .matchStride(3, 2, 3, 2)
1584 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
1585 m.strided(W, w, 2), C},
1590 return std::nullopt;
1594std::optional<DilationsAndStrides>
1598 dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
1600 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1601 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1606 return std::nullopt;
1610 AffineExpr N = m.dim(0);
1611 AffineExpr D = m.dim(1);
1612 AffineExpr H = m.dim(2);
1613 AffineExpr
W = m.dim(3);
1614 AffineExpr d = m.dim(4);
1615 AffineExpr h = m.dim(5);
1616 AffineExpr w = m.dim(6);
1617 AffineExpr
C = m.dim(7);
1619 if (m.matchStride(2, 1, 2, 0)
1620 .matchStride(3, 2, 3, 1)
1621 .matchStride(4, 3, 4, 2)
1622 .matchMaps({{N, C, m.strided(D, d, 0),
1623 m.strided(H, h, 1), m.strided(W, w, 2)},
1628 return std::nullopt;
1632std::optional<DilationsAndStrides>
1636 dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
1638 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1639 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1644 return std::nullopt;
1648 AffineExpr N = m.dim(0);
1649 AffineExpr D = m.dim(1);
1650 AffineExpr H = m.dim(2);
1651 AffineExpr
W = m.dim(3);
1652 AffineExpr CM = m.dim(4);
1653 AffineExpr d = m.dim(5);
1654 AffineExpr h = m.dim(6);
1655 AffineExpr w = m.dim(7);
1656 AffineExpr
C = m.dim(8);
1658 if (m.matchStride(1, 0, 1, 0)
1659 .matchStride(2, 1, 2, 1)
1660 .matchStride(3, 2, 3, 2)
1661 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
1662 m.strided(W, w, 2), C},
1664 {N, D, H, W, C, CM}})
1667 return std::nullopt;
1671std::optional<DilationsAndStrides>
1674 if (
auto poolOp = dyn_cast<linalg::PoolingNhwcMaxOp>(op.getOperation())) {
1676 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1677 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1682 return std::nullopt;
1686 AffineExpr N = m.dim(0);
1687 AffineExpr H = m.dim(1);
1688 AffineExpr
W = m.dim(2);
1689 AffineExpr
C = m.dim(3);
1690 AffineExpr h = m.dim(4);
1691 AffineExpr w = m.dim(5);
1693 if (m.matchStride(1, 0, 1, 0)
1694 .matchStride(2, 1, 2, 1)
1696 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1701 return std::nullopt;
1705std::optional<DilationsAndStrides>
1708 if (
auto poolOp = dyn_cast<linalg::PoolingNhwcMinOp>(op.getOperation())) {
1710 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1711 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1716 return std::nullopt;
1720 AffineExpr N = m.dim(0);
1721 AffineExpr H = m.dim(1);
1722 AffineExpr
W = m.dim(2);
1723 AffineExpr
C = m.dim(3);
1724 AffineExpr h = m.dim(4);
1725 AffineExpr w = m.dim(5);
1727 if (m.matchStride(1, 0, 1, 0)
1728 .matchStride(2, 1, 2, 1)
1730 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1735 return std::nullopt;
1739std::optional<DilationsAndStrides>
1742 if (
auto poolOp = dyn_cast<linalg::PoolingNhwcSumOp>(op.getOperation())) {
1744 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1745 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1750 return std::nullopt;
1754 AffineExpr N = m.dim(0);
1755 AffineExpr H = m.dim(1);
1756 AffineExpr
W = m.dim(2);
1757 AffineExpr
C = m.dim(3);
1758 AffineExpr h = m.dim(4);
1759 AffineExpr w = m.dim(5);
1761 if (m.matchStride(1, 0, 1, 0)
1762 .matchStride(2, 1, 2, 1)
1764 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1769 return std::nullopt;
1773std::optional<DilationsAndStrides>
1777 dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
1779 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1780 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1785 return std::nullopt;
1789 AffineExpr N = m.dim(0);
1790 AffineExpr H = m.dim(1);
1791 AffineExpr
W = m.dim(2);
1792 AffineExpr
C = m.dim(3);
1793 AffineExpr h = m.dim(4);
1794 AffineExpr w = m.dim(5);
1796 if (m.matchStride(1, 0, 1, 0)
1797 .matchStride(2, 1, 2, 1)
1799 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1804 return std::nullopt;
1808std::optional<DilationsAndStrides>
1812 dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
1814 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1815 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1820 return std::nullopt;
1824 AffineExpr N = m.dim(0);
1825 AffineExpr H = m.dim(1);
1826 AffineExpr
W = m.dim(2);
1827 AffineExpr
C = m.dim(3);
1828 AffineExpr h = m.dim(4);
1829 AffineExpr w = m.dim(5);
1831 if (m.matchStride(1, 0, 1, 0)
1832 .matchStride(2, 1, 2, 1)
1834 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1839 return std::nullopt;
1843std::optional<DilationsAndStrides>
1846 if (
auto poolOp = dyn_cast<linalg::PoolingNchwSumOp>(op.getOperation())) {
1848 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1849 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1854 return std::nullopt;
1858 AffineExpr N = m.dim(0);
1859 AffineExpr
C = m.dim(1);
1860 AffineExpr H = m.dim(2);
1861 AffineExpr
W = m.dim(3);
1862 AffineExpr h = m.dim(4);
1863 AffineExpr w = m.dim(5);
1865 if (m.matchStride(2, 0, 2, 0)
1866 .matchStride(3, 1, 3, 1)
1868 {{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1873 return std::nullopt;
1877std::optional<DilationsAndStrides>
1880 if (
auto poolOp = dyn_cast<linalg::PoolingNchwMaxOp>(op.getOperation())) {
1882 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1883 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1888 return std::nullopt;
1892 AffineExpr N = m.dim(0);
1893 AffineExpr
C = m.dim(1);
1894 AffineExpr H = m.dim(2);
1895 AffineExpr
W = m.dim(3);
1896 AffineExpr h = m.dim(4);
1897 AffineExpr w = m.dim(5);
1899 if (m.matchStride(2, 0, 2, 0)
1900 .matchStride(3, 1, 3, 1)
1902 {{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1907 return std::nullopt;
1911std::optional<DilationsAndStrides>
1914 if (
auto poolOp = dyn_cast<linalg::PoolingNwcSumOp>(op.getOperation())) {
1916 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1917 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1922 return std::nullopt;
1926 AffineExpr N = m.dim(0);
1927 AffineExpr
W = m.dim(1);
1928 AffineExpr
C = m.dim(2);
1929 AffineExpr w = m.dim(3);
1931 if (m.matchStride(1, 0, 1, 0)
1932 .matchMaps({{N, m.strided(W, w, 0), C},
1937 return std::nullopt;
1941std::optional<DilationsAndStrides>
1944 if (
auto poolOp = dyn_cast<linalg::PoolingNcwSumOp>(op.getOperation())) {
1946 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1947 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1952 return std::nullopt;
1956 AffineExpr N = m.dim(0);
1957 AffineExpr
C = m.dim(1);
1958 AffineExpr
W = m.dim(2);
1959 AffineExpr w = m.dim(3);
1961 if (m.matchStride(2, 0, 2, 0)
1962 .matchMaps({{N, C, m.strided(W, w, 0)},
1967 return std::nullopt;
1971std::optional<DilationsAndStrides>
1974 if (
auto poolOp = dyn_cast<linalg::PoolingNwcMaxOp>(op.getOperation())) {
1976 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1977 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1982 return std::nullopt;
1986 AffineExpr N = m.dim(0);
1987 AffineExpr
W = m.dim(1);
1988 AffineExpr
C = m.dim(2);
1989 AffineExpr w = m.dim(3);
1991 if (m.matchStride(1, 0, 1, 0)
1992 .matchMaps({{N, m.strided(W, w, 0), C},
1997 return std::nullopt;
2001std::optional<DilationsAndStrides>
2005 dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
2007 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2008 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2013 return std::nullopt;
2017 AffineExpr N = m.dim(0);
2018 AffineExpr
W = m.dim(1);
2019 AffineExpr
C = m.dim(2);
2020 AffineExpr w = m.dim(3);
2022 if (m.matchStride(1, 0, 1, 0)
2023 .matchMaps({{N, m.strided(W, w, 0), C},
2028 return std::nullopt;
2032std::optional<DilationsAndStrides>
2035 if (
auto poolOp = dyn_cast<linalg::PoolingNcwMaxOp>(op.getOperation())) {
2037 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2038 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2043 return std::nullopt;
2047 AffineExpr N = m.dim(0);
2048 AffineExpr
C = m.dim(1);
2049 AffineExpr
W = m.dim(2);
2050 AffineExpr w = m.dim(3);
2052 if (m.matchStride(2, 0, 2, 0)
2053 .matchMaps({{N, C, m.strided(W, w, 0)},
2058 return std::nullopt;
2062std::optional<DilationsAndStrides>
2065 if (
auto poolOp = dyn_cast<linalg::PoolingNwcMinOp>(op.getOperation())) {
2067 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2068 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2073 return std::nullopt;
2077 AffineExpr N = m.dim(0);
2078 AffineExpr
W = m.dim(1);
2079 AffineExpr
C = m.dim(2);
2080 AffineExpr w = m.dim(3);
2082 if (m.matchStride(1, 0, 1, 0)
2083 .matchMaps({{N, m.strided(W, w, 0), C},
2088 return std::nullopt;
2092std::optional<DilationsAndStrides>
2096 dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
2098 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2099 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2104 return std::nullopt;
2108 AffineExpr N = m.dim(0);
2109 AffineExpr
W = m.dim(1);
2110 AffineExpr
C = m.dim(2);
2111 AffineExpr w = m.dim(3);
2113 if (m.matchStride(1, 0, 1, 0)
2114 .matchMaps({{N, m.strided(W, w, 0), C},
2119 return std::nullopt;
2123std::optional<DilationsAndStrides>
2126 if (
auto poolOp = dyn_cast<linalg::PoolingNdhwcSumOp>(op.getOperation())) {
2128 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2129 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2134 return std::nullopt;
2138 AffineExpr N = m.dim(0);
2139 AffineExpr D = m.dim(1);
2140 AffineExpr H = m.dim(2);
2141 AffineExpr
W = m.dim(3);
2142 AffineExpr
C = m.dim(4);
2143 AffineExpr d = m.dim(5);
2144 AffineExpr h = m.dim(6);
2145 AffineExpr w = m.dim(7);
2147 if (m.matchStride(1, 0, 1, 0)
2148 .matchStride(2, 1, 2, 1)
2149 .matchStride(3, 2, 3, 2)
2150 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
2151 m.strided(W, w, 2), C},
2156 return std::nullopt;
2160std::optional<DilationsAndStrides>
2163 if (
auto poolOp = dyn_cast<linalg::PoolingNdhwcMaxOp>(op.getOperation())) {
2165 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2166 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2171 return std::nullopt;
2175 AffineExpr N = m.dim(0);
2176 AffineExpr D = m.dim(1);
2177 AffineExpr H = m.dim(2);
2178 AffineExpr
W = m.dim(3);
2179 AffineExpr
C = m.dim(4);
2180 AffineExpr d = m.dim(5);
2181 AffineExpr h = m.dim(6);
2182 AffineExpr w = m.dim(7);
2184 if (m.matchStride(1, 0, 1, 0)
2185 .matchStride(2, 1, 2, 1)
2186 .matchStride(3, 2, 3, 2)
2187 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
2188 m.strided(W, w, 2), C},
2193 return std::nullopt;
2197std::optional<DilationsAndStrides>
2200 if (
auto poolOp = dyn_cast<linalg::PoolingNdhwcMinOp>(op.getOperation())) {
2202 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2203 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2208 return std::nullopt;
2212 AffineExpr N = m.dim(0);
2213 AffineExpr D = m.dim(1);
2214 AffineExpr H = m.dim(2);
2215 AffineExpr
W = m.dim(3);
2216 AffineExpr
C = m.dim(4);
2217 AffineExpr d = m.dim(5);
2218 AffineExpr h = m.dim(6);
2219 AffineExpr w = m.dim(7);
2221 if (m.matchStride(1, 0, 1, 0)
2222 .matchStride(2, 1, 2, 1)
2223 .matchStride(3, 2, 3, 2)
2224 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
2225 m.strided(W, w, 2), C},
2230 return std::nullopt;
2237 auto sliceOp = source.
getDefiningOp<tensor::ExtractSliceOp>();
2243 Value current = sliceOp.getSource();
2248 OpResult opResult = cast<OpResult>(current);
2249 current = linalgOp.getDpsInitOperand(opResult.
getResultNumber())->get();
2251 auto padOp = current ? current.
getDefiningOp<tensor::PadOp>() :
nullptr;
2260 if (sliceOp.getSource().getType() != type)
2265 if (llvm::any_of(padOp.getMixedLowPad(), [](
OpFoldResult ofr) {
2266 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
2273 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
2274 if (!padOpSliceOp ||
2275 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
2282 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
2283 [](std::tuple<OpFoldResult, OpFoldResult> it) {
2284 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
2291 Value padOpPad = padOp.getConstantPaddingValue();
2298 return sliceOp.getSource();
2302 auto memrefTypeTo = cast<MemRefType>(to.
getType());
2304 auto memrefTypeFrom = cast<MemRefType>(from.
getType());
2305 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
2306 "`from` and `to` memref must have the same rank");
2312 utils::IteratorType::parallel);
2313 return linalg::GenericOp::create(
2320 linalg::YieldOp::create(
b, loc, args.front());
2333 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2334 "expected as many entries for proc info as number of loops, even if "
2335 "they are null entries");
2337 if (!linalgOp.hasPureBufferSemantics())
2338 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2342 b, loc, lbs, ubs, steps, iterArgInitValues,
2344 assert(iterArgs.size() == iterArgInitValues.size() &&
2345 "expect the number of output tensors and iter args to match");
2347 if (!iterArgs.empty()) {
2348 operandValuesToUse = linalgOp.getDpsInputs();
2349 operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
2351 return bodyBuilderFn(
b, loc, ivs, operandValuesToUse);
2354 if (loopNest.
loops.empty() || procInfo.empty())
2358 for (
const auto &loop : llvm::enumerate(loopNest.
loops)) {
2359 if (procInfo[loop.index()].distributionMethod ==
2361 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
2362 procInfo[loop.index()].nprocs);
2377 if (!linalgOp.hasPureBufferSemantics())
2378 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2379 assert(iterArgInitValues.empty() &&
"unexpected AffineForOp init values");
2385 constantSteps.reserve(steps.size());
2386 for (
Value v : steps) {
2388 assert(constVal.has_value() &&
"Affine loops require constant steps");
2389 constantSteps.push_back(constVal.value());
2394 bodyBuilderFn(
b, loc, ivs,
2395 linalgOp->getOperands());
2427 assert(lbs.size() == ubs.size());
2428 assert(lbs.size() == steps.size());
2429 assert(lbs.size() == iteratorTypes.size());
2430 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
2434 if (iteratorTypes.empty()) {
2435 bodyBuilderFn(
b, loc, ivStorage);
2443 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
2445 ivStorage.append(ivs.begin(), ivs.end());
2446 generateParallelLoopNest(
2447 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
2448 iteratorTypes.drop_front(),
2449 procInfo.empty() ? procInfo : procInfo.drop_front(),
2450 bodyBuilderFn, ivStorage);
2455 unsigned nLoops = iteratorTypes.size();
2456 unsigned numProcessed = 0;
2458 if (procInfo.empty()) {
2461 distributionMethod = procInfo.front().distributionMethod;
2470 auto remainderProcInfo =
2471 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
2472 switch (distributionMethod) {
2476 scf::ParallelOp::create(
2477 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2478 steps.take_front(numProcessed),
2480 ivStorage.append(localIvs.begin(), localIvs.end());
2481 generateParallelLoopNest(
2482 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2483 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2484 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2485 bodyBuilderFn, ivStorage);
2492 scf::ParallelOp::create(
2493 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2494 steps.take_front(numProcessed),
2496 ivStorage.append(localIvs.begin(), localIvs.end());
2497 generateParallelLoopNest(
2498 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2499 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2500 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2501 bodyBuilderFn, ivStorage);
2508 Value cond = ab.
slt(lbs[0], ubs[0]);
2509 for (
unsigned i = 1; i < numProcessed; ++i)
2510 cond = ab.
_and(cond, ab.
slt(lbs[i], ubs[i]));
2511 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2514 ubs.drop_front(numProcessed),
2515 steps.drop_front(numProcessed),
2516 iteratorTypes.drop_front(numProcessed),
2517 remainderProcInfo, bodyBuilderFn, ivStorage);
2525 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2527 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
2528 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
2529 remainderProcInfo, bodyBuilderFn, ivStorage);
2544 if (!linalgOp.hasPureBufferSemantics())
2545 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2546 assert(iterArgInitValues.empty() &&
"unexpected ParallelOp init values");
2548 assert(iteratorTypes.size() >= loopRanges.size() &&
2549 "expected iterator type for all ranges");
2550 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2551 "expected proc information for all loops when present");
2552 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
2554 unsigned numLoops = iteratorTypes.size();
2555 ivs.reserve(numLoops);
2556 lbsStorage.reserve(numLoops);
2557 ubsStorage.reserve(numLoops);
2558 stepsStorage.reserve(numLoops);
2561 unpackRanges(
b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
2564 for (
const auto &it : llvm::enumerate(procInfo)) {
2567 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
2568 ubsStorage[it.index()], stepsStorage[it.index()]);
2571 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
2573 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
2575 bodyBuilderFn(
b, loc, ivs, linalgOp->getOperands());
2579 assert(ivs.size() == iteratorTypes.size() &&
"did not generate enough loops");
2585 auto shapedType = dyn_cast<ShapedType>(valueToTile.
getType());
2587 .Case([&](MemRefType) {
2588 return memref::SubViewOp::create(
2589 builder, loc, valueToTile, sliceParams.
offsets,
2592 .Case([&](RankedTensorType) {
2593 return tensor::ExtractSliceOp::create(
2594 builder, loc, valueToTile, sliceParams.
offsets,
2597 .DefaultUnreachable(
"Unexpected shaped type");
2606 bool omitPartialTileCheck) {
2609 ubs, subShapeSizes, omitPartialTileCheck);
2618 bool omitPartialTileCheck) {
2619 auto shapedType = dyn_cast<ShapedType>(valueToTile.
getType());
2620 assert(shapedType &&
"only shaped types can be tiled");
2622 int64_t rank = shapedType.getRank();
2626 sliceParams.
offsets.reserve(rank);
2627 sliceParams.
sizes.reserve(rank);
2628 sliceParams.
strides.reserve(rank);
2629 for (
unsigned r = 0; r < rank; ++r) {
2630 LLVM_DEBUG(llvm::dbgs() <<
"computeSliceParameters: for dim#" << r);
2634 sliceParams.
sizes.push_back(dim);
2636 LLVM_DEBUG(llvm::dbgs() <<
": not tiled: use size: " << dim <<
"\n");
2639 LLVM_DEBUG(llvm::dbgs() <<
": tiled: figure out subsize...\n");
2644 LLVM_DEBUG(llvm::dbgs() <<
"computeSliceParameters: submap: " << m <<
"\n");
2649 [[maybe_unused]]
auto res = m.constantFold(zeros, mAtZero);
2650 assert(succeeded(res) &&
"affine_map must be evaluatable (not symbols)");
2652 cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
2654 rewriter, loc, m.getResult(0) - mAtZeroInt, lbs);
2655 sliceParams.
offsets.push_back(offset);
2663 LLVM_DEBUG(llvm::dbgs()
2664 <<
"computeSliceParameters: raw size: " << size <<
"\n");
2665 LLVM_DEBUG(llvm::dbgs()
2666 <<
"computeSliceParameters: new offset: " << offset <<
"\n");
2669 if (omitPartialTileCheck) {
2672 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: new size: " << size <<
"\n");
2673 sliceParams.
sizes.push_back(size);
2684 auto hasTileSizeOne = sizeCst == 1;
2685 auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
2686 ((shapeSize % *sizeCst) == 0);
2687 if (!hasTileSizeOne && !dividesEvenly) {
2688 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: shapeSize=" << shapeSize
2689 <<
", size: " << size
2690 <<
": make sure in bound with affine.min\n");
2694 bindDims(context, dim0, dim1, dim2);
2725 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: new size: " << size <<
"\n");
2726 sliceParams.
sizes.push_back(size);
2735 for (
unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
2736 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShapes: for loop#" << idx <<
"\n");
2738 offsets.push_back(
isTiled ? ivs[idxIvs++] :
b.getIndexAttr(0));
2739 LLVM_DEBUG(llvm::dbgs()
2740 <<
"computeTileOffsets: " << offsets.back() <<
"\n");
2749 for (
unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
2756 LLVM_DEBUG(llvm::dbgs() <<
"computeTileSizes: " << sizes.back() <<
"\n");
2762 if (op.hasPureBufferSemantics())
2764 return llvm::map_to_vector(
2765 op.getDpsInitsMutable(), [&](
OpOperand &opOperand) {
2766 return operands[opOperand.getOperandNumber()].getType();
2773 if (op.hasPureBufferSemantics())
2776 tensorResults.reserve(results.size());
2778 unsigned resultIdx = 0;
2779 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2782 Value outputTensor = operands[opOperand.getOperandNumber()];
2783 if (
auto sliceOp = outputTensor.
getDefiningOp<tensor::ExtractSliceOp>()) {
2785 builder, loc, sliceOp.getSource().getType(), results[resultIdx],
2786 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
2787 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2788 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2791 tensorResults.push_back(results[resultIdx]);
2795 return tensorResults;
2803 bool omitPartialTileCheck) {
2804 assert(ivs.size() ==
static_cast<size_t>(llvm::count_if(
2805 llvm::make_range(tileSizes.begin(), tileSizes.end()),
2807 "expected as many ivs as non-zero sizes");
2816 assert(
static_cast<int64_t>(valuesToTile.size()) <=
2817 linalgOp->getNumOperands() &&
2818 "more value to tile than operands.");
2820 allSliceParams.reserve(valuesToTile.size());
2821 for (
auto [opOperand, val] :
2822 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
2823 Value shapedOp = val;
2824 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShapes: for operand " << shapedOp);
2825 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
2832 Type operandType = opOperand.get().getType();
2833 if (!
isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
2834 linalgOp.isDpsInit(&opOperand))) {
2835 allSliceParams.push_back(std::nullopt);
2836 LLVM_DEBUG(llvm::dbgs()
2837 <<
": not tiled: use shape: " << operandType <<
"\n");
2840 LLVM_DEBUG(llvm::dbgs() <<
": tiled: figure out subshape...\n");
2843 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
2844 omitPartialTileCheck));
2847 return allSliceParams;
2855 bool omitPartialTileCheck) {
2858 tileSizes, sizeBounds, omitPartialTileCheck);
2860 for (
auto item : llvm::zip(valuesToTile, allSliceParameter)) {
2861 Value valueToTile = std::get<0>(item);
2862 std::optional<SliceParameters> sliceParams = std::get<1>(item);
2863 tiledShapes.push_back(
2864 sliceParams.has_value()
2880 if (!linalgOp.hasIndexSemantics())
2883 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
2884 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
2887 b.setInsertionPointAfter(indexOp);
2891 b, indexOp.getLoc(),
index + offset,
2892 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
2893 Value materialized =
2895 b.replaceUsesWithIf(indexOp, materialized, [&](
OpOperand &use) {
2907std::optional<SmallVector<ReassociationIndices>>
2911 for (
const auto &it : llvm::enumerate(mixedSizes)) {
2912 auto dim = it.index();
2913 auto size = it.value();
2914 curr.push_back(dim);
2915 auto attr = llvm::dyn_cast_if_present<Attribute>(size);
2916 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
2919 std::swap(reassociation.back(), curr);
2924 if (!curr.empty() && !reassociation.empty())
2925 reassociation.back().append(curr.begin(), curr.end());
2926 return reassociation;
static SmallVector< int64_t > computePackUnPackPerm(int64_t rank, ArrayRef< int64_t > &innerDimsPos, ArrayRef< int64_t > &outerPerm, PackingMetadata &packingMetadata)
The permutation can be obtained from two permutations: a) Compute the permutation vector to move the ...
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
static void unpackRanges(OpBuilder &builder, Location loc, ArrayRef< Range > ranges, SmallVectorImpl< Value > &lbs, SmallVectorImpl< Value > &ubs, SmallVectorImpl< Value > &steps)
Given a list of subview ranges, extract individual values for lower, upper bounds and steps and put t...
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessIntOrFloat() const
Return true of this is a signless integer or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
Helper class for building convolution op matchers with minimal boilerplate.
ConvMatcherBuilder & matchStride(unsigned iDim, unsigned fDim, unsigned oDim, unsigned idx)
Match stride/dilation pattern for a spatial dimension.
bool matchBody(bool containsZeroPointOffset=false)
Match body pattern. This should be called last.
AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx)
Build strided expression: base * stride[idx] + kernel * dilation[idx].
AffineExpr dim(unsigned i)
Get affine dimension expression for dimension i.
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector< int64_t > *d, SmallVector< int64_t > *s, PoolingType poolingType=PoolingType::None)
ConvMatcherBuilder & matchMaps(ArrayRef< ArrayRef< AffineExpr > > maps)
Match expected indexing maps layout. Returns *this for method chaining.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNchwChwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcmOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNcwMaxOp >(LinalgOp op)
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMinOp >(LinalgOp op)
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body, bool containsZeroPointOffset=false)
Utility to match block body for convolution ops.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcmOp >(LinalgOp op)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DNcwFcwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNchwSumOp >(LinalgOp op)
SmallVector< OpFoldResult > computeTileSizes(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds)
Computes tile sizes, given a list of tileSizes and dimension sizes (sizeBounds).
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMaxUnsignedOp >(LinalgOp op)
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to)
Returns GenericOp that copies an n-D memref.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcOp >(LinalgOp op)
static void generateParallelLoopNest(OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef< utils::IteratorType > iteratorTypes, ArrayRef< linalg::ProcInfo > procInfo, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, SmallVectorImpl< Value > &ivStorage)
Generates a loop nest consisting of scf.parallel and scf.for, depending on the iteratorTypes.
SmallVector< OpFoldResult > computeTileOffsets(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes)
Computes tile offsets, given a list of loop ivs and tileSizes.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwGfchwQOp >(LinalgOp op)
PoolingType
Enum representing pooling operation types used by ConvMatcherBuilder.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNcwCwOp >(LinalgOp op)
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body)
static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp, Block *body)
Utility function to match the zero point offset body of quantized convolution ops.
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcOp >(LinalgOp op)
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
bool hasOnlyScalarElementwiseOp(Region &r)
Detect whether r has only ConstantOp, ElementwiseMappable and YieldOp.
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex)
static BlockArgument getBlockArgumentWithOptionalCastOps(Value val)
Returns the BlockArgument that leads to val, if any.
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body)
Utility to match block body for linalg.pool* ops.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNchwFchwQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DNwcWcfOp >(LinalgOp op)
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMinUnsignedOp >(LinalgOp op)
DistributionMethod
Scheme used to distribute loops to processors.
@ CyclicNumProcsGeNumIters
Cyclic distribution where the number of processors can be assumed to be more than or equal to the num...
@ Cyclic
Cyclic distribution where no assumption is made about the dynamic relationship between number of proc...
@ CyclicNumProcsEqNumIters
Cyclic distribution where the number of processors can be assumed to be equal to the number of iterat...
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body)
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNcdhwCdhwOp >(LinalgOp op)
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMaxUnsignedOp >(LinalgOp op)
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body)
Matches sum pooling body pattern.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcHwcfOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNcdhwFcdhwOp >(LinalgOp op)
SmallVector< std::optional< SliceParameters > > computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Computes SliceParamaters for all valuesToTile of the given linalgOp, assuming linalgOp is being fused...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcHwcfQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNchwFchwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcFhwcQOp >(LinalgOp op)
Operation * makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Creates an extract_slice/subview op for a single valueToTile with builder.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNchwMaxOp >(LinalgOp op)
static bool convLayoutMatches(ArrayRef< ArrayRef< AffineExpr > > mapListExpected, ArrayAttr indexingMaps, MLIRContext *context)
Returns true if the given indexing maps matches with the expected indexing maps.
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body)
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride)
Given an array of AffineMaps indexingMaps verify the following commutatively:- indexingMaps[0]....
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcSumOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMinOp >(LinalgOp op)
static Operation * materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwFgchwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwGfchwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMaxOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcMinOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMinUnsignedOp >(LinalgOp op)
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcSumOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DOp >(LinalgOp op)
static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim)
Check if expr is either:
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcSumOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcFhwcOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMaxOp >(LinalgOp op)
void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, Value procId, Value nprocs, Value &lb, Value &ub, Value &step)
Update the lb, ub and step to get per processor lb, ub and step.
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNcwSumOp >(LinalgOp op)
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcMaxOp >(LinalgOp op)
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
@ Mul
RHS of mul is always a constant or a symbolic expression.
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Helper struct to build simple arithmetic quantities with minimal type inference support.
Value _and(Value lhs, Value rhs)
Value slt(Value lhs, Value rhs)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A struct containing dilations and strides inferred from convolution ops.
Utility class used to generate nested loops with ranges described by loopRanges and loop type describ...
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
Callback function type used to get processor ID, and number of processors used for distribution for a...
DistributionMethod distributionMethod
static std::optional< BinaryOpKind > matchAsScalarBinaryOp(GenericOp op)
Matches the given linalg op if its body is performing binary operation on int or float scalar values ...
A struct containg offsets-sizes-strides arguments of the tiled shape.
SmallVector< OpFoldResult > strides
SmallVector< OpFoldResult > sizes
SmallVector< OpFoldResult > offsets