MLIR 23.0.0git
TilingInterfaceImpl.cpp
Go to the documentation of this file.
1//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/Support/Debug.h"
29#include <optional>
30
31#define DEBUG_TYPE "linalg-tiling-interface-impl"
32
33using namespace mlir;
34using namespace mlir::linalg;
35
36//===----------------------------------------------------------------------===//
37// Utility methods for implementation of Tiling Interface for Linalg ops
38//===----------------------------------------------------------------------===//
39
40/// Return the SSA values that represent the data point accessed using a given
41/// `indexingMap` for a given point in the iteration space represented by `ivs`.
43 AffineMap indexingMap,
44 ValueRange ivs) {
46 indices.reserve(indexingMap.getNumResults());
47 for (auto result : indexingMap.getResults()) {
48 AffineMap m = AffineMap::get(indexingMap.getNumDims(),
49 indexingMap.getNumSymbols(), result);
50 Value v = affine::AffineApplyOp::create(b, loc, m, ivs);
51 indices.push_back(v);
52 }
53 return indices;
54}
55
56/// Method to inline the payload of a `linalgOp` given the iteration space
57/// point and values for the arguments of the payload.
58static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
59 ValueRange ivs, ValueRange argValues) {
60 Block *body = linalgOp.getBlock();
61 IRMapping map;
62 map.map(body->getArguments(), argValues);
63 for (auto &op : body->without_terminator()) {
64 if (auto indexOp = dyn_cast<IndexOp>(&op)) {
65 map.map(indexOp.getResult(), ivs[indexOp.getDim()]);
66 continue;
67 }
68 b.clone(op, map);
69 }
70
71 Operation *terminator = body->getTerminator();
72 Location loc = terminator->getLoc();
73 for (const auto &operand : llvm::enumerate(terminator->getOperands())) {
74 Value toStore = map.lookupOrDefault(operand.value());
75 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
77 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
78 memref::StoreOp::create(b, loc, toStore,
79 linalgOp.getDpsInitOperand(operand.index())->get(),
80 indices);
81 }
82 return success();
83}
84
85//===----------------------------------------------------------------------===//
86// External Model for implementing `TilingInterface` for `LinalgOp`s.
87//===----------------------------------------------------------------------===//
88
89namespace {
90/// External model implementation of TilingInterface for LinalgOps. An external
91/// model implementation is used for now till the use of `TilingInterface` is
92/// on-par with the current Linalg tiling + fusion patterns. Once it is
93/// maybe possible to move this into the op-definition (though there are
94/// advantages to leaving it as an external model)
95template <typename LinalgOpTy>
96struct LinalgOpTilingInterface
97 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
98 LinalgOpTy> {
99 /// Return the loop iterator type.
100 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
101 LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
102 return concreteOp.getIteratorTypesArray();
103 }
104
105 /// Return the iteration domain range.
106 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
107 OpBuilder::InsertionGuard g(b);
108 b.setInsertionPoint(op);
109 Location loc = op->getLoc();
110 LinalgOp linalgOp = cast<LinalgOp>(op);
111 SmallVector<OpFoldResult> allShapesSizes =
112 linalgOp.createFlatListOfOperandDims(b, loc);
113 AffineMap map = linalgOp.getShapesToLoopsMap();
114
115 return llvm::map_to_vector(map.getResults(), [&](AffineExpr loopExpr) {
116 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(b, loc, loopExpr,
117 allShapesSizes);
118 return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
119 });
120 }
121
122 /// Instantiate the tiled implementation of the operation.
123 FailureOr<TilingResult>
126 ArrayRef<OpFoldResult> sizes) const {
127 // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
128 // specified could lead to out of bounds accesses.
129 Location loc = op->getLoc();
130 LinalgOp linalgOp = cast<LinalgOp>(op);
131 SmallVector<Value> valuesToTile = linalgOp->getOperands();
132 SmallVector<Value> tiledOperands = makeTiledShapes(
133 b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
134 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
135 llvm::make_filter_range(
136 tiledOperands,
137 [](Value v) -> bool {
138 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
139 v.getDefiningOp());
140 }),
141 [](Value v) -> Operation * { return v.getDefiningOp(); });
142
143 SmallVector<Type> resultTensorTypes =
144 getTensorOutputTypes(linalgOp, tiledOperands);
145
146 Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
147 offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
148
149 return TilingResult{
150 {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
151 }
152
153 /// Utility to fetch the offsets and sizes when applied as per the indexing
154 /// map of the linalg op. This helps in fusing the linalg op as a consumer of
155 /// a given slice op.
156 static LogicalResult
157 getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
158 ArrayRef<AffineMap> indexingMaps,
161 SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
162 SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
163 DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
164
165 for (auto [indexingMap, offsets, sizes] :
166 llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
167 for (auto [resultExpr, offset, size] :
168 llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
169 auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
170 if (!dimExpr)
171 return failure();
172 unsigned position = dimExpr.getPosition();
173 auto it = mappedOffsets.find(position);
174 if (it != mappedOffsets.end()) {
175 OpFoldResult seenOffset = it->second;
176 OpFoldResult seenSize = mappedSizes.lookup(position);
177 if (seenOffset != offset || seenSize != size) {
178 LLVM_DEBUG({
179 llvm::dbgs() << "inconsistent iteration space mapping from "
180 "offsets/sizes of operands/results";
181 });
182 return failure();
183 }
184 } else {
185 mappedOffsets[position] = offset;
186 mappedSizes[position] = size;
187 }
188 }
189 }
190
191 // Aggregate from the given operand offsets and sizes, or default to
192 // iteration space values.
193 SmallVector<Range> iterationDomain =
194 cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
195 mappedOffsetsVec.resize(iterationDomain.size());
196 mappedSizesVec.resize(iterationDomain.size());
197 for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
198 auto it = mappedOffsets.find(index);
199 if (it != mappedOffsets.end()) {
200 mappedOffsetsVec[index] = it->second;
201 mappedSizesVec[index] = mappedSizes.lookup(index);
202 continue;
203 }
204 mappedOffsetsVec[index] = domain.offset;
205 mappedSizesVec[index] = domain.size;
206 }
207 return success();
208 }
209
210 /// Method to return the position of the result tile computed by the tiled
211 /// operation.
212 LogicalResult getIterationDomainTileFromOperandTiles(
213 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
216 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
217 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
218 auto linalgOp = cast<LinalgOp>(op);
219
220 SmallVector<AffineMap> indexingMaps =
221 llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
222 OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
223 return linalgOp.getMatchingIndexingMap(&opOperand);
224 });
225 if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
226 allSizes, iterDomainOffsets,
227 iterDomainSizes))) {
228 return failure();
229 }
230 return success();
231 }
232
233 /// Return the details of the output tile generated by the tiled
234 /// implementation.
235 LogicalResult
236 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
239 SmallVector<OpFoldResult> &resultOffsets,
240 SmallVector<OpFoldResult> &resultSizes) const {
241 Location loc = op->getLoc();
242 LinalgOp linalgOp = cast<LinalgOp>(op);
243
244 AffineExpr d0;
245 bindDims(b.getContext(), d0);
246 SmallVector<OpFoldResult> subShapeSizes =
247 llvm::map_to_vector(sizes, [&](OpFoldResult ofr) {
248 return affine::makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr);
249 });
250
251 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
253 b, loc, outOperand->get(), sizes,
254 linalgOp.getMatchingIndexingMap(outOperand), offsets,
255 /*ubs*/ {}, subShapeSizes, true);
256 resultOffsets = sliceParams.offsets;
257 resultSizes = sliceParams.sizes;
258 return success();
259 }
260
261 LogicalResult getIterationDomainTileFromResultTile(
262 Operation *op, OpBuilder &b, unsigned resultNumber,
264 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
265 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
266 auto linalgOp = cast<LinalgOp>(op);
267
268 // Check that the indexing map used for the output is a projected
269 // permutation. This could be relaxed with a more general approach that can
270 // map the offsets and sizes from the result to iteration space tiles
271 // (filling in full extent for dimensions not used to access the result).
272 AffineMap indexingMap =
273 linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));
274 if (!indexingMap.isProjectedPermutation()) {
275 return op->emitOpError(
276 "unhandled tiled implementation generation when result is not "
277 "accessed using a permuted projection");
278 }
279
280 SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
281 SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
282 auto status =
283 getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
284 {allSizes}, iterDomainOffsets, iterDomainSizes);
285 (void)status;
286 assert(succeeded(status) && "unexpected error in offset calculation");
287 return success();
288 }
289
290 FailureOr<TilingResult>
291 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
293 ArrayRef<OpFoldResult> sizes) const {
294 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
295 if (failed(getIterationDomainTileFromResultTile(
296 op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
297 return failure();
298 }
299 auto tilingInterfaceOp = cast<TilingInterface>(op);
300 FailureOr<TilingResult> tilingResult =
301 tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
302
303 if (failed(tilingResult))
304 return failure();
305
306 if (tilingResult->tiledOps.size() != 1)
307 return op->emitOpError("failed to generate tiled implementation");
308
309 return TilingResult{
310 tilingResult->tiledOps,
311 SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
312 tilingResult->generatedSlices};
313 }
314
315 /// Method to generate the tiled implementation of an operation from the tile
316 /// of the operand.
317 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
318 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
320 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
321 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
322 if (failed(getIterationDomainTileFromOperandTiles(
323 op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
324 mappedSizes))) {
325 return failure();
326 }
327 return getTiledImplementation(op, b, mappedOffsets, mappedSizes);
328 }
329
330 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
331 Location loc,
332 ValueRange ivs) const {
333 auto linalgOp = cast<LinalgOp>(op);
334 if (!linalgOp.hasPureBufferSemantics())
335 return op->emitOpError("expected operation to have buffer semantics");
336
337 SmallVector<Value> indexedValues;
338 indexedValues.reserve(linalgOp->getNumOperands());
339 Location linalgOpLoc = op->getLoc();
340 /// Load the data corresponding to the block arguments that
341 /// represent input operands.
342 for (OpOperand &operand : linalgOp->getOpOperands()) {
343 if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
344 indexedValues.push_back(nullptr);
345 continue;
346 }
347 if (linalgOp.isScalar(&operand)) {
348 indexedValues.push_back(operand.get());
349 continue;
350 }
352 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
353 Value load =
354 memref::LoadOp::create(builder, linalgOpLoc, operand.get(), indices);
355 indexedValues.push_back(load);
356 }
357
358 /// Inline the op payload and store the result.
359 return inlinePayload(builder, linalgOp, ivs, indexedValues);
360 }
361
362 bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
364 ArrayRef<OpFoldResult> sizes) const {
365 // The verifier gives all the necessary requirements for consumer fusion.
366 return true;
367 }
368
369 bool isOpFusableWithProducerSlices(
370 Operation *op, ArrayRef<unsigned> operandNumbers,
372 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
373
374 auto linalgOp = cast<LinalgOp>(op);
375 SmallVector<AffineMap> indexingMaps =
376 llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
377 OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
378 return linalgOp.getMatchingIndexingMap(&opOperand);
379 });
380 // Check that offsets/sizes are consistent across all operands.
381 OpBuilder b(op);
382 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
383 return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps,
384 allOffsets, allSizes, mappedOffsets,
385 mappedSizes));
386 }
387};
388
389//===----------------------------------------------------------------------===//
390// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
391//===----------------------------------------------------------------------===//
392
393/// In a given set vector, get the position of a particular element.
394std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims,
395 unsigned value) {
396 for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
397 if (reductionDim == value) {
398 return index;
399 }
400 }
401 return std::nullopt;
402}
403
404/// Return an AffineMaps to use for the `outs` operands of the linalg op
405/// generated for partial results. The new AffineMap is the AffineMap of the
406/// untiled op with reduction dimensions appended at end in order in which they
407/// were specified during tiling.
409getPartialResultAffineMaps(LinalgOp linalgOp,
410 const SetVector<unsigned> &reductionDims) {
411 auto partialReductionMaps = llvm::map_to_vector(
412 linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
413 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
414 for (auto redPos : reductionDims) {
415 map =
416 map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
417 map.getNumResults());
418 }
419 return map;
420 });
421 return partialReductionMaps;
422}
423
424struct InitSliceInfo {
425 SmallVector<int64_t> resultShape;
426 SmallVector<OpFoldResult> offsets;
427 SmallVector<OpFoldResult> sizes;
428 SmallVector<OpFoldResult> strides;
429};
430
431/// Return the result shape, offsets, sizes and strides of the slice of the
432/// `initValue` to use as the destination of the partial reduction op generated
433/// with outer reduction strategy.
434static InitSliceInfo getInitSliceInfoForOuterReduction(
435 MLIRContext *context, ArrayRef<OpFoldResult> offsets,
436 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
437 ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap,
438 ArrayRef<OpFoldResult> initOperandShape) {
439 int64_t initRank = partialReductionMap.getNumResults();
440 SmallVector<OpFoldResult> initOffsets, initSizes;
441 Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
442 Attribute one = IntegerAttr::get(IndexType::get(context), 1);
443 SmallVector<OpFoldResult> initStrides(initRank, one);
444 for (auto [resultIdx, dimExpr] :
445 llvm::enumerate(partialReductionMap.getResults())) {
446 if (isa<AffineConstantExpr>(dimExpr)) {
447 // A constant index in the output map accesses a fixed position; keep
448 // the full output dimension to match the original output operand shape.
449 initOffsets.push_back(zero);
450 initSizes.push_back(initOperandShape[resultIdx]);
451 continue;
452 }
453 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
454 if (reductionDims.contains(dim)) {
455 initOffsets.push_back(zero);
456 } else {
457 initOffsets.push_back(offsets[dim]);
458 }
459 initSizes.push_back(sizes[dim]);
460 }
461 SmallVector<int64_t> resultShape;
462 std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
463 return {resultShape, initOffsets, initSizes, initStrides};
464}
465
466/// Return the result shape, offsets, sizes and strides of the slice of the
467/// `initValue` to use as destination of the partial reduction op generated with
468/// outer parallel strategy.
469static InitSliceInfo getInitSliceInfoForOuterParallel(
470 MLIRContext *context, ArrayRef<OpFoldResult> offsets,
471 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
472 ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap,
473 ArrayRef<OpFoldResult> initOperandShape) {
474 int64_t initRank = partialReductionMap.getNumResults();
475 SmallVector<OpFoldResult> initOffsets, initSizes;
476 Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
477 Attribute one = IntegerAttr::get(IndexType::get(context), 1);
478 SmallVector<OpFoldResult> initStrides(initRank, one);
479 SmallVector<OpFoldResult> resultShape;
480 for (auto [resultIdx, dimExpr] :
481 llvm::enumerate(partialReductionMap.getResults())) {
482 if (isa<AffineConstantExpr>(dimExpr)) {
483 // A constant index accesses a fixed position; keep the full output
484 // dimension to match the original output operand shape.
485 initOffsets.push_back(zero);
486 initSizes.push_back(initOperandShape[resultIdx]);
487 resultShape.push_back(initOperandShape[resultIdx]);
488 continue;
489 }
490 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
491 if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
492 initOffsets.push_back(splitReductionIvs[dimPos.value()]);
493 initSizes.push_back(one);
494 } else {
495 initOffsets.push_back(offsets[dim]);
496 initSizes.push_back(sizes[dim]);
497 resultShape.push_back(sizes[dim]);
498 }
499 }
500 SmallVector<int64_t> staticShapes;
501 std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape);
502 return {staticShapes, initOffsets, initSizes, initStrides};
503}
504
505/// Return the result shape, offsets, sizes and strides of the slice of the
506/// `initValue` to use as destination of the partial reduction op.
507static InitSliceInfo getInitSliceInfo(MLIRContext *context,
511 const SetVector<unsigned> &reductionDims,
512 ArrayRef<OpFoldResult> splitReductionIvs,
513 AffineMap partialReductionMap,
514 ArrayRef<OpFoldResult> initOperandShape) {
516 return getInitSliceInfoForOuterReduction(
517 context, offsets, sizes, reductionDims, splitReductionIvs,
518 partialReductionMap, initOperandShape);
519 }
521 "unexpected ReductionTilingStrategy");
522 return getInitSliceInfoForOuterParallel(
523 context, offsets, sizes, reductionDims, splitReductionIvs,
524 partialReductionMap, initOperandShape);
525}
526
527/// External model implementation of PartialReductionInterface for
528/// LinalgOps.
529template <typename LinalgOpTy>
530struct LinalgOpPartialReductionInterface
531 : public PartialReductionOpInterface::ExternalModel<
532 LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
533 FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
534 Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
535 const SetVector<unsigned> &reductionDims) const {
536 auto linalgOp = cast<LinalgOp>(op);
537
538 OpBuilder::InsertionGuard guard(b);
539 if (linalgOp.hasPureBufferSemantics())
540 return op->emitOpError("expected operation to have tensor semantics");
541
542 SmallVector<AffineMap> partialResultMaps =
543 getPartialResultAffineMaps(linalgOp, reductionDims);
544
545 SmallVector<Value> inits;
546 for (auto [initIdx, result, partialMap] :
547 llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
548 SmallVector<Operation *, 4> combinerOps;
549 if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
550 combinerOps) ||
551 combinerOps.size() != 1)
552 return op->emitOpError("Failed to anaysis the reduction operation.");
553
554 Operation *reductionOp = combinerOps[0];
555 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
556 if (!identity.has_value())
557 return op->emitOpError(
558 "Failed to get an identity value for the reduction operation.");
559
560 // Append the new partial result dimensions.
561 SmallVector<OpFoldResult> partialResultShape;
562 Value initValue = linalgOp.getDpsInits()[initIdx];
563 SmallVector<OpFoldResult> initShape =
564 tensor::getMixedSizes(b, loc, initValue);
565 for (auto [resultIdx, dimExpr] :
566 llvm::enumerate(partialMap.getResults())) {
567 if (isa<AffineConstantExpr>(dimExpr)) {
568 // A constant index in the output map accesses a fixed position; use
569 // the actual output dimension size (not a hardcoded 1).
570 partialResultShape.push_back(initShape[resultIdx]);
571 continue;
572 }
573 auto dim = cast<AffineDimExpr>(dimExpr);
574 partialResultShape.push_back(sizes[dim.getPosition()]);
575 }
576
577 Type elType = getElementTypeOrSelf(result.getType());
578 Value emptyTensor =
579 tensor::EmptyOp::create(b, loc, partialResultShape, elType);
580 Value constantOp = arith::ConstantOp::create(b, loc, *identity);
581 auto identityTensor =
582 linalg::FillOp::create(b, loc, constantOp, emptyTensor);
583 inits.push_back(identityTensor.getResult(0));
584 }
585
586 return inits;
587 }
588
589 FailureOr<TilingResult>
590 tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
591 ReductionTilingStrategy tilingStrategy,
592 ValueRange init, ArrayRef<OpFoldResult> offsets,
593 ArrayRef<OpFoldResult> sizes,
594 const SetVector<unsigned> &reductionDims,
595 ArrayRef<OpFoldResult> splitReductionIvs) const {
596 OpBuilder::InsertionGuard guard(b);
597 auto linalgOp = cast<LinalgOp>(op);
598
599 SmallVector<AffineMap> partialReductionMaps =
600 getPartialResultAffineMaps(linalgOp, reductionDims);
601
602 // Step 1. Extend init maps to have reduction dimension dims, since we
603 // are converting them to parallel dimensions.
604 SmallVector<AffineMap> newInitMaps;
605 if (tilingStrategy ==
606 ReductionTilingStrategy::PartialReductionOuterReduction) {
607 newInitMaps = llvm::to_vector(partialReductionMaps);
608 } else {
609 newInitMaps = llvm::map_to_vector(
610 linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
611 return linalgOp.getMatchingIndexingMap(&opOperand);
612 });
613 }
614
615 // Step 2a: Extract a slice of the input operands.
616 SmallVector<Value> tiledInputs = makeTiledShapes(
617 b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
618 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
619 llvm::make_filter_range(
620 tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
621 [](Value v) -> Operation * { return v.getDefiningOp(); });
622
623 // Step 2b: Extract a slice of the init operands.
624 SmallVector<Value, 1> tiledInits;
625 for (auto [partialReductionMap, valueToTile, initOperandValue] :
626 llvm::zip_equal(partialReductionMaps, init, linalgOp.getDpsInits())) {
627 // Compute the actual shape of the original init operand for handling
628 // constant expressions in the partial reduction map.
629 SmallVector<OpFoldResult> initOperandShape =
630 tensor::getMixedSizes(b, loc, initOperandValue);
631 InitSliceInfo sliceInfo = getInitSliceInfo(
632 b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
633 splitReductionIvs, partialReductionMap, initOperandShape);
634 auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
635 RankedTensorType sliceResultType = RankedTensorType::get(
636 sliceInfo.resultShape, valueToTileType.getElementType(),
637 valueToTileType.getEncoding());
638 auto sliceOp = tensor::ExtractSliceOp::create(
639 b, loc, sliceResultType, valueToTile, sliceInfo.offsets,
640 sliceInfo.sizes, sliceInfo.strides);
641 tiledInits.push_back(sliceOp.getResult());
642 generatedSlices.push_back(sliceOp);
643 }
644
645 // Update the indexing maps.
646 SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
647 for (auto [initOperand, newInitMap] :
648 llvm::zip_equal(linalgOp.getDpsInitsMutable(), newInitMaps)) {
649 int mapIdx = linalgOp.getIndexingMapIndex(&initOperand);
650 newMaps[mapIdx] = newInitMap;
651 }
652
653 // Step 3. Change the reduction dim iterator types.
654 SmallVector<utils::IteratorType> newIteratorTypes =
655 linalgOp.getIteratorTypesArray();
656 if (tilingStrategy ==
657 ReductionTilingStrategy::PartialReductionOuterReduction) {
658 for (int dim : reductionDims)
659 newIteratorTypes[dim] = utils::IteratorType::parallel;
660 }
661
662 // Step 4. Create the new generic op.
663 Operation *partialReductionOp;
664 auto resultTypes = ValueRange(tiledInits).getTypes();
665 if (tilingStrategy ==
666 ReductionTilingStrategy::PartialReductionOuterReduction) {
667 auto genericOp = GenericOp::create(b, loc, resultTypes, tiledInputs,
668 tiledInits, newMaps, newIteratorTypes);
669 IRMapping mapping;
670 op->getRegion(0).cloneInto(&genericOp.getRegion(),
671 genericOp.getRegion().begin(), mapping);
672 offsetIndices(b, genericOp, offsets);
673 partialReductionOp = genericOp.getOperation();
674 } else {
675 SmallVector<Value> operands = std::move(tiledInputs);
676 llvm::append_range(operands, tiledInits);
677 partialReductionOp = mlir::clone(b, op, resultTypes, operands);
678 offsetIndices(b, cast<LinalgOp>(partialReductionOp), offsets);
679 }
680 return TilingResult{
681 {partialReductionOp},
682 llvm::map_to_vector(partialReductionOp->getResults(),
683 [](OpResult r) -> Value { return r; }),
684 generatedSlices};
685 }
686
687 FailureOr<MergeResult>
688 mergeReductions(Operation *op, OpBuilder &b, Location loc,
689 ValueRange partialReduce,
690 const SetVector<unsigned> &reductionDims) const {
691 auto linalgOp = cast<LinalgOp>(op);
692 SmallVector<AffineMap> partialReductionMaps =
693 getPartialResultAffineMaps(linalgOp, reductionDims);
694
695 // Permute the reduction dims as permuted by the partial result map.
696 SmallVector<Operation *> mergeOperations;
697 SmallVector<Value> replacements;
698 for (auto [idx, init, partialResult, partialMap] : llvm::enumerate(
699 linalgOp.getDpsInits(), partialReduce, partialReductionMaps)) {
700 unsigned initIdx = idx;
701 // linalg.reduce's iteration space is the tiled result's iteration space
702 // (and not the tiled operation's iteration space). To account for this,
703 // permute the reduction dimensions based on the partial result map of the
704 // tiled result.
705 SmallVector<int64_t> partialReductionDims;
706 for (auto [resultNum, dimExpr] :
707 llvm::enumerate(partialMap.getResults())) {
708 if (isa<AffineConstantExpr>(dimExpr))
709 continue; // Constant dims are never reduction dims.
710 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
711 if (llvm::is_contained(reductionDims, dim)) {
712 partialReductionDims.push_back(resultNum);
713 }
714 }
715
716 auto reduction = linalg::ReduceOp::create(
717 b, loc, partialResult, init, partialReductionDims,
718 [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) {
719 // Get the combiner op.
720 SmallVector<Operation *, 4> combinerOps;
721 matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
722 combinerOps);
723 Operation *clonedReductionOp = b.clone(*combinerOps[0]);
724 // Combine the input at idx and output at numInits + idx.
725 clonedReductionOp->setOperand(0, inputs[0]);
726 clonedReductionOp->setOperand(1, inputs[1]);
727 linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0));
728 });
729
730 mergeOperations.push_back(reduction);
731 replacements.push_back(reduction->getResult(0));
732 }
733
734 return MergeResult{mergeOperations, replacements};
735 }
736
737 LogicalResult getPartialResultTilePosition(
738 Operation *op, OpBuilder &b, unsigned resultNumber,
739 ReductionTilingStrategy tilingStrategy, ArrayRef<OpFoldResult> offsets,
740 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
741 ArrayRef<OpFoldResult> splitReductionIvs,
742 SmallVector<OpFoldResult> &resultOffsets,
743 SmallVector<OpFoldResult> &resultSizes) const {
744 auto linalgOp = cast<LinalgOp>(op);
745 SmallVector<AffineMap> partialReductionMaps =
746 getPartialResultAffineMaps(linalgOp, reductionDims);
747 // Compute the actual shape of the init operand for handling constant
748 // expressions in the partial reduction map.
749 Value initOperandValue = linalgOp.getDpsInits()[resultNumber];
750 Location loc = op->getLoc();
751 SmallVector<OpFoldResult> initOperandShape =
752 tensor::getMixedSizes(b, loc, initOperandValue);
753 InitSliceInfo sliceInfo =
754 getInitSliceInfo(b.getContext(), tilingStrategy, offsets, sizes,
755 reductionDims, splitReductionIvs,
756 partialReductionMaps[resultNumber], initOperandShape);
757 std::swap(resultOffsets, sliceInfo.offsets);
758 std::swap(resultSizes, sliceInfo.sizes);
759
760 return success();
761 }
762};
763
764template <typename OpTy>
765static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
766 OpBuilder &builder) {
767 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
768 "applies to only pack or unpack operations");
769 OpBuilder::InsertionGuard g(builder);
770 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
771 : op.getDestRank();
772 OpFoldResult zero = builder.getIndexAttr(0);
773 OpFoldResult one = builder.getIndexAttr(1);
774 ReifiedRankedShapedTypeDims resultShape;
775 (void)op.reifyResultShapes(builder, resultShape);
776 SmallVector<Range> loopBounds(rank);
777 for (auto dim : llvm::seq<int64_t>(0, rank)) {
778 loopBounds[dim].offset = zero;
779 loopBounds[dim].stride = one;
780 loopBounds[dim].size = resultShape[0][dim];
781 }
782 return loopBounds;
783}
784
785static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
787 ArrayRef<int64_t> permutation) {
788 if (permutation.empty())
789 return;
790 applyPermutationToVector<OpFoldResult>(offsets, permutation);
791 applyPermutationToVector<OpFoldResult>(sizes, permutation);
792}
793
794/// Compute the permutation vector to interchange `elements` such that the
795/// elements at positions in `dimsPos` are moved to the positions `[0, ...,
796/// dimsPos.size())` in order.
798computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, int64_t rank) {
799 SmallVector<int64_t> interchangeVector;
800 interchangeVector.reserve(dimsPos.size());
801 // First map dims and their position. For example, dims_pos = [2, 0] will map
802 // to:
803 // [
804 // [ key: 2, value: 0]
805 // [ key: 0, value: 1]
806 // ]
807 // where key is the idx in dims_pos while value its position in dims_pos.
808 DenseMap<int64_t, int64_t> dimsAndPosMapping;
809 for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++)
810 dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
811
812 // Scan the position in order and insert the value in the map
813 // to compute the interchange vector.
814 for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
815 if (dimsAndPosMapping.count(dimsIdx))
816 interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
817 }
818 return interchangeVector;
819}
820
821/// Permute the elements of `vec` starting at position `offset` according to
822/// `interchangeVector`. The permutation maps position `i` in the permuted range
823/// to position `interchangeVector[i]` in the original range. Elements before
824/// `offset` are unchanged.
825///
826/// Example: interchange([a, b, c, d, e], [2, 0, 1], offset=2)
827/// returns [a, b, e, c, d] (permutes the suffix [c, d, e])
828///
829/// Note: This is similar to `applyPermutationToVector` but supports an offset
830/// for permuting a suffix of the vector. It is only used for pack/unpack scalar
831/// implementation where we need to permute inner tile dimensions which are
832/// stored at the end of the index vector.
833template <typename T>
834static SmallVector<T> interchange(ArrayRef<T> elements,
835 ArrayRef<int64_t> interchangeVector,
836 int offset = 0) {
837 SmallVector<T> vec = llvm::to_vector(elements);
838 for (auto [idx, val] : llvm::enumerate(interchangeVector))
839 vec[idx + offset] = elements[val + offset];
840 return vec;
841}
842
843/// Generate the body of the innermost loop of the scalar implementation
844/// of `pack` operation.
845static void generatePackOpScalarImplementationBody(PackOp packOp,
846 OpBuilder &builder,
847 Location loc,
848 ValueRange ivs) {
849 // Note: `ivs` are already in the correct order, possibly interchanged based
850 // on `dims_pos`. However, connecting the loops with the access patterns is
851 // difficult - What is the relation between the position of the tile loop and
852 // the point loop? However, if we interchange `ivs` once more to go to the
853 // canonical blocking format: ABCabc, this connection becomes trivial: Each
854 // point loop is pointLoopsOffset + inputRank away from the tiled loop.
855 ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
856 ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
857
858 SmallVector<Value> interchangedIvs = ivs;
859 SmallVector<int64_t> interchangeVector =
860 computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getSourceRank());
861 interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
862 /*offset=*/packOp.getSourceRank());
863 if (!dimsToOuterBlock.empty()) {
864 interchangeVector =
865 computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getSourceRank());
866 interchangedIvs =
867 interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
868 }
869 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
870 packOp.getDimAndTileMapping();
871 SmallVector<OpFoldResult> sourceIndices;
872 size_t pointLoopsOffset = 0;
873 int64_t sourceRank = packOp.getSourceRank();
874 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
875 if (dimAndTileMapping.contains(dim)) {
876 AffineExpr i, j, tile;
877 bindDims(builder.getContext(), i, j);
878 bindSymbols(builder.getContext(), tile);
880 builder, loc, i * tile + j,
882 interchangedIvs[dim],
883 interchangedIvs[pointLoopsOffset + packOp.getSourceRank()],
884 dimAndTileMapping[dim]});
885 sourceIndices.push_back(sourceIndex);
886 ++pointLoopsOffset;
887 } else {
888 sourceIndices.push_back(interchangedIvs[dim]);
889 }
890 }
891
892 auto createLoad = [&]() -> Value {
893 return memref::LoadOp::create(
894 builder, loc, packOp.getSource(),
895 getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
896 };
897 Value scalar;
898 if (auto paddingValue = packOp.getPaddingValue()) {
899 ArithBuilder arithBuilder(builder, loc);
901 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
902 Value idx =
903 getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
904 Value cond = arithBuilder.slt(
905 idx, createOrFoldDimOp(builder, loc, packOp.getSource(), dim));
906 isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
907 }
908 scalar = scf::IfOp::create(
909 builder, loc, isInBounds, /*thenBuilder=*/
910 [&](OpBuilder &b, Location l) {
911 scf::YieldOp::create(b, l, createLoad());
912 },
913 /*elseBuilder=*/
914 [&](OpBuilder &b, Location l) {
915 scf::YieldOp::create(b, l, paddingValue);
916 })
917 .getResult(0);
918 } else {
919 scalar = createLoad();
920 }
921
922 memref::StoreOp::create(builder, loc, scalar, packOp.getDest(), ivs);
923}
924
925struct PackOpTiling
926 : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
927
928 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
929 // Note that here we only consider untiled dimensions and outer tiled data
930 // dimensions, the inner tiled data dimensions are materialized when
931 // building the body of the operation.
932 auto packOp = cast<PackOp>(op);
933 SmallVector<utils::IteratorType> iteratorTypes(
934 packOp.getSourceRank(), utils::IteratorType::parallel);
935 return iteratorTypes;
936 }
937
938 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
939 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
940 }
941
942 FailureOr<TilingResult>
943 getTiledImplementation(Operation *op, OpBuilder &b,
944 ArrayRef<OpFoldResult> offsets,
945 ArrayRef<OpFoldResult> sizes) const {
946 auto packOp = cast<PackOp>(op);
947 // TODO: Support Memref PackOp. Temporarily return failure.
948 if (!packOp.hasPureTensorSemantics())
949 return failure();
950
951 Location loc = packOp.getLoc();
952
953 // The tiling is applied on interchanged dimensions. We have to undo the
954 // interchange to map sizes and offsets to the original input.
955 int64_t inputRank = packOp.getSourceRank();
956 SmallVector<OpFoldResult> origOffsets(offsets);
957 SmallVector<OpFoldResult> origSizes(sizes);
958 applyPermToRange(origOffsets, origSizes,
959 invertPermutationVector(packOp.getOuterDimsPerm()));
960
961 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
962 packOp.getDimAndTileMapping();
963 SmallVector<OpFoldResult> srcDimValues =
964 tensor::getMixedSizes(b, loc, packOp.getSource());
965 SmallVector<OpFoldResult> inputIndices, inputSizes;
966 for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
967 using AV = affine::AffineValueExpr;
968 affine::AffineBuilder ab(b, loc);
969 AffineExpr dim0, dim1, sym;
970 bindDims(b.getContext(), dim0, dim1);
971 bindSymbols(b.getContext(), sym);
972 if (dimAndTileMapping.count(dim)) {
973 // If the data dimension is tiled, the i-th index is the product of
974 // offset_i and tile_i, and the i-th size is the product of sizes_i and
975 // tile_i.
976 auto avOffset = AV(dim0).bind(origOffsets[dim]);
977 auto avSize = AV(dim0).bind(origSizes[dim]);
978 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
979 inputIndices.push_back(ab.mul(avOffset, avTileSize));
980 inputSizes.push_back(ab.mul(avSize, avTileSize));
981 } else {
982 inputIndices.push_back(origOffsets[dim]);
983 inputSizes.push_back(origSizes[dim]);
984 }
985
986 // Limit the size of the input operand for incomplete tiles.
987 if (packOp.getPaddingValue()) {
988 OpFoldResult dimSize = srcDimValues[dim];
989 auto avDimSize = AV(dim0).bind(dimSize);
990 auto avInputIdx = AV(dim1).bind(inputIndices.back());
991 inputSizes.back() =
992 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
993 }
994 }
995
996 auto oneAttr = b.getI64IntegerAttr(1);
997 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
998
999 SmallVector<Value> tiledOperands;
1000 auto sourceSlice = tensor::ExtractSliceOp::create(
1001 b, loc, packOp.getSource(), inputIndices, inputSizes, strides);
1002 tiledOperands.push_back(sourceSlice);
1003
1004 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1005 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
1006 outputSizes)))
1007 return {};
1008
1009 strides.append(packOp.getDestRank() - inputRank, oneAttr);
1010 auto outSlice = tensor::ExtractSliceOp::create(
1011 b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
1012 tiledOperands.push_back(outSlice);
1013
1014 if (auto val = packOp.getPaddingValue())
1015 tiledOperands.push_back(val);
1016 for (auto tile : packOp.getInnerTiles())
1017 tiledOperands.push_back(tile);
1018
1019 Operation *tiledPackOp = PackOp::create(
1020 b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
1021
1022 return TilingResult{
1023 {tiledPackOp},
1024 SmallVector<Value>(tiledPackOp->getResults()),
1025 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
1026 }
1027
1028 LogicalResult
1029 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
1030 ArrayRef<OpFoldResult> offsets,
1031 ArrayRef<OpFoldResult> sizes,
1032 SmallVector<OpFoldResult> &resultOffsets,
1033 SmallVector<OpFoldResult> &resultSizes) const {
1034 // The iteration domain is over outer dimensions of packed layout. In this
1035 // context, the outer dimensions of `resultOffsets` are `offsets`. The
1036 // inner dimensions of `resultOffsets` are zeros because tiling is not
1037 // applied to them.
1038 auto packOp = cast<PackOp>(op);
1039 int64_t inputRank = packOp.getSourceRank();
1040 int64_t outputRank = packOp.getDestRank();
1041 auto zeroAttr = b.getI64IntegerAttr(0);
1042 resultOffsets.assign(offsets.begin(), offsets.end());
1043 resultOffsets.append(outputRank - inputRank, zeroAttr);
1044
1045 ReifiedRankedShapedTypeDims outputShape;
1046 (void)reifyResultShapes(b, packOp, outputShape);
1047 resultSizes.assign(sizes.begin(), sizes.end());
1048 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
1049 resultSizes.push_back(outputShape[0][dataTileDim]);
1050
1051 return success();
1052 }
1053
1054 FailureOr<TilingResult>
1055 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
1056 ArrayRef<OpFoldResult> offsets,
1057 ArrayRef<OpFoldResult> sizes) const {
1058 auto packOp = cast<PackOp>(op);
1059 int64_t numTiles = packOp.getInnerDimsPos().size();
1060
1061 // tensor.pack op is fusible (as a producer) only if full inner tiles are
1062 // iterated or inner dims are not tiled. Otherwise, it will generate a
1063 // sequence of non-trivial ops (for partial tiles).
1064 for (auto offset : offsets.take_back(numTiles))
1065 if (!isZeroInteger(offset))
1066 return failure();
1067
1068 for (auto iter :
1069 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
1070 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
1071 return failure();
1072
1073 FailureOr<TilingResult> tilingResult = getTiledImplementation(
1074 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
1075 if (failed(tilingResult))
1076 return failure();
1077 return tilingResult.value();
1078 }
1079
1080 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
1081 Location loc,
1082 ValueRange ivs) const {
1083 auto packOp = cast<PackOp>(op);
1084 assert(packOp.hasPureBufferSemantics() &&
1085 "expected operation to have buffer semantics");
1086 OpBuilder::InsertionGuard g(builder);
1087 // The `ivs` already represent the position into the output for the non
1088 // data-tile dimensions.
1089 SmallVector<Value> ivVec(ivs);
1090
1091 // Get output shape - for memrefs, get dimensions from dest directly.
1092 SmallVector<OpFoldResult> outputShape;
1093 Value dest = packOp.getDest();
1094 for (auto dim : llvm::seq<int64_t>(0, packOp.getDestRank()))
1095 outputShape.push_back(createOrFoldDimOp(builder, loc, dest, dim));
1096
1097 // Generate the loops that iterate over the data tile.
1098 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
1099 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
1100
1101 // All loops except the innermost are simple loops that just iterate
1102 // over the tile dimensions.
1103 for (auto dataTileDim : llvm::seq<unsigned>(packOp.getSourceRank(),
1104 packOp.getDestRank() - 1)) {
1105 Value ub = getValueOrCreateConstantIndexOp(builder, loc,
1106 outputShape[dataTileDim]);
1107 scf::ForOp loop = scf::ForOp::create(builder, loc, zero, ub, one);
1108 builder.setInsertionPointToStart(loop.getBody());
1109 ivVec.push_back(loop.getInductionVar());
1110 }
1111 // The body of the innermost loops does the actual data movement.
1112 scf::ForOp::create(
1113 builder, loc, zero,
1114 getValueOrCreateConstantIndexOp(builder, loc, outputShape.back()), one,
1115 ValueRange{},
1116 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
1117 ValueRange regionIterArgs) {
1118 ivVec.push_back(iv);
1119 generatePackOpScalarImplementationBody(packOp, bodyBuilder, bodyLoc,
1120 ivVec);
1121 scf::YieldOp::create(bodyBuilder, bodyLoc);
1122 });
1123 return success();
1124 }
1125
1126 /// Method to return the position of iteration domain tile computed by the
1127 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
1128 /// `resultSizes` only cover outer dimensions.
1129 LogicalResult getIterationDomainTileFromOperandTiles(
1130 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1131 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1132 ArrayRef<SmallVector<OpFoldResult>> allSizes,
1133 SmallVectorImpl<OpFoldResult> &resultOffsets,
1134 SmallVectorImpl<OpFoldResult> &resultSizes) const {
1135 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1136 LLVM_DEBUG(
1137 { llvm::dbgs() << "unsupported operands for consumer fusion"; });
1138 return failure();
1139 }
1140
1141 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1142 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1143 auto packOp = cast<PackOp>(op);
1144 Location loc = packOp.getLoc();
1145 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
1146 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1147 packOp.getDimAndTileMapping();
1148 SmallVector<int64_t> outerShapeWithoutTranspose(
1149 packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
1150 if (!packOp.getOuterDimsPerm().empty()) {
1152 outerShapeWithoutTranspose,
1153 invertPermutationVector(packOp.getOuterDimsPerm()));
1154 }
1155 for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
1156 if (dimAndTileMapping.count(dim)) {
1157 FailureOr<int64_t> cstTileSize =
1159 presburger::BoundType::UB, sizes[dim],
1160 /*stopCondition=*/nullptr, /*closedUB=*/true);
1161 std::optional<int64_t> cstInnerSize =
1162 getConstantIntValue(dimAndTileMapping[dim]);
1163
1164 // If a dimension is not tiled, it is always valid to fuse the pack op,
1165 // even if the op has padding semantics. Because it always generates a
1166 // full slice along the dimension. The tile sizes are for unpacked
1167 // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means that the
1168 // dimension is tiled.
1169 // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
1170 // hard check to determine if a dimension is tiled or not.
1171 int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
1172 int64_t destDimSize = outerShapeWithoutTranspose[dim];
1173 bool isTiled = failed(cstTileSize) ||
1174 ShapedType::isDynamic(srcDimSize) ||
1175 cstTileSize.value() < srcDimSize;
1176 if (!isTiled) {
1177 outerDimOffsets.push_back(offsets[dim]);
1178 if (ShapedType::isStatic(destDimSize)) {
1179 outerDimSizes.push_back(b.getIndexAttr(destDimSize));
1180 } else {
1181 outerDimSizes.push_back(
1182 b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
1183 }
1184 continue;
1185 }
1186
1187 // Currently fusing `packOp` as consumer only expects perfect tiling
1188 // scenario because even if without padding semantic, the `packOp` may
1189 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
1190 // where the `tileSize` from operand of `packOp` is 5, which is not
1191 // exactly divided by `innerTile`(=6) of `packOp`. As the result:
1192 // 1. the first slice is extracted from (0) to (4) and inserted into
1193 // (0,0)~(0,4) at first row.
1194 // 2. the second slice is extracted from (5) to (9) and SHOULD BE
1195 // respectively inserted into two rows with different length, including
1196 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
1197 // them, thus adding below constraint to bypass them temporarily. In
1198 // another word, we can only support tiling with consumer if the tile
1199 // size for the producer is a multiple of the inner tile size for the
1200 // packed dimensions at this moment.
1201 if ((failed(cstTileSize) || !cstInnerSize ||
1202 *cstTileSize % *cstInnerSize != 0))
1203 return failure();
1204
1205 using AV = affine::AffineValueExpr;
1206 affine::AffineBuilder ab(b, loc);
1207 AffineExpr dim0, sym;
1208 bindDims(b.getContext(), dim0);
1209 bindSymbols(b.getContext(), sym);
1210 auto avOffset = AV(dim0).bind(offsets[dim]);
1211 auto avSize = AV(dim0).bind(sizes[dim]);
1212 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
1213 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
1214 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
1215 } else {
1216 outerDimOffsets.push_back(offsets[dim]);
1217 outerDimSizes.push_back(sizes[dim]);
1218 }
1219 }
1220 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
1221 resultOffsets = outerDimOffsets;
1222 resultSizes = outerDimSizes;
1223 return success();
1224 }
1225
1226 /// Method to return the tiled implementation of tensor.pack as a consumer.
1227 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
1228 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1229 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1230 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
1231 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1232 LLVM_DEBUG(
1233 { llvm ::dbgs() << "unhandled operands for consumer fusion"; });
1234 return failure();
1235 }
1236
1237 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1238 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1239
1240 auto packOp = cast<PackOp>(op);
1241 // TODO: Support Memref UnPackOp. Temporarily return failure.
1242 if (!packOp.hasPureTensorSemantics())
1243 return failure();
1244
1245 Location loc = packOp.getLoc();
1246
1247 int64_t inputRank = packOp.getSourceRank();
1248 auto oneAttr = b.getI64IntegerAttr(1);
1249 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
1250
1251 SmallVector<Value> tiledOperands;
1252 auto sourceSlice = tensor::ExtractSliceOp::create(
1253 b, loc, packOp.getSource(), offsets, sizes, strides);
1254 tiledOperands.push_back(sourceSlice);
1255
1256 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
1257 if (failed(getIterationDomainTileFromOperandTiles(
1258 op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
1259 outerDimSizes)))
1260 return failure();
1261
1262 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1263 if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
1264 outputOffsets, outputSizes)))
1265 return failure();
1266
1267 strides.append(packOp.getDestRank() - inputRank, oneAttr);
1268 auto outSlice = tensor::ExtractSliceOp::create(
1269 b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
1270 tiledOperands.push_back(outSlice);
1271
1272 if (auto val = packOp.getPaddingValue())
1273 tiledOperands.push_back(val);
1274 for (auto tile : packOp.getInnerTiles())
1275 tiledOperands.push_back(tile);
1276
1277 Operation *tiledPackOp = PackOp::create(
1278 b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
1279
1280 return TilingResult{
1281 {tiledPackOp},
1282 SmallVector<Value>(tiledPackOp->getResults()),
1283 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
1284 }
1285};
1286
1287struct UnpackTileDimInfo {
1288 bool isAlignedToInnerTileSize;
1289 OpFoldResult sourceOffset;
1290 OpFoldResult sourceSize;
1291 OpFoldResult resultOffset;
1292 OpFoldResult destExpandedSize;
1293};
1294
1295/// Returns the needed information for tiling unpack op on `tileDim` with given
1296/// `tileOffset` and `tileSize`. For more details, see the comment of the
1297/// `getTiledImplementation`.
1298static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
1299 int64_t tileDim,
1300 OpFoldResult tileOffset,
1301 OpFoldResult tileSize) {
1302 UnpackTileDimInfo info;
1303 Attribute zeroAttr = b.getIndexAttr(0);
1304 Attribute oneAttr = b.getIndexAttr(1);
1305 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1306 unpackOp.getDimAndTileMapping();
1307 // The dimension is not one of packed data dimension.
1308 if (!dimAndTileMapping.count(tileDim)) {
1309 info.isAlignedToInnerTileSize = true;
1310 info.sourceOffset = tileOffset;
1311 info.sourceSize = tileSize;
1312 info.resultOffset = zeroAttr;
1313 info.destExpandedSize = tileSize;
1314 return info;
1315 }
1316
1317 Location loc = unpackOp.getLoc();
1318 using AV = affine::AffineValueExpr;
1319 affine::AffineBuilder ab(b, loc);
1320 AffineExpr dim0, dim1, sym0;
1321 bindDims(b.getContext(), dim0, dim1);
1322 bindSymbols(b.getContext(), sym0);
1323
1324 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
1325
1326 info.isAlignedToInnerTileSize = false;
1327 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
1328 presburger::BoundType::UB, tileSize,
1329 /*stopCondition=*/nullptr, /*closedUB=*/true);
1330 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
1331 if (!failed(cstSize) && cstInnerSize) {
1332 if (*cstSize % *cstInnerSize == 0)
1333 info.isAlignedToInnerTileSize = true;
1334
1335 // If the tiling size equals to the inner tiling size, the outer dims are
1336 // always 1.
1337 if (*cstInnerSize == *cstSize) {
1338 auto lhs = AV(dim0).bind(tileOffset);
1339 auto rhs = AV(dim1).bind(innerTileSize);
1340 info.sourceOffset = ab.floor(lhs, rhs);
1341 info.sourceSize = oneAttr;
1342 info.resultOffset = zeroAttr;
1343 info.destExpandedSize = tileSize;
1344 return info;
1345 }
1346 }
1347
1348 if (info.isAlignedToInnerTileSize) {
1349 info.sourceOffset =
1350 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
1351 info.resultOffset = zeroAttr;
1352 info.destExpandedSize = tileSize;
1353
1354 // The ceilDiv is needed here because there could be incomplete tile even
1355 // it is perfect tiling cases. E.g.,
1356 // %0 = unpack tensor<33x2xf32> into tensor<64xf32>
1357 // If the tiling size is 32, there will be 3 tiles. Two of them have
1358 // size=32; one of them have size=2. The size is represented using
1359 // affine_min op; we need ceilDiv.
1360 info.sourceSize =
1361 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
1362 return info;
1363 }
1364
1365 affine::DivModValue firstCoord = affine::getDivMod(
1366 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
1367 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
1368 OpFoldResult tileExclusiveBound =
1369 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
1370 affine::DivModValue lastCoord = affine::getDivMod(
1371 b, loc,
1373 b, loc,
1374 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
1375 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
1376
1377 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
1378 AV(dim1).bind(firstCoord.quotient));
1379 info.sourceSize =
1380 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
1381 info.sourceOffset = firstCoord.quotient;
1382 info.resultOffset = firstCoord.remainder;
1383 // Do not create an Affine ops for expanded size because the affine op is too
1384 // complicated which would trigger an issue in affine ops simplification.
1385 info.destExpandedSize = b.createOrFold<arith::MulIOp>(
1386 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
1387 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
1388 return info;
1389}
1390
1391struct UnPackOpTiling
1392 : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> {
1393
1394 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
1395 auto unpackOp = cast<UnPackOp>(op);
1396 SmallVector<utils::IteratorType> iteratorTypes(
1397 unpackOp.getDestRank(), utils::IteratorType::parallel);
1398 return iteratorTypes;
1399 }
1400
1401 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
1402 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
1403 }
1404
1405 /// There are two cases in tiling unpack ops. If the tiling size is aligned to
1406 /// the inner tile size, the corresponding tiles of source are all complete.
1407 /// Otherwise, there are in-complete tiles. We will need to expand the slice
1408 /// of source for getting complete tiles. The tiled unpack op unpacks more
1409 /// data from source, so We'll need an extract_slice op to shift and truncate
1410 /// the output.
1411 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
1412 /// coordinates of second tile (i.e., result[15..31]) are
1413 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
1414 /// row are incomplete tiles. To represent the unpack op, we have to complete
1415 /// the rows. I.e., the input coordinates would start with (1, 0); end with
1416 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
1417 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
1418 /// can get the actual result.
1419 FailureOr<TilingResult>
1420 getTiledImplementation(Operation *op, OpBuilder &b,
1421 ArrayRef<OpFoldResult> offsets,
1422 ArrayRef<OpFoldResult> sizes) const {
1423 auto unpackOp = cast<UnPackOp>(op);
1424 // TODO: Support Memref UnPackOp. Temporarily return failure.
1425 if (!unpackOp.hasPureTensorSemantics())
1426 return failure();
1427
1428 int64_t srcRank = unpackOp.getSourceRank();
1429 int64_t destRank = unpackOp.getDestRank();
1430 int64_t numInnerTiles = srcRank - destRank;
1431 Location loc = unpackOp.getLoc();
1432
1433 // The perfect tiling case indicates that the tiling sizes are multiple of
1434 // inner_tile_size. In this context, no extra data is needed when
1435 // representing the tiled unpack op.
1436 bool isPerfectTilingCase = true;
1437 Attribute oneAttr = b.getIndexAttr(1);
1438 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
1439 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
1440 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
1441 for (auto dim : llvm::seq<int64_t>(0, destRank)) {
1442 UnpackTileDimInfo info =
1443 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
1444 if (!info.isAlignedToInnerTileSize)
1445 isPerfectTilingCase = false;
1446 sliceSrcIndices.push_back(info.sourceOffset);
1447 sliceSrcSizes.push_back(info.sourceSize);
1448 destExpandedSizes.push_back(info.destExpandedSize);
1449 resultOffsetsFromDest.push_back(info.resultOffset);
1450 }
1451
1452 // The tiling is applied on destination dimensions. We have to apply the
1453 // interchange on source dimensions if outer_dims_perm is set.
1454 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
1455 unpackOp.getOuterDimsPerm());
1456 Attribute zeroAttr = b.getIndexAttr(0);
1457 sliceSrcIndices.append(numInnerTiles, zeroAttr);
1458 sliceSrcSizes.append(unpackOp.getMixedTiles());
1459 sliceSrcStrides.append(numInnerTiles, oneAttr);
1460 SmallVector<Operation *> generatedSlices;
1461 tensor::ExtractSliceOp sliceSource = tensor::ExtractSliceOp::create(
1462 b, loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,
1463 sliceSrcStrides);
1464 generatedSlices.push_back(sliceSource);
1465
1466 SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
1467 Value sliceDest;
1468 if (isPerfectTilingCase) {
1469 auto destSliceOp = tensor::ExtractSliceOp::create(
1470 b, loc, unpackOp.getDest(), offsets, sizes, destStrides);
1471 sliceDest = destSliceOp;
1472 generatedSlices.push_back(destSliceOp);
1473 } else {
1474 sliceDest = tensor::EmptyOp::create(
1475 b, loc, destExpandedSizes, unpackOp.getDestType().getElementType());
1476 }
1477
1478 SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
1479 for (auto tile : unpackOp.getInnerTiles())
1480 tiledOperands.push_back(tile);
1481
1482 Operation *tiledUnpackOp = UnPackOp::create(
1483 b, loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
1484
1485 if (isPerfectTilingCase)
1486 return TilingResult{{tiledUnpackOp},
1487 SmallVector<Value>(tiledUnpackOp->getResults()),
1488 generatedSlices};
1489
1490 auto extractSlice = tensor::ExtractSliceOp::create(
1491 b, loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes,
1492 destStrides);
1493 return TilingResult{
1494 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
1495 }
1496
1497 LogicalResult
1498 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
1499 ArrayRef<OpFoldResult> offsets,
1500 ArrayRef<OpFoldResult> sizes,
1501 SmallVector<OpFoldResult> &resultOffsets,
1502 SmallVector<OpFoldResult> &resultSizes) const {
1503 resultOffsets = llvm::to_vector(offsets);
1504 resultSizes = llvm::to_vector(sizes);
1505 return success();
1506 }
1507
1508 FailureOr<TilingResult>
1509 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
1510 ArrayRef<OpFoldResult> offsets,
1511 ArrayRef<OpFoldResult> sizes) const {
1512 FailureOr<TilingResult> tilingResult =
1513 getTiledImplementation(op, b, offsets, sizes);
1514 if (failed(tilingResult))
1515 return failure();
1516 return tilingResult.value();
1517 }
1518
1519 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
1520 Location loc,
1521 ValueRange ivs) const {
1522 auto unpackOp = cast<UnPackOp>(op);
1523 assert(unpackOp.hasPureBufferSemantics() &&
1524 "expected operation to have buffer semantics");
1525 assert(ivs.size() == unpackOp.getDestRank() &&
1526 "number of ivs must match the rank of the output tensor");
1527 OpBuilder::InsertionGuard g(builder);
1528
1529 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1530 unpackOp.getDimAndTileMapping();
1531 // Untiled loops and tile loops induction variables.
1532 SmallVector<Value> inputIvs;
1533 // Point loops induction variables.
1534 SmallVector<Value> inputIvsPointLoops;
1535 inputIvs.reserve(unpackOp.getDestRank());
1536 inputIvsPointLoops.reserve(dimAndTileMapping.size());
1537 for (auto dim : llvm::seq<int64_t>(0, unpackOp.getDestRank())) {
1538 if (dimAndTileMapping.count(dim)) {
1539 affine::DivModValue divMod =
1540 affine::getDivMod(builder, loc, ivs[dim],
1542 builder, loc, dimAndTileMapping[dim]));
1543 inputIvsPointLoops.push_back(divMod.remainder);
1544 inputIvs.push_back(divMod.quotient);
1545 } else {
1546 inputIvs.push_back(ivs[dim]);
1547 }
1548 }
1549
1550 // TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
1551 // `inputIvsPointLoops` and `inputIvs`.
1552 assert(inputIvsPointLoops.size() + inputIvs.size() ==
1553 unpackOp.getSourceRank() &&
1554 "expect same number of induction variables equals to input rank");
1555 // Interchange the point loops induction variables based on `inner_dim_pos`.
1556 ArrayRef<int64_t> innerDims = unpackOp.getInnerDimsPos();
1557 SmallVector<int64_t> interchangeVector =
1558 computeInterchangeFromDimPos(innerDims, unpackOp.getDestRank());
1559 SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
1560 interchangedInputIvsPointLoops = interchange<Value>(
1561 interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
1562 // Interchange the tiled loops induction variables based on
1563 // `outer_dims_perm`.
1564 ArrayRef<int64_t> outerDims = unpackOp.getOuterDimsPerm();
1565 if (!outerDims.empty())
1566 inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0);
1567
1568 llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
1569 Value scalar =
1570 memref::LoadOp::create(builder, loc, unpackOp.getSource(), inputIvs);
1571 memref::StoreOp::create(builder, loc, scalar, unpackOp.getDest(), ivs);
1572 return success();
1573 }
1574
1575 /// Method to return the position of iteration domain tile computed by the
1576 /// tiled operation.
1577 LogicalResult getIterationDomainTileFromOperandTiles(
1578 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1579 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1580 ArrayRef<SmallVector<OpFoldResult>> allSizes,
1581 SmallVectorImpl<OpFoldResult> &resultOffsets,
1582 SmallVectorImpl<OpFoldResult> &resultSizes) const {
1583 if (operandNumbers.size() != 1) {
1584 LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
1585 return failure();
1586 }
1587 auto unPackOp = cast<UnPackOp>(op);
1588 unsigned operandNumber = operandNumbers[0];
1589 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1590 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1591
1592 // If the operand tile is the dest, then no adjustment is needed.
1593 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
1594 resultOffsets = llvm::to_vector(offsets);
1595 resultSizes = llvm::to_vector(sizes);
1596 return success();
1597 }
1598 Location loc = unPackOp.getLoc();
1599
1600 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1601 auto destOffsets = offsets.drop_back(numTiles);
1602 auto destSizes = sizes.drop_back(numTiles);
1603 // The tiling is applied on interchanged dimensions. We have to undo the
1604 // interchange to map sizes and offsets to the original input.
1605 int64_t outputRank = unPackOp.getDestRank();
1606 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1607 if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
1608 return failure();
1609 SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
1610 SmallVector<OpFoldResult> origOffsets(destOffsets);
1611 SmallVector<OpFoldResult> origSizes(destSizes);
1612 applyPermToRange(origOffsets, origSizes,
1613 invertPermutationVector(unPackOp.getOuterDimsPerm()));
1614
1615 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1616 unPackOp.getDimAndTileMapping();
1617
1618 for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
1619 using AV = affine::AffineValueExpr;
1620 affine::AffineBuilder ab(b, loc);
1621 AffineExpr dim0, dim1, sym0;
1622 bindDims(b.getContext(), dim0, dim1);
1623 bindSymbols(b.getContext(), sym0);
1624 if (dimAndTileMapping.count(dim)) {
1625 // If the data dimension is tiled, the i-th index is the product of
1626 // offset_i and tile_i, and the i-th size is the product of sizes_i and
1627 // tile_i. The sizes must be clamped to the sizes of the unpack result.
1628 auto avOffset = AV(dim0).bind(origOffsets[dim]);
1629 auto avSize = AV(dim0).bind(origSizes[dim]);
1630 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
1631 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
1632 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
1633 auto avResultOffset = AV(dim1).bind(resultOffsets.back());
1634 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
1635 ab.sub(avResultSize, avResultOffset)}));
1636 } else {
1637 resultOffsets.push_back(origOffsets[dim]);
1638 resultSizes.push_back(origSizes[dim]);
1639 }
1640 }
1641 return success();
1642 }
1643
1644 /// Method to return the tiled implementation of tensor.unpack as a consumer.
1645 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
1646 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1647 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1648 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
1649 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1650 LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
1651 return failure();
1652 }
1653 auto unPackOp = cast<UnPackOp>(op);
1654 // TODO: Support Memref UnPackOp. Temporarily return failure.
1655 if (!unPackOp.hasPureTensorSemantics())
1656 return failure();
1657
1658 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1659 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1660
1661 // tensor.unpack op is fusible (as a consumer) only if inner dims are not
1662 // tiled.
1663 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1664 for (auto iter :
1665 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
1666 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
1667 return failure();
1668 }
1669
1670 Location loc = unPackOp.getLoc();
1671
1672 // Fetch offset/size for creating the slice of the dest operand of
1673 // unpack op.
1674 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1675 if (failed(getIterationDomainTileFromOperandTiles(
1676 op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
1677 outputSizes)))
1678 return failure();
1679
1680 auto oneAttr = b.getI64IntegerAttr(1);
1681 int64_t outputRank = unPackOp.getDestRank();
1682 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
1683
1684 SmallVector<Value> tiledOperands;
1685 // Create slice of the dest operand.
1686 auto extractDestSlice = tensor::ExtractSliceOp::create(
1687 b, loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
1688 tiledOperands.push_back(extractDestSlice);
1689
1690 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
1691 // Create slice of the source operand.
1692 auto extractSourceSlice = tensor::ExtractSliceOp::create(
1693 b, loc, unPackOp.getSource(), offsets, sizes, strides);
1694 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
1695 for (auto tile : unPackOp.getInnerTiles())
1696 tiledOperands.push_back(tile);
1697
1698 // Create tiled unpack op.
1699 Operation *tiledUnPackOp =
1700 UnPackOp::create(b, loc, TypeRange{extractDestSlice.getType()},
1701 tiledOperands, op->getAttrs());
1702
1703 return TilingResult{{tiledUnPackOp},
1704 SmallVector<Value>(tiledUnPackOp->getResults()),
1705 llvm::to_vector(ArrayRef<Operation *>{
1706 extractSourceSlice, extractDestSlice})};
1707 }
1708};
1709
1710} // namespace
1711
1712template <typename OpType>
1713static void registerOne(MLIRContext *ctx) {
1714 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
1715 OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
1716 *ctx);
1717}
1718
1719/// Variadic helper function.
1720template <typename... OpTypes>
1721static void registerAll(MLIRContext *ctx) {
1722 (registerOne<OpTypes>(ctx), ...);
1723}
1724
1725#define GET_OP_LIST
1726
1728 DialectRegistry &registry) {
1729 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
1731 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1732 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
1734#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1735 >(ctx);
1736 });
1737}
1738
1740 DialectRegistry &registry) {
1741 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
1742 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1743 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
1744 });
1745}
return success()
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
Definition Utils.cpp:76
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, ValueRange ivs, ValueRange argValues)
Method to inline the payload of a linalgOp given the iteration space point and values for the argumen...
static SmallVector< Value > getIndicesForAccess(OpBuilder &b, Location loc, AffineMap indexingMap, ValueRange ivs)
Return the SSA values that represent the data point accessed using a given indexingMap for a given po...
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext * getContext() const
Definition Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
void setOperand(unsigned idx, Value value)
Definition Operation.h:377
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition Region.cpp:70
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:2850
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry &registry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:2872
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
void registerTilingInterfaceExternalModels(DialectRegistry &registry)
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition Utils.cpp:2761
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
Definition Utils.cpp:2614
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Container for result values of tiling.
Helper struct to build simple AffineValueExprs with minimal type inference support.
Definition Utils.h:377
A struct containg offsets-sizes-strides arguments of the tiled shape.
Definition Utils.h:172
SmallVector< OpFoldResult > sizes
Definition Utils.h:174
SmallVector< OpFoldResult > offsets
Definition Utils.h:173
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.