MLIR 22.0.0git
Tiling.cpp
Go to the documentation of this file.
1//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Tiling pass.
10//
11//===----------------------------------------------------------------------===//
12
24#include "mlir/IR/AffineExpr.h"
25#include "mlir/IR/AffineMap.h"
26#include "mlir/IR/ValueRange.h"
28#include "llvm/ADT/STLExtras.h"
29#include <utility>
30
31namespace mlir {
32#define GEN_PASS_DEF_LINALGTILINGPASS
33#include "mlir/Dialect/Linalg/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::affine;
38using namespace mlir::linalg;
39using namespace mlir::scf;
40
41#define DEBUG_TYPE "linalg-tiling"
42
43std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
45 ArrayRef<OpFoldResult> allShapeSizes,
46 ArrayRef<OpFoldResult> allTileSizes) {
47 assert(allTileSizes.size() == map.getNumResults());
48 // Apply `map` to get shape sizes in loop order.
49 SmallVector<OpFoldResult> shapeSizes =
50 makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes);
51 SmallVector<OpFoldResult> tileSizes(allTileSizes);
52
53 // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
54 LoopIndexToRangeIndexMap loopIndexToRangeIndex;
55 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
56 if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
57 static_cast<int64_t>(0)) {
58 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
59 tileSizes.erase(tileSizes.begin() + idx - zerosCount);
60 ++zerosCount;
61 continue;
62 }
63 loopIndexToRangeIndex[idx] = idx - zerosCount;
64 }
65
66 // Create a new range with the applied tile sizes.
68 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
69 res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]});
70 return std::make_tuple(res, loopIndexToRangeIndex);
71}
72
74 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
75 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
76 SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
77 for (auto en : enumerate(allIvs)) {
78 auto rangeIndex = loopIndexToRangeIndex.find(en.index());
79 if (rangeIndex == loopIndexToRangeIndex.end())
80 continue;
81 en.value() = ivs[rangeIndex->second];
82 }
83 offsetIndices(b, op, getAsOpFoldResult(allIvs));
84}
85
86/// Asserts that the given index-typed value is strictly positive. If the value
87/// is an attribute, asserts at compile time, otherwise emits an assertion
88/// checked at runtime.
90 OpFoldResult value) {
91 if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
92 assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
93 "expected strictly positive tile size and divisor");
94 return;
95 }
96
98 Value condition = arith::CmpIOp::create(b, arith::CmpIPredicate::sgt,
99 cast<Value>(value), zero);
100 cf::AssertOp::create(
101 b, condition,
102 b.getStringAttr("expected strictly positive tile size and divisor"));
103}
104
105FailureOr<StaticContinuousTileSizeSpecification>
107 unsigned targetSize) {
108
109 assert(!op.hasDynamicShape() &&
110 "cannot compute static multi-tile sizes for an op with dynamic shape");
111 assert(targetSize > 0 && "target size must be non-negative");
112 assert(dimension < op.getNumLoops() && "dimension overflow");
113
115 int64_t loopRange = op.getStaticLoopRanges()[dimension];
116 int64_t tripCount = loopRange / targetSize;
117
118 unsigned tileSize = targetSize;
119
120 spec.tileSizes.push_back(tileSize);
121 spec.tripCounts.push_back(tripCount);
122
123 int64_t remainderChunk = loopRange % targetSize;
124
125 while (tileSize > 1 && remainderChunk != 0) {
126
127 uint64_t maxPower = llvm::bit_floor(tileSize);
128 tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
129
130 tripCount = remainderChunk / tileSize;
131
132 if (tripCount > 0) {
133 spec.tileSizes.push_back(tileSize);
134 spec.tripCounts.push_back(tripCount);
135 }
136
137 remainderChunk = remainderChunk % tileSize;
138 }
139
140 auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
141 SmallVector<int64_t> tripCounts,
142 int64_t range) -> bool {
143 int64_t computedRange = 0;
144 for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
145 computedRange += tileSize * tripCount;
146 return range == computedRange;
147 };
148
149 if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
150 return failure();
151
152 return spec;
153}
154
155FailureOr<ContinuousTileSizeSpecification>
157 unsigned dimension,
158 OpFoldResult targetSize,
159 bool emitAssertions) {
160
161 SmallVector<Range> loopRanges = op.getIterationDomain(builder);
162 unsigned numLoops = loopRanges.size();
163
164 // Bail out on dimension overflow.
165 if (dimension >= numLoops)
166 return failure();
167
168 // The code below works only on values.
169 Location loc = op->getLoc();
170 ImplicitLocOpBuilder b(loc, builder);
171 if (emitAssertions) {
172 emitIsPositiveIndexAssertion(b, targetSize);
173 }
174 Value targetSizeValue =
175 getValueOrCreateConstantIndexOp(builder, loc, targetSize);
176
177 // Find the trip count of the iteration space dimension for which the tile
178 // sizes are computed.
179 Value loopRange =
180 getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
182
183 // Compute the tile sizes and the respective numbers of tiles.
184 AffineExpr s0 = b.getAffineSymbolExpr(0);
185 AffineExpr s1 = b.getAffineSymbolExpr(1);
186 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
187 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
188 };
189
190 Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
191 Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
192
194 b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
195
196 // emitAssertions above already asserts that targetSize is
197 // a poistive integer.
198 uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
199
200 assert(tileSizeInt > 0 && "target size must be non-negative");
201
202 spec.tileSizes.push_back(targetSizeValue);
203 spec.tripCounts.push_back(tripCountValue);
204
205 while (tileSizeInt > 1) {
206 uint64_t maxPower = llvm::bit_floor(tileSizeInt);
207 tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
208 auto constStepOp =
209 builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
210 tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
211
213 b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
214
215 // Optimization if tripCount can be determined to be zero.
216 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
217 auto intAttr = cast<IntegerAttr>(attr);
218 bool isTripCountZero = intAttr.getValue().isZero();
219
220 if (!isTripCountZero) {
221 spec.tileSizes.push_back(constStepOp);
222 spec.tripCounts.push_back(tripCountValue);
223 }
224 } else {
225 spec.tileSizes.push_back(constStepOp);
226 spec.tripCounts.push_back(tripCountValue);
227 }
228
229 remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
230 }
231
232 return spec;
233}
234
235FailureOr<StaticMultiSizeSpecification>
236mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
237 int64_t targetSize, int64_t divisor) {
238 assert(!op.hasDynamicShape() &&
239 "cannot compute static multi-tile sizes for an op with dynamic shape");
240 assert(targetSize > 0 && "target size must be non-negative");
241 assert(divisor > 0 && "divisor must be non-negative");
242 assert(dimension < op.getNumLoops() && "dimension overflow");
243
245 int64_t tripCount = op.getStaticLoopRanges()[dimension];
246 int64_t a = tripCount / divisor;
247 int64_t t = (targetSize + divisor - 1) / divisor;
248 int64_t totalTripCount = (a + t - 1) / t;
249 spec.lowTileSize = (a / totalTripCount) * divisor;
250 spec.highTileSize = spec.lowTileSize + divisor;
251 spec.highTripCount = a % totalTripCount;
252 spec.lowTripCount = totalTripCount - spec.highTripCount;
253 if (spec.lowTileSize * spec.lowTripCount +
254 spec.highTileSize * spec.highTripCount !=
255 tripCount) {
256 return failure();
257 }
258 return spec;
259}
260
261FailureOr<MultiSizeSpecification>
263 unsigned dimension, OpFoldResult targetSize,
264 OpFoldResult divisor, bool emitAssertions) {
265 // Bail out on dimension overflow.
266 if (dimension >= op.getNumLoops())
267 return failure();
268
269 // The code below works only on values.
270 Location loc = op.getLoc();
271 ImplicitLocOpBuilder b(loc, builder);
272 if (emitAssertions) {
273 emitIsPositiveIndexAssertion(b, targetSize);
275 }
276 Value targetSizeValue =
277 getValueOrCreateConstantIndexOp(builder, loc, targetSize);
278 Value divisorValue = getValueOrCreateConstantIndexOp(builder, loc, divisor);
279
280 // Find the trip count of the iteration space dimension for which the tile
281 // sizes are computed.
282 SmallVector<OpFoldResult> allShapes =
283 op.createFlatListOfOperandDims(b, b.getLoc());
284 AffineMap shapesToLoops = op.getShapesToLoopsMap();
285 SmallVector<OpFoldResult> loopRanges =
286 makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
287 allShapes);
288 Value tripCount =
289 getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
290
291 // Compute the tile sizes and the respective numbers of tiles.
292 AffineExpr s0 = b.getAffineSymbolExpr(0);
293 AffineExpr s1 = b.getAffineSymbolExpr(1);
294 AffineExpr s2 = b.getAffineSymbolExpr(2);
295 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
296 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
297 };
298 Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
299 Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
300 Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
301 Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
302 Value v = apply(s0 % s1, {a, d});
303 Value u = apply(s0 - s1, {d, v});
304
306 spec.lowTileSize = s;
307 spec.highTileSize = apply(s0 + s1, {s, divisorValue});
308 spec.lowTripCount = u;
309 spec.highTripCount = v;
310
311 // If requested, emit the check that the tile sizes are computed correctly.
312 // For example, for iteration dimension size of 15 and the target size 8 it is
313 // impossible to find two tile sizes both divisible by 8 that fully cover the
314 // original space dimension.
315 if (emitAssertions) {
316 AffineExpr s3 = builder.getAffineSymbolExpr(3);
317 Value coveredSize =
318 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
319 spec.highTileSize, spec.highTripCount});
320 Value equals = arith::CmpIOp::create(b, arith::CmpIPredicate::eq,
321 coveredSize, tripCount);
322 cf::AssertOp::create(
323 b, equals,
324 builder.getStringAttr(
325 "could not compute dynamic multi-size tile shapes"));
326 }
327
328 return spec;
329}
330
331/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
332/// than `iterationSize`.
334 OpFoldResult numThreads,
335 OpFoldResult iterationSize) {
336 std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
337 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
338 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
339 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
340 return false;
341 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
342}
343
344/// Build an `affine_max` of all the `vals`.
351
352/// Build an `affine_min` of all the `vals`.
359
360/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
361/// number of threads.
363 RewriterBase &b, Location loc, scf::ForallOp forallOp,
364 ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
365 bool omitTileOffsetBoundsCheck,
366 std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
367 SmallVector<OpFoldResult> &tiledOffsets,
368 SmallVector<OpFoldResult> &tiledSizes) {
370 b.setInsertionPointToStart(forallOp.getBody(0));
371
372 SmallVector<Value> threadIds = forallOp.getInductionVars();
373 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
374 numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
375 int64_t nLoops = loopRanges.size();
376 tiledOffsets.reserve(nLoops);
377 tiledSizes.reserve(nLoops);
378 for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
379 bool overflow = loopIdx >= numThreads.size();
380 bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]);
381 // Degenerate case: take the whole domain.
382 if (overflow || isZero) {
383 tiledOffsets.push_back(loopRanges[loopIdx].offset);
384 tiledSizes.push_back(loopRanges[loopIdx].size);
385 continue;
386 }
387
388 // Tiled case: compute the offset and size.
389 AffineExpr i, j, m, n, o;
390 bindDims(b.getContext(), i, j);
391 bindSymbols(b.getContext(), m, n, o);
392 OpFoldResult size = loopRanges[loopIdx].size;
393 OpFoldResult offset = loopRanges[loopIdx].offset;
394 OpFoldResult threadId = threadIds[threadIdIdx];
395 // Symbolic fixed max size per thread.
396 // TODO: floor + 0/1 depending on case for better load-balancing.
397 OpFoldResult tileSizePerThread =
398 nominalTileSizes.has_value()
399 ? (*nominalTileSizes)[loopIdx]
401 b, loc, m.ceilDiv(n),
402 ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
403
404 // Dynamic offset shifted by threadId * maxSizePerThread.
406 b, loc, i + j * m, {offset, threadId, tileSizePerThread});
407 // Dynamic upper-bound depending on the threadId.
409 b, loc, i + j * m - n,
410 {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
411 if (!isZeroInteger(residualTileSize)) {
412 OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
413 b, loc, -i + m, {offsetPerThread, size});
414 tileSizePerThread =
415 buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread});
416 }
417
418 tiledOffsets.push_back(offsetPerThread);
419 // TODO: if tileSizePerThread <= 0 early exit.
420 if (!omitTileOffsetBoundsCheck &&
421 !canOmitTileOffsetInBoundsCheck(tileSizePerThread,
422 nonZeroNumThreads[threadIdIdx], size))
423 tileSizePerThread =
424 buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread});
425
426 tiledSizes.push_back(tileSizePerThread);
427 ++threadIdIdx;
428 }
429}
430
431template <typename LoopTy>
432static FailureOr<TiledLinalgOp>
436
437 auto nLoops = op.getNumLoops();
438 // Initial tile sizes may be too big, only take the first nLoops.
439 tileSizes = tileSizes.take_front(nLoops);
440
441 if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
442 return getConstantIntValue(ofr) == static_cast<int64_t>(0);
443 })) {
444 TiledLinalgOp tiledOp;
445 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
446 tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
447 tiledOp.op->result_end());
448 return tiledOp;
449 }
450
451 // 1. Build the tiled loop ranges.
452 SmallVector<OpFoldResult> allShapeSizes =
453 op.createFlatListOfOperandDims(b, op.getLoc());
454 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
455 assert(shapeSizesToLoopsMap && "invalid linalgOp with null ShapesToLoopsMap");
456
457 auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
458 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
459
461 for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
462 if (loopIndexToRangeIndex.count(attr.index()))
463 iteratorTypes.push_back(attr.value());
464 }
465 // If interchangeVector is empty, use the identity. Build the permutation map
466 // otherwise.
467 auto invPermutationMap =
468 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
469 if (!options.interchangeVector.empty()) {
470 // Based on the pruned iterations (due to zero tile size), recompute the
471 // interchange vector.
472 SmallVector<unsigned, 4> interchangeVector;
473 interchangeVector.reserve(options.interchangeVector.size());
474 for (auto pos : options.interchangeVector) {
475 auto it = loopIndexToRangeIndex.find(pos);
476 if (it == loopIndexToRangeIndex.end())
477 continue;
478 interchangeVector.push_back(it->second);
479 }
480 // Interchange vector is guaranteed to be a permutation,
481 // `inversePermutation` must succeed.
482 invPermutationMap = inversePermutation(
483 AffineMap::getPermutationMap(interchangeVector, b.getContext()));
484 assert(invPermutationMap);
485 SmallVector<int64_t> permutation(interchangeVector.begin(),
486 interchangeVector.end());
487 applyPermutationToVector(loopRanges, permutation);
488 applyPermutationToVector(iteratorTypes, permutation);
489 }
490
491 // Handle distribution. Create a vector of the same size of loops that are to
492 // be tiled.
494 if (options.distribution) {
495 procInfo.resize(
496 iteratorTypes.size(),
497 linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None});
498 // Collect loop ranges of tiled loops, loops that are parallel.
499 SmallVector<Range> parallelLoopRanges;
500 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
501 if (!isParallelIterator(iteratorType.value()))
502 break;
503 parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
504 }
505 auto returnedProcInfo =
506 options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
507 unsigned procIdIdx = 0;
508 // Update the distribution information for the loops.
509 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
510 if (!isParallelIterator(iteratorType.value()))
511 break;
512 procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
513 }
514 }
515
516 // 2. Create the tiled loops.
517 LinalgOp res = op;
518 SmallVector<Value, 4> ivs, tensorResults;
519 auto tiledLoopBodyBuilder =
520 [&](OpBuilder &builder, Location loc, ValueRange localIvs,
521 ValueRange operandValuesToUse) -> scf::ValueVector {
522 ivs.assign(localIvs.begin(), localIvs.end());
523
524 // When an `interchangeVector` is present, it has been applied to the
525 // loop ranges and the iterator types. Apply its inverse to the
526 // resulting loop `ivs` to match the op definition.
527 SmallVector<Value, 4> interchangedIvs;
528 if (!options.interchangeVector.empty()) {
529 for (AffineExpr result : invPermutationMap.getResults())
530 interchangedIvs.push_back(
531 ivs[cast<AffineDimExpr>(result).getPosition()]);
532 } else {
533 interchangedIvs.assign(ivs.begin(), ivs.end());
534 }
535
536 // Tile the `operandValuesToUse` that either match the `op` operands
537 // themselves or the tile loop arguments forwarding them.
538 assert(operandValuesToUse.size() ==
539 static_cast<size_t>(op->getNumOperands()) &&
540 "expect the number of operands and inputs and outputs to match");
541 SmallVector<Value> valuesToTile = operandValuesToUse;
542 SmallVector<OpFoldResult> sizeBounds =
543 makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap,
544 allShapeSizes);
545 SmallVector<Value> tiledOperands = makeTiledShapes(
546 b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes,
547 sizeBounds,
548 /*omitPartialTileCheck=*/false);
549
550 SmallVector<Type> resultTensorTypes =
551 getTensorOutputTypes(op, tiledOperands);
552 res = clone(b, op, resultTensorTypes, tiledOperands);
553 tensorResults =
554 insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
555 return scf::ValueVector(tensorResults.begin(), tensorResults.end());
556 };
557 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
558 tiledLoopBodyBuilder, procInfo);
559
560 // 3. Transform IndexOp results w.r.t. the tiling.
561 transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
562
563 // 4. Gather the newly created loops and return them with the new op.
565 loops.reserve(ivs.size());
566 for (auto iv : ivs) {
567 if (isa<BlockArgument>(iv)) {
568 loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp());
569 assert(loops.back() && "no owner found for induction variable!");
570 } else {
571 // TODO: Instead of doing this, try to recover the ops used instead of the
572 // loop.
573 loops.push_back(nullptr);
574 }
575 }
576
577 // 5. Get the tensor results from the outermost loop if available. Otherwise
578 // use the previously captured `tensorResults`.
579 Operation *outermostLoop = nullptr;
580 for (Operation *loop : loops)
581 if ((outermostLoop = loop))
582 break;
583
584 return TiledLinalgOp{
585 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
586}
587
588FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
589 RewriterBase &b, PartialReductionOpInterface op,
591 std::optional<ArrayAttr> mapping) {
592 Location loc = op.getLoc();
594
595 // Ops implementing PartialReductionOpInterface are expected to implement
596 // TilingInterface.
597 // TODO: proper core mechanism to tie interfaces together.
598 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
599
600 // Ops implementing PartialReductionOpInterface are not necessarily expected
601 // to implement TilingInterface.. This cast is unsafe atm.
602 // TODO: proper core mechanism to tie interfaces together.
603 // TODO: this function requires a pair of interfaces ..
604 auto destinationStyleOp =
605 dyn_cast<DestinationStyleOpInterface>(op.getOperation());
606 if (!destinationStyleOp)
607 return b.notifyMatchFailure(op, "not a destination style op");
608
609 // Actually this only work for Linalg ops atm.
610 auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
611 if (!linalgOp)
612 return b.notifyMatchFailure(op, "not a linalg op");
613
614 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
615 if (op->getNumResults() != 1)
616 return b.notifyMatchFailure(
617 op, "don't support ops with multiple results for now");
618
620 tilingInterfaceOp.getLoopIteratorTypes();
621 SmallVector<unsigned> redDims;
622 linalgOp.getReductionDims(redDims);
623 if (redDims.size() != 1)
624 return b.notifyMatchFailure(
625 op, "only support ops with one reduction dimension.");
626 if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
627 return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
628 "many elements as number of threads");
629
630 if (redDims.front() >= numThreads.size())
631 return b.notifyMatchFailure(
632 op, "reduction dimension must be mapped to threads");
633
634 // 1. Create the inital tensor value.
635 unsigned reductionDim = redDims.front();
636 SetVector<unsigned> reductionDims;
637 reductionDims.insert(reductionDim);
638 FailureOr<SmallVector<Value>> maybeInitTensors =
639 op.generateInitialTensorForPartialReduction(b, loc, numThreads,
640 reductionDims);
641 if (failed(maybeInitTensors))
642 return b.notifyMatchFailure(
643 op, "Failed to create inital tensors for partial reduction");
644 SmallVector<Value> &initTensors = maybeInitTensors.value();
645
646 // Gather destination tensors.
648 if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
649 return b.notifyMatchFailure(op, "failed to get destination tensors");
650
651 Operation *tiledOp = nullptr;
652
653 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
654 numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
655 SmallVector<Value> materializedNonZeroNumThreads =
656 getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
657
658 // 2. Create the ForallOp with an empty region.
659 scf::ForallOp forallOp = scf::ForallOp::create(
660 b, loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
661 mapping);
662
663 // 3. Calculate the tile offsets and sizes for the subsequent loop that will
664 // be nested under `forallOp`.
665 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
666 calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain,
667 /*omitTileOffsetBoundsCheck =*/false,
668 /*nominalTileSizes=*/std::nullopt, tiledOffsets,
669 tiledSizes);
670
671 // 4b. Clone the tileable op and update its destination operands to use the
672 // output bbArgs of the ForallOp.
673 SmallVector<Value> tilingResults;
674 ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
675 {
676 // 4.a. RAII guard, inserting within forallOp, before terminator.
678 b.setInsertionPoint(forallOp.getTerminator());
679
680 SmallVector<Value> tiledDpsInitOperands;
681 for (Value initOperand : destinationStyleOp.getDpsInits()) {
682 auto *it = llvm::find(dest, initOperand);
683 assert(it != dest.end() && "dest operand not found in dest");
684 unsigned destNum = std::distance(dest.begin(), it);
685 SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
686 SmallVector<OpFoldResult> outOffsets(numThreads.size(),
687 b.getIndexAttr(0));
688 SmallVector<OpFoldResult> sizes = tiledSizes;
689 sizes[reductionDim] = b.getIndexAttr(1);
690 outOffsets[reductionDim] = forallOp.getInductionVars()[0];
691 // TODO: use SubsetExtractOpInterface once it is available.
692 tiledDpsInitOperands.push_back(tensor::ExtractSliceOp::create(
693 b, loc, cast<RankedTensorType>(initOperand.getType()),
694 destBbArgs[destNum], outOffsets, sizes, strides));
695 }
696
697 // 4.b. Clone the op and update init operands.
698 // We cannot use a IRMapping here because it can replace
699 // different OpOperands with the same value.
700 Operation *clonedOp = b.clone(*op.getOperation());
701 b.modifyOpInPlace(clonedOp, [&]() {
702 for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
703 cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
704 tiledDpsInitOperands)) {
705 initOperandPtr.set(tiledInitValue);
706 }
707 });
708
709 // 5. Tile the cloned op and delete the clone.
710 if (tileSizes.empty()) {
711 FailureOr<TilingResult> tilingResult =
712 cast<TilingInterface>(clonedOp).getTiledImplementation(
713 b, tiledOffsets, tiledSizes);
714 if (failed(tilingResult))
715 return clonedOp->emitError("Failed to tile op: ");
716 if (tilingResult->tiledOps.size() != 1) {
717 return clonedOp->emitError("expected a single produced tiled op, got ")
718 << tilingResult->tiledOps.size();
719 }
720 tiledOp = tilingResult->tiledOps.front();
721 tilingResults = tilingResult->tiledValues;
722 } else {
724 FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
725 b, cast<LinalgOp>(clonedOp), tileSizes, options);
726 if (failed(maybeTiled))
727 return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
728
729 SmallVector<Value> ids = forallOp.getInductionVars();
730 mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
731 materializedNonZeroNumThreads);
732 if (maybeTiled->loops.size() != 1) {
733 return clonedOp->emitError("expected a single produced loop");
734 }
735 tiledOp = maybeTiled->op;
736 tilingResults = maybeTiled->loops.front()->getResults();
737 }
738
739 b.eraseOp(clonedOp);
740 }
741
742 // 6. Insert the partial reductions back into a new tensor.
743 for (auto [index, result, bbArg] : llvm::zip(
744 llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
745 // 6.a. Partial subset information is inserted just before the terminator.
747 b.setInsertionPoint(forallOp.getTerminator());
748
749 SmallVector<OpFoldResult> resultOffsets, resultSizes;
750 if (failed(tilingInterfaceOp.getResultTilePosition(
751 b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
752 return op->emitOpError("output offsets couldn't be calculated");
753 SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
754 int64_t offIdx = 0;
755 int64_t sizeIdx = 0;
756 for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
757 if (i == reductionDim) {
758 resultOffsetsRank.push_back(forallOp.getInductionVars()[0]);
759 resultSizesRank.push_back(b.getIndexAttr(1));
760 continue;
761 }
762 resultOffsetsRank.push_back(resultOffsets[offIdx++]);
763 resultSizesRank.push_back(resultSizes[sizeIdx++]);
764 }
765 SmallVector<OpFoldResult> strides(resultSizesRank.size(),
766 b.getIndexAttr(1));
767
768 // 6.b. Parallel insertions are inserted at the end of the combining
769 // terminator.
770 b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
771 tensor::ParallelInsertSliceOp::create(
772 b, loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
773 }
774
775 // 7. Merge the partial reductions.
776 b.setInsertionPointAfter(forallOp);
777 FailureOr<MergeResult> mergeResult =
778 op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
779 if (failed(mergeResult)) {
780 return failure();
781 }
782 b.replaceOp(op, mergeResult->replacements);
783
784 // 8. Return.
786 results.initialValues = initTensors;
787 results.loops = forallOp;
788 results.parallelTiledOps.push_back(tiledOp);
789 results.mergeOps.append(mergeResult->mergeOps);
790 return results;
791}
792
793template <typename LoopTy>
794FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
795 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
797 b.setInsertionPoint(op);
798
799 if (!options.tileSizeComputationFunction)
800 return failure();
801
802 // Enforce the convention that "tiling by zero" skips tiling a particular
803 // dimension. This convention is significantly simpler to handle instead of
804 // adjusting affine maps to account for missing dimensions.
805 auto nLoops = op.getNumLoops();
806 SmallVector<OpFoldResult> tileSizeVector =
807 getAsOpFoldResult(options.tileSizeComputationFunction(b, op));
808 if (tileSizeVector.size() < nLoops) {
809 tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0));
810 }
811
812 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
813}
814
815FailureOr<TiledLinalgOp>
818 switch (options.loopType) {
823 default:;
824 }
825 return failure();
826}
827
828namespace {
829/// Helper classes for type list expansion.
830template <typename... OpTypes>
831class CanonicalizationPatternList;
832
833template <>
834class CanonicalizationPatternList<> {
835public:
836 static void insert(RewritePatternSet &patterns) {}
837};
838
839template <typename OpTy, typename... OpTypes>
840class CanonicalizationPatternList<OpTy, OpTypes...> {
841public:
842 static void insert(RewritePatternSet &patterns) {
843 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
844 CanonicalizationPatternList<OpTypes...>::insert(patterns);
845 }
846};
847} // namespace
848
855
858 auto *ctx = patterns.getContext();
859 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
860 affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx);
861 affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
862 affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
863 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
864
865 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
866 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
867
868 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
869 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
870
871 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
872 tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
873 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
874 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
875 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
876 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
877
878 CanonicalizationPatternList<
879#define GET_OP_LIST
880#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
881 >::insert(patterns);
882}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static void calculateTileOffsetsAndSizes(RewriterBase &b, Location loc, scf::ForallOp forallOp, ArrayRef< OpFoldResult > numThreads, SmallVector< Range > loopRanges, bool omitTileOffsetBoundsCheck, std::optional< ArrayRef< OpFoldResult > > nominalTileSizes, SmallVector< OpFoldResult > &tiledOffsets, SmallVector< OpFoldResult > &tiledSizes)
Fill out the tiledOffsets and tiledSizes to be used to tile to a given number of threads.
Definition Tiling.cpp:362
static FailureOr< TiledLinalgOp > tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef< OpFoldResult > tileSizes, const LinalgTilingOptions &options)
Definition Tiling.cpp:433
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
Definition Tiling.cpp:333
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, OpFoldResult value)
Asserts that the given index-typed value is strictly positive.
Definition Tiling.cpp:89
static OpFoldResult buildMax(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > vals)
Build an affine_max of all the vals.
Definition Tiling.cpp:345
static OpFoldResult buildMin(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > vals)
Build an affine_min of all the vals.
Definition Tiling.cpp:353
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Attributes are known-constant values of operations.
Definition Attributes.h:25
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:415
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:2497
void transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl< Value > &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex)
All indices returned by IndexOp should be invariant with respect to tiling.
Definition Tiling.cpp:73
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:230
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition Tiling.cpp:856
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
Definition Utils.cpp:2417
std::tuple< SmallVector< Range, 4 >, LoopIndexToRangeIndexMap > makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > allShapeSizes, ArrayRef< OpFoldResult > allTileSizes)
Definition Tiling.cpp:44
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:2519
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition Tiling.cpp:236
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition Tiling.cpp:106
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition Tiling.cpp:588
FailureOr< TiledLinalgOp > tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options)
Definition Tiling.cpp:816
DenseMap< int, int > LoopIndexToRangeIndexMap
Creates a number of ranges equal to the number of non-zero in tileSizes.
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx)
Canonicalization patterns relevant to apply after tiling patterns.
Definition Tiling.cpp:850
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition Utils.cpp:2408
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition Tiling.cpp:262
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Transformation information returned after reduction tiling.
SmallVector< Operation * > mergeOps
The final reduction operation merging all the partial reductions.
SmallVector< Value > initialValues
Initial values used for partial reductions.
scf::ForallOp loops
The scf.forall operation that iterate over the tiles.
SmallVector< Operation * > parallelTiledOps
The partial reduction tiled op generated.
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
A description of a multi-size tiling comprising tile sizes and numbers of tiles, expressed as Values ...
Callback function type used to get processor ID, and number of processors used for distribution for a...
Definition Utils.h:312
Perform standalone tiling of a single LinalgOp by tileSizes.
Definition Transforms.h:894
SmallVector< Value, 4 > tensorResults
Definition Transforms.h:897
SmallVector< T > tripCounts
Number of tiles associated with each size.
T lowTripCount
Number of tiles associated with each size.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.