33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
37#define DEBUG_TYPE "linalg-utils"
65 assert(cast<AffineConstantExpr>(expr.
getRHS()).getValue() > 0 &&
66 "nonpositive multiplying coefficient");
77 TileCheck t(tileSizes);
92std::optional<RegionMatcher::BinaryOpKind>
94 auto ®ion = op.getRegion();
95 if (!region.hasOneBlock())
113 if (addPattern.match(&ops.back()))
130 for (
Range range : ranges) {
149static SmallVector<int64_t>
152 PackingMetadata &packingMetadata) {
153 int64_t numPackedDims = innerDimsPos.size();
155 llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
156 packingMetadata = computePackingMetadata(rank, innerDimsPos);
161 if (!outerPerm.empty())
168 return packInverseDestPermutation;
175 PackingMetadata &metadata) {
177 int64_t packedRank = packOp.getDestType().getRank();
182 return packInvDestPerm;
186 PackingMetadata &metadata) {
187 int64_t packedRank = unpackOp.getSourceType().getRank();
192 return unpackInvSrcPerm;
196 return llvm::all_of(op.getIndexingMapsArray(), [](
AffineMap m) {
197 return m.isProjectedPermutation(true);
205 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
206 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
208 llvm::any_of(op.getResultTypes(),
209 [](
Type type) { return !type.isIntOrIndexOrFloat(); }))
216 if (op.getNumLoops() != op.getNumParallelLoops())
223 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
224 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
231 return iteratorType == utils::IteratorType::parallel;
235 return iteratorType == utils::IteratorType::reduction;
250 if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
251 !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
252 !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
253 !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
256 return dyn_cast<BlockArgument>(defOp->
getOperand(0));
278 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(inputSubOp))
282 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(filterSubOp))
298 if (!inputBlockArg || !inputZpBlockArg || !filterBlockArg ||
299 !filterZpBlockArg || !outBlockArg)
303 if (inputBlockArg.
getOwner() != body || inputZpBlockArg.
getOwner() != body ||
304 filterBlockArg.
getOwner() != body ||
330 bool containsZeroPointOffset =
false) {
333 if (!isa_and_present<arith::AddIOp, arith::AddFOp>(accOp)) {
334 if (!isa_and_present<arith::OrIOp>(accOp))
340 if (!isOrOp && !isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
342 if (isOrOp && !isa_and_present<arith::AndIOp>(mulOp))
345 if (containsZeroPointOffset) {
354 if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
363template <
typename... OpTypes>
366 if (!(isa_and_present<OpTypes>(defOp) || ...))
373 if (!lhsArg || !rhsArg || lhsArg.
getOwner() != body ||
407 auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
408 if (dimIndex < affineMap.getNumResults())
409 return affineMap.getResult(dimIndex);
419 if ((dim = dyn_cast<AffineDimExpr>(expr)))
422 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
430 if (((dim = dyn_cast<AffineDimExpr>(
lhs)) &&
431 (cst = dyn_cast<AffineConstantExpr>(
rhs))) ||
432 ((dim = dyn_cast<AffineDimExpr>(
rhs)) &&
433 (cst = dyn_cast<AffineConstantExpr>(
lhs))))
461 unsigned fDim,
unsigned oDim,
463 unsigned inputMapIdx = 0, filterMapIdx = 1,
464 outputMapIdx = indexingMaps.size() - 1;
466 auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
474 if (c0 == -1 || c1 == -1)
479 if (dim0 == fExpr && dim1 == oExpr) {
484 if (dim1 == fExpr && dim0 == oExpr) {
498 return indexingMaps ==
500 context, llvm::to_vector<4>(llvm::map_range(
502 return AffineMapAttr::get(m);
537 ArrayAttr indexingMaps;
545 : op(op), ctx(op->
getContext()), dilations(d), strides(s),
546 indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
556 return base * (*strides)[idx] + kernel * (*dilations)[idx];
565 (*dilations)[idx], (*strides)[idx]);
581 Block *body = op.getBlock();
583 switch (poolingType) {
586 containsZeroPointOffset);
607std::optional<DilationsAndStrides>
610 if (isa<linalg::Conv1DOp>(op)) {
625 if (m.matchStride(0, 0, 0, 0)
626 .matchMaps({{m.strided(W, w, 0)},
635std::optional<DilationsAndStrides>
638 if (
auto convOp = dyn_cast<linalg::Conv1DNwcWcfOp>(op.getOperation())) {
640 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
641 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
650 AffineExpr N = m.dim(0);
651 AffineExpr
W = m.dim(1);
652 AffineExpr F = m.dim(2);
653 AffineExpr w = m.dim(3);
654 AffineExpr c = m.dim(4);
656 if (m.matchStride(1, 0, 1, 0)
657 .matchMaps({{N, m.strided(W, w, 0), c},
666std::optional<DilationsAndStrides>
669 if (
auto convOp = dyn_cast<linalg::Conv1DNcwFcwOp>(op.getOperation())) {
671 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
672 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
681 AffineExpr N = m.dim(0);
682 AffineExpr F = m.dim(1);
683 AffineExpr
W = m.dim(2);
684 AffineExpr c = m.dim(3);
685 AffineExpr w = m.dim(4);
687 if (m.matchStride(2, 2, 2, 0)
688 .matchMaps({{N, c, m.strided(W, w, 0)},
697std::optional<DilationsAndStrides>
700 if (isa<linalg::Conv2DOp>(op)) {
702 result.dilations = SmallVector<int64_t>(2, 1);
703 result.strides = SmallVector<int64_t>(2, 1);
712 AffineExpr H = m.dim(0);
713 AffineExpr
W = m.dim(1);
714 AffineExpr h = m.dim(2);
715 AffineExpr w = m.dim(3);
717 if (m.matchStride(0, 0, 0, 0)
718 .matchStride(1, 1, 1, 1)
719 .matchMaps({{m.strided(H, h, 0), m.strided(W, w, 1)},
728std::optional<DilationsAndStrides>
731 if (
auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(op.getOperation())) {
733 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
734 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
743 AffineExpr N = m.dim(0);
744 AffineExpr H = m.dim(1);
745 AffineExpr
W = m.dim(2);
746 AffineExpr F = m.dim(3);
747 AffineExpr h = m.dim(4);
748 AffineExpr w = m.dim(5);
749 AffineExpr c = m.dim(6);
751 if (m.matchStride(1, 0, 1, 0)
752 .matchStride(2, 1, 2, 1)
754 {{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
763std::optional<DilationsAndStrides>
766 if (
auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfQOp>(op.getOperation())) {
768 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
769 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
778 AffineExpr N = m.dim(0);
779 AffineExpr H = m.dim(1);
780 AffineExpr
W = m.dim(2);
781 AffineExpr F = m.dim(3);
782 AffineExpr h = m.dim(4);
783 AffineExpr w = m.dim(5);
784 AffineExpr c = m.dim(6);
786 if (m.matchStride(1, 0, 1, 0)
787 .matchStride(2, 1, 2, 1)
789 {{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
800std::optional<DilationsAndStrides>
803 if (
auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcOp>(op.getOperation())) {
805 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
806 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
815 AffineExpr N = m.dim(0);
816 AffineExpr H = m.dim(1);
817 AffineExpr
W = m.dim(2);
818 AffineExpr F = m.dim(3);
819 AffineExpr h = m.dim(4);
820 AffineExpr w = m.dim(5);
821 AffineExpr c = m.dim(6);
823 if (m.matchStride(1, 1, 1, 0)
824 .matchStride(2, 2, 2, 1)
826 {{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
835std::optional<DilationsAndStrides>
838 if (
auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcQOp>(op.getOperation())) {
840 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
841 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
850 AffineExpr N = m.dim(0);
851 AffineExpr H = m.dim(1);
852 AffineExpr
W = m.dim(2);
853 AffineExpr F = m.dim(3);
854 AffineExpr h = m.dim(4);
855 AffineExpr w = m.dim(5);
856 AffineExpr c = m.dim(6);
858 if (m.matchStride(1, 1, 1, 0)
859 .matchStride(2, 2, 2, 1)
861 {{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
872std::optional<DilationsAndStrides>
875 if (
auto convOp = dyn_cast<linalg::Conv2DNchwFchwOp>(op.getOperation())) {
877 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
878 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
887 AffineExpr N = m.dim(0);
888 AffineExpr F = m.dim(1);
889 AffineExpr H = m.dim(2);
890 AffineExpr
W = m.dim(3);
891 AffineExpr c = m.dim(4);
892 AffineExpr h = m.dim(5);
893 AffineExpr w = m.dim(6);
895 if (m.matchStride(2, 2, 2, 0)
896 .matchStride(3, 3, 3, 1)
898 {{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
907std::optional<DilationsAndStrides>
910 if (
auto convOp = dyn_cast<linalg::Conv2DNchwFchwQOp>(op.getOperation())) {
912 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
913 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
922 AffineExpr N = m.dim(0);
923 AffineExpr F = m.dim(1);
924 AffineExpr H = m.dim(2);
925 AffineExpr
W = m.dim(3);
926 AffineExpr c = m.dim(4);
927 AffineExpr h = m.dim(5);
928 AffineExpr w = m.dim(6);
930 if (m.matchStride(2, 2, 2, 0)
931 .matchStride(3, 3, 3, 1)
933 {{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
944std::optional<DilationsAndStrides>
947 if (
auto convOp = dyn_cast<linalg::Conv2DNgchwFgchwOp>(op.getOperation())) {
949 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
950 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
959 AffineExpr N = m.dim(0);
960 AffineExpr G = m.dim(1);
961 AffineExpr F = m.dim(2);
962 AffineExpr H = m.dim(3);
963 AffineExpr
W = m.dim(4);
964 AffineExpr c = m.dim(5);
965 AffineExpr h = m.dim(6);
966 AffineExpr w = m.dim(7);
968 if (m.matchStride(3, 3, 3, 0)
969 .matchStride(4, 4, 4, 1)
971 {{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
980std::optional<DilationsAndStrides>
983 if (
auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwOp>(op.getOperation())) {
985 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
986 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
995 AffineExpr N = m.dim(0);
996 AffineExpr G = m.dim(1);
997 AffineExpr F = m.dim(2);
998 AffineExpr H = m.dim(3);
999 AffineExpr
W = m.dim(4);
1000 AffineExpr c = m.dim(5);
1001 AffineExpr h = m.dim(6);
1002 AffineExpr w = m.dim(7);
1004 if (m.matchStride(3, 3, 3, 0)
1005 .matchStride(4, 4, 4, 1)
1007 {{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
1012 return std::nullopt;
1016std::optional<DilationsAndStrides>
1019 if (
auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwQOp>(op.getOperation())) {
1021 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1022 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1027 return std::nullopt;
1031 AffineExpr N = m.dim(0);
1032 AffineExpr G = m.dim(1);
1033 AffineExpr F = m.dim(2);
1034 AffineExpr H = m.dim(3);
1035 AffineExpr
W = m.dim(4);
1036 AffineExpr c = m.dim(5);
1037 AffineExpr h = m.dim(6);
1038 AffineExpr w = m.dim(7);
1040 if (m.matchStride(3, 3, 3, 0)
1041 .matchStride(4, 4, 4, 1)
1043 {{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
1050 return std::nullopt;
1054std::optional<DilationsAndStrides>
1057 if (
auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcOp>(op.getOperation())) {
1059 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1060 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1065 return std::nullopt;
1069 AffineExpr N = m.dim(0);
1070 AffineExpr H = m.dim(1);
1071 AffineExpr
W = m.dim(2);
1072 AffineExpr G = m.dim(3);
1073 AffineExpr F = m.dim(4);
1074 AffineExpr h = m.dim(5);
1075 AffineExpr w = m.dim(6);
1076 AffineExpr c = m.dim(7);
1078 if (m.matchStride(1, 2, 1, 0)
1079 .matchStride(2, 3, 2, 1)
1081 {{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1086 return std::nullopt;
1090std::optional<DilationsAndStrides>
1093 if (
auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcQOp>(op.getOperation())) {
1095 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1096 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1101 return std::nullopt;
1105 AffineExpr N = m.dim(0);
1106 AffineExpr H = m.dim(1);
1107 AffineExpr
W = m.dim(2);
1108 AffineExpr G = m.dim(3);
1109 AffineExpr F = m.dim(4);
1110 AffineExpr h = m.dim(5);
1111 AffineExpr w = m.dim(6);
1112 AffineExpr c = m.dim(7);
1114 if (m.matchStride(1, 2, 1, 0)
1115 .matchStride(2, 3, 2, 1)
1117 {{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1124 return std::nullopt;
1128std::optional<DilationsAndStrides>
1131 if (isa<linalg::Conv3DOp>(op)) {
1133 result.dilations = SmallVector<int64_t>(3, 1);
1134 result.strides = SmallVector<int64_t>(3, 1);
1139 return std::nullopt;
1143 AffineExpr D = m.dim(0);
1144 AffineExpr H = m.dim(1);
1145 AffineExpr
W = m.dim(2);
1146 AffineExpr d = m.dim(3);
1147 AffineExpr h = m.dim(4);
1148 AffineExpr w = m.dim(5);
1150 if (m.matchStride(0, 0, 0, 0)
1151 .matchStride(1, 1, 1, 1)
1152 .matchStride(2, 2, 2, 2)
1153 .matchMaps({{m.strided(D, d, 0), m.strided(H, h, 1),
1154 m.strided(W, w, 2)},
1159 return std::nullopt;
1163std::optional<DilationsAndStrides>
1166 if (
auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfOp>(op.getOperation())) {
1168 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1169 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1174 return std::nullopt;
1178 AffineExpr N = m.dim(0);
1179 AffineExpr D = m.dim(1);
1180 AffineExpr H = m.dim(2);
1181 AffineExpr
W = m.dim(3);
1182 AffineExpr F = m.dim(4);
1183 AffineExpr d = m.dim(5);
1184 AffineExpr h = m.dim(6);
1185 AffineExpr w = m.dim(7);
1186 AffineExpr c = m.dim(8);
1188 if (m.matchStride(1, 0, 1, 0)
1189 .matchStride(2, 1, 2, 1)
1190 .matchStride(3, 2, 3, 2)
1191 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
1192 m.strided(W, w, 2), c},
1197 return std::nullopt;
1201std::optional<DilationsAndStrides>
1204 if (
auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfQOp>(op.getOperation())) {
1206 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1207 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1212 return std::nullopt;
1216 AffineExpr N = m.dim(0);
1217 AffineExpr D = m.dim(1);
1218 AffineExpr H = m.dim(2);
1219 AffineExpr
W = m.dim(3);
1220 AffineExpr F = m.dim(4);
1221 AffineExpr d = m.dim(5);
1222 AffineExpr h = m.dim(6);
1223 AffineExpr w = m.dim(7);
1224 AffineExpr c = m.dim(8);
1226 if (m.matchStride(1, 0, 1, 0)
1227 .matchStride(2, 1, 2, 1)
1228 .matchStride(3, 2, 3, 2)
1229 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
1230 m.strided(W, w, 2), c},
1237 return std::nullopt;
1241std::optional<DilationsAndStrides>
1244 if (
auto convOp = dyn_cast<linalg::Conv3DNcdhwFcdhwOp>(op.getOperation())) {
1246 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1247 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1252 return std::nullopt;
1256 AffineExpr N = m.dim(0);
1257 AffineExpr F = m.dim(1);
1258 AffineExpr D = m.dim(2);
1259 AffineExpr H = m.dim(3);
1260 AffineExpr
W = m.dim(4);
1261 AffineExpr c = m.dim(5);
1262 AffineExpr d = m.dim(6);
1263 AffineExpr h = m.dim(7);
1264 AffineExpr w = m.dim(8);
1266 if (m.matchStride(2, 2, 2, 0)
1267 .matchStride(3, 3, 3, 1)
1268 .matchStride(4, 4, 4, 2)
1269 .matchMaps({{N, c, m.strided(D, d, 0),
1270 m.strided(H, h, 1), m.strided(W, w, 2)},
1275 return std::nullopt;
1279std::optional<DilationsAndStrides>
1283 dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
1285 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1286 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1291 return std::nullopt;
1295 AffineExpr N = m.dim(0);
1296 AffineExpr
W = m.dim(1);
1297 AffineExpr
C = m.dim(2);
1298 AffineExpr w = m.dim(3);
1300 if (m.matchStride(2, 1, 2, 0)
1301 .matchMaps({{N, C, m.strided(W, w, 0)},
1306 return std::nullopt;
1310std::optional<DilationsAndStrides>
1314 dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
1316 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1317 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1322 return std::nullopt;
1326 AffineExpr N = m.dim(0);
1327 AffineExpr
W = m.dim(1);
1328 AffineExpr
C = m.dim(2);
1329 AffineExpr w = m.dim(3);
1331 if (m.matchStride(1, 0, 1, 0)
1332 .matchMaps({{N, m.strided(W, w, 0), C},
1337 return std::nullopt;
1341std::optional<DilationsAndStrides>
1345 dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
1347 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1348 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1353 return std::nullopt;
1357 AffineExpr N = m.dim(0);
1358 AffineExpr
W = m.dim(1);
1359 AffineExpr
C = m.dim(2);
1360 AffineExpr CM = m.dim(3);
1361 AffineExpr w = m.dim(4);
1363 if (m.matchStride(1, 0, 1, 0)
1364 .matchMaps({{N, m.strided(W, w, 0), C},
1369 return std::nullopt;
1373std::optional<DilationsAndStrides>
1377 dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
1379 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1380 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1385 return std::nullopt;
1389 AffineExpr N = m.dim(0);
1390 AffineExpr H = m.dim(1);
1391 AffineExpr
W = m.dim(2);
1392 AffineExpr
C = m.dim(3);
1393 AffineExpr h = m.dim(4);
1394 AffineExpr w = m.dim(5);
1396 if (m.matchStride(2, 1, 2, 0)
1397 .matchStride(3, 2, 3, 1)
1399 {{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1404 return std::nullopt;
1408std::optional<DilationsAndStrides>
1412 dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
1414 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1415 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1420 return std::nullopt;
1424 AffineExpr N = m.dim(0);
1425 AffineExpr H = m.dim(1);
1426 AffineExpr
W = m.dim(2);
1427 AffineExpr
C = m.dim(3);
1428 AffineExpr h = m.dim(4);
1429 AffineExpr w = m.dim(5);
1431 if (m.matchStride(1, 0, 1, 0)
1432 .matchStride(2, 1, 2, 1)
1434 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1439 return std::nullopt;
1443std::optional<DilationsAndStrides>
1447 dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
1449 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1450 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1455 return std::nullopt;
1459 AffineExpr N = m.dim(0);
1460 AffineExpr H = m.dim(1);
1461 AffineExpr
W = m.dim(2);
1462 AffineExpr
C = m.dim(3);
1463 AffineExpr h = m.dim(4);
1464 AffineExpr w = m.dim(5);
1466 if (m.matchStride(1, 0, 1, 0)
1467 .matchStride(2, 1, 2, 1)
1469 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1476 return std::nullopt;
1480std::optional<DilationsAndStrides>
1484 dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
1486 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1487 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1492 return std::nullopt;
1496 AffineExpr N = m.dim(0);
1497 AffineExpr H = m.dim(1);
1498 AffineExpr
W = m.dim(2);
1499 AffineExpr
C = m.dim(3);
1500 AffineExpr CM = m.dim(4);
1501 AffineExpr h = m.dim(5);
1502 AffineExpr w = m.dim(6);
1504 if (m.matchStride(1, 0, 1, 0)
1505 .matchStride(2, 1, 2, 1)
1507 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1512 return std::nullopt;
1516std::optional<DilationsAndStrides>
1520 dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
1522 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1523 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1528 return std::nullopt;
1532 AffineExpr N = m.dim(0);
1533 AffineExpr H = m.dim(1);
1534 AffineExpr
W = m.dim(2);
1535 AffineExpr
C = m.dim(3);
1536 AffineExpr CM = m.dim(4);
1537 AffineExpr h = m.dim(5);
1538 AffineExpr w = m.dim(6);
1540 if (m.matchStride(1, 0, 1, 0)
1541 .matchStride(2, 1, 2, 1)
1543 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1550 return std::nullopt;
1554std::optional<DilationsAndStrides>
1558 dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
1560 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1561 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1566 return std::nullopt;
1570 AffineExpr N = m.dim(0);
1571 AffineExpr D = m.dim(1);
1572 AffineExpr H = m.dim(2);
1573 AffineExpr
W = m.dim(3);
1574 AffineExpr d = m.dim(4);
1575 AffineExpr h = m.dim(5);
1576 AffineExpr w = m.dim(6);
1577 AffineExpr
C = m.dim(7);
1579 if (m.matchStride(1, 0, 1, 0)
1580 .matchStride(2, 1, 2, 1)
1581 .matchStride(3, 2, 3, 2)
1582 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
1583 m.strided(W, w, 2), C},
1588 return std::nullopt;
1592std::optional<DilationsAndStrides>
1596 dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
1598 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1599 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1604 return std::nullopt;
1608 AffineExpr N = m.dim(0);
1609 AffineExpr D = m.dim(1);
1610 AffineExpr H = m.dim(2);
1611 AffineExpr
W = m.dim(3);
1612 AffineExpr d = m.dim(4);
1613 AffineExpr h = m.dim(5);
1614 AffineExpr w = m.dim(6);
1615 AffineExpr
C = m.dim(7);
1617 if (m.matchStride(2, 1, 2, 0)
1618 .matchStride(3, 2, 3, 1)
1619 .matchStride(4, 3, 4, 2)
1620 .matchMaps({{N, C, m.strided(D, d, 0),
1621 m.strided(H, h, 1), m.strided(W, w, 2)},
1626 return std::nullopt;
1630std::optional<DilationsAndStrides>
1634 dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
1636 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1637 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1642 return std::nullopt;
1646 AffineExpr N = m.dim(0);
1647 AffineExpr D = m.dim(1);
1648 AffineExpr H = m.dim(2);
1649 AffineExpr
W = m.dim(3);
1650 AffineExpr CM = m.dim(4);
1651 AffineExpr d = m.dim(5);
1652 AffineExpr h = m.dim(6);
1653 AffineExpr w = m.dim(7);
1654 AffineExpr
C = m.dim(8);
1656 if (m.matchStride(1, 0, 1, 0)
1657 .matchStride(2, 1, 2, 1)
1658 .matchStride(3, 2, 3, 2)
1659 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
1660 m.strided(W, w, 2), C},
1662 {N, D, H, W, C, CM}})
1665 return std::nullopt;
1669std::optional<DilationsAndStrides>
1672 if (
auto poolOp = dyn_cast<linalg::PoolingNhwcMaxOp>(op.getOperation())) {
1674 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1675 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1680 return std::nullopt;
1684 AffineExpr N = m.dim(0);
1685 AffineExpr H = m.dim(1);
1686 AffineExpr
W = m.dim(2);
1687 AffineExpr
C = m.dim(3);
1688 AffineExpr h = m.dim(4);
1689 AffineExpr w = m.dim(5);
1691 if (m.matchStride(1, 0, 1, 0)
1692 .matchStride(2, 1, 2, 1)
1694 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1699 return std::nullopt;
1703std::optional<DilationsAndStrides>
1706 if (
auto poolOp = dyn_cast<linalg::PoolingNhwcMinOp>(op.getOperation())) {
1708 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1709 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1714 return std::nullopt;
1718 AffineExpr N = m.dim(0);
1719 AffineExpr H = m.dim(1);
1720 AffineExpr
W = m.dim(2);
1721 AffineExpr
C = m.dim(3);
1722 AffineExpr h = m.dim(4);
1723 AffineExpr w = m.dim(5);
1725 if (m.matchStride(1, 0, 1, 0)
1726 .matchStride(2, 1, 2, 1)
1728 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1733 return std::nullopt;
1737std::optional<DilationsAndStrides>
1740 if (
auto poolOp = dyn_cast<linalg::PoolingNhwcSumOp>(op.getOperation())) {
1742 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1743 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1748 return std::nullopt;
1752 AffineExpr N = m.dim(0);
1753 AffineExpr H = m.dim(1);
1754 AffineExpr
W = m.dim(2);
1755 AffineExpr
C = m.dim(3);
1756 AffineExpr h = m.dim(4);
1757 AffineExpr w = m.dim(5);
1759 if (m.matchStride(1, 0, 1, 0)
1760 .matchStride(2, 1, 2, 1)
1762 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1767 return std::nullopt;
1771std::optional<DilationsAndStrides>
1775 dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
1777 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1778 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1783 return std::nullopt;
1787 AffineExpr N = m.dim(0);
1788 AffineExpr H = m.dim(1);
1789 AffineExpr
W = m.dim(2);
1790 AffineExpr
C = m.dim(3);
1791 AffineExpr h = m.dim(4);
1792 AffineExpr w = m.dim(5);
1794 if (m.matchStride(1, 0, 1, 0)
1795 .matchStride(2, 1, 2, 1)
1797 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1802 return std::nullopt;
1806std::optional<DilationsAndStrides>
1810 dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
1812 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1813 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1818 return std::nullopt;
1822 AffineExpr N = m.dim(0);
1823 AffineExpr H = m.dim(1);
1824 AffineExpr
W = m.dim(2);
1825 AffineExpr
C = m.dim(3);
1826 AffineExpr h = m.dim(4);
1827 AffineExpr w = m.dim(5);
1829 if (m.matchStride(1, 0, 1, 0)
1830 .matchStride(2, 1, 2, 1)
1832 {{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1837 return std::nullopt;
1841std::optional<DilationsAndStrides>
1844 if (
auto poolOp = dyn_cast<linalg::PoolingNchwSumOp>(op.getOperation())) {
1846 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1847 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1852 return std::nullopt;
1856 AffineExpr N = m.dim(0);
1857 AffineExpr
C = m.dim(1);
1858 AffineExpr H = m.dim(2);
1859 AffineExpr
W = m.dim(3);
1860 AffineExpr h = m.dim(4);
1861 AffineExpr w = m.dim(5);
1863 if (m.matchStride(2, 0, 2, 0)
1864 .matchStride(3, 1, 3, 1)
1866 {{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1871 return std::nullopt;
1875std::optional<DilationsAndStrides>
1878 if (
auto poolOp = dyn_cast<linalg::PoolingNchwMaxOp>(op.getOperation())) {
1880 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1881 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1886 return std::nullopt;
1890 AffineExpr N = m.dim(0);
1891 AffineExpr
C = m.dim(1);
1892 AffineExpr H = m.dim(2);
1893 AffineExpr
W = m.dim(3);
1894 AffineExpr h = m.dim(4);
1895 AffineExpr w = m.dim(5);
1897 if (m.matchStride(2, 0, 2, 0)
1898 .matchStride(3, 1, 3, 1)
1900 {{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1905 return std::nullopt;
1909std::optional<DilationsAndStrides>
1912 if (
auto poolOp = dyn_cast<linalg::PoolingNwcSumOp>(op.getOperation())) {
1914 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1915 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1920 return std::nullopt;
1924 AffineExpr N = m.dim(0);
1925 AffineExpr
W = m.dim(1);
1926 AffineExpr
C = m.dim(2);
1927 AffineExpr w = m.dim(3);
1929 if (m.matchStride(1, 0, 1, 0)
1930 .matchMaps({{N, m.strided(W, w, 0), C},
1935 return std::nullopt;
1939std::optional<DilationsAndStrides>
1942 if (
auto poolOp = dyn_cast<linalg::PoolingNcwSumOp>(op.getOperation())) {
1944 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1945 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1950 return std::nullopt;
1954 AffineExpr N = m.dim(0);
1955 AffineExpr
C = m.dim(1);
1956 AffineExpr
W = m.dim(2);
1957 AffineExpr w = m.dim(3);
1959 if (m.matchStride(2, 0, 2, 0)
1960 .matchMaps({{N, C, m.strided(W, w, 0)},
1965 return std::nullopt;
1969std::optional<DilationsAndStrides>
1972 if (
auto poolOp = dyn_cast<linalg::PoolingNwcMaxOp>(op.getOperation())) {
1974 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1975 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1980 return std::nullopt;
1984 AffineExpr N = m.dim(0);
1985 AffineExpr
W = m.dim(1);
1986 AffineExpr
C = m.dim(2);
1987 AffineExpr w = m.dim(3);
1989 if (m.matchStride(1, 0, 1, 0)
1990 .matchMaps({{N, m.strided(W, w, 0), C},
1995 return std::nullopt;
1999std::optional<DilationsAndStrides>
2003 dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
2005 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2006 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2011 return std::nullopt;
2015 AffineExpr N = m.dim(0);
2016 AffineExpr
W = m.dim(1);
2017 AffineExpr
C = m.dim(2);
2018 AffineExpr w = m.dim(3);
2020 if (m.matchStride(1, 0, 1, 0)
2021 .matchMaps({{N, m.strided(W, w, 0), C},
2026 return std::nullopt;
2030std::optional<DilationsAndStrides>
2033 if (
auto poolOp = dyn_cast<linalg::PoolingNcwMaxOp>(op.getOperation())) {
2035 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2036 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2041 return std::nullopt;
2045 AffineExpr N = m.dim(0);
2046 AffineExpr
C = m.dim(1);
2047 AffineExpr
W = m.dim(2);
2048 AffineExpr w = m.dim(3);
2050 if (m.matchStride(2, 0, 2, 0)
2051 .matchMaps({{N, C, m.strided(W, w, 0)},
2056 return std::nullopt;
2060std::optional<DilationsAndStrides>
2063 if (
auto poolOp = dyn_cast<linalg::PoolingNwcMinOp>(op.getOperation())) {
2065 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2066 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2071 return std::nullopt;
2075 AffineExpr N = m.dim(0);
2076 AffineExpr
W = m.dim(1);
2077 AffineExpr
C = m.dim(2);
2078 AffineExpr w = m.dim(3);
2080 if (m.matchStride(1, 0, 1, 0)
2081 .matchMaps({{N, m.strided(W, w, 0), C},
2086 return std::nullopt;
2090std::optional<DilationsAndStrides>
2094 dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
2096 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2097 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2102 return std::nullopt;
2106 AffineExpr N = m.dim(0);
2107 AffineExpr
W = m.dim(1);
2108 AffineExpr
C = m.dim(2);
2109 AffineExpr w = m.dim(3);
2111 if (m.matchStride(1, 0, 1, 0)
2112 .matchMaps({{N, m.strided(W, w, 0), C},
2117 return std::nullopt;
2121std::optional<DilationsAndStrides>
2124 if (
auto poolOp = dyn_cast<linalg::PoolingNdhwcSumOp>(op.getOperation())) {
2126 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2127 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2132 return std::nullopt;
2136 AffineExpr N = m.dim(0);
2137 AffineExpr D = m.dim(1);
2138 AffineExpr H = m.dim(2);
2139 AffineExpr
W = m.dim(3);
2140 AffineExpr
C = m.dim(4);
2141 AffineExpr d = m.dim(5);
2142 AffineExpr h = m.dim(6);
2143 AffineExpr w = m.dim(7);
2145 if (m.matchStride(1, 0, 1, 0)
2146 .matchStride(2, 1, 2, 1)
2147 .matchStride(3, 2, 3, 2)
2148 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
2149 m.strided(W, w, 2), C},
2154 return std::nullopt;
2158std::optional<DilationsAndStrides>
2161 if (
auto poolOp = dyn_cast<linalg::PoolingNdhwcMaxOp>(op.getOperation())) {
2163 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2164 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2169 return std::nullopt;
2173 AffineExpr N = m.dim(0);
2174 AffineExpr D = m.dim(1);
2175 AffineExpr H = m.dim(2);
2176 AffineExpr
W = m.dim(3);
2177 AffineExpr
C = m.dim(4);
2178 AffineExpr d = m.dim(5);
2179 AffineExpr h = m.dim(6);
2180 AffineExpr w = m.dim(7);
2182 if (m.matchStride(1, 0, 1, 0)
2183 .matchStride(2, 1, 2, 1)
2184 .matchStride(3, 2, 3, 2)
2185 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
2186 m.strided(W, w, 2), C},
2191 return std::nullopt;
2195std::optional<DilationsAndStrides>
2198 if (
auto poolOp = dyn_cast<linalg::PoolingNdhwcMinOp>(op.getOperation())) {
2200 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2201 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2206 return std::nullopt;
2210 AffineExpr N = m.dim(0);
2211 AffineExpr D = m.dim(1);
2212 AffineExpr H = m.dim(2);
2213 AffineExpr
W = m.dim(3);
2214 AffineExpr
C = m.dim(4);
2215 AffineExpr d = m.dim(5);
2216 AffineExpr h = m.dim(6);
2217 AffineExpr w = m.dim(7);
2219 if (m.matchStride(1, 0, 1, 0)
2220 .matchStride(2, 1, 2, 1)
2221 .matchStride(3, 2, 3, 2)
2222 .matchMaps({{N, m.strided(D, d, 0), m.strided(H, h, 1),
2223 m.strided(W, w, 2), C},
2228 return std::nullopt;
2235 auto sliceOp = source.
getDefiningOp<tensor::ExtractSliceOp>();
2241 Value current = sliceOp.getSource();
2246 OpResult opResult = cast<OpResult>(current);
2247 current = linalgOp.getDpsInitOperand(opResult.
getResultNumber())->get();
2249 auto padOp = current ? current.
getDefiningOp<tensor::PadOp>() :
nullptr;
2258 if (sliceOp.getSource().getType() != type)
2263 if (llvm::any_of(padOp.getMixedLowPad(), [](
OpFoldResult ofr) {
2264 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
2271 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
2272 if (!padOpSliceOp ||
2273 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
2280 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
2281 [](std::tuple<OpFoldResult, OpFoldResult> it) {
2282 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
2289 Value padOpPad = padOp.getConstantPaddingValue();
2296 return sliceOp.getSource();
2300 auto memrefTypeTo = cast<MemRefType>(to.
getType());
2302 auto memrefTypeFrom = cast<MemRefType>(from.
getType());
2303 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
2304 "`from` and `to` memref must have the same rank");
2310 utils::IteratorType::parallel);
2311 return linalg::GenericOp::create(
2318 linalg::YieldOp::create(
b, loc, args.front());
2331 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2332 "expected as many entries for proc info as number of loops, even if "
2333 "they are null entries");
2335 if (!linalgOp.hasPureBufferSemantics())
2336 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2340 b, loc, lbs, ubs, steps, iterArgInitValues,
2342 assert(iterArgs.size() == iterArgInitValues.size() &&
2343 "expect the number of output tensors and iter args to match");
2345 if (!iterArgs.empty()) {
2346 operandValuesToUse = linalgOp.getDpsInputs();
2347 operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
2349 return bodyBuilderFn(
b, loc, ivs, operandValuesToUse);
2352 if (loopNest.
loops.empty() || procInfo.empty())
2356 for (
const auto &loop : llvm::enumerate(loopNest.
loops)) {
2357 if (procInfo[loop.index()].distributionMethod ==
2359 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
2360 procInfo[loop.index()].nprocs);
2375 if (!linalgOp.hasPureBufferSemantics())
2376 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2377 assert(iterArgInitValues.empty() &&
"unexpected AffineForOp init values");
2383 constantSteps.reserve(steps.size());
2384 for (
Value v : steps) {
2386 assert(constVal.has_value() &&
"Affine loops require constant steps");
2387 constantSteps.push_back(constVal.value());
2392 bodyBuilderFn(
b, loc, ivs,
2393 linalgOp->getOperands());
2425 assert(lbs.size() == ubs.size());
2426 assert(lbs.size() == steps.size());
2427 assert(lbs.size() == iteratorTypes.size());
2428 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
2432 if (iteratorTypes.empty()) {
2433 bodyBuilderFn(
b, loc, ivStorage);
2441 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
2443 ivStorage.append(ivs.begin(), ivs.end());
2444 generateParallelLoopNest(
2445 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
2446 iteratorTypes.drop_front(),
2447 procInfo.empty() ? procInfo : procInfo.drop_front(),
2448 bodyBuilderFn, ivStorage);
2453 unsigned nLoops = iteratorTypes.size();
2454 unsigned numProcessed = 0;
2456 if (procInfo.empty()) {
2459 distributionMethod = procInfo.front().distributionMethod;
2468 auto remainderProcInfo =
2469 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
2470 switch (distributionMethod) {
2474 scf::ParallelOp::create(
2475 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2476 steps.take_front(numProcessed),
2478 ivStorage.append(localIvs.begin(), localIvs.end());
2479 generateParallelLoopNest(
2480 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2481 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2482 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2483 bodyBuilderFn, ivStorage);
2490 scf::ParallelOp::create(
2491 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2492 steps.take_front(numProcessed),
2494 ivStorage.append(localIvs.begin(), localIvs.end());
2495 generateParallelLoopNest(
2496 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2497 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2498 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2499 bodyBuilderFn, ivStorage);
2506 Value cond = ab.
slt(lbs[0], ubs[0]);
2507 for (
unsigned i = 1; i < numProcessed; ++i)
2508 cond = ab.
_and(cond, ab.
slt(lbs[i], ubs[i]));
2509 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2512 ubs.drop_front(numProcessed),
2513 steps.drop_front(numProcessed),
2514 iteratorTypes.drop_front(numProcessed),
2515 remainderProcInfo, bodyBuilderFn, ivStorage);
2523 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2525 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
2526 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
2527 remainderProcInfo, bodyBuilderFn, ivStorage);
2542 if (!linalgOp.hasPureBufferSemantics())
2543 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2544 assert(iterArgInitValues.empty() &&
"unexpected ParallelOp init values");
2546 assert(iteratorTypes.size() >= loopRanges.size() &&
2547 "expected iterator type for all ranges");
2548 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2549 "expected proc information for all loops when present");
2550 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
2552 unsigned numLoops = iteratorTypes.size();
2553 ivs.reserve(numLoops);
2554 lbsStorage.reserve(numLoops);
2555 ubsStorage.reserve(numLoops);
2556 stepsStorage.reserve(numLoops);
2559 unpackRanges(
b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
2562 for (
const auto &it : llvm::enumerate(procInfo)) {
2565 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
2566 ubsStorage[it.index()], stepsStorage[it.index()]);
2569 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
2571 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
2573 bodyBuilderFn(
b, loc, ivs, linalgOp->getOperands());
2577 assert(ivs.size() == iteratorTypes.size() &&
"did not generate enough loops");
2583 auto shapedType = dyn_cast<ShapedType>(valueToTile.
getType());
2585 .Case([&](MemRefType) {
2586 return memref::SubViewOp::create(
2587 builder, loc, valueToTile, sliceParams.
offsets,
2590 .Case([&](RankedTensorType) {
2591 return tensor::ExtractSliceOp::create(
2592 builder, loc, valueToTile, sliceParams.
offsets,
2595 .DefaultUnreachable(
"Unexpected shaped type");
2604 bool omitPartialTileCheck) {
2607 ubs, subShapeSizes, omitPartialTileCheck);
2616 bool omitPartialTileCheck) {
2617 auto shapedType = dyn_cast<ShapedType>(valueToTile.
getType());
2618 assert(shapedType &&
"only shaped types can be tiled");
2620 int64_t rank = shapedType.getRank();
2624 sliceParams.
offsets.reserve(rank);
2625 sliceParams.
sizes.reserve(rank);
2626 sliceParams.
strides.reserve(rank);
2627 for (
unsigned r = 0; r < rank; ++r) {
2628 LLVM_DEBUG(llvm::dbgs() <<
"computeSliceParameters: for dim#" << r);
2632 sliceParams.
sizes.push_back(dim);
2634 LLVM_DEBUG(llvm::dbgs() <<
": not tiled: use size: " << dim <<
"\n");
2637 LLVM_DEBUG(llvm::dbgs() <<
": tiled: figure out subsize...\n");
2642 LLVM_DEBUG(llvm::dbgs() <<
"computeSliceParameters: submap: " << m <<
"\n");
2647 [[maybe_unused]]
auto res = m.constantFold(zeros, mAtZero);
2648 assert(succeeded(res) &&
"affine_map must be evaluatable (not symbols)");
2650 cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
2652 rewriter, loc, m.getResult(0) - mAtZeroInt, lbs);
2653 sliceParams.
offsets.push_back(offset);
2661 LLVM_DEBUG(llvm::dbgs()
2662 <<
"computeSliceParameters: raw size: " << size <<
"\n");
2663 LLVM_DEBUG(llvm::dbgs()
2664 <<
"computeSliceParameters: new offset: " << offset <<
"\n");
2667 if (omitPartialTileCheck) {
2670 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: new size: " << size <<
"\n");
2671 sliceParams.
sizes.push_back(size);
2682 auto hasTileSizeOne = sizeCst == 1;
2683 auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
2684 ((shapeSize % *sizeCst) == 0);
2685 if (!hasTileSizeOne && !dividesEvenly) {
2686 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: shapeSize=" << shapeSize
2687 <<
", size: " << size
2688 <<
": make sure in bound with affine.min\n");
2692 bindDims(context, dim0, dim1, dim2);
2723 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: new size: " << size <<
"\n");
2724 sliceParams.
sizes.push_back(size);
2733 for (
unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
2734 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShapes: for loop#" << idx <<
"\n");
2736 offsets.push_back(
isTiled ? ivs[idxIvs++] :
b.getIndexAttr(0));
2737 LLVM_DEBUG(llvm::dbgs()
2738 <<
"computeTileOffsets: " << offsets.back() <<
"\n");
2747 for (
unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
2754 LLVM_DEBUG(llvm::dbgs() <<
"computeTileSizes: " << sizes.back() <<
"\n");
2760 if (op.hasPureBufferSemantics())
2762 return llvm::to_vector(
2763 llvm::map_range(op.getDpsInitsMutable(), [&](
OpOperand &opOperand) {
2764 return operands[opOperand.getOperandNumber()].getType();
2771 if (op.hasPureBufferSemantics())
2774 tensorResults.reserve(results.size());
2776 unsigned resultIdx = 0;
2777 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2780 Value outputTensor = operands[opOperand.getOperandNumber()];
2781 if (
auto sliceOp = outputTensor.
getDefiningOp<tensor::ExtractSliceOp>()) {
2783 builder, loc, sliceOp.getSource().getType(), results[resultIdx],
2784 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
2785 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2786 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2789 tensorResults.push_back(results[resultIdx]);
2793 return tensorResults;
2801 bool omitPartialTileCheck) {
2802 assert(ivs.size() ==
static_cast<size_t>(llvm::count_if(
2803 llvm::make_range(tileSizes.begin(), tileSizes.end()),
2805 "expected as many ivs as non-zero sizes");
2814 assert(
static_cast<int64_t>(valuesToTile.size()) <=
2815 linalgOp->getNumOperands() &&
2816 "more value to tile than operands.");
2818 allSliceParams.reserve(valuesToTile.size());
2819 for (
auto [opOperand, val] :
2820 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
2821 Value shapedOp = val;
2822 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShapes: for operand " << shapedOp);
2823 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
2830 Type operandType = opOperand.get().getType();
2831 if (!
isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
2832 linalgOp.isDpsInit(&opOperand))) {
2833 allSliceParams.push_back(std::nullopt);
2834 LLVM_DEBUG(llvm::dbgs()
2835 <<
": not tiled: use shape: " << operandType <<
"\n");
2838 LLVM_DEBUG(llvm::dbgs() <<
": tiled: figure out subshape...\n");
2841 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
2842 omitPartialTileCheck));
2845 return allSliceParams;
2853 bool omitPartialTileCheck) {
2856 tileSizes, sizeBounds, omitPartialTileCheck);
2858 for (
auto item : llvm::zip(valuesToTile, allSliceParameter)) {
2859 Value valueToTile = std::get<0>(item);
2860 std::optional<SliceParameters> sliceParams = std::get<1>(item);
2861 tiledShapes.push_back(
2862 sliceParams.has_value()
2878 if (!linalgOp.hasIndexSemantics())
2881 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
2882 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
2885 b.setInsertionPointAfter(indexOp);
2889 b, indexOp.getLoc(),
index + offset,
2890 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
2891 Value materialized =
2893 b.replaceUsesWithIf(indexOp, materialized, [&](
OpOperand &use) {
2905std::optional<SmallVector<ReassociationIndices>>
2909 for (
const auto &it : llvm::enumerate(mixedSizes)) {
2910 auto dim = it.index();
2911 auto size = it.value();
2912 curr.push_back(dim);
2913 auto attr = llvm::dyn_cast_if_present<Attribute>(size);
2914 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
2917 std::swap(reassociation.back(), curr);
2922 if (!curr.empty() && !reassociation.empty())
2923 reassociation.back().append(curr.begin(), curr.end());
2924 return reassociation;
static SmallVector< int64_t > computePackUnPackPerm(int64_t rank, ArrayRef< int64_t > &innerDimsPos, ArrayRef< int64_t > &outerPerm, PackingMetadata &packingMetadata)
The permutation can be obtained from two permutations: a) Compute the permutation vector to move the ...
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
static void unpackRanges(OpBuilder &builder, Location loc, ArrayRef< Range > ranges, SmallVectorImpl< Value > &lbs, SmallVectorImpl< Value > &ubs, SmallVectorImpl< Value > &steps)
Given a list of subview ranges, extract individual values for lower, upper bounds and steps and put t...
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessIntOrFloat() const
Return true of this is a signless integer or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
Helper class for building convolution op matchers with minimal boilerplate.
ConvMatcherBuilder & matchStride(unsigned iDim, unsigned fDim, unsigned oDim, unsigned idx)
Match stride/dilation pattern for a spatial dimension.
bool matchBody(bool containsZeroPointOffset=false)
Match body pattern. This should be called last.
AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx)
Build strided expression: base * stride[idx] + kernel * dilation[idx].
AffineExpr dim(unsigned i)
Get affine dimension expression for dimension i.
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector< int64_t > *d, SmallVector< int64_t > *s, PoolingType poolingType=PoolingType::None)
ConvMatcherBuilder & matchMaps(ArrayRef< ArrayRef< AffineExpr > > maps)
Match expected indexing maps layout. Returns *this for method chaining.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNchwChwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcmOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNcwMaxOp >(LinalgOp op)
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMinOp >(LinalgOp op)
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body, bool containsZeroPointOffset=false)
Utility to match block body for convolution ops.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcmOp >(LinalgOp op)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DNcwFcwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNchwSumOp >(LinalgOp op)
SmallVector< OpFoldResult > computeTileSizes(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds)
Computes tile sizes, given a list of tileSizes and dimension sizes (sizeBounds).
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMaxUnsignedOp >(LinalgOp op)
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to)
Returns GenericOp that copies an n-D memref.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcOp >(LinalgOp op)
static void generateParallelLoopNest(OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef< utils::IteratorType > iteratorTypes, ArrayRef< linalg::ProcInfo > procInfo, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, SmallVectorImpl< Value > &ivStorage)
Generates a loop nest consisting of scf.parallel and scf.for, depending on the iteratorTypes.
SmallVector< OpFoldResult > computeTileOffsets(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes)
Computes tile offsets, given a list of loop ivs and tileSizes.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwGfchwQOp >(LinalgOp op)
PoolingType
Enum representing pooling operation types used by ConvMatcherBuilder.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNcwCwOp >(LinalgOp op)
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body)
static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp, Block *body)
Utility function to match the zero point offset body of quantized convolution ops.
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcOp >(LinalgOp op)
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
bool hasOnlyScalarElementwiseOp(Region &r)
Detect whether r has only ConstantOp, ElementwiseMappable and YieldOp.
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex)
static BlockArgument getBlockArgumentWithOptionalCastOps(Value val)
Returns the BlockArgument that leads to val, if any.
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body)
Utility to match block body for linalg.pool* ops.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNchwFchwQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DNwcWcfOp >(LinalgOp op)
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMinUnsignedOp >(LinalgOp op)
DistributionMethod
Scheme used to distribute loops to processors.
@ CyclicNumProcsGeNumIters
Cyclic distribution where the number of processors can be assumed to be more than or equal to the num...
@ Cyclic
Cyclic distribution where no assumption is made about the dynamic relationship between number of proc...
@ CyclicNumProcsEqNumIters
Cyclic distribution where the number of processors can be assumed to be equal to the number of iterat...
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body)
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNcdhwCdhwOp >(LinalgOp op)
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMaxUnsignedOp >(LinalgOp op)
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body)
Matches sum pooling body pattern.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcHwcfOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNcdhwFcdhwOp >(LinalgOp op)
SmallVector< std::optional< SliceParameters > > computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Computes SliceParamaters for all valuesToTile of the given linalgOp, assuming linalgOp is being fused...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcHwcfQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNchwFchwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcFhwcQOp >(LinalgOp op)
Operation * makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Creates an extract_slice/subview op for a single valueToTile with builder.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNchwMaxOp >(LinalgOp op)
static bool convLayoutMatches(ArrayRef< ArrayRef< AffineExpr > > mapListExpected, ArrayAttr indexingMaps, MLIRContext *context)
Returns true if the given indexing maps matches with the expected indexing maps.
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body)
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride)
Given an array of AffineMaps indexingMaps verify the following commutatively:- indexingMaps[0]....
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcSumOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMinOp >(LinalgOp op)
static Operation * materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwFgchwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwGfchwOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMaxOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcMinOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMinUnsignedOp >(LinalgOp op)
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcSumOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DOp >(LinalgOp op)
static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim)
Check if expr is either:
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcSumOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcQOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcFhwcOp >(LinalgOp op)
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMaxOp >(LinalgOp op)
void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, Value procId, Value nprocs, Value &lb, Value &ub, Value &step)
Update the lb, ub and step to get per processor lb, ub and step.
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNcwSumOp >(LinalgOp op)
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcMaxOp >(LinalgOp op)
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
@ Mul
RHS of mul is always a constant or a symbolic expression.
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Helper struct to build simple arithmetic quantities with minimal type inference support.
Value _and(Value lhs, Value rhs)
Value slt(Value lhs, Value rhs)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A struct containing dilations and strides inferred from convolution ops.
Utility class used to generate nested loops with ranges described by loopRanges and loop type describ...
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
Callback function type used to get processor ID, and number of processors used for distribution for a...
DistributionMethod distributionMethod
static std::optional< BinaryOpKind > matchAsScalarBinaryOp(GenericOp op)
Matches the given linalg op if its body is performing binary operation on int or float scalar values ...
A struct containg offsets-sizes-strides arguments of the tiled shape.
SmallVector< OpFoldResult > strides
SmallVector< OpFoldResult > sizes
SmallVector< OpFoldResult > offsets