MLIR 22.0.0git
Tiling.cpp
Go to the documentation of this file.
1//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Tiling pass.
10//
11//===----------------------------------------------------------------------===//
12
24#include "mlir/IR/AffineExpr.h"
25#include "mlir/IR/AffineMap.h"
26#include "mlir/IR/ValueRange.h"
28#include "llvm/ADT/STLExtras.h"
29#include <utility>
30
31namespace mlir {
32#define GEN_PASS_DEF_LINALGTILINGPASS
33#include "mlir/Dialect/Linalg/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::affine;
38using namespace mlir::linalg;
39using namespace mlir::scf;
40
41#define DEBUG_TYPE "linalg-tiling"
42
43std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
45 ArrayRef<OpFoldResult> allShapeSizes,
46 ArrayRef<OpFoldResult> allTileSizes) {
47 assert(allTileSizes.size() == map.getNumResults());
48 // Apply `map` to get shape sizes in loop order.
49 SmallVector<OpFoldResult> shapeSizes =
50 makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes);
51 SmallVector<OpFoldResult> tileSizes(allTileSizes);
52
53 // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
54 LoopIndexToRangeIndexMap loopIndexToRangeIndex;
55 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
56 if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
57 static_cast<int64_t>(0)) {
58 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
59 tileSizes.erase(tileSizes.begin() + idx - zerosCount);
60 ++zerosCount;
61 continue;
62 }
63 loopIndexToRangeIndex[idx] = idx - zerosCount;
64 }
65
66 // Create a new range with the applied tile sizes.
68 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
69 res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]});
70 return std::make_tuple(res, loopIndexToRangeIndex);
71}
72
74 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
75 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
76 SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
77 for (auto en : enumerate(allIvs)) {
78 auto rangeIndex = loopIndexToRangeIndex.find(en.index());
79 if (rangeIndex == loopIndexToRangeIndex.end())
80 continue;
81 en.value() = ivs[rangeIndex->second];
82 }
83 offsetIndices(b, op, getAsOpFoldResult(allIvs));
84}
85
86/// Asserts that the given index-typed value is strictly positive. If the value
87/// is an attribute, asserts at compile time, otherwise emits an assertion
88/// checked at runtime.
90 OpFoldResult value) {
91 if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
92 assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
93 "expected strictly positive tile size and divisor");
94 return;
95 }
96
98 Value condition = arith::CmpIOp::create(b, arith::CmpIPredicate::sgt,
99 cast<Value>(value), zero);
100 cf::AssertOp::create(
101 b, condition,
102 b.getStringAttr("expected strictly positive tile size and divisor"));
103}
104
105FailureOr<StaticContinuousTileSizeSpecification>
107 unsigned targetSize) {
108
109 assert(!op.hasDynamicShape() &&
110 "cannot compute static multi-tile sizes for an op with dynamic shape");
111 assert(targetSize > 0 && "target size must be non-negative");
112 assert(dimension < op.getNumLoops() && "dimension overflow");
113
115 int64_t loopRange = op.getStaticLoopRanges()[dimension];
116 int64_t tripCount = loopRange / targetSize;
117
118 unsigned tileSize = targetSize;
119
120 spec.tileSizes.push_back(tileSize);
121 spec.tripCounts.push_back(tripCount);
122
123 int64_t remainderChunk = loopRange % targetSize;
124
125 while (tileSize > 1 && remainderChunk != 0) {
126
127 uint64_t maxPower = llvm::bit_floor(tileSize);
128 tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
129
130 tripCount = remainderChunk / tileSize;
131
132 if (tripCount > 0) {
133 spec.tileSizes.push_back(tileSize);
134 spec.tripCounts.push_back(tripCount);
135 }
136
137 remainderChunk = remainderChunk % tileSize;
138 }
139
140 auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
141 SmallVector<int64_t> tripCounts,
142 int64_t range) -> bool {
143 int64_t computedRange = 0;
144 for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
145 computedRange += tileSize * tripCount;
146 return range == computedRange;
147 };
148
149 if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
150 return failure();
151
152 return spec;
153}
154
155FailureOr<ContinuousTileSizeSpecification>
157 unsigned dimension,
158 OpFoldResult targetSize,
159 bool emitAssertions) {
160
161 SmallVector<Range> loopRanges = op.getIterationDomain(builder);
162 unsigned numLoops = loopRanges.size();
163
164 // Bail out on dimension overflow.
165 if (dimension >= numLoops)
166 return failure();
167
168 // The code below works only on values.
169 Location loc = op->getLoc();
170 ImplicitLocOpBuilder b(loc, builder);
171 if (emitAssertions) {
172 emitIsPositiveIndexAssertion(b, targetSize);
173 }
174 Value targetSizeValue =
175 getValueOrCreateConstantIndexOp(builder, loc, targetSize);
176
177 // Find the trip count of the iteration space dimension for which the tile
178 // sizes are computed.
179 Value loopRange =
180 getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
182
183 // Compute the tile sizes and the respective numbers of tiles.
184 AffineExpr s0 = b.getAffineSymbolExpr(0);
185 AffineExpr s1 = b.getAffineSymbolExpr(1);
186 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
187 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
188 };
189
190 Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
191 Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
192
194 b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
195
196 // emitAssertions above already asserts that targetSize is
197 // a poistive integer.
198 uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
199
200 assert(tileSizeInt > 0 && "target size must be non-negative");
201
202 spec.tileSizes.push_back(targetSizeValue);
203 spec.tripCounts.push_back(tripCountValue);
204
205 while (tileSizeInt > 1) {
206 uint64_t maxPower = llvm::bit_floor(tileSizeInt);
207 tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
208 auto constStepOp =
209 builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
210 tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
211
213 b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
214
215 // Optimization if tripCount can be determined to be zero.
216 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
217 auto intAttr = cast<IntegerAttr>(attr);
218 bool isTripCountZero = intAttr.getValue().isZero();
219
220 if (!isTripCountZero) {
221 spec.tileSizes.push_back(constStepOp);
222 spec.tripCounts.push_back(tripCountValue);
223 }
224 } else {
225 spec.tileSizes.push_back(constStepOp);
226 spec.tripCounts.push_back(tripCountValue);
227 }
228
229 remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
230 }
231
232 return spec;
233}
234
235FailureOr<StaticMultiSizeSpecification>
236mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
237 int64_t targetSize, int64_t divisor) {
238 assert(!op.hasDynamicShape() &&
239 "cannot compute static multi-tile sizes for an op with dynamic shape");
240 assert(targetSize > 0 && "target size must be non-negative");
241 assert(divisor > 0 && "divisor must be non-negative");
242 assert(dimension < op.getNumLoops() && "dimension overflow");
243
245 int64_t tripCount = op.getStaticLoopRanges()[dimension];
246 int64_t a = tripCount / divisor;
247 int64_t t = (targetSize + divisor - 1) / divisor;
248 int64_t totalTripCount = (a + t - 1) / t;
249 spec.lowTileSize = (a / totalTripCount) * divisor;
250 spec.highTileSize = spec.lowTileSize + divisor;
251 spec.highTripCount = a % totalTripCount;
252 spec.lowTripCount = totalTripCount - spec.highTripCount;
253 if (spec.lowTileSize * spec.lowTripCount +
254 spec.highTileSize * spec.highTripCount !=
255 tripCount) {
256 return failure();
257 }
258 return spec;
259}
260
261FailureOr<MultiSizeSpecification>
263 unsigned dimension, OpFoldResult targetSize,
264 OpFoldResult divisor, bool emitAssertions) {
265 // Bail out on dimension overflow.
266 if (dimension >= op.getNumLoops())
267 return failure();
268
269 // The code below works only on values.
270 Location loc = op.getLoc();
271 ImplicitLocOpBuilder b(loc, builder);
272 if (emitAssertions) {
273 emitIsPositiveIndexAssertion(b, targetSize);
275 }
276 Value targetSizeValue =
277 getValueOrCreateConstantIndexOp(builder, loc, targetSize);
278 Value divisorValue = getValueOrCreateConstantIndexOp(builder, loc, divisor);
279
280 // Find the trip count of the iteration space dimension for which the tile
281 // sizes are computed.
282 SmallVector<OpFoldResult> allShapes =
283 op.createFlatListOfOperandDims(b, b.getLoc());
284 AffineMap shapesToLoops = op.getShapesToLoopsMap();
285 SmallVector<OpFoldResult> loopRanges =
286 makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
287 allShapes);
288 Value tripCount =
289 getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
290
291 // Compute the tile sizes and the respective numbers of tiles.
292 AffineExpr s0 = b.getAffineSymbolExpr(0);
293 AffineExpr s1 = b.getAffineSymbolExpr(1);
294 AffineExpr s2 = b.getAffineSymbolExpr(2);
295 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
296 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
297 };
298 Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
299 Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
300 Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
301 Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
302 Value v = apply(s0 % s1, {a, d});
303 Value u = apply(s0 - s1, {d, v});
304
306 spec.lowTileSize = s;
307 spec.highTileSize = apply(s0 + s1, {s, divisorValue});
308 spec.lowTripCount = u;
309 spec.highTripCount = v;
310
311 // If requested, emit the check that the tile sizes are computed correctly.
312 // For example, for iteration dimension size of 15 and the target size 8 it is
313 // impossible to find two tile sizes both divisible by 8 that fully cover the
314 // original space dimension.
315 if (emitAssertions) {
316 AffineExpr s3 = builder.getAffineSymbolExpr(3);
317 Value coveredSize =
318 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
319 spec.highTileSize, spec.highTripCount});
320 Value equals = arith::CmpIOp::create(b, arith::CmpIPredicate::eq,
321 coveredSize, tripCount);
322 cf::AssertOp::create(
323 b, equals,
324 builder.getStringAttr(
325 "could not compute dynamic multi-size tile shapes"));
326 }
327
328 return spec;
329}
330
331/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
332/// than `iterationSize`.
334 OpFoldResult numThreads,
335 OpFoldResult iterationSize) {
336 std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
337 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
338 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
339 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
340 return false;
341 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
342}
343
344/// Build an `affine_max` of all the `vals`.
351
352/// Build an `affine_min` of all the `vals`.
359
360/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
361/// number of threads.
363 RewriterBase &b, Location loc, scf::ForallOp forallOp,
364 ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
365 bool omitTileOffsetBoundsCheck,
366 std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
367 SmallVector<OpFoldResult> &tiledOffsets,
368 SmallVector<OpFoldResult> &tiledSizes) {
370 b.setInsertionPointToStart(forallOp.getBody(0));
371
372 SmallVector<Value> threadIds = forallOp.getInductionVars();
373 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
374 numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
375 int64_t nLoops = loopRanges.size();
376 tiledOffsets.reserve(nLoops);
377 tiledSizes.reserve(nLoops);
378 for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
379 bool overflow = loopIdx >= numThreads.size();
380 bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]);
381 // Degenerate case: take the whole domain.
382 if (overflow || isZero) {
383 tiledOffsets.push_back(loopRanges[loopIdx].offset);
384 tiledSizes.push_back(loopRanges[loopIdx].size);
385 continue;
386 }
387
388 // Tiled case: compute the offset and size.
389 AffineExpr i, j, m, n, o;
390 bindDims(b.getContext(), i, j);
391 bindSymbols(b.getContext(), m, n, o);
392 OpFoldResult size = loopRanges[loopIdx].size;
393 OpFoldResult offset = loopRanges[loopIdx].offset;
394 OpFoldResult threadId = threadIds[threadIdIdx];
395 // Symbolic fixed max size per thread.
396 // TODO: floor + 0/1 depending on case for better load-balancing.
397 OpFoldResult tileSizePerThread =
398 nominalTileSizes.has_value()
399 ? (*nominalTileSizes)[loopIdx]
401 b, loc, m.ceilDiv(n),
402 ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
403
404 // Dynamic offset shifted by threadId * maxSizePerThread.
406 b, loc, i + j * m, {offset, threadId, tileSizePerThread});
407 // Dynamic upper-bound depending on the threadId.
409 b, loc, i + j * m - n,
410 {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
411 if (!isZeroInteger(residualTileSize)) {
412 OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
413 b, loc, -i + m, {offsetPerThread, size});
414 tileSizePerThread =
415 buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread});
416 }
417
418 tiledOffsets.push_back(offsetPerThread);
419 // TODO: if tileSizePerThread <= 0 early exit.
420 if (!omitTileOffsetBoundsCheck &&
421 !canOmitTileOffsetInBoundsCheck(tileSizePerThread,
422 nonZeroNumThreads[threadIdIdx], size))
423 tileSizePerThread =
424 buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread});
425
426 tiledSizes.push_back(tileSizePerThread);
427 ++threadIdIdx;
428 }
429}
430
431template <typename LoopTy>
432static FailureOr<TiledLinalgOp>
436
437 auto nLoops = op.getNumLoops();
438 // Initial tile sizes may be too big, only take the first nLoops.
439 tileSizes = tileSizes.take_front(nLoops);
440
441 if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
442 return getConstantIntValue(ofr) == static_cast<int64_t>(0);
443 })) {
444 TiledLinalgOp tiledOp;
445 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
446 tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
447 tiledOp.op->result_end());
448 return tiledOp;
449 }
450
451 // 1. Build the tiled loop ranges.
452 SmallVector<OpFoldResult> allShapeSizes =
453 op.createFlatListOfOperandDims(b, op.getLoc());
454 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
455 if (!shapeSizesToLoopsMap)
456 return failure();
457
458 auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
459 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
460
462 for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
463 if (loopIndexToRangeIndex.count(attr.index()))
464 iteratorTypes.push_back(attr.value());
465 }
466 // If interchangeVector is empty, use the identity. Build the permutation map
467 // otherwise.
468 auto invPermutationMap =
469 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
470 if (!options.interchangeVector.empty()) {
471 // Based on the pruned iterations (due to zero tile size), recompute the
472 // interchange vector.
473 SmallVector<unsigned, 4> interchangeVector;
474 interchangeVector.reserve(options.interchangeVector.size());
475 for (auto pos : options.interchangeVector) {
476 auto it = loopIndexToRangeIndex.find(pos);
477 if (it == loopIndexToRangeIndex.end())
478 continue;
479 interchangeVector.push_back(it->second);
480 }
481 // Interchange vector is guaranteed to be a permutation,
482 // `inversePermutation` must succeed.
483 invPermutationMap = inversePermutation(
484 AffineMap::getPermutationMap(interchangeVector, b.getContext()));
485 assert(invPermutationMap);
486 SmallVector<int64_t> permutation(interchangeVector.begin(),
487 interchangeVector.end());
488 applyPermutationToVector(loopRanges, permutation);
489 applyPermutationToVector(iteratorTypes, permutation);
490 }
491
492 // Handle distribution. Create a vector of the same size of loops that are to
493 // be tiled.
495 if (options.distribution) {
496 procInfo.resize(
497 iteratorTypes.size(),
498 linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None});
499 // Collect loop ranges of tiled loops, loops that are parallel.
500 SmallVector<Range> parallelLoopRanges;
501 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
502 if (!isParallelIterator(iteratorType.value()))
503 break;
504 parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
505 }
506 auto returnedProcInfo =
507 options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
508 unsigned procIdIdx = 0;
509 // Update the distribution information for the loops.
510 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
511 if (!isParallelIterator(iteratorType.value()))
512 break;
513 procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
514 }
515 }
516
517 // 2. Create the tiled loops.
518 LinalgOp res = op;
519 SmallVector<Value, 4> ivs, tensorResults;
520 auto tiledLoopBodyBuilder =
521 [&](OpBuilder &builder, Location loc, ValueRange localIvs,
522 ValueRange operandValuesToUse) -> scf::ValueVector {
523 ivs.assign(localIvs.begin(), localIvs.end());
524
525 // When an `interchangeVector` is present, it has been applied to the
526 // loop ranges and the iterator types. Apply its inverse to the
527 // resulting loop `ivs` to match the op definition.
528 SmallVector<Value, 4> interchangedIvs;
529 if (!options.interchangeVector.empty()) {
530 for (AffineExpr result : invPermutationMap.getResults())
531 interchangedIvs.push_back(
532 ivs[cast<AffineDimExpr>(result).getPosition()]);
533 } else {
534 interchangedIvs.assign(ivs.begin(), ivs.end());
535 }
536
537 // Tile the `operandValuesToUse` that either match the `op` operands
538 // themselves or the tile loop arguments forwarding them.
539 assert(operandValuesToUse.size() ==
540 static_cast<size_t>(op->getNumOperands()) &&
541 "expect the number of operands and inputs and outputs to match");
542 SmallVector<Value> valuesToTile = operandValuesToUse;
543 SmallVector<OpFoldResult> sizeBounds =
544 makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap,
545 allShapeSizes);
546 SmallVector<Value> tiledOperands = makeTiledShapes(
547 b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes,
548 sizeBounds,
549 /*omitPartialTileCheck=*/false);
550
551 SmallVector<Type> resultTensorTypes =
552 getTensorOutputTypes(op, tiledOperands);
553 res = clone(b, op, resultTensorTypes, tiledOperands);
554 tensorResults =
555 insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
556 return scf::ValueVector(tensorResults.begin(), tensorResults.end());
557 };
558 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
559 tiledLoopBodyBuilder, procInfo);
560
561 // 3. Transform IndexOp results w.r.t. the tiling.
562 transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
563
564 // 4. Gather the newly created loops and return them with the new op.
566 loops.reserve(ivs.size());
567 for (auto iv : ivs) {
568 if (isa<BlockArgument>(iv)) {
569 loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp());
570 assert(loops.back() && "no owner found for induction variable!");
571 } else {
572 // TODO: Instead of doing this, try to recover the ops used instead of the
573 // loop.
574 loops.push_back(nullptr);
575 }
576 }
577
578 // 5. Get the tensor results from the outermost loop if available. Otherwise
579 // use the previously captured `tensorResults`.
580 Operation *outermostLoop = nullptr;
581 for (Operation *loop : loops)
582 if ((outermostLoop = loop))
583 break;
584
585 return TiledLinalgOp{
586 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
587}
588
589FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
590 RewriterBase &b, PartialReductionOpInterface op,
592 std::optional<ArrayAttr> mapping) {
593 Location loc = op.getLoc();
595
596 // Ops implementing PartialReductionOpInterface are expected to implement
597 // TilingInterface.
598 // TODO: proper core mechanism to tie interfaces together.
599 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
600
601 // Ops implementing PartialReductionOpInterface are not necessarily expected
602 // to implement TilingInterface.. This cast is unsafe atm.
603 // TODO: proper core mechanism to tie interfaces together.
604 // TODO: this function requires a pair of interfaces ..
605 auto destinationStyleOp =
606 dyn_cast<DestinationStyleOpInterface>(op.getOperation());
607 if (!destinationStyleOp)
608 return b.notifyMatchFailure(op, "not a destination style op");
609
610 // Actually this only work for Linalg ops atm.
611 auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
612 if (!linalgOp)
613 return b.notifyMatchFailure(op, "not a linalg op");
614
615 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
616 if (op->getNumResults() != 1)
617 return b.notifyMatchFailure(
618 op, "don't support ops with multiple results for now");
619
621 tilingInterfaceOp.getLoopIteratorTypes();
622 SmallVector<unsigned> redDims;
623 linalgOp.getReductionDims(redDims);
624 if (redDims.size() != 1)
625 return b.notifyMatchFailure(
626 op, "only support ops with one reduction dimension.");
627 if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
628 return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
629 "many elements as number of threads");
630
631 if (redDims.front() >= numThreads.size())
632 return b.notifyMatchFailure(
633 op, "reduction dimension must be mapped to threads");
634
635 // 1. Create the inital tensor value.
636 unsigned reductionDim = redDims.front();
637 SetVector<unsigned> reductionDims;
638 reductionDims.insert(reductionDim);
639 FailureOr<SmallVector<Value>> maybeInitTensors =
640 op.generateInitialTensorForPartialReduction(b, loc, numThreads,
641 reductionDims);
642 if (failed(maybeInitTensors))
643 return b.notifyMatchFailure(
644 op, "Failed to create inital tensors for partial reduction");
645 SmallVector<Value> &initTensors = maybeInitTensors.value();
646
647 // Gather destination tensors.
649 if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
650 return b.notifyMatchFailure(op, "failed to get destination tensors");
651
652 Operation *tiledOp = nullptr;
653
654 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
655 numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
656 SmallVector<Value> materializedNonZeroNumThreads =
657 getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
658
659 // 2. Create the ForallOp with an empty region.
660 scf::ForallOp forallOp = scf::ForallOp::create(
661 b, loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
662 mapping);
663
664 // 3. Calculate the tile offsets and sizes for the subsequent loop that will
665 // be nested under `forallOp`.
666 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
667 calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain,
668 /*omitTileOffsetBoundsCheck =*/false,
669 /*nominalTileSizes=*/std::nullopt, tiledOffsets,
670 tiledSizes);
671
672 // 4b. Clone the tileable op and update its destination operands to use the
673 // output bbArgs of the ForallOp.
674 SmallVector<Value> tilingResults;
675 ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
676 {
677 // 4.a. RAII guard, inserting within forallOp, before terminator.
679 b.setInsertionPoint(forallOp.getTerminator());
680
681 SmallVector<Value> tiledDpsInitOperands;
682 for (Value initOperand : destinationStyleOp.getDpsInits()) {
683 auto *it = llvm::find(dest, initOperand);
684 assert(it != dest.end() && "dest operand not found in dest");
685 unsigned destNum = std::distance(dest.begin(), it);
686 SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
687 SmallVector<OpFoldResult> outOffsets(numThreads.size(),
688 b.getIndexAttr(0));
689 SmallVector<OpFoldResult> sizes = tiledSizes;
690 sizes[reductionDim] = b.getIndexAttr(1);
691 outOffsets[reductionDim] = forallOp.getInductionVars()[0];
692 // TODO: use SubsetExtractOpInterface once it is available.
693 tiledDpsInitOperands.push_back(tensor::ExtractSliceOp::create(
694 b, loc, cast<RankedTensorType>(initOperand.getType()),
695 destBbArgs[destNum], outOffsets, sizes, strides));
696 }
697
698 // 4.b. Clone the op and update init operands.
699 // We cannot use a IRMapping here because it can replace
700 // different OpOperands with the same value.
701 Operation *clonedOp = b.clone(*op.getOperation());
702 b.modifyOpInPlace(clonedOp, [&]() {
703 for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
704 cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
705 tiledDpsInitOperands)) {
706 initOperandPtr.set(tiledInitValue);
707 }
708 });
709
710 // 5. Tile the cloned op and delete the clone.
711 if (tileSizes.empty()) {
712 FailureOr<TilingResult> tilingResult =
713 cast<TilingInterface>(clonedOp).getTiledImplementation(
714 b, tiledOffsets, tiledSizes);
715 if (failed(tilingResult))
716 return clonedOp->emitError("Failed to tile op: ");
717 if (tilingResult->tiledOps.size() != 1) {
718 return clonedOp->emitError("expected a single produced tiled op, got ")
719 << tilingResult->tiledOps.size();
720 }
721 tiledOp = tilingResult->tiledOps.front();
722 tilingResults = tilingResult->tiledValues;
723 } else {
725 FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
726 b, cast<LinalgOp>(clonedOp), tileSizes, options);
727 if (failed(maybeTiled))
728 return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
729
730 SmallVector<Value> ids = forallOp.getInductionVars();
731 mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
732 materializedNonZeroNumThreads);
733 if (maybeTiled->loops.size() != 1) {
734 return clonedOp->emitError("expected a single produced loop");
735 }
736 tiledOp = maybeTiled->op;
737 tilingResults = maybeTiled->loops.front()->getResults();
738 }
739
740 b.eraseOp(clonedOp);
741 }
742
743 // 6. Insert the partial reductions back into a new tensor.
744 for (auto [index, result, bbArg] : llvm::zip(
745 llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
746 // 6.a. Partial subset information is inserted just before the terminator.
748 b.setInsertionPoint(forallOp.getTerminator());
749
750 SmallVector<OpFoldResult> resultOffsets, resultSizes;
751 if (failed(tilingInterfaceOp.getResultTilePosition(
752 b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
753 return op->emitOpError("output offsets couldn't be calculated");
754 SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
755 int64_t offIdx = 0;
756 int64_t sizeIdx = 0;
757 for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
758 if (i == reductionDim) {
759 resultOffsetsRank.push_back(forallOp.getInductionVars()[0]);
760 resultSizesRank.push_back(b.getIndexAttr(1));
761 continue;
762 }
763 resultOffsetsRank.push_back(resultOffsets[offIdx++]);
764 resultSizesRank.push_back(resultSizes[sizeIdx++]);
765 }
766 SmallVector<OpFoldResult> strides(resultSizesRank.size(),
767 b.getIndexAttr(1));
768
769 // 6.b. Parallel insertions are inserted at the end of the combining
770 // terminator.
771 b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
772 tensor::ParallelInsertSliceOp::create(
773 b, loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
774 }
775
776 // 7. Merge the partial reductions.
777 b.setInsertionPointAfter(forallOp);
778 FailureOr<MergeResult> mergeResult =
779 op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
780 if (failed(mergeResult)) {
781 return failure();
782 }
783 b.replaceOp(op, mergeResult->replacements);
784
785 // 8. Return.
787 results.initialValues = initTensors;
788 results.loops = forallOp;
789 results.parallelTiledOps.push_back(tiledOp);
790 results.mergeOps.append(mergeResult->mergeOps);
791 return results;
792}
793
794template <typename LoopTy>
795FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
796 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
798 b.setInsertionPoint(op);
799
800 if (!options.tileSizeComputationFunction)
801 return failure();
802
803 // Enforce the convention that "tiling by zero" skips tiling a particular
804 // dimension. This convention is significantly simpler to handle instead of
805 // adjusting affine maps to account for missing dimensions.
806 auto nLoops = op.getNumLoops();
807 SmallVector<OpFoldResult> tileSizeVector =
808 getAsOpFoldResult(options.tileSizeComputationFunction(b, op));
809 if (tileSizeVector.size() < nLoops) {
810 tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0));
811 }
812
813 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
814}
815
816FailureOr<TiledLinalgOp>
819 switch (options.loopType) {
824 default:;
825 }
826 return failure();
827}
828
829namespace {
830/// Helper classes for type list expansion.
831template <typename... OpTypes>
832class CanonicalizationPatternList;
833
834template <>
835class CanonicalizationPatternList<> {
836public:
837 static void insert(RewritePatternSet &patterns) {}
838};
839
840template <typename OpTy, typename... OpTypes>
841class CanonicalizationPatternList<OpTy, OpTypes...> {
842public:
843 static void insert(RewritePatternSet &patterns) {
844 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
845 CanonicalizationPatternList<OpTypes...>::insert(patterns);
846 }
847};
848} // namespace
849
856
859 auto *ctx = patterns.getContext();
860 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
861 affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx);
862 affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
863 affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
864 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
865
866 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
867 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
868
869 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
870 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
871
872 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
873 tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
874 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
875 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
876 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
877 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
878
879 CanonicalizationPatternList<
880#define GET_OP_LIST
881#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
882 >::insert(patterns);
883}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static void calculateTileOffsetsAndSizes(RewriterBase &b, Location loc, scf::ForallOp forallOp, ArrayRef< OpFoldResult > numThreads, SmallVector< Range > loopRanges, bool omitTileOffsetBoundsCheck, std::optional< ArrayRef< OpFoldResult > > nominalTileSizes, SmallVector< OpFoldResult > &tiledOffsets, SmallVector< OpFoldResult > &tiledSizes)
Fill out the tiledOffsets and tiledSizes to be used to tile to a given number of threads.
Definition Tiling.cpp:362
static FailureOr< TiledLinalgOp > tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef< OpFoldResult > tileSizes, const LinalgTilingOptions &options)
Definition Tiling.cpp:433
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
Definition Tiling.cpp:333
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, OpFoldResult value)
Asserts that the given index-typed value is strictly positive.
Definition Tiling.cpp:89
static OpFoldResult buildMax(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > vals)
Build an affine_max of all the vals.
Definition Tiling.cpp:345
static OpFoldResult buildMin(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > vals)
Build an affine_min of all the vals.
Definition Tiling.cpp:353
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Attributes are known-constant values of operations.
Definition Attributes.h:25
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:415
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:1732
void transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl< Value > &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex)
All indices returned by IndexOp should be invariant with respect to tiling.
Definition Tiling.cpp:73
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:230
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition Tiling.cpp:857
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
Definition Utils.cpp:1652
std::tuple< SmallVector< Range, 4 >, LoopIndexToRangeIndexMap > makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > allShapeSizes, ArrayRef< OpFoldResult > allTileSizes)
Definition Tiling.cpp:44
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:1754
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition Tiling.cpp:236
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition Tiling.cpp:106
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition Tiling.cpp:589
FailureOr< TiledLinalgOp > tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options)
Definition Tiling.cpp:817
DenseMap< int, int > LoopIndexToRangeIndexMap
Creates a number of ranges equal to the number of non-zero in tileSizes.
Definition Transforms.h:940
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx)
Canonicalization patterns relevant to apply after tiling patterns.
Definition Tiling.cpp:851
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition Utils.cpp:1643
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition Tiling.cpp:262
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Transformation information returned after reduction tiling.
SmallVector< Operation * > mergeOps
The final reduction operation merging all the partial reductions.
SmallVector< Value > initialValues
Initial values used for partial reductions.
scf::ForallOp loops
The scf.forall operation that iterate over the tiles.
SmallVector< Operation * > parallelTiledOps
The partial reduction tiled op generated.
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
A description of a multi-size tiling comprising tile sizes and numbers of tiles, expressed as Values ...
Definition Transforms.h:969
Callback function type used to get processor ID, and number of processors used for distribution for a...
Definition Utils.h:312
Perform standalone tiling of a single LinalgOp by tileSizes.
Definition Transforms.h:798
SmallVector< Value, 4 > tensorResults
Definition Transforms.h:801
SmallVector< T > tripCounts
Number of tiles associated with each size.
Definition Transforms.h:960
T lowTripCount
Number of tiles associated with each size.
Definition Transforms.h:952
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.