33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
37#define DEBUG_TYPE "linalg-utils"
65 assert(cast<AffineConstantExpr>(expr.
getRHS()).getValue() > 0 &&
66 "nonpositive multiplying coefficient");
77 TileCheck t(tileSizes);
92std::optional<RegionMatcher::BinaryOpKind>
94 auto ®ion = op.getRegion();
95 if (!region.hasOneBlock())
113 if (addPattern.match(&ops.back()))
130 for (
Range range : ranges) {
149static SmallVector<int64_t>
152 PackingMetadata &packingMetadata) {
153 int64_t numPackedDims = innerDimsPos.size();
155 llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
156 packingMetadata = computePackingMetadata(rank, innerDimsPos);
161 if (!outerPerm.empty())
168 return packInverseDestPermutation;
175 PackingMetadata &metadata) {
177 int64_t packedRank = packOp.getDestType().getRank();
182 return packInvDestPerm;
186 PackingMetadata &metadata) {
187 int64_t packedRank = unpackOp.getSourceType().getRank();
192 return unpackInvSrcPerm;
196 return llvm::all_of(op.getIndexingMapsArray(), [](
AffineMap m) {
197 return m.isProjectedPermutation(true);
205 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
206 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
208 llvm::any_of(op.getResultTypes(),
209 [](
Type type) { return !type.isIntOrIndexOrFloat(); }))
216 if (op.getNumLoops() != op.getNumParallelLoops())
223 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
224 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
231 return iteratorType == utils::IteratorType::parallel;
235 return iteratorType == utils::IteratorType::reduction;
250 if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
251 !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
252 !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
255 return dyn_cast<BlockArgument>(defOp->
getOperand(0));
265 if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
269 if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
278 if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
287template <
typename... OpTypes>
290 if (!(isa_and_present<OpTypes>(defOp) || ...))
297 if (!lhsArg || !rhsArg || lhsArg.
getOwner() != body ||
334 auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
335 if (dimIndex < affineMap.getNumResults())
336 return affineMap.getResult(dimIndex);
346 if ((dim = dyn_cast<AffineDimExpr>(expr)))
349 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
357 if (((dim = dyn_cast<AffineDimExpr>(
lhs)) &&
358 (cst = dyn_cast<AffineConstantExpr>(
rhs))) ||
359 ((dim = dyn_cast<AffineDimExpr>(
rhs)) &&
360 (cst = dyn_cast<AffineConstantExpr>(
lhs))))
388 unsigned fDim,
unsigned oDim,
390 unsigned inputMapIdx = 0, filterMapIdx = 1,
391 outputMapIdx = indexingMaps.size() - 1;
393 auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
401 if (c0 == -1 || c1 == -1)
406 if (dim0 == fExpr && dim1 == oExpr) {
411 if (dim1 == fExpr && dim0 == oExpr) {
429 return indexingMaps ==
431 context, llvm::to_vector<4>(llvm::map_range(
433 return AffineMapAttr::get(m);
444 if (isa<linalg::Conv1DOp>(op))
448 "expected op to implement ConvolutionOpInterface");
455 ArrayAttr indexingMaps = op.getIndexingMaps();
459 0, (*dilations)[0], (*strides)[0]))
463 {{
W * (*strides)[0] + w * (*dilations)[0]},
466 indexingMaps, context))
469 Block *body = op.getBlock();
471 Value yieldVal = yieldOp.getOperand(0);
482 if (isa<linalg::Conv1DNwcWcfOp>(op))
486 "expected op to implement ConvolutionOpInterface");
496 ArrayAttr indexingMaps = op.getIndexingMaps();
500 1, (*dilations)[0], (*strides)[0]))
504 {{N,
W * (*strides)[0] + w * (*dilations)[0], c},
507 indexingMaps, context))
510 Block *body = op.getBlock();
512 Value yieldVal = yieldOp.getOperand(0);
523 if (isa<linalg::Conv1DNcwFcwOp>(op))
527 "expected op to implement ConvolutionOpInterface");
537 ArrayAttr indexingMaps = op.getIndexingMaps();
541 2, (*dilations)[0], (*strides)[0]))
545 {{N, c,
W * (*strides)[0] + w * (*dilations)[0]},
548 indexingMaps, context))
551 Block *body = op.getBlock();
553 Value yieldVal = yieldOp.getOperand(0);
564 if (isa<linalg::Conv2DOp>(op))
568 "expected op to implement ConvolutionOpInterface");
577 ArrayAttr indexingMaps = op.getIndexingMaps();
581 0, (*dilations)[0], (*strides)[0]))
585 1, (*dilations)[1], (*strides)[1]))
589 {{H * (*strides)[0] + h * (*dilations)[0],
590 W * (*strides)[1] + w * (*dilations)[1]},
593 indexingMaps, context))
596 Block *body = op.getBlock();
598 Value yieldVal = yieldOp.getOperand(0);
609 if (isa<linalg::Conv3DOp>(op))
613 "expected op to implement ConvolutionOpInterface");
624 ArrayAttr indexingMaps = op.getIndexingMaps();
628 0, (*dilations)[0], (*strides)[0]))
632 1, (*dilations)[1], (*strides)[1]))
636 2, (*dilations)[2], (*strides)[2]))
640 {{D * (*strides)[0] + d * (*dilations)[0],
641 H * (*strides)[1] + h * (*dilations)[1],
642 W * (*strides)[2] + w * (*dilations)[2]},
645 indexingMaps, context))
648 Block *body = op.getBlock();
650 Value yieldVal = yieldOp.getOperand(0);
661 if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
665 "expected op to implement ConvolutionOpInterface");
674 ArrayAttr indexingMaps = op.getIndexingMaps();
678 2, (*dilations)[0], (*strides)[0]))
682 {{N, C,
W * (*strides)[0] + w * (*dilations)[0]},
685 indexingMaps, context))
688 Block *body = op.getBlock();
690 Value yieldVal = yieldOp.getOperand(0);
701 if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
705 "expected op to implement ConvolutionOpInterface");
714 ArrayAttr indexingMaps = op.getIndexingMaps();
718 1, (*dilations)[0], (*strides)[0]))
722 {{N,
W * (*strides)[0] + w * (*dilations)[0], C},
725 indexingMaps, context))
728 Block *body = op.getBlock();
730 Value yieldVal = yieldOp.getOperand(0);
741 if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
745 "expected op to implement ConvolutionOpInterface");
755 ArrayAttr indexingMaps = op.getIndexingMaps();
759 1, (*dilations)[0], (*strides)[0]))
763 {{N,
W * (*strides)[0] + w * (*dilations)[0], C},
766 indexingMaps, context))
769 Block *body = op.getBlock();
771 Value yieldVal = yieldOp.getOperand(0);
782 if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
786 "expected op to implement ConvolutionOpInterface");
797 ArrayAttr indexingMaps = op.getIndexingMaps();
801 2, (*dilations)[0], (*strides)[0]))
805 3, (*dilations)[1], (*strides)[1]))
809 {{N, C, H * (*strides)[0] + h * (*dilations)[0],
810 W * (*strides)[1] + w * (*dilations)[1]},
813 indexingMaps, context))
816 Block *body = op.getBlock();
818 Value yieldVal = yieldOp.getOperand(0);
832 if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
836 "expected op to implement ConvolutionOpInterface");
850 ArrayAttr indexingMaps = op.getIndexingMaps();
854 1, (*dilations)[0], (*strides)[0]))
858 2, (*dilations)[1], (*strides)[1]))
862 3, (*dilations)[2], (*strides)[2]))
866 {{N, D * (*strides)[0] + d * (*dilations)[0],
867 H * (*strides)[1] + h * (*dilations)[1],
868 W * (*strides)[2] + w * (*dilations)[2], C},
870 {N, D, H,
W, C, CM}},
871 indexingMaps, context))
874 Block *body = op.getBlock();
876 Value yieldVal = yieldOp.getOperand(0);
887 if (isa<linalg::PoolingNhwcMaxOp>(op))
891 "expected op to implement ConvolutionOpInterface");
902 ArrayAttr indexingMaps = op.getIndexingMaps();
906 1, (*dilations)[0], (*strides)[0]))
910 2, (*dilations)[1], (*strides)[1]))
914 {{N, H * (*strides)[0] + h * (*dilations)[0],
915 W * (*strides)[1] + w * (*dilations)[1], C},
918 indexingMaps, context))
921 Block *body = op.getBlock();
923 Value yieldVal = yieldOp.getOperand(0);
934 if (isa<linalg::PoolingNhwcMinOp>(op))
938 "expected op to implement ConvolutionOpInterface");
949 ArrayAttr indexingMaps = op.getIndexingMaps();
953 1, (*dilations)[0], (*strides)[0]))
957 2, (*dilations)[1], (*strides)[1]))
961 {{N, H * (*strides)[0] + h * (*dilations)[0],
962 W * (*strides)[1] + w * (*dilations)[1], C},
965 indexingMaps, context))
968 Block *body = op.getBlock();
970 Value yieldVal = yieldOp.getOperand(0);
981 if (isa<linalg::PoolingNhwcSumOp>(op))
985 "expected op to implement ConvolutionOpInterface");
996 ArrayAttr indexingMaps = op.getIndexingMaps();
1000 1, (*dilations)[0], (*strides)[0]))
1004 2, (*dilations)[1], (*strides)[1]))
1008 {{N, H * (*strides)[0] + h * (*dilations)[0],
1009 W * (*strides)[1] + w * (*dilations)[1], C},
1012 indexingMaps, context))
1015 Block *body = op.getBlock();
1016 auto yieldOp = cast<linalg::YieldOp>(body->
getTerminator());
1017 Value yieldVal = yieldOp.getOperand(0);
1028 if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
1032 "expected op to implement ConvolutionOpInterface");
1043 ArrayAttr indexingMaps = op.getIndexingMaps();
1047 1, (*dilations)[0], (*strides)[0]))
1051 2, (*dilations)[1], (*strides)[1]))
1055 {{N, H * (*strides)[0] + h * (*dilations)[0],
1056 W * (*strides)[1] + w * (*dilations)[1], C},
1059 indexingMaps, context))
1062 Block *body = op.getBlock();
1063 auto yieldOp = cast<linalg::YieldOp>(body->
getTerminator());
1064 Value yieldVal = yieldOp.getOperand(0);
1075 if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
1079 "expected op to implement ConvolutionOpInterface");
1090 ArrayAttr indexingMaps = op.getIndexingMaps();
1094 1, (*dilations)[0], (*strides)[0]))
1098 2, (*dilations)[1], (*strides)[1]))
1102 {{N, H * (*strides)[0] + h * (*dilations)[0],
1103 W * (*strides)[1] + w * (*dilations)[1], C},
1106 indexingMaps, context))
1109 Block *body = op.getBlock();
1110 auto yieldOp = cast<linalg::YieldOp>(body->
getTerminator());
1111 Value yieldVal = yieldOp.getOperand(0);
1119 auto sliceOp = source.
getDefiningOp<tensor::ExtractSliceOp>();
1125 Value current = sliceOp.getSource();
1130 OpResult opResult = cast<OpResult>(current);
1131 current = linalgOp.getDpsInitOperand(opResult.
getResultNumber())->get();
1133 auto padOp = current ? current.
getDefiningOp<tensor::PadOp>() :
nullptr;
1142 if (sliceOp.getSource().getType() != type)
1147 if (llvm::any_of(padOp.getMixedLowPad(), [](
OpFoldResult ofr) {
1148 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
1155 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
1156 if (!padOpSliceOp ||
1157 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
1164 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
1165 [](std::tuple<OpFoldResult, OpFoldResult> it) {
1166 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
1173 Value padOpPad = padOp.getConstantPaddingValue();
1180 return sliceOp.getSource();
1184 auto memrefTypeTo = cast<MemRefType>(to.
getType());
1186 auto memrefTypeFrom = cast<MemRefType>(from.
getType());
1187 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
1188 "`from` and `to` memref must have the same rank");
1194 utils::IteratorType::parallel);
1195 return linalg::GenericOp::create(
1202 linalg::YieldOp::create(
b, loc, args.front());
1215 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
1216 "expected as many entries for proc info as number of loops, even if "
1217 "they are null entries");
1219 if (!linalgOp.hasPureBufferSemantics())
1220 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
1224 b, loc, lbs, ubs, steps, iterArgInitValues,
1226 assert(iterArgs.size() == iterArgInitValues.size() &&
1227 "expect the number of output tensors and iter args to match");
1229 if (!iterArgs.empty()) {
1230 operandValuesToUse = linalgOp.getDpsInputs();
1231 operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
1233 return bodyBuilderFn(
b, loc, ivs, operandValuesToUse);
1236 if (loopNest.
loops.empty() || procInfo.empty())
1240 for (
const auto &loop : llvm::enumerate(loopNest.
loops)) {
1241 if (procInfo[loop.index()].distributionMethod ==
1243 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
1244 procInfo[loop.index()].nprocs);
1259 if (!linalgOp.hasPureBufferSemantics())
1260 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
1261 assert(iterArgInitValues.empty() &&
"unexpected AffineForOp init values");
1267 constantSteps.reserve(steps.size());
1268 for (
Value v : steps) {
1270 assert(constVal.has_value() &&
"Affine loops require constant steps");
1271 constantSteps.push_back(constVal.value());
1276 bodyBuilderFn(
b, loc, ivs,
1277 linalgOp->getOperands());
1309 assert(lbs.size() == ubs.size());
1310 assert(lbs.size() == steps.size());
1311 assert(lbs.size() == iteratorTypes.size());
1312 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
1316 if (iteratorTypes.empty()) {
1317 bodyBuilderFn(
b, loc, ivStorage);
1325 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
1327 ivStorage.append(ivs.begin(), ivs.end());
1328 generateParallelLoopNest(
1329 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
1330 iteratorTypes.drop_front(),
1331 procInfo.empty() ? procInfo : procInfo.drop_front(),
1332 bodyBuilderFn, ivStorage);
1337 unsigned nLoops = iteratorTypes.size();
1338 unsigned numProcessed = 0;
1340 if (procInfo.empty()) {
1343 distributionMethod = procInfo.front().distributionMethod;
1352 auto remainderProcInfo =
1353 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
1354 switch (distributionMethod) {
1358 scf::ParallelOp::create(
1359 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
1360 steps.take_front(numProcessed),
1362 ivStorage.append(localIvs.begin(), localIvs.end());
1363 generateParallelLoopNest(
1364 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
1365 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
1366 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
1367 bodyBuilderFn, ivStorage);
1374 scf::ParallelOp::create(
1375 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
1376 steps.take_front(numProcessed),
1378 ivStorage.append(localIvs.begin(), localIvs.end());
1379 generateParallelLoopNest(
1380 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
1381 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
1382 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
1383 bodyBuilderFn, ivStorage);
1390 Value cond = ab.
slt(lbs[0], ubs[0]);
1391 for (
unsigned i = 1; i < numProcessed; ++i)
1392 cond = ab.
_and(cond, ab.
slt(lbs[i], ubs[i]));
1393 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
1396 ubs.drop_front(numProcessed),
1397 steps.drop_front(numProcessed),
1398 iteratorTypes.drop_front(numProcessed),
1399 remainderProcInfo, bodyBuilderFn, ivStorage);
1407 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
1409 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
1410 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
1411 remainderProcInfo, bodyBuilderFn, ivStorage);
1426 if (!linalgOp.hasPureBufferSemantics())
1427 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
1428 assert(iterArgInitValues.empty() &&
"unexpected ParallelOp init values");
1430 assert(iteratorTypes.size() >= loopRanges.size() &&
1431 "expected iterator type for all ranges");
1432 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
1433 "expected proc information for all loops when present");
1434 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
1436 unsigned numLoops = iteratorTypes.size();
1437 ivs.reserve(numLoops);
1438 lbsStorage.reserve(numLoops);
1439 ubsStorage.reserve(numLoops);
1440 stepsStorage.reserve(numLoops);
1443 unpackRanges(
b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
1446 for (
const auto &it : llvm::enumerate(procInfo)) {
1449 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
1450 ubsStorage[it.index()], stepsStorage[it.index()]);
1453 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
1455 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
1457 bodyBuilderFn(
b, loc, ivs, linalgOp->getOperands());
1461 assert(ivs.size() == iteratorTypes.size() &&
"did not generate enough loops");
1467 auto shapedType = dyn_cast<ShapedType>(valueToTile.
getType());
1469 .Case([&](MemRefType) {
1470 return memref::SubViewOp::create(
1471 builder, loc, valueToTile, sliceParams.
offsets,
1474 .Case([&](RankedTensorType) {
1475 return tensor::ExtractSliceOp::create(
1476 builder, loc, valueToTile, sliceParams.
offsets,
1479 .DefaultUnreachable(
"Unexpected shaped type");
1488 bool omitPartialTileCheck) {
1491 ubs, subShapeSizes, omitPartialTileCheck);
1500 bool omitPartialTileCheck) {
1501 auto shapedType = dyn_cast<ShapedType>(valueToTile.
getType());
1502 assert(shapedType &&
"only shaped types can be tiled");
1504 int64_t rank = shapedType.getRank();
1508 sliceParams.
offsets.reserve(rank);
1509 sliceParams.
sizes.reserve(rank);
1510 sliceParams.
strides.reserve(rank);
1511 for (
unsigned r = 0; r < rank; ++r) {
1512 LLVM_DEBUG(llvm::dbgs() <<
"computeSliceParameters: for dim#" << r);
1516 sliceParams.
sizes.push_back(dim);
1518 LLVM_DEBUG(llvm::dbgs() <<
": not tiled: use size: " << dim <<
"\n");
1521 LLVM_DEBUG(llvm::dbgs() <<
": tiled: figure out subsize...\n");
1526 LLVM_DEBUG(llvm::dbgs() <<
"computeSliceParameters: submap: " << m <<
"\n");
1531 [[maybe_unused]]
auto res = m.constantFold(zeros, mAtZero);
1532 assert(succeeded(res) &&
"affine_map must be evaluatable (not symbols)");
1534 cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
1536 rewriter, loc, m.getResult(0) - mAtZeroInt, lbs);
1537 sliceParams.
offsets.push_back(offset);
1545 LLVM_DEBUG(llvm::dbgs()
1546 <<
"computeSliceParameters: raw size: " << size <<
"\n");
1547 LLVM_DEBUG(llvm::dbgs()
1548 <<
"computeSliceParameters: new offset: " << offset <<
"\n");
1551 if (omitPartialTileCheck) {
1554 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: new size: " << size <<
"\n");
1555 sliceParams.
sizes.push_back(size);
1566 auto hasTileSizeOne = sizeCst == 1;
1567 auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
1568 ((shapeSize % *sizeCst) == 0);
1569 if (!hasTileSizeOne && !dividesEvenly) {
1570 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: shapeSize=" << shapeSize
1571 <<
", size: " << size
1572 <<
": make sure in bound with affine.min\n");
1576 bindDims(context, dim0, dim1, dim2);
1607 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShape: new size: " << size <<
"\n");
1608 sliceParams.
sizes.push_back(size);
1617 for (
unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
1618 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShapes: for loop#" << idx <<
"\n");
1620 offsets.push_back(
isTiled ? ivs[idxIvs++] :
b.getIndexAttr(0));
1621 LLVM_DEBUG(llvm::dbgs()
1622 <<
"computeTileOffsets: " << offsets.back() <<
"\n");
1631 for (
unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
1638 LLVM_DEBUG(llvm::dbgs() <<
"computeTileSizes: " << sizes.back() <<
"\n");
1644 if (op.hasPureBufferSemantics())
1646 return llvm::to_vector(
1647 llvm::map_range(op.getDpsInitsMutable(), [&](
OpOperand &opOperand) {
1648 return operands[opOperand.getOperandNumber()].getType();
1655 if (op.hasPureBufferSemantics())
1658 tensorResults.reserve(results.size());
1660 unsigned resultIdx = 0;
1661 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1664 Value outputTensor = operands[opOperand.getOperandNumber()];
1665 if (
auto sliceOp = outputTensor.
getDefiningOp<tensor::ExtractSliceOp>()) {
1667 builder, loc, sliceOp.getSource().getType(), results[resultIdx],
1668 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
1669 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1670 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1673 tensorResults.push_back(results[resultIdx]);
1677 return tensorResults;
1685 bool omitPartialTileCheck) {
1686 assert(ivs.size() ==
static_cast<size_t>(llvm::count_if(
1687 llvm::make_range(tileSizes.begin(), tileSizes.end()),
1689 "expected as many ivs as non-zero sizes");
1698 assert(
static_cast<int64_t>(valuesToTile.size()) <=
1699 linalgOp->getNumOperands() &&
1700 "more value to tile than operands.");
1702 allSliceParams.reserve(valuesToTile.size());
1703 for (
auto [opOperand, val] :
1704 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
1705 Value shapedOp = val;
1706 LLVM_DEBUG(llvm::dbgs() <<
"makeTiledShapes: for operand " << shapedOp);
1707 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
1714 Type operandType = opOperand.get().getType();
1715 if (!
isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
1716 linalgOp.isDpsInit(&opOperand))) {
1717 allSliceParams.push_back(std::nullopt);
1718 LLVM_DEBUG(llvm::dbgs()
1719 <<
": not tiled: use shape: " << operandType <<
"\n");
1722 LLVM_DEBUG(llvm::dbgs() <<
": tiled: figure out subshape...\n");
1725 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
1726 omitPartialTileCheck));
1729 return allSliceParams;
1737 bool omitPartialTileCheck) {
1740 tileSizes, sizeBounds, omitPartialTileCheck);
1742 for (
auto item : llvm::zip(valuesToTile, allSliceParameter)) {
1743 Value valueToTile = std::get<0>(item);
1744 std::optional<SliceParameters> sliceParams = std::get<1>(item);
1745 tiledShapes.push_back(
1746 sliceParams.has_value()
1762 if (!linalgOp.hasIndexSemantics())
1765 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
1766 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
1769 b.setInsertionPointAfter(indexOp);
1773 b, indexOp.getLoc(),
index + offset,
1774 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
1775 Value materialized =
1777 b.replaceUsesWithIf(indexOp, materialized, [&](
OpOperand &use) {
1789std::optional<SmallVector<ReassociationIndices>>
1793 for (
const auto &it : llvm::enumerate(mixedSizes)) {
1794 auto dim = it.index();
1795 auto size = it.value();
1796 curr.push_back(dim);
1797 auto attr = llvm::dyn_cast_if_present<Attribute>(size);
1798 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
1801 std::swap(reassociation.back(), curr);
1806 if (!curr.empty() && !reassociation.empty())
1807 reassociation.back().append(curr.begin(), curr.end());
1808 return reassociation;
static SmallVector< int64_t > computePackUnPackPerm(int64_t rank, ArrayRef< int64_t > &innerDimsPos, ArrayRef< int64_t > &outerPerm, PackingMetadata &packingMetadata)
The permutation can be obtained from two permutations: a) Compute the permutation vector to move the ...
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
static void unpackRanges(OpBuilder &builder, Location loc, ArrayRef< Range > ranges, SmallVectorImpl< Value > &lbs, SmallVectorImpl< Value > &ubs, SmallVectorImpl< Value > &steps)
Given a list of subview ranges, extract individual values for lower, upper bounds and steps and put t...
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessIntOrFloat() const
Return true of this is a signless integer or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isaConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcmOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
bool isaConvolutionOpOfType< linalg::PoolingNhwcSumOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNchwChwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
bool isaConvolutionOpOfType< linalg::Conv1DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
SmallVector< OpFoldResult > computeTileSizes(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds)
Computes tile sizes, given a list of tileSizes and dimension sizes (sizeBounds).
bool isaConvolutionOpOfType< linalg::Conv1DNcwFcwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to)
Returns GenericOp that copies an n-D memref.
static void generateParallelLoopNest(OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef< utils::IteratorType > iteratorTypes, ArrayRef< linalg::ProcInfo > procInfo, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, SmallVectorImpl< Value > &ivStorage)
Generates a loop nest consisting of scf.parallel and scf.for, depending on the iteratorTypes.
SmallVector< OpFoldResult > computeTileOffsets(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes)
Computes tile offsets, given a list of loop ivs and tileSizes.
bool isaConvolutionOpOfType< linalg::Conv2DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body)
bool isaConvolutionOpOfType< linalg::Conv3DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
bool isaConvolutionOpOfType< linalg::PoolingNhwcMaxOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body)
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
bool hasOnlyScalarElementwiseOp(Region &r)
Detect whether r has only ConstantOp, ElementwiseMappable and YieldOp.
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex)
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body)
Utility to match block body for linalg.pool* ops.
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
DistributionMethod
Scheme used to distribute loops to processors.
@ CyclicNumProcsGeNumIters
Cyclic distribution where the number of processors can be assumed to be more than or equal to the num...
@ Cyclic
Cyclic distribution where no assumption is made about the dynamic relationship between number of proc...
@ CyclicNumProcsEqNumIters
Cyclic distribution where the number of processors can be assumed to be equal to the number of iterat...
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body)
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNcwCwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
static BlockArgument getBlockArgumentWithOptionalExtOps(Value val)
Returns the BlockArgument that leads to val, if any.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body)
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcmOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
bool isaConvolutionOpOfType< linalg::PoolingNhwcMinOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body)
Utility to match block body for convolution ops.
SmallVector< std::optional< SliceParameters > > computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Computes SliceParamaters for all valuesToTile of the given linalgOp, assuming linalgOp is being fused...
Operation * makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Creates an extract_slice/subview op for a single valueToTile with builder.
static bool convLayoutMatches(ArrayRef< ArrayRef< AffineExpr > > mapListExpected, ArrayAttr indexingMaps, MLIRContext *context)
Returns true if the given indexing maps matches with the expected indexing maps.
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body)
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride)
Given an array of AffineMaps indexingMaps verify the following commutatively:- indexingMaps[0]....
bool isaConvolutionOpOfType< linalg::PoolingNhwcMinUnsignedOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
static Operation * materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams)
bool isaConvolutionOpOfType< linalg::Conv1DNwcWcfOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim)
Check if expr is either:
void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, Value procId, Value nprocs, Value &lb, Value &ub, Value &step)
Update the lb, ub and step to get per processor lb, ub and step.
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
bool isaConvolutionOpOfType< linalg::PoolingNhwcMaxUnsignedOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
@ Mul
RHS of mul is always a constant or a symbolic expression.
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Helper struct to build simple arithmetic quantities with minimal type inference support.
Value _and(Value lhs, Value rhs)
Value slt(Value lhs, Value rhs)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Utility class used to generate nested loops with ranges described by loopRanges and loop type describ...
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
Callback function type used to get processor ID, and number of processors used for distribution for a...
DistributionMethod distributionMethod
static std::optional< BinaryOpKind > matchAsScalarBinaryOp(GenericOp op)
Matches the given linalg op if its body is performing binary operation on int or float scalar values ...
A struct containg offsets-sizes-strides arguments of the tiled shape.
SmallVector< OpFoldResult > strides
SmallVector< OpFoldResult > sizes
SmallVector< OpFoldResult > offsets