MLIR 23.0.0git
TilingInterfaceImpl.cpp
Go to the documentation of this file.
1//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/Support/Debug.h"
29#include <optional>
30
31#define DEBUG_TYPE "linalg-tiling-interface-impl"
32
33using namespace mlir;
34using namespace mlir::linalg;
35
36//===----------------------------------------------------------------------===//
37// Utility methods for implementation of Tiling Interface for Linalg ops
38//===----------------------------------------------------------------------===//
39
40/// Return the SSA values that represent the data point accessed using a given
41/// `indexingMap` for a given point in the iteration space represented by `ivs`.
43 AffineMap indexingMap,
44 ValueRange ivs) {
46 indices.reserve(indexingMap.getNumResults());
47 for (auto result : indexingMap.getResults()) {
48 AffineMap m = AffineMap::get(indexingMap.getNumDims(),
49 indexingMap.getNumSymbols(), result);
50 Value v = affine::AffineApplyOp::create(b, loc, m, ivs);
51 indices.push_back(v);
52 }
53 return indices;
54}
55
56/// Method to inline the payload of a `linalgOp` given the iteration space
57/// point and values for the arguments of the payload.
58static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
59 ValueRange ivs, ValueRange argValues) {
60 Block *body = linalgOp.getBlock();
61 IRMapping map;
62 map.map(body->getArguments(), argValues);
63 for (auto &op : body->without_terminator()) {
64 if (auto indexOp = dyn_cast<IndexOp>(&op)) {
65 map.map(indexOp.getResult(), ivs[indexOp.getDim()]);
66 continue;
67 }
68 b.clone(op, map);
69 }
70
71 Operation *terminator = body->getTerminator();
72 Location loc = terminator->getLoc();
73 for (const auto &operand : llvm::enumerate(terminator->getOperands())) {
74 Value toStore = map.lookupOrDefault(operand.value());
75 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
77 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
78 memref::StoreOp::create(b, loc, toStore,
79 linalgOp.getDpsInitOperand(operand.index())->get(),
80 indices);
81 }
82 return success();
83}
84
85//===----------------------------------------------------------------------===//
86// External Model for implementing `TilingInterface` for `LinalgOp`s.
87//===----------------------------------------------------------------------===//
88
89namespace {
90/// External model implementation of TilingInterface for LinalgOps. An external
91/// model implementation is used for now till the use of `TilingInterface` is
92/// on-par with the current Linalg tiling + fusion patterns. Once it is
93/// maybe possible to move this into the op-definition (though there are
94/// advantages to leaving it as an external model)
95template <typename LinalgOpTy>
96struct LinalgOpTilingInterface
97 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
98 LinalgOpTy> {
99 /// Return the loop iterator type.
100 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
101 LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
102 return concreteOp.getIteratorTypesArray();
103 }
104
105 /// Return the iteration domain range.
106 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
107 OpBuilder::InsertionGuard g(b);
108 b.setInsertionPoint(op);
109 Location loc = op->getLoc();
110 LinalgOp linalgOp = cast<LinalgOp>(op);
111 SmallVector<OpFoldResult> allShapesSizes =
112 linalgOp.createFlatListOfOperandDims(b, loc);
113 AffineMap map = linalgOp.getShapesToLoopsMap();
114
115 return llvm::map_to_vector(map.getResults(), [&](AffineExpr loopExpr) {
116 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(b, loc, loopExpr,
117 allShapesSizes);
118 return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
119 });
120 }
121
122 /// Instantiate the tiled implementation of the operation.
123 FailureOr<TilingResult>
126 ArrayRef<OpFoldResult> sizes) const {
127 // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
128 // specified could lead to out of bounds accesses.
129 Location loc = op->getLoc();
130 LinalgOp linalgOp = cast<LinalgOp>(op);
131 SmallVector<Value> valuesToTile = linalgOp->getOperands();
132 SmallVector<Value> tiledOperands = makeTiledShapes(
133 b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
134 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
135 llvm::make_filter_range(
136 tiledOperands,
137 [](Value v) -> bool {
138 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
139 v.getDefiningOp());
140 }),
141 [](Value v) -> Operation * { return v.getDefiningOp(); });
142
143 SmallVector<Type> resultTensorTypes =
144 getTensorOutputTypes(linalgOp, tiledOperands);
145
146 Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
147 offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
148
149 return TilingResult{
150 {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
151 }
152
153 /// Utility to fetch the offsets and sizes when applied as per the indexing
154 /// map of the linalg op. This helps in fusing the linalg op as a consumer of
155 /// a given slice op.
156 static LogicalResult
157 getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
158 ArrayRef<AffineMap> indexingMaps,
161 SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
162 SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
163 DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
164
165 for (auto [indexingMap, offsets, sizes] :
166 llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
167 for (auto [resultExpr, offset, size] :
168 llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
169 auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
170 if (!dimExpr)
171 return failure();
172 unsigned position = dimExpr.getPosition();
173 auto it = mappedOffsets.find(position);
174 if (it != mappedOffsets.end()) {
175 OpFoldResult seenOffset = it->second;
176 OpFoldResult seenSize = mappedSizes.lookup(position);
177 if (seenOffset != offset || seenSize != size) {
178 LLVM_DEBUG({
179 llvm::dbgs() << "inconsistent iteration space mapping from "
180 "offsets/sizes of operands/results";
181 });
182 return failure();
183 }
184 } else {
185 mappedOffsets[position] = offset;
186 mappedSizes[position] = size;
187 }
188 }
189 }
190
191 // Aggregate from the given operand offsets and sizes, or default to
192 // iteration space values.
193 SmallVector<Range> iterationDomain =
194 cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
195 mappedOffsetsVec.resize(iterationDomain.size());
196 mappedSizesVec.resize(iterationDomain.size());
197 for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
198 auto it = mappedOffsets.find(index);
199 if (it != mappedOffsets.end()) {
200 mappedOffsetsVec[index] = it->second;
201 mappedSizesVec[index] = mappedSizes.lookup(index);
202 continue;
203 }
204 mappedOffsetsVec[index] = domain.offset;
205 mappedSizesVec[index] = domain.size;
206 }
207 return success();
208 }
209
210 /// Method to return the position of the result tile computed by the tiled
211 /// operation.
212 LogicalResult getIterationDomainTileFromOperandTiles(
213 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
216 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
217 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
218 auto linalgOp = cast<LinalgOp>(op);
219
220 SmallVector<AffineMap> indexingMaps =
221 llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
222 OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
223 return linalgOp.getMatchingIndexingMap(&opOperand);
224 });
225 if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
226 allSizes, iterDomainOffsets,
227 iterDomainSizes))) {
228 return failure();
229 }
230 return success();
231 }
232
233 /// Return the details of the output tile generated by the tiled
234 /// implementation.
235 LogicalResult
236 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
239 SmallVector<OpFoldResult> &resultOffsets,
240 SmallVector<OpFoldResult> &resultSizes) const {
241 Location loc = op->getLoc();
242 LinalgOp linalgOp = cast<LinalgOp>(op);
243
244 AffineExpr d0;
245 bindDims(b.getContext(), d0);
246 SmallVector<OpFoldResult> subShapeSizes =
247 llvm::map_to_vector(sizes, [&](OpFoldResult ofr) {
248 return affine::makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr);
249 });
250
251 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
253 b, loc, outOperand->get(), sizes,
254 linalgOp.getMatchingIndexingMap(outOperand), offsets,
255 /*ubs*/ {}, subShapeSizes, true);
256 resultOffsets = sliceParams.offsets;
257 resultSizes = sliceParams.sizes;
258 return success();
259 }
260
261 LogicalResult getIterationDomainTileFromResultTile(
262 Operation *op, OpBuilder &b, unsigned resultNumber,
264 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
265 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
266 auto linalgOp = cast<LinalgOp>(op);
267
268 // Check that the indexing map used for the output is a projected
269 // permutation. This could be relaxed with a more general approach that can
270 // map the offsets and sizes from the result to iteration space tiles
271 // (filling in full extent for dimensions not used to access the result).
272 AffineMap indexingMap =
273 linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));
274 if (!indexingMap.isProjectedPermutation()) {
275 return op->emitOpError(
276 "unhandled tiled implementation generation when result is not "
277 "accessed using a permuted projection");
278 }
279
280 SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
281 SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
282 auto status =
283 getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
284 {allSizes}, iterDomainOffsets, iterDomainSizes);
285 (void)status;
286 assert(succeeded(status) && "unexpected error in offset calculation");
287 return success();
288 }
289
290 FailureOr<TilingResult>
291 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
293 ArrayRef<OpFoldResult> sizes) const {
294 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
295 if (failed(getIterationDomainTileFromResultTile(
296 op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
297 return failure();
298 }
299 auto tilingInterfaceOp = cast<TilingInterface>(op);
300 FailureOr<TilingResult> tilingResult =
301 tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
302
303 if (failed(tilingResult))
304 return failure();
305
306 if (tilingResult->tiledOps.size() != 1)
307 return op->emitOpError("failed to generate tiled implementation");
308
309 return TilingResult{
310 tilingResult->tiledOps,
311 SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
312 tilingResult->generatedSlices};
313 }
314
315 /// Method to generate the tiled implementation of an operation from the tile
316 /// of the operand.
317 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
318 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
320 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
321 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
322 if (failed(getIterationDomainTileFromOperandTiles(
323 op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
324 mappedSizes))) {
325 return failure();
326 }
327 return getTiledImplementation(op, b, mappedOffsets, mappedSizes);
328 }
329
330 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
331 Location loc,
332 ValueRange ivs) const {
333 auto linalgOp = cast<LinalgOp>(op);
334 if (!linalgOp.hasPureBufferSemantics())
335 return op->emitOpError("expected operation to have buffer semantics");
336
337 SmallVector<Value> indexedValues;
338 indexedValues.reserve(linalgOp->getNumOperands());
339 Location linalgOpLoc = op->getLoc();
340 /// Load the data corresponding to the block arguments that
341 /// represent input operands.
342 for (OpOperand &operand : linalgOp->getOpOperands()) {
343 if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
344 indexedValues.push_back(nullptr);
345 continue;
346 }
347 if (linalgOp.isScalar(&operand)) {
348 indexedValues.push_back(operand.get());
349 continue;
350 }
352 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
353 Value load =
354 memref::LoadOp::create(builder, linalgOpLoc, operand.get(), indices);
355 indexedValues.push_back(load);
356 }
357
358 /// Inline the op payload and store the result.
359 return inlinePayload(builder, linalgOp, ivs, indexedValues);
360 }
361
362 bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
364 ArrayRef<OpFoldResult> sizes) const {
365 // The verifier gives all the necessary requirements for consumer fusion.
366 return true;
367 }
368
369 bool isOpFusableWithProducerSlices(
370 Operation *op, ArrayRef<unsigned> operandNumbers,
372 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
373
374 auto linalgOp = cast<LinalgOp>(op);
375 SmallVector<AffineMap> indexingMaps =
376 llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
377 OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
378 return linalgOp.getMatchingIndexingMap(&opOperand);
379 });
380 // Check that offsets/sizes are consistent across all operands.
381 OpBuilder b(op);
382 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
383 return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps,
384 allOffsets, allSizes, mappedOffsets,
385 mappedSizes));
386 }
387};
388
389//===----------------------------------------------------------------------===//
390// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
391//===----------------------------------------------------------------------===//
392
393/// In a given set vector, get the position of a particular element.
394std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims,
395 unsigned value) {
396 for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
397 if (reductionDim == value) {
398 return index;
399 }
400 }
401 return std::nullopt;
402}
403
404/// Return an AffineMaps to use for the `outs` operands of the linalg op
405/// generated for partial results. The new AffineMap is the AffineMap of the
406/// untiled op with reduction dimensions appended at end in order in which they
407/// were specified during tiling.
409getPartialResultAffineMaps(LinalgOp linalgOp,
410 const SetVector<unsigned> &reductionDims) {
411 auto partialReductionMaps = llvm::map_to_vector(
412 linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
413 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
414 for (auto redPos : reductionDims) {
415 map =
416 map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
417 map.getNumResults());
418 }
419 return map;
420 });
421 return partialReductionMaps;
422}
423
424struct InitSliceInfo {
425 SmallVector<int64_t> resultShape;
426 SmallVector<OpFoldResult> offsets;
427 SmallVector<OpFoldResult> sizes;
428 SmallVector<OpFoldResult> strides;
429};
430
431/// Return the result shape, offsets, sizes and strides of the slice of the
432/// `initValue` to use as the destination of the partial reduction op generated
433/// with outer reduction strategy.
434static InitSliceInfo getInitSliceInfoForOuterReduction(
435 MLIRContext *context, ArrayRef<OpFoldResult> offsets,
436 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
437 ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
438 int64_t initRank = partialReductionMap.getNumResults();
439 SmallVector<OpFoldResult> initOffsets, initSizes;
440 Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
441 Attribute one = IntegerAttr::get(IndexType::get(context), 1);
442 SmallVector<OpFoldResult> initStrides(initRank, one);
443 for (AffineExpr dimExpr : partialReductionMap.getResults()) {
444 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
445 if (reductionDims.contains(dim)) {
446 initOffsets.push_back(zero);
447 } else {
448 initOffsets.push_back(offsets[dim]);
449 }
450 initSizes.push_back(sizes[dim]);
451 }
452 SmallVector<int64_t> resultShape;
453 std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
454 return {resultShape, initOffsets, initSizes, initStrides};
455}
456
457/// Return the result shape, offsets, sizes and strides of the slice of the
458/// `initValue` to use as destination of the partial reduction op generated with
459/// outer parallel strategy.
460static InitSliceInfo getInitSliceInfoForOuterParallel(
461 MLIRContext *context, ArrayRef<OpFoldResult> offsets,
462 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
463 ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
464 int64_t initRank = partialReductionMap.getNumResults();
465 SmallVector<OpFoldResult> initOffsets, initSizes;
466 Attribute one = IntegerAttr::get(IndexType::get(context), 1);
467 SmallVector<OpFoldResult> initStrides(initRank, one);
468 SmallVector<OpFoldResult> resultShape;
469 for (AffineExpr dimExpr : partialReductionMap.getResults()) {
470 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
471 if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
472 initOffsets.push_back(splitReductionIvs[dimPos.value()]);
473 initSizes.push_back(one);
474 } else {
475 initOffsets.push_back(offsets[dim]);
476 initSizes.push_back(sizes[dim]);
477 resultShape.push_back(sizes[dim]);
478 }
479 }
480 SmallVector<int64_t> staticShapes;
481 std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape);
482 return {staticShapes, initOffsets, initSizes, initStrides};
483}
484
485/// Return the result shape, offsets, sizes and strides of the slice of the
486/// `initValue` to use as destination of the partial reduction op.
487static InitSliceInfo getInitSliceInfo(MLIRContext *context,
491 const SetVector<unsigned> &reductionDims,
492 ArrayRef<OpFoldResult> splitReductionIvs,
493 AffineMap partialReductionMap) {
495 return getInitSliceInfoForOuterReduction(context, offsets, sizes,
496 reductionDims, splitReductionIvs,
497 partialReductionMap);
498 }
500 "unexpected ReductionTilingStrategy");
501 return getInitSliceInfoForOuterParallel(context, offsets, sizes,
502 reductionDims, splitReductionIvs,
503 partialReductionMap);
504}
505
506/// External model implementation of PartialReductionInterface for
507/// LinalgOps.
508template <typename LinalgOpTy>
509struct LinalgOpPartialReductionInterface
510 : public PartialReductionOpInterface::ExternalModel<
511 LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
512 FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
513 Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
514 const SetVector<unsigned> &reductionDims) const {
515 auto linalgOp = cast<LinalgOp>(op);
516
517 OpBuilder::InsertionGuard guard(b);
518 if (linalgOp.hasPureBufferSemantics())
519 return op->emitOpError("expected operation to have tensor semantics");
520
521 SmallVector<AffineMap> partialResultMaps =
522 getPartialResultAffineMaps(linalgOp, reductionDims);
523
524 SmallVector<Value> inits;
525 for (auto [initIdx, result, partialMap] :
526 llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
527 SmallVector<Operation *, 4> combinerOps;
528 if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
529 combinerOps) ||
530 combinerOps.size() != 1)
531 return op->emitOpError("Failed to anaysis the reduction operation.");
532
533 Operation *reductionOp = combinerOps[0];
534 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
535 if (!identity.has_value())
536 return op->emitOpError(
537 "Failed to get an identity value for the reduction operation.");
538
539 // Append the new partial result dimensions.
540 SmallVector<OpFoldResult> partialResultShape;
541 for (AffineExpr dimExpr : partialMap.getResults()) {
542 auto dim = cast<AffineDimExpr>(dimExpr);
543 partialResultShape.push_back(sizes[dim.getPosition()]);
544 }
545
546 Type elType = getElementTypeOrSelf(result.getType());
547 Value emptyTensor =
548 tensor::EmptyOp::create(b, loc, partialResultShape, elType);
549 Value constantOp = arith::ConstantOp::create(b, loc, *identity);
550 auto identityTensor =
551 linalg::FillOp::create(b, loc, constantOp, emptyTensor);
552 inits.push_back(identityTensor.getResult(0));
553 }
554
555 return inits;
556 }
557
558 FailureOr<TilingResult>
559 tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
560 ReductionTilingStrategy tilingStrategy,
561 ValueRange init, ArrayRef<OpFoldResult> offsets,
562 ArrayRef<OpFoldResult> sizes,
563 const SetVector<unsigned> &reductionDims,
564 ArrayRef<OpFoldResult> splitReductionIvs) const {
565 OpBuilder::InsertionGuard guard(b);
566 auto linalgOp = cast<LinalgOp>(op);
567
568 SmallVector<AffineMap> partialReductionMaps =
569 getPartialResultAffineMaps(linalgOp, reductionDims);
570
571 // Step 1. Extend init maps to have reduction dimension dims, since we
572 // are converting them to parallel dimensions.
573 SmallVector<AffineMap> newInitMaps;
574 if (tilingStrategy ==
575 ReductionTilingStrategy::PartialReductionOuterReduction) {
576 newInitMaps = llvm::to_vector(partialReductionMaps);
577 } else {
578 newInitMaps = llvm::map_to_vector(
579 linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
580 return linalgOp.getMatchingIndexingMap(&opOperand);
581 });
582 }
583
584 // Step 2a: Extract a slice of the input operands.
585 SmallVector<Value> tiledInputs = makeTiledShapes(
586 b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
587 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
588 llvm::make_filter_range(
589 tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
590 [](Value v) -> Operation * { return v.getDefiningOp(); });
591
592 // Step 2b: Extract a slice of the init operands.
593 SmallVector<Value, 1> tiledInits;
594 for (auto [partialReductionMap, valueToTile] :
595 llvm::zip_equal(partialReductionMaps, init)) {
596 InitSliceInfo sliceInfo = getInitSliceInfo(
597 b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
598 splitReductionIvs, partialReductionMap);
599 auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
600 RankedTensorType sliceResultType = RankedTensorType::get(
601 sliceInfo.resultShape, valueToTileType.getElementType(),
602 valueToTileType.getEncoding());
603 auto sliceOp = tensor::ExtractSliceOp::create(
604 b, loc, sliceResultType, valueToTile, sliceInfo.offsets,
605 sliceInfo.sizes, sliceInfo.strides);
606 tiledInits.push_back(sliceOp.getResult());
607 generatedSlices.push_back(sliceOp);
608 }
609
610 // Update the indexing maps.
611 SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
612 for (auto [initOperand, newInitMap] :
613 llvm::zip_equal(linalgOp.getDpsInitsMutable(), newInitMaps)) {
614 int mapIdx = linalgOp.getIndexingMapIndex(&initOperand);
615 newMaps[mapIdx] = newInitMap;
616 }
617
618 // Step 3. Change the reduction dim iterator types.
619 SmallVector<utils::IteratorType> newIteratorTypes =
620 linalgOp.getIteratorTypesArray();
621 if (tilingStrategy ==
622 ReductionTilingStrategy::PartialReductionOuterReduction) {
623 for (int dim : reductionDims)
624 newIteratorTypes[dim] = utils::IteratorType::parallel;
625 }
626
627 // Step 4. Create the new generic op.
628 Operation *partialReductionOp;
629 auto resultTypes = ValueRange(tiledInits).getTypes();
630 if (tilingStrategy ==
631 ReductionTilingStrategy::PartialReductionOuterReduction) {
632 auto genericOp = GenericOp::create(b, loc, resultTypes, tiledInputs,
633 tiledInits, newMaps, newIteratorTypes);
634 IRMapping mapping;
635 op->getRegion(0).cloneInto(&genericOp.getRegion(),
636 genericOp.getRegion().begin(), mapping);
637 partialReductionOp = genericOp.getOperation();
638 } else {
639 SmallVector<Value> operands = std::move(tiledInputs);
640 llvm::append_range(operands, tiledInits);
641 partialReductionOp = mlir::clone(b, op, resultTypes, operands);
642 }
643 return TilingResult{
644 {partialReductionOp},
645 llvm::map_to_vector(partialReductionOp->getResults(),
646 [](OpResult r) -> Value { return r; }),
647 generatedSlices};
648 }
649
650 FailureOr<MergeResult>
651 mergeReductions(Operation *op, OpBuilder &b, Location loc,
652 ValueRange partialReduce,
653 const SetVector<unsigned> &reductionDims) const {
654 auto linalgOp = cast<LinalgOp>(op);
655 SmallVector<AffineMap> partialReductionMaps =
656 getPartialResultAffineMaps(linalgOp, reductionDims);
657
658 // Permute the reduction dims as permuted by the partial result map.
659 SmallVector<Operation *> mergeOperations;
660 SmallVector<Value> replacements;
661 for (auto [idx, init, partialResult, partialMap] : llvm::enumerate(
662 linalgOp.getDpsInits(), partialReduce, partialReductionMaps)) {
663 unsigned initIdx = idx;
664 // linalg.reduce's iteration space is the tiled result's iteration space
665 // (and not the tiled operation's iteration space). To account for this,
666 // permute the reduction dimensions based on the partial result map of the
667 // tiled result.
668 SmallVector<int64_t> partialReductionDims;
669 for (auto [resultNum, dimExpr] :
670 llvm::enumerate(partialMap.getResults())) {
671 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
672 if (llvm::is_contained(reductionDims, dim)) {
673 partialReductionDims.push_back(resultNum);
674 }
675 }
676
677 auto reduction = linalg::ReduceOp::create(
678 b, loc, partialResult, init, partialReductionDims,
679 [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) {
680 // Get the combiner op.
681 SmallVector<Operation *, 4> combinerOps;
682 matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
683 combinerOps);
684 Operation *clonedReductionOp = b.clone(*combinerOps[0]);
685 // Combine the input at idx and output at numInits + idx.
686 clonedReductionOp->setOperand(0, inputs[0]);
687 clonedReductionOp->setOperand(1, inputs[1]);
688 linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0));
689 });
690
691 mergeOperations.push_back(reduction);
692 replacements.push_back(reduction->getResult(0));
693 }
694
695 return MergeResult{mergeOperations, replacements};
696 }
697
698 LogicalResult getPartialResultTilePosition(
699 Operation *op, OpBuilder &b, unsigned resultNumber,
700 ReductionTilingStrategy tilingStrategy, ArrayRef<OpFoldResult> offsets,
701 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
702 ArrayRef<OpFoldResult> splitReductionIvs,
703 SmallVector<OpFoldResult> &resultOffsets,
704 SmallVector<OpFoldResult> &resultSizes) const {
705 auto linalgOp = cast<LinalgOp>(op);
706 SmallVector<AffineMap> partialReductionMaps =
707 getPartialResultAffineMaps(linalgOp, reductionDims);
708 InitSliceInfo sliceInfo = getInitSliceInfo(
709 b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
710 splitReductionIvs, partialReductionMaps[resultNumber]);
711 std::swap(resultOffsets, sliceInfo.offsets);
712 std::swap(resultSizes, sliceInfo.sizes);
713
714 return success();
715 }
716};
717
718template <typename OpTy>
719static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
720 OpBuilder &builder) {
721 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
722 "applies to only pack or unpack operations");
723 OpBuilder::InsertionGuard g(builder);
724 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
725 : op.getDestRank();
726 OpFoldResult zero = builder.getIndexAttr(0);
727 OpFoldResult one = builder.getIndexAttr(1);
728 ReifiedRankedShapedTypeDims resultShape;
729 (void)op.reifyResultShapes(builder, resultShape);
730 SmallVector<Range> loopBounds(rank);
731 for (auto dim : llvm::seq<int64_t>(0, rank)) {
732 loopBounds[dim].offset = zero;
733 loopBounds[dim].stride = one;
734 loopBounds[dim].size = resultShape[0][dim];
735 }
736 return loopBounds;
737}
738
739static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
741 ArrayRef<int64_t> permutation) {
742 if (permutation.empty())
743 return;
744 applyPermutationToVector<OpFoldResult>(offsets, permutation);
745 applyPermutationToVector<OpFoldResult>(sizes, permutation);
746}
747
748/// Compute the permutation vector to interchange `elements` such that the
749/// elements at positions in `dimsPos` are moved to the positions `[0, ...,
750/// dimsPos.size())` in order.
752computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, int64_t rank) {
753 SmallVector<int64_t> interchangeVector;
754 interchangeVector.reserve(dimsPos.size());
755 // First map dims and their position. For example, dims_pos = [2, 0] will map
756 // to:
757 // [
758 // [ key: 2, value: 0]
759 // [ key: 0, value: 1]
760 // ]
761 // where key is the idx in dims_pos while value its position in dims_pos.
762 DenseMap<int64_t, int64_t> dimsAndPosMapping;
763 for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++)
764 dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
765
766 // Scan the position in order and insert the value in the map
767 // to compute the interchange vector.
768 for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
769 if (dimsAndPosMapping.count(dimsIdx))
770 interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
771 }
772 return interchangeVector;
773}
774
775/// Permute the elements of `vec` starting at position `offset` according to
776/// `interchangeVector`. The permutation maps position `i` in the permuted range
777/// to position `interchangeVector[i]` in the original range. Elements before
778/// `offset` are unchanged.
779///
780/// Example: interchange([a, b, c, d, e], [2, 0, 1], offset=2)
781/// returns [a, b, e, c, d] (permutes the suffix [c, d, e])
782///
783/// Note: This is similar to `applyPermutationToVector` but supports an offset
784/// for permuting a suffix of the vector. It is only used for pack/unpack scalar
785/// implementation where we need to permute inner tile dimensions which are
786/// stored at the end of the index vector.
787template <typename T>
788static SmallVector<T> interchange(ArrayRef<T> elements,
789 ArrayRef<int64_t> interchangeVector,
790 int offset = 0) {
791 SmallVector<T> vec = llvm::to_vector(elements);
792 for (auto [idx, val] : llvm::enumerate(interchangeVector))
793 vec[idx + offset] = elements[val + offset];
794 return vec;
795}
796
797/// Generate the body of the innermost loop of the scalar implementation
798/// of `pack` operation.
799static void generatePackOpScalarImplementationBody(PackOp packOp,
800 OpBuilder &builder,
801 Location loc,
802 ValueRange ivs) {
803 // Note: `ivs` are already in the correct order, possibly interchanged based
804 // on `dims_pos`. However, connecting the loops with the access patterns is
805 // difficult - What is the relation between the position of the tile loop and
806 // the point loop? However, if we interchange `ivs` once more to go to the
807 // canonical blocking format: ABCabc, this connection becomes trivial: Each
808 // point loop is pointLoopsOffset + inputRank away from the tiled loop.
809 ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
810 ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
811
812 SmallVector<Value> interchangedIvs = ivs;
813 SmallVector<int64_t> interchangeVector =
814 computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getSourceRank());
815 interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
816 /*offset=*/packOp.getSourceRank());
817 if (!dimsToOuterBlock.empty()) {
818 interchangeVector =
819 computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getSourceRank());
820 interchangedIvs =
821 interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
822 }
823 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
824 packOp.getDimAndTileMapping();
825 SmallVector<OpFoldResult> sourceIndices;
826 size_t pointLoopsOffset = 0;
827 int64_t sourceRank = packOp.getSourceRank();
828 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
829 if (dimAndTileMapping.contains(dim)) {
830 AffineExpr i, j, tile;
831 bindDims(builder.getContext(), i, j);
832 bindSymbols(builder.getContext(), tile);
834 builder, loc, i * tile + j,
836 interchangedIvs[dim],
837 interchangedIvs[pointLoopsOffset + packOp.getSourceRank()],
838 dimAndTileMapping[dim]});
839 sourceIndices.push_back(sourceIndex);
840 ++pointLoopsOffset;
841 } else {
842 sourceIndices.push_back(interchangedIvs[dim]);
843 }
844 }
845
846 auto createLoad = [&]() -> Value {
847 return memref::LoadOp::create(
848 builder, loc, packOp.getSource(),
849 getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
850 };
851 Value scalar;
852 if (auto paddingValue = packOp.getPaddingValue()) {
853 ArithBuilder arithBuilder(builder, loc);
855 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
856 Value idx =
857 getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
858 Value cond = arithBuilder.slt(
859 idx, createOrFoldDimOp(builder, loc, packOp.getSource(), dim));
860 isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
861 }
862 scalar = scf::IfOp::create(
863 builder, loc, isInBounds, /*thenBuilder=*/
864 [&](OpBuilder &b, Location l) {
865 scf::YieldOp::create(b, l, createLoad());
866 },
867 /*elseBuilder=*/
868 [&](OpBuilder &b, Location l) {
869 scf::YieldOp::create(b, l, paddingValue);
870 })
871 .getResult(0);
872 } else {
873 scalar = createLoad();
874 }
875
876 memref::StoreOp::create(builder, loc, scalar, packOp.getDest(), ivs);
877}
878
879struct PackOpTiling
880 : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
881
882 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
883 // Note that here we only consider untiled dimensions and outer tiled data
884 // dimensions, the inner tiled data dimensions are materialized when
885 // building the body of the operation.
886 auto packOp = cast<PackOp>(op);
887 SmallVector<utils::IteratorType> iteratorTypes(
888 packOp.getSourceRank(), utils::IteratorType::parallel);
889 return iteratorTypes;
890 }
891
892 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
893 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
894 }
895
896 FailureOr<TilingResult>
897 getTiledImplementation(Operation *op, OpBuilder &b,
898 ArrayRef<OpFoldResult> offsets,
899 ArrayRef<OpFoldResult> sizes) const {
900 auto packOp = cast<PackOp>(op);
901 // TODO: Support Memref PackOp. Temporarily return failure.
902 if (!packOp.hasPureTensorSemantics())
903 return failure();
904
905 Location loc = packOp.getLoc();
906
907 // The tiling is applied on interchanged dimensions. We have to undo the
908 // interchange to map sizes and offsets to the original input.
909 int64_t inputRank = packOp.getSourceRank();
910 SmallVector<OpFoldResult> origOffsets(offsets);
911 SmallVector<OpFoldResult> origSizes(sizes);
912 applyPermToRange(origOffsets, origSizes,
913 invertPermutationVector(packOp.getOuterDimsPerm()));
914
915 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
916 packOp.getDimAndTileMapping();
917 SmallVector<OpFoldResult> srcDimValues =
918 tensor::getMixedSizes(b, loc, packOp.getSource());
919 SmallVector<OpFoldResult> inputIndices, inputSizes;
920 for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
921 using AV = affine::AffineValueExpr;
922 affine::AffineBuilder ab(b, loc);
923 AffineExpr dim0, dim1, sym;
924 bindDims(b.getContext(), dim0, dim1);
925 bindSymbols(b.getContext(), sym);
926 if (dimAndTileMapping.count(dim)) {
927 // If the data dimension is tiled, the i-th index is the product of
928 // offset_i and tile_i, and the i-th size is the product of sizes_i and
929 // tile_i.
930 auto avOffset = AV(dim0).bind(origOffsets[dim]);
931 auto avSize = AV(dim0).bind(origSizes[dim]);
932 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
933 inputIndices.push_back(ab.mul(avOffset, avTileSize));
934 inputSizes.push_back(ab.mul(avSize, avTileSize));
935 } else {
936 inputIndices.push_back(origOffsets[dim]);
937 inputSizes.push_back(origSizes[dim]);
938 }
939
940 // Limit the size of the input operand for incomplete tiles.
941 if (packOp.getPaddingValue()) {
942 OpFoldResult dimSize = srcDimValues[dim];
943 auto avDimSize = AV(dim0).bind(dimSize);
944 auto avInputIdx = AV(dim1).bind(inputIndices.back());
945 inputSizes.back() =
946 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
947 }
948 }
949
950 auto oneAttr = b.getI64IntegerAttr(1);
951 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
952
953 SmallVector<Value> tiledOperands;
954 auto sourceSlice = tensor::ExtractSliceOp::create(
955 b, loc, packOp.getSource(), inputIndices, inputSizes, strides);
956 tiledOperands.push_back(sourceSlice);
957
958 SmallVector<OpFoldResult> outputOffsets, outputSizes;
959 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
960 outputSizes)))
961 return {};
962
963 strides.append(packOp.getDestRank() - inputRank, oneAttr);
964 auto outSlice = tensor::ExtractSliceOp::create(
965 b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
966 tiledOperands.push_back(outSlice);
967
968 if (auto val = packOp.getPaddingValue())
969 tiledOperands.push_back(val);
970 for (auto tile : packOp.getInnerTiles())
971 tiledOperands.push_back(tile);
972
973 Operation *tiledPackOp = PackOp::create(
974 b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
975
976 return TilingResult{
977 {tiledPackOp},
978 SmallVector<Value>(tiledPackOp->getResults()),
979 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
980 }
981
982 LogicalResult
983 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
984 ArrayRef<OpFoldResult> offsets,
985 ArrayRef<OpFoldResult> sizes,
986 SmallVector<OpFoldResult> &resultOffsets,
987 SmallVector<OpFoldResult> &resultSizes) const {
988 // The iteration domain is over outer dimensions of packed layout. In this
989 // context, the outer dimensions of `resultOffsets` are `offsets`. The
990 // inner dimensions of `resultOffsets` are zeros because tiling is not
991 // applied to them.
992 auto packOp = cast<PackOp>(op);
993 int64_t inputRank = packOp.getSourceRank();
994 int64_t outputRank = packOp.getDestRank();
995 auto zeroAttr = b.getI64IntegerAttr(0);
996 resultOffsets.assign(offsets.begin(), offsets.end());
997 resultOffsets.append(outputRank - inputRank, zeroAttr);
998
999 ReifiedRankedShapedTypeDims outputShape;
1000 (void)reifyResultShapes(b, packOp, outputShape);
1001 resultSizes.assign(sizes.begin(), sizes.end());
1002 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
1003 resultSizes.push_back(outputShape[0][dataTileDim]);
1004
1005 return success();
1006 }
1007
1008 FailureOr<TilingResult>
1009 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
1010 ArrayRef<OpFoldResult> offsets,
1011 ArrayRef<OpFoldResult> sizes) const {
1012 auto packOp = cast<PackOp>(op);
1013 int64_t numTiles = packOp.getInnerDimsPos().size();
1014
1015 // tensor.pack op is fusible (as a producer) only if full inner tiles are
1016 // iterated or inner dims are not tiled. Otherwise, it will generate a
1017 // sequence of non-trivial ops (for partial tiles).
1018 for (auto offset : offsets.take_back(numTiles))
1019 if (!isZeroInteger(offset))
1020 return failure();
1021
1022 for (auto iter :
1023 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
1024 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
1025 return failure();
1026
1027 FailureOr<TilingResult> tilingResult = getTiledImplementation(
1028 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
1029 if (failed(tilingResult))
1030 return failure();
1031 return tilingResult.value();
1032 }
1033
1034 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
1035 Location loc,
1036 ValueRange ivs) const {
1037 auto packOp = cast<PackOp>(op);
1038 assert(packOp.hasPureBufferSemantics() &&
1039 "expected operation to have buffer semantics");
1040 OpBuilder::InsertionGuard g(builder);
1041 // The `ivs` already represent the position into the output for the non
1042 // data-tile dimensions.
1043 SmallVector<Value> ivVec(ivs);
1044
1045 // Get output shape - for memrefs, get dimensions from dest directly.
1046 SmallVector<OpFoldResult> outputShape;
1047 Value dest = packOp.getDest();
1048 for (auto dim : llvm::seq<int64_t>(0, packOp.getDestRank()))
1049 outputShape.push_back(createOrFoldDimOp(builder, loc, dest, dim));
1050
1051 // Generate the loops that iterate over the data tile.
1052 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
1053 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
1054
1055 // All loops except the innermost are simple loops that just iterate
1056 // over the tile dimensions.
1057 for (auto dataTileDim : llvm::seq<unsigned>(packOp.getSourceRank(),
1058 packOp.getDestRank() - 1)) {
1059 Value ub = getValueOrCreateConstantIndexOp(builder, loc,
1060 outputShape[dataTileDim]);
1061 scf::ForOp loop = scf::ForOp::create(builder, loc, zero, ub, one);
1062 builder.setInsertionPointToStart(loop.getBody());
1063 ivVec.push_back(loop.getInductionVar());
1064 }
1065 // The body of the innermost loops does the actual data movement.
1066 scf::ForOp::create(
1067 builder, loc, zero,
1068 getValueOrCreateConstantIndexOp(builder, loc, outputShape.back()), one,
1069 ValueRange{},
1070 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
1071 ValueRange regionIterArgs) {
1072 ivVec.push_back(iv);
1073 generatePackOpScalarImplementationBody(packOp, bodyBuilder, bodyLoc,
1074 ivVec);
1075 scf::YieldOp::create(bodyBuilder, bodyLoc);
1076 });
1077 return success();
1078 }
1079
1080 /// Method to return the position of iteration domain tile computed by the
1081 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
1082 /// `resultSizes` only cover outer dimensions.
1083 LogicalResult getIterationDomainTileFromOperandTiles(
1084 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1085 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1086 ArrayRef<SmallVector<OpFoldResult>> allSizes,
1087 SmallVectorImpl<OpFoldResult> &resultOffsets,
1088 SmallVectorImpl<OpFoldResult> &resultSizes) const {
1089 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1090 LLVM_DEBUG(
1091 { llvm::dbgs() << "unsupported operands for consumer fusion"; });
1092 return failure();
1093 }
1094
1095 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1096 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1097 auto packOp = cast<PackOp>(op);
1098 Location loc = packOp.getLoc();
1099 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
1100 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1101 packOp.getDimAndTileMapping();
1102 SmallVector<int64_t> outerShapeWithoutTranspose(
1103 packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
1104 if (!packOp.getOuterDimsPerm().empty()) {
1106 outerShapeWithoutTranspose,
1107 invertPermutationVector(packOp.getOuterDimsPerm()));
1108 }
1109 for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
1110 if (dimAndTileMapping.count(dim)) {
1111 FailureOr<int64_t> cstTileSize =
1113 presburger::BoundType::UB, sizes[dim],
1114 /*stopCondition=*/nullptr, /*closedUB=*/true);
1115 std::optional<int64_t> cstInnerSize =
1116 getConstantIntValue(dimAndTileMapping[dim]);
1117
1118 // If a dimension is not tiled, it is always valid to fuse the pack op,
1119 // even if the op has padding semantics. Because it always generates a
1120 // full slice along the dimension. The tile sizes are for unpacked
1121 // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means that the
1122 // dimension is tiled.
1123 // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
1124 // hard check to determine if a dimension is tiled or not.
1125 int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
1126 int64_t destDimSize = outerShapeWithoutTranspose[dim];
1127 bool isTiled = failed(cstTileSize) ||
1128 ShapedType::isDynamic(srcDimSize) ||
1129 cstTileSize.value() < srcDimSize;
1130 if (!isTiled) {
1131 outerDimOffsets.push_back(offsets[dim]);
1132 if (ShapedType::isStatic(destDimSize)) {
1133 outerDimSizes.push_back(b.getIndexAttr(destDimSize));
1134 } else {
1135 outerDimSizes.push_back(
1136 b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
1137 }
1138 continue;
1139 }
1140
1141 // Currently fusing `packOp` as consumer only expects perfect tiling
1142 // scenario because even if without padding semantic, the `packOp` may
1143 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
1144 // where the `tileSize` from operand of `packOp` is 5, which is not
1145 // exactly divided by `innerTile`(=6) of `packOp`. As the result:
1146 // 1. the first slice is extracted from (0) to (4) and inserted into
1147 // (0,0)~(0,4) at first row.
1148 // 2. the second slice is extracted from (5) to (9) and SHOULD BE
1149 // respectively inserted into two rows with different length, including
1150 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
1151 // them, thus adding below constraint to bypass them temporarily. In
1152 // another word, we can only support tiling with consumer if the tile
1153 // size for the producer is a multiple of the inner tile size for the
1154 // packed dimensions at this moment.
1155 if ((failed(cstTileSize) || !cstInnerSize ||
1156 *cstTileSize % *cstInnerSize != 0))
1157 return failure();
1158
1159 using AV = affine::AffineValueExpr;
1160 affine::AffineBuilder ab(b, loc);
1161 AffineExpr dim0, sym;
1162 bindDims(b.getContext(), dim0);
1163 bindSymbols(b.getContext(), sym);
1164 auto avOffset = AV(dim0).bind(offsets[dim]);
1165 auto avSize = AV(dim0).bind(sizes[dim]);
1166 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
1167 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
1168 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
1169 } else {
1170 outerDimOffsets.push_back(offsets[dim]);
1171 outerDimSizes.push_back(sizes[dim]);
1172 }
1173 }
1174 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
1175 resultOffsets = outerDimOffsets;
1176 resultSizes = outerDimSizes;
1177 return success();
1178 }
1179
1180 /// Method to return the tiled implementation of tensor.pack as a consumer.
1181 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
1182 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1183 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1184 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
1185 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1186 LLVM_DEBUG(
1187 { llvm ::dbgs() << "unhandled operands for consumer fusion"; });
1188 return failure();
1189 }
1190
1191 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1192 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1193
1194 auto packOp = cast<PackOp>(op);
1195 // TODO: Support Memref UnPackOp. Temporarily return failure.
1196 if (!packOp.hasPureTensorSemantics())
1197 return failure();
1198
1199 Location loc = packOp.getLoc();
1200
1201 int64_t inputRank = packOp.getSourceRank();
1202 auto oneAttr = b.getI64IntegerAttr(1);
1203 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
1204
1205 SmallVector<Value> tiledOperands;
1206 auto sourceSlice = tensor::ExtractSliceOp::create(
1207 b, loc, packOp.getSource(), offsets, sizes, strides);
1208 tiledOperands.push_back(sourceSlice);
1209
1210 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
1211 if (failed(getIterationDomainTileFromOperandTiles(
1212 op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
1213 outerDimSizes)))
1214 return failure();
1215
1216 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1217 if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
1218 outputOffsets, outputSizes)))
1219 return failure();
1220
1221 strides.append(packOp.getDestRank() - inputRank, oneAttr);
1222 auto outSlice = tensor::ExtractSliceOp::create(
1223 b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
1224 tiledOperands.push_back(outSlice);
1225
1226 if (auto val = packOp.getPaddingValue())
1227 tiledOperands.push_back(val);
1228 for (auto tile : packOp.getInnerTiles())
1229 tiledOperands.push_back(tile);
1230
1231 Operation *tiledPackOp = PackOp::create(
1232 b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
1233
1234 return TilingResult{
1235 {tiledPackOp},
1236 SmallVector<Value>(tiledPackOp->getResults()),
1237 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
1238 }
1239};
1240
1241struct UnpackTileDimInfo {
1242 bool isAlignedToInnerTileSize;
1243 OpFoldResult sourceOffset;
1244 OpFoldResult sourceSize;
1245 OpFoldResult resultOffset;
1246 OpFoldResult destExpandedSize;
1247};
1248
1249/// Returns the needed information for tiling unpack op on `tileDim` with given
1250/// `tileOffset` and `tileSize`. For more details, see the comment of the
1251/// `getTiledImplementation`.
1252static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
1253 int64_t tileDim,
1254 OpFoldResult tileOffset,
1255 OpFoldResult tileSize) {
1256 UnpackTileDimInfo info;
1257 Attribute zeroAttr = b.getIndexAttr(0);
1258 Attribute oneAttr = b.getIndexAttr(1);
1259 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1260 unpackOp.getDimAndTileMapping();
1261 // The dimension is not one of packed data dimension.
1262 if (!dimAndTileMapping.count(tileDim)) {
1263 info.isAlignedToInnerTileSize = true;
1264 info.sourceOffset = tileOffset;
1265 info.sourceSize = tileSize;
1266 info.resultOffset = zeroAttr;
1267 info.destExpandedSize = tileSize;
1268 return info;
1269 }
1270
1271 Location loc = unpackOp.getLoc();
1272 using AV = affine::AffineValueExpr;
1273 affine::AffineBuilder ab(b, loc);
1274 AffineExpr dim0, dim1, sym0;
1275 bindDims(b.getContext(), dim0, dim1);
1276 bindSymbols(b.getContext(), sym0);
1277
1278 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
1279
1280 info.isAlignedToInnerTileSize = false;
1281 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
1282 presburger::BoundType::UB, tileSize,
1283 /*stopCondition=*/nullptr, /*closedUB=*/true);
1284 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
1285 if (!failed(cstSize) && cstInnerSize) {
1286 if (*cstSize % *cstInnerSize == 0)
1287 info.isAlignedToInnerTileSize = true;
1288
1289 // If the tiling size equals to the inner tiling size, the outer dims are
1290 // always 1.
1291 if (*cstInnerSize == *cstSize) {
1292 auto lhs = AV(dim0).bind(tileOffset);
1293 auto rhs = AV(dim1).bind(innerTileSize);
1294 info.sourceOffset = ab.floor(lhs, rhs);
1295 info.sourceSize = oneAttr;
1296 info.resultOffset = zeroAttr;
1297 info.destExpandedSize = tileSize;
1298 return info;
1299 }
1300 }
1301
1302 if (info.isAlignedToInnerTileSize) {
1303 info.sourceOffset =
1304 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
1305 info.resultOffset = zeroAttr;
1306 info.destExpandedSize = tileSize;
1307
1308 // The ceilDiv is needed here because there could be incomplete tile even
1309 // it is perfect tiling cases. E.g.,
1310 // %0 = unpack tensor<33x2xf32> into tensor<64xf32>
1311 // If the tiling size is 32, there will be 3 tiles. Two of them have
1312 // size=32; one of them have size=2. The size is represented using
1313 // affine_min op; we need ceilDiv.
1314 info.sourceSize =
1315 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
1316 return info;
1317 }
1318
1319 affine::DivModValue firstCoord = affine::getDivMod(
1320 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
1321 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
1322 OpFoldResult tileExclusiveBound =
1323 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
1324 affine::DivModValue lastCoord = affine::getDivMod(
1325 b, loc,
1327 b, loc,
1328 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
1329 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
1330
1331 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
1332 AV(dim1).bind(firstCoord.quotient));
1333 info.sourceSize =
1334 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
1335 info.sourceOffset = firstCoord.quotient;
1336 info.resultOffset = firstCoord.remainder;
1337 // Do not create an Affine ops for expanded size because the affine op is too
1338 // complicated which would trigger an issue in affine ops simplification.
1339 info.destExpandedSize = b.createOrFold<arith::MulIOp>(
1340 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
1341 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
1342 return info;
1343}
1344
1345struct UnPackOpTiling
1346 : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> {
1347
1348 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
1349 auto unpackOp = cast<UnPackOp>(op);
1350 SmallVector<utils::IteratorType> iteratorTypes(
1351 unpackOp.getDestRank(), utils::IteratorType::parallel);
1352 return iteratorTypes;
1353 }
1354
1355 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
1356 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
1357 }
1358
1359 /// There are two cases in tiling unpack ops. If the tiling size is aligned to
1360 /// the inner tile size, the corresponding tiles of source are all complete.
1361 /// Otherwise, there are in-complete tiles. We will need to expand the slice
1362 /// of source for getting complete tiles. The tiled unpack op unpacks more
1363 /// data from source, so We'll need an extract_slice op to shift and truncate
1364 /// the output.
1365 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
1366 /// coordinates of second tile (i.e., result[15..31]) are
1367 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
1368 /// row are incomplete tiles. To represent the unpack op, we have to complete
1369 /// the rows. I.e., the input coordinates would start with (1, 0); end with
1370 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
1371 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
1372 /// can get the actual result.
1373 FailureOr<TilingResult>
1374 getTiledImplementation(Operation *op, OpBuilder &b,
1375 ArrayRef<OpFoldResult> offsets,
1376 ArrayRef<OpFoldResult> sizes) const {
1377 auto unpackOp = cast<UnPackOp>(op);
1378 // TODO: Support Memref UnPackOp. Temporarily return failure.
1379 if (!unpackOp.hasPureTensorSemantics())
1380 return failure();
1381
1382 int64_t srcRank = unpackOp.getSourceRank();
1383 int64_t destRank = unpackOp.getDestRank();
1384 int64_t numInnerTiles = srcRank - destRank;
1385 Location loc = unpackOp.getLoc();
1386
1387 // The perfect tiling case indicates that the tiling sizes are multiple of
1388 // inner_tile_size. In this context, no extra data is needed when
1389 // representing the tiled unpack op.
1390 bool isPerfectTilingCase = true;
1391 Attribute oneAttr = b.getIndexAttr(1);
1392 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
1393 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
1394 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
1395 for (auto dim : llvm::seq<int64_t>(0, destRank)) {
1396 UnpackTileDimInfo info =
1397 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
1398 if (!info.isAlignedToInnerTileSize)
1399 isPerfectTilingCase = false;
1400 sliceSrcIndices.push_back(info.sourceOffset);
1401 sliceSrcSizes.push_back(info.sourceSize);
1402 destExpandedSizes.push_back(info.destExpandedSize);
1403 resultOffsetsFromDest.push_back(info.resultOffset);
1404 }
1405
1406 // The tiling is applied on destination dimensions. We have to apply the
1407 // interchange on source dimensions if outer_dims_perm is set.
1408 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
1409 unpackOp.getOuterDimsPerm());
1410 Attribute zeroAttr = b.getIndexAttr(0);
1411 sliceSrcIndices.append(numInnerTiles, zeroAttr);
1412 sliceSrcSizes.append(unpackOp.getMixedTiles());
1413 sliceSrcStrides.append(numInnerTiles, oneAttr);
1414 SmallVector<Operation *> generatedSlices;
1415 tensor::ExtractSliceOp sliceSource = tensor::ExtractSliceOp::create(
1416 b, loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,
1417 sliceSrcStrides);
1418 generatedSlices.push_back(sliceSource);
1419
1420 SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
1421 Value sliceDest;
1422 if (isPerfectTilingCase) {
1423 auto destSliceOp = tensor::ExtractSliceOp::create(
1424 b, loc, unpackOp.getDest(), offsets, sizes, destStrides);
1425 sliceDest = destSliceOp;
1426 generatedSlices.push_back(destSliceOp);
1427 } else {
1428 sliceDest = tensor::EmptyOp::create(
1429 b, loc, destExpandedSizes, unpackOp.getDestType().getElementType());
1430 }
1431
1432 SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
1433 for (auto tile : unpackOp.getInnerTiles())
1434 tiledOperands.push_back(tile);
1435
1436 Operation *tiledUnpackOp = UnPackOp::create(
1437 b, loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
1438
1439 if (isPerfectTilingCase)
1440 return TilingResult{{tiledUnpackOp},
1441 SmallVector<Value>(tiledUnpackOp->getResults()),
1442 generatedSlices};
1443
1444 auto extractSlice = tensor::ExtractSliceOp::create(
1445 b, loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes,
1446 destStrides);
1447 return TilingResult{
1448 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
1449 }
1450
1451 LogicalResult
1452 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
1453 ArrayRef<OpFoldResult> offsets,
1454 ArrayRef<OpFoldResult> sizes,
1455 SmallVector<OpFoldResult> &resultOffsets,
1456 SmallVector<OpFoldResult> &resultSizes) const {
1457 resultOffsets = llvm::to_vector(offsets);
1458 resultSizes = llvm::to_vector(sizes);
1459 return success();
1460 }
1461
1462 FailureOr<TilingResult>
1463 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
1464 ArrayRef<OpFoldResult> offsets,
1465 ArrayRef<OpFoldResult> sizes) const {
1466 FailureOr<TilingResult> tilingResult =
1467 getTiledImplementation(op, b, offsets, sizes);
1468 if (failed(tilingResult))
1469 return failure();
1470 return tilingResult.value();
1471 }
1472
1473 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
1474 Location loc,
1475 ValueRange ivs) const {
1476 auto unpackOp = cast<UnPackOp>(op);
1477 assert(unpackOp.hasPureBufferSemantics() &&
1478 "expected operation to have buffer semantics");
1479 assert(ivs.size() == unpackOp.getDestRank() &&
1480 "number of ivs must match the rank of the output tensor");
1481 OpBuilder::InsertionGuard g(builder);
1482
1483 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1484 unpackOp.getDimAndTileMapping();
1485 // Untiled loops and tile loops induction variables.
1486 SmallVector<Value> inputIvs;
1487 // Point loops induction variables.
1488 SmallVector<Value> inputIvsPointLoops;
1489 inputIvs.reserve(unpackOp.getDestRank());
1490 inputIvsPointLoops.reserve(dimAndTileMapping.size());
1491 for (auto dim : llvm::seq<int64_t>(0, unpackOp.getDestRank())) {
1492 if (dimAndTileMapping.count(dim)) {
1493 affine::DivModValue divMod =
1494 affine::getDivMod(builder, loc, ivs[dim],
1496 builder, loc, dimAndTileMapping[dim]));
1497 inputIvsPointLoops.push_back(divMod.remainder);
1498 inputIvs.push_back(divMod.quotient);
1499 } else {
1500 inputIvs.push_back(ivs[dim]);
1501 }
1502 }
1503
1504 // TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
1505 // `inputIvsPointLoops` and `inputIvs`.
1506 assert(inputIvsPointLoops.size() + inputIvs.size() ==
1507 unpackOp.getSourceRank() &&
1508 "expect same number of induction variables equals to input rank");
1509 // Interchange the point loops induction variables based on `inner_dim_pos`.
1510 ArrayRef<int64_t> innerDims = unpackOp.getInnerDimsPos();
1511 SmallVector<int64_t> interchangeVector =
1512 computeInterchangeFromDimPos(innerDims, unpackOp.getDestRank());
1513 SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
1514 interchangedInputIvsPointLoops = interchange<Value>(
1515 interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
1516 // Interchange the tiled loops induction variables based on
1517 // `outer_dims_perm`.
1518 ArrayRef<int64_t> outerDims = unpackOp.getOuterDimsPerm();
1519 if (!outerDims.empty())
1520 inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0);
1521
1522 llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
1523 Value scalar =
1524 memref::LoadOp::create(builder, loc, unpackOp.getSource(), inputIvs);
1525 memref::StoreOp::create(builder, loc, scalar, unpackOp.getDest(), ivs);
1526 return success();
1527 }
1528
1529 /// Method to return the position of iteration domain tile computed by the
1530 /// tiled operation.
1531 LogicalResult getIterationDomainTileFromOperandTiles(
1532 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1533 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1534 ArrayRef<SmallVector<OpFoldResult>> allSizes,
1535 SmallVectorImpl<OpFoldResult> &resultOffsets,
1536 SmallVectorImpl<OpFoldResult> &resultSizes) const {
1537 if (operandNumbers.size() != 1) {
1538 LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
1539 return failure();
1540 }
1541 auto unPackOp = cast<UnPackOp>(op);
1542 unsigned operandNumber = operandNumbers[0];
1543 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1544 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1545
1546 // If the operand tile is the dest, then no adjustment is needed.
1547 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
1548 resultOffsets = llvm::to_vector(offsets);
1549 resultSizes = llvm::to_vector(sizes);
1550 return success();
1551 }
1552 Location loc = unPackOp.getLoc();
1553
1554 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1555 auto destOffsets = offsets.drop_back(numTiles);
1556 auto destSizes = sizes.drop_back(numTiles);
1557 // The tiling is applied on interchanged dimensions. We have to undo the
1558 // interchange to map sizes and offsets to the original input.
1559 int64_t outputRank = unPackOp.getDestRank();
1560 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1561 if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
1562 return failure();
1563 SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
1564 SmallVector<OpFoldResult> origOffsets(destOffsets);
1565 SmallVector<OpFoldResult> origSizes(destSizes);
1566 applyPermToRange(origOffsets, origSizes,
1567 invertPermutationVector(unPackOp.getOuterDimsPerm()));
1568
1569 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1570 unPackOp.getDimAndTileMapping();
1571
1572 for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
1573 using AV = affine::AffineValueExpr;
1574 affine::AffineBuilder ab(b, loc);
1575 AffineExpr dim0, dim1, sym0;
1576 bindDims(b.getContext(), dim0, dim1);
1577 bindSymbols(b.getContext(), sym0);
1578 if (dimAndTileMapping.count(dim)) {
1579 // If the data dimension is tiled, the i-th index is the product of
1580 // offset_i and tile_i, and the i-th size is the product of sizes_i and
1581 // tile_i. The sizes must be clamped to the sizes of the unpack result.
1582 auto avOffset = AV(dim0).bind(origOffsets[dim]);
1583 auto avSize = AV(dim0).bind(origSizes[dim]);
1584 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
1585 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
1586 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
1587 auto avResultOffset = AV(dim1).bind(resultOffsets.back());
1588 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
1589 ab.sub(avResultSize, avResultOffset)}));
1590 } else {
1591 resultOffsets.push_back(origOffsets[dim]);
1592 resultSizes.push_back(origSizes[dim]);
1593 }
1594 }
1595 return success();
1596 }
1597
1598 /// Method to return the tiled implementation of tensor.unpack as a consumer.
1599 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
1600 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1601 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1602 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
1603 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1604 LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
1605 return failure();
1606 }
1607 auto unPackOp = cast<UnPackOp>(op);
1608 // TODO: Support Memref UnPackOp. Temporarily return failure.
1609 if (!unPackOp.hasPureTensorSemantics())
1610 return failure();
1611
1612 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1613 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1614
1615 // tensor.unpack op is fusible (as a consumer) only if inner dims are not
1616 // tiled.
1617 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1618 for (auto iter :
1619 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
1620 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
1621 return failure();
1622 }
1623
1624 Location loc = unPackOp.getLoc();
1625
1626 // Fetch offset/size for creating the slice of the dest operand of
1627 // unpack op.
1628 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1629 if (failed(getIterationDomainTileFromOperandTiles(
1630 op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
1631 outputSizes)))
1632 return failure();
1633
1634 auto oneAttr = b.getI64IntegerAttr(1);
1635 int64_t outputRank = unPackOp.getDestRank();
1636 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
1637
1638 SmallVector<Value> tiledOperands;
1639 // Create slice of the dest operand.
1640 auto extractDestSlice = tensor::ExtractSliceOp::create(
1641 b, loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
1642 tiledOperands.push_back(extractDestSlice);
1643
1644 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
1645 // Create slice of the source operand.
1646 auto extractSourceSlice = tensor::ExtractSliceOp::create(
1647 b, loc, unPackOp.getSource(), offsets, sizes, strides);
1648 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
1649 for (auto tile : unPackOp.getInnerTiles())
1650 tiledOperands.push_back(tile);
1651
1652 // Create tiled unpack op.
1653 Operation *tiledUnPackOp =
1654 UnPackOp::create(b, loc, TypeRange{extractDestSlice.getType()},
1655 tiledOperands, op->getAttrs());
1656
1657 return TilingResult{{tiledUnPackOp},
1658 SmallVector<Value>(tiledUnPackOp->getResults()),
1659 llvm::to_vector(ArrayRef<Operation *>{
1660 extractSourceSlice, extractDestSlice})};
1661 }
1662};
1663
1664} // namespace
1665
1666template <typename OpType>
1667static void registerOne(MLIRContext *ctx) {
1668 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
1669 OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
1670 *ctx);
1671}
1672
1673/// Variadic helper function.
1674template <typename... OpTypes>
1675static void registerAll(MLIRContext *ctx) {
1676 (registerOne<OpTypes>(ctx), ...);
1677}
1678
1679#define GET_OP_LIST
1680
1682 DialectRegistry &registry) {
1683 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
1685 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1686 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
1688#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1689 >(ctx);
1690 });
1691}
1692
1694 DialectRegistry &registry) {
1695 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
1696 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1697 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
1698 });
1699}
return success()
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
Definition Utils.cpp:76
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, ValueRange ivs, ValueRange argValues)
Method to inline the payload of a linalgOp given the iteration space point and values for the argumen...
static SmallVector< Value > getIndicesForAccess(OpBuilder &b, Location loc, AffineMap indexingMap, ValueRange ivs)
Return the SSA values that represent the data point accessed using a given indexingMap for a given po...
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext * getContext() const
Definition Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
void setOperand(unsigned idx, Value value)
Definition Operation.h:380
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition Region.cpp:70
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:2850
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry &registry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:2872
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
void registerTilingInterfaceExternalModels(DialectRegistry &registry)
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition Utils.cpp:2761
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
Definition Utils.cpp:2614
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1306
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Container for result values of tiling.
Helper struct to build simple AffineValueExprs with minimal type inference support.
Definition Utils.h:377
A struct containg offsets-sizes-strides arguments of the tiled shape.
Definition Utils.h:172
SmallVector< OpFoldResult > sizes
Definition Utils.h:174
SmallVector< OpFoldResult > offsets
Definition Utils.h:173
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.