MLIR 23.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1//===- DataLayoutPropagation.cpp -----------------------------------------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Dominance.h"
18#include "llvm/ADT/SetOperations.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Debug.h"
22#include <optional>
23
24namespace mlir {
25#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26#include "mlir/Dialect/Linalg/Passes.h.inc"
27} // namespace mlir
28
29using namespace mlir;
30using namespace mlir::linalg;
31
32#define DEBUG_TYPE "linalg-data-layout-propagation"
33
34namespace {
35
36static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37 for (Operation &op : genericOp.getBody()->getOperations())
38 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
39 return true;
40 return false;
41}
42
43// The struct contains the infomation about mapping packing information to
44// the iteration domain of Linalg ops.
45struct PackInfo {
46 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
47 // InnerDimsPos on iteration domain, which follows the order in pack ops.
48 SmallVector<int64_t> tiledDimsPos;
49 // The sizes of tiling data dimensions on iteration domain.
50 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
51 // The mapping from a dimension of iteration domain to the corresponding inner
52 // tiling dimension on iteration domain.
53 llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
54 // The permutation of outer dims (on domain).
55 SmallVector<int64_t> outerDimsOnDomainPerm;
56};
57
58template <typename OpTy>
59static FailureOr<PackInfo>
60getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
61 OpTy packOrUnPackOp) {
62 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
63 "applies to only pack or unpack operations");
64 LLVM_DEBUG(
65 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
66
67 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
68 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
70 genericOp.getIteratorTypesArray();
71
72 PackInfo packInfo;
73 int64_t origNumDims = indexingMap.getNumDims();
74 SmallVector<AffineExpr> exprs(indexingMap.getResults());
75 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
76 for (auto [index, innerDimPos, tileSize] :
77 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
78 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
79 auto expr = exprs[innerDimPos];
80 if (!isa<AffineDimExpr>(expr))
81 return failure();
82 int64_t domainDimPos =
83 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
84 if (!isParallelIterator(iterators[domainDimPos]))
85 return failure();
86 packInfo.tiledDimsPos.push_back(domainDimPos);
87 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
89 LLVM_DEBUG({
90 llvm::dbgs() << "map innerDimPos=" << innerDimPos
91 << " to iteration dimension (d" << domainDimPos << ", d"
92 << packInfo.tileToPointMapping[domainDimPos]
93 << "), which has size=("
94 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
95 });
96 }
97
98 // Bail out if a tiled dimension is present in a map but not as an affine dim
99 // expression.
100 auto areAllAffineDimExpr = [&](int dim) {
101 for (AffineMap map : indexingMaps) {
102 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
103 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
104 })) {
105 return false;
106 }
107 }
108 return true;
109 };
110 for (int64_t i : packInfo.tiledDimsPos)
111 if (!areAllAffineDimExpr(i))
112 return failure();
113
114 // Get the outer dims perm on the iteration domain. Start by identifying the
115 // set of domain dims affected by the outer permutation along with the
116 // permuted ordering for those dims. Then the full outer dims permutation can
117 // be constructed by replacing the affected dims with the permuted result in a
118 // numLoops-rank identity. e.g.
119 // outerDimsPerm = [1, 2, 0]
120 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
121 //
122 // permutedOuterDims = [4, 3, 1]
123 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
124 //
125 // Non-affine dim expressions must not be permuted by the outer dims
126 // permutation.
127 SmallVector<int64_t> permutedOuterDims;
128 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129 auto permutedExpr = indexingMap.getResult(dim);
130 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131 permutedOuterDims.push_back(dimExpr.getPosition());
132 continue;
133 }
134
135 // TODO: Allow propagation with transposes on non affine dim expressions,
136 // e.g. d0 + d1 which implies transposing both dims simultaneously while
137 // maintaining the relative position between them.
138 if (static_cast<int64_t>(index) != dim)
139 return failure();
140 }
141 if (!permutedOuterDims.empty()) {
142 int64_t outerDimIndex = 0;
143 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
144 permutedOuterDims.end());
145 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146 packInfo.outerDimsOnDomainPerm.push_back(
147 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
148 : i);
149 LLVM_DEBUG({
150 llvm::dbgs() << "map outer dimsDimsPerm to ";
151 for (auto dim : packInfo.outerDimsOnDomainPerm)
152 llvm::dbgs() << dim << " ";
153 llvm::dbgs() << "\n";
154 });
155 }
156
157 return packInfo;
158}
159
160static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
161 ArrayRef<AffineExpr> exprs) {
162 // Compute `outer_dims_perm`. See example:
163 // current exprs : (d0, d1, d2, d3) -> (d2, d3)
164 // perm : [0, 3, 1, 2]
165 // First map d2, d3 with their position in the array as:
166 // currentPositionTileLoops: dim | pos
167 // d2 | 0
168 // d3 | 1
169 // then scan `perm` in order and get the `outer_dims_perm`
170 // to be used, here it would be [1, 0].
171 assert(!perm.empty() && "expect perm not to be empty");
172 assert(!exprs.empty() && "expect exprs not to be empty");
173 if (exprs.size() == 1)
174 return {};
175 SmallVector<int64_t> outerDimsPerm;
176 DenseMap<int64_t, int64_t> currentPositionTileLoops;
177 for (auto [pos, expr] : llvm::enumerate(exprs)) {
178 // Here we rely on the assumption that the outer dims permutation
179 // when propagating currently requires that non-affine dim expressions
180 // are not permuted, thus allowing the identity assignment below.
181 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182 currentPositionTileLoops[dimExpr.getPosition()] = pos;
183 else
184 currentPositionTileLoops[pos] = pos;
185 }
186 for (int64_t loopIdx : perm) {
187 if (currentPositionTileLoops.count(loopIdx))
188 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
189 }
190 return outerDimsPerm;
191}
192
193struct PackedOperandDetails {
194 SmallVector<OpFoldResult> innerTileSizes;
195 SmallVector<int64_t> innerDimsPos;
196 SmallVector<int64_t> outerDimsPerm;
197 AffineMap indexingMap;
198};
199
200/// Helper function for getOrCreatePackedViewOfOperand that populates
201/// the details of the packedOperand that needs to be formed and also
202/// returns if the packing would require padding.
203static bool getPackedOperandDetails(
204 OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
206 PackedOperandDetails currOperandDetails;
207 int64_t numOrigLoops = genericOp.getNumLoops();
208 int64_t numInnerLoops = packInfo.getNumTiledLoops();
209 int64_t numLoops = numOrigLoops + numInnerLoops;
210 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
211 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
212 SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
213
214 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
215 if (genericOp.isScalar(opOperand) || exprs.empty()) {
216 currOperandDetails.indexingMap =
217 AffineMap::get(numLoops, 0, exprs, b.getContext());
218 packedOperandMap[opOperand] = currOperandDetails;
219 return false;
220 }
221
222 // Step 1. Construct the information of packing data dimensions; append inner
223 // dimensions to the indexing maps for the operand.
224 for (auto [index, expr] : llvm::enumerate(exprs)) {
225 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
226 int64_t dimPos = dimExpr.getPosition();
227 domainDimToOperandDim[dimPos] = index;
228 continue;
229 }
230 }
231 SmallVector<int64_t> innerDimsPos;
232 SmallVector<OpFoldResult> innerTileSizes;
233 for (auto dimPos : packInfo.tiledDimsPos) {
234 if (!domainDimToOperandDim.count(dimPos))
235 continue;
236 int64_t index = domainDimToOperandDim[dimPos];
237 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
238 innerDimsPos.push_back(index);
239 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
240 }
241
242 // Step 2. Handle outer dim permutations.
243 SmallVector<int64_t> outerDimsPerm;
244 if (!packInfo.outerDimsOnDomainPerm.empty()) {
245 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
246
247 // Step 2.1: Fold transpose into the linalg.generic.
248 SmallVector<int64_t> inversedOuterPerm =
249 invertPermutationVector(packInfo.outerDimsOnDomainPerm);
250 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
251 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
252 int64_t dimPos = dimExpr.getPosition();
253 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
254 continue;
255 }
256 assert(isa<AffineConstantExpr>(exprs[i]) &&
257 "Attempted to permute non-constant and non-affine dim expression");
258 }
259 // Step 2.2: Undo the transposition on `exprs` and propagate the
260 // transposition on the pack using outerDimsPerm.
261 if (!outerDimsPerm.empty()) {
262 SmallVector<AffineExpr> auxVec = exprs;
263 for (const auto &en : enumerate(outerDimsPerm))
264 auxVec[en.index()] = exprs[en.value()];
265 exprs = auxVec;
266 }
267 }
268 currOperandDetails.indexingMap =
269 AffineMap::get(numLoops, 0, exprs, b.getContext());
270
271 // The operand does not have dimensions that relates to pack op.
272 if (innerDimsPos.empty() && outerDimsPerm.empty()) {
273 packedOperandMap[opOperand] = currOperandDetails;
274 return false;
275 }
276 auto inputType = cast<RankedTensorType>(opOperand->get().getType());
277
278 auto maybeIntInnerTileSizes =
279 llvm::map_to_vector(innerTileSizes, [](OpFoldResult ofr) -> int64_t {
280 std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
281 return maybeCst.value_or(ShapedType::kDynamic);
282 });
283 bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
284 inputType.getShape(), innerDimsPos,
285 linalg::PackOp::inferPackedTensorType(inputType, maybeIntInnerTileSizes,
286 innerDimsPos, outerDimsPerm)
287 .getShape(),
288 outerDimsPerm, innerTileSizes);
289 currOperandDetails.innerDimsPos = innerDimsPos;
290 currOperandDetails.innerTileSizes = innerTileSizes;
291 currOperandDetails.outerDimsPerm = outerDimsPerm;
292 packedOperandMap[opOperand] = currOperandDetails;
293
294 return requirePadding;
295}
296
297/// Returns a tuple for packed operand and indexing_map with the assumptions:
298/// 1) The generic op is the producer of the pack op.
299/// 2) The generic op has only one result.
300/// If the operand is a scalar or packing dimensions are all irrelevant to the
301/// operand, the operand and the updated indexing map will be returned.
302/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
303///
304/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
305/// #map1 = affine_map<(d0, d1) -> (d0)>
306/// #map2 = affine_map<(d0, d1) -> (d1)>
307/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
308/// iterator_types = ["parallel", "parallel"]}
309/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
310/// outs(%init : tensor<?x?xf32>) {
311/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
312/// %4 = arith.addf %arg3, %arg4 : f32
313/// linalg.yield %4 : f32
314/// } -> tensor<?x?xf32>
315/// %1 = linalg.pack %0
316/// inner_dims_pos = [0, 1]
317/// inner_tiles = [8, 2]
318/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
319///
320/// Taking the first input operand as an example, the inner tile size of d1 is
321/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
322/// affine_map<(d1, d3)>` will be returned.
323///
324/// %pack = linalg.pack %arg0
325/// inner_dims_pos = [0]
326/// inner_tiles = [8]
327/// into %init : tensor<?xf32> -> tensor<?x8xf32>
328static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
329 OpBuilder &b, Location loc, OpOperand *opOperand,
330 const DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
331 assert(packedOperandMap.contains(opOperand) &&
332 "packed operand details expected to be populated");
333 auto currOperandDetails = packedOperandMap.at(opOperand);
334 auto innerDimsPos = currOperandDetails.innerDimsPos;
335 auto outerDimsPerm = currOperandDetails.outerDimsPerm;
336 auto innerTileSizes = currOperandDetails.innerTileSizes;
337 if (innerDimsPos.empty() && outerDimsPerm.empty())
338 return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap);
339
340 auto empty = linalg::PackOp::createDestinationTensor(
341 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
342 auto poison = ub::PoisonOp::create(
343 b, loc, getElementTypeOrSelf(opOperand->get().getType()));
344 PackOp packedOperand =
345 linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
346 innerTileSizes, poison, outerDimsPerm);
347 return std::make_tuple(packedOperand.getResult(),
348 currOperandDetails.indexingMap);
349}
350
351/// This function is a helper subroutine to pack a genericOp and return it. It
352/// will create a new generic op with the packed operand and the packed output
353/// according to packInfo when we attempt to push down unpack or bubble up pack
354/// around it. Implicitly this will only work when a packInfo can be obtained.
355/// This make sure that we are only using this function on parallel permuted
356/// dimensions.
357static FailureOr<GenericOp>
358packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
359 AffineMap packedOutIndexingMap, const PackInfo &packInfo,
360 bool isFoldableUnpackPack, bool poisonPaddingOk) {
361 Location loc = genericOp.getLoc();
362 SmallVector<Value> inputOperands;
363 SmallVector<Value> inputOperandsFromUnpackedSource;
364 SmallVector<AffineMap> indexingMaps;
365 auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
366 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
367 packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
368 llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
369 };
371 bool requiresPadding = false;
372 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
373 requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp,
374 inputOperand, packedOperandMap);
375 }
376 if (requiresPadding && !poisonPaddingOk)
377 return failure();
378
379 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
380 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
381 rewriter, loc, inputOperand, packedOperandMap);
382 auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
383 auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
384 if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
385 inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
386 } else {
387 inputOperandsFromUnpackedSource.push_back(packedOperand);
388 }
389 inputOperands.push_back(packedOperand);
390 indexingMaps.push_back(packedIndexingMap);
391 }
392
393 // If the unpack->pack sequences can be folded, replace use the sources of
394 // the unpack ops in any unpack->pack chains on the generic op operands.
395 if (isFoldableUnpackPack) {
396 inputOperands = inputOperandsFromUnpackedSource;
397 if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
398 auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
399 if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
400 dest = destUnPack.getSource();
401 }
402 }
403 }
404
405 int64_t numInnerLoops = packInfo.getNumTiledLoops();
407 genericOp.getIteratorTypesArray();
408 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
409
410 indexingMaps.push_back(packedOutIndexingMap);
411
412 auto newGenericOp = linalg::GenericOp::create(
413 rewriter, loc, dest.getType(), inputOperands, dest, indexingMaps,
414 iterTypes,
415 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
416 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
417 newGenericOp.getRegion().begin());
418 return newGenericOp;
419}
420
421static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
422 return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
423 return genericOp.getMatchingBlockArgument(&operand).use_empty();
424 });
425}
426
427/// Bubbles up linalg.pack op through a producer generic op. This
428/// swap pack(generic) to generic(pack). The new generic op works on packed
429/// domain; pack ops are created for input and output operands. E.g.,
430///
431/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
432/// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
433/// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
434/// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
435/// %3 = linalg.generic {indexing_maps = [#map0, #map0],
436/// iterator_types = ["parallel", "parallel"]}
437/// ins(%arg0 : tensor<?x?xf32>)
438/// outs(%2 : tensor<?x?xf32>) {
439/// ^bb0(%arg3: f32, %arg4: f32):
440/// %4 = arith.addf %arg3, %arg3 : f32
441/// linalg.yield %4 : f32
442/// } -> tensor<?x?xf32>
443/// %4 = linalg.pack %3
444/// inner_dims_pos = [0, 1]
445/// inner_tiles = [8, 2]
446/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
447///
448/// will be converted to
449///
450/// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
451/// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
452/// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
453/// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
454/// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
455/// %0 = affine.apply #map()[%dim]
456/// %1 = affine.apply #map1()[%dim_0]
457/// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
458/// %pack = linalg.pack %arg0
459/// inner_dims_pos = [0, 1]
460/// inner_tiles = [8, 2]
461/// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
462/// %3 = linalg.generic {indexing_maps = [#map2, #map2],
463/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
464/// ins(%pack : tensor<?x?x8x2xf32>)
465/// outs(%arg1 : tensor<?x?x8x2xf32>) {
466/// ^bb0(%in: f32, %out: f32):
467/// %4 = arith.addf %in, %in : f32
468/// linalg.yield %4 : f32
469/// } -> tensor<?x?x8x2xf32>
470static FailureOr<GenericOp>
471bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
472 const ControlPropagationFn &controlFn,
473 bool poisonPaddingOk) {
474 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
475 if (!genericOp)
476 return failure();
477
478 // User controlled propagation function.
479 if (!controlFn(&packOp.getSourceMutable()))
480 return failure();
481
482 // TODO: Enable propagation in the presence of linalg.index and
483 // tensor.extract, likely as a separate pattern as the pack information and
484 // propagation decision needs to be inferred from the region of the generic.
485 if (hasGatherSemantics(genericOp))
486 return failure();
487
488 // TODO: Relax the restriction. We are able to bubble up the pack op through
489 // multi-result generic op. It just needs more work.
490 if (genericOp.getNumResults() != 1)
491 return failure();
492
493 // Bail-out if the result of the generic has multiple uses, as bubbling up
494 // creates recomputation if the generic has multiple users.
495 // TODO: Enable the case where every use is an identical pack op as no
496 // recomputation is needed in that case.
497 if (!genericOp->getResult(0).hasOneUse())
498 return failure();
499
500 // TODO: Add an option for allowing padding values. It could introduce
501 // undefined behavior if we unconditionally propagate pack op through all
502 // the ops. E.g., if the padding value is zero and there are division ops in
503 // a generic op. Some values of padding area could be NaN (0/0).
504 if (packOp.getPaddingValue())
505 return failure();
506
507 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
508 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
509 if (failed(packInfo))
510 return failure();
511
512 // We want to move the pack not the generic.
513 OpBuilder::InsertionGuard guard(rewriter);
514 rewriter.setInsertionPoint(genericOp);
515
516 // We need to handle two cases:
517 // 1) The linalg.pack destination is a tensor.empty. If this is the case, we
518 // create a new tensor.empty to avoid breaking dominance, as we are moving the
519 // linalg.pack above the linalg.generic.
520 // 2) The destination is not a tensor.empty. In this case we can replace only
521 // if the destination of the linalg.pack dominates the linalg.generic.
522 Value packOpDest = packOp.getDest();
523 if (!packOpDest.hasOneUse())
524 return failure();
525 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
526 packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(),
527 emptyOp.getMixedSizes(),
528 emptyOp.getType().getElementType());
529 } else {
530 DominanceInfo dom(genericOp);
531 if (!dom.properlyDominates(packOpDest, genericOp))
532 return failure();
533 }
534
535 // Rebuild the indexing map for the corresponding init operand.
537 bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp,
538 opOperand, packedOperandMap);
539 if (requiresPadding && !poisonPaddingOk)
540 return failure();
541
542 auto [packedOutOperand, packedOutIndexingMap] =
543 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand,
544 packedOperandMap);
545 // Forward the new tensor.empty as a destination if it is one of the following
546 // situations:
547 // 1) The dps init operand is a tensor.empty.
548 // 2) The dps init is a write-only operand, i.e., it is not used in the
549 // genericOp
550 Value dest = packedOutOperand;
551 auto initTensor =
552 genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
553 if (initTensor || isGenericOutsNotUsed(genericOp)) {
554 dest = packOpDest;
555 }
556 // pack(unpack) isn't naively foldable because the unpack op can be from
557 // an arbitrary domain so we need to keep both.
558 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
559 *packInfo, /*isFoldableUnpackPack=*/false,
560 poisonPaddingOk);
561}
562
563/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
564struct BubbleUpPackOpThroughGenericOpPattern
565 : public OpRewritePattern<linalg::PackOp> {
566public:
567 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
569 bool poisonPaddingOk)
570 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)),
571 poisonPaddingOk(std::move(poisonPaddingOk)) {}
572
573 LogicalResult matchAndRewrite(linalg::PackOp packOp,
574 PatternRewriter &rewriter) const override {
575 if (!packOp.hasPureTensorSemantics())
576 return failure();
577
578 auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
579 poisonPaddingOk);
580 if (failed(genericOp))
581 return failure();
582 rewriter.replaceOp(packOp, genericOp->getResults());
583 return success();
584 }
585
586private:
587 ControlPropagationFn controlFn;
588 bool poisonPaddingOk;
589};
590
591/// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
592/// add as many zero padding dimensions in `high` and `low` based on the number
593/// of point loops.
594class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
595public:
596 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
597 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
598
599 LogicalResult matchAndRewrite(linalg::PackOp packOp,
600 PatternRewriter &rewriter) const override {
601 if (!packOp.hasPureTensorSemantics())
602 return failure();
603
604 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
605 if (!padOp)
606 return failure();
607
608 // User controlled propagation function.
609 if (!controlFn(&packOp.getSourceMutable()))
610 return failure();
611
612 // TODO: Enable padding when the padding values are the same.
613 if (packOp.getPaddingValue())
614 return failure();
615
616 // Fail for non-constant padding values. The body of the pad could
617 // depend on the padding indices and/or properties of the padded
618 // tensor so for now we fail.
619 // TODO: Support non-constant padding values.
620 Value paddingVal = padOp.getConstantPaddingValue();
621 if (!paddingVal)
622 return failure();
623
624 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
625 return failure();
626
627 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
628
629 // Bail out if one of the padded dimension is a tiled one.
630 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
631 llvm::SmallBitVector innerDims(paddedDims.size());
632 for (int64_t dim : innerDimsPos)
633 innerDims.flip(dim);
634 if (paddedDims.anyCommon(innerDims))
635 return failure();
636
637 Location loc = padOp->getLoc();
638 OpBuilder::InsertionGuard guard(rewriter);
639 rewriter.setInsertionPoint(padOp);
640
641 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
642 SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
643 auto empty = linalg::PackOp::createDestinationTensor(
644 rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
645 outerDimsPerm);
646 auto sourcePack = linalg::PackOp::create(
647 rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
648 /*padding=*/std::nullopt, outerDimsPerm);
649
650 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
651 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
652 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
653 if (!outerDimsPerm.empty()) {
654 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
655 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
656 }
657 // The tiled dimensions were verified to be unpadded above, so here we
658 // just append 0 for the inner tile dimensions.
659 size_t pointLoopsSize = innerDimsPos.size();
660 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
661 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
662
663 auto newPadOp = tensor::PadOp::create(
664 rewriter, loc, /*result=*/Type(), sourcePack.getResult(), lowPad,
665 highPad, paddingVal, padOp.getNofold());
666
667 // If the pad has more than one user, create an unpack on the new pad to
668 // replace the other uses.
669 if (!padOp->hasOneUse()) {
670 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
671 rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
672 UnPackOp unpackedPad =
673 linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
674 innerDimsPos, mixedTiles, outerDimsPerm);
675 rewriter.replaceAllUsesExcept(padOp, unpackedPad.getResult(), sourcePack);
676 }
677
678 // Replace the pack with the new pad.
679 rewriter.replaceOp(packOp, newPadOp.getResult());
680
681 return success();
682 }
683
684private:
685 ControlPropagationFn controlFn;
686};
687
688/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
689///
690/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
691/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
692/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
693/// non-unit projected dims in pos [2, 3] is 2.
694///
695/// If all candidates in a reassociation are unit dims, it chooses the
696/// inner-most dim pos.
698projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
699 ArrayRef<ReassociationIndices> reassocIndices,
700 ArrayRef<int64_t> targetShape) {
701 SmallVector<int64_t> projectedDimsPos;
702 for (auto pos : dimsPos) {
703 // In the case all dims are unit, this will return the inner-most one.
704 int64_t projectedPos = reassocIndices[pos].back();
705 for (auto i : llvm::reverse(reassocIndices[pos])) {
706 int64_t dim = targetShape[i];
707 if (dim > 1 || ShapedType::isDynamic(dim)) {
708 projectedPos = i;
709 break;
710 }
711 }
712 projectedDimsPos.push_back(projectedPos);
713 }
714 return projectedDimsPos;
715}
716
717/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
718static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
720 ArrayRef<int64_t> tileSizes) {
721 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
722 int64_t dim = shape[pos];
723 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
724 return false;
725 }
726 return true;
727}
728
729/// Permutate the reassociation indices and reindex them in the sequence order.
730/// Returns the next dim pos in the sequence.
731///
732/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
733/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
734/// [[0], [1, 2]].
735static int64_t applyPermutationAndReindexReassoc(
736 SmallVector<ReassociationIndices> &reassocIndices,
737 ArrayRef<int64_t> permutation) {
738 if (!permutation.empty())
739 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
740 int64_t nextPos = 0;
741 for (ReassociationIndices &indices : reassocIndices) {
742 for (auto &index : indices) {
743 index = nextPos;
744 nextPos += 1;
745 }
746 }
747 return nextPos;
748}
749
750/// Bubble up pack op through collapse shape op when the packed dims can be
751/// projected to the dims before collapsing. This is possible when the inner
752/// tile sizes can divide the projected dims.
753///
754/// For example:
755///
756/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
757/// : tensor<?x16x4xf32> into tensor<?x4xf32>
758/// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1]
759/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
760/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
761///
762/// can be transformed into:
763///
764/// %pack = linalg.pack %in outer_dims_perm = [1, 2]
765/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
766/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
767/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
768/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
769static LogicalResult
770bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
771 linalg::PackOp packOp,
772 PatternRewriter &rewriter) {
773 if (!packOp.hasPureTensorSemantics())
774 return failure();
775
776 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
777 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
778 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
779
780 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
781 SmallVector<ReassociationIndices> reassocIndices =
782 collapseOp.getReassociationIndices();
783 // Project inner tile pos to the dim pos before collapsing. For example, if
784 // dims [x, y] is collapsed into [z], packing on dim z can be projected back
785 // to pack on dim y.
786 //
787 // Project to inner-most non-unit dims to increase the chance that they can be
788 // divided by the inner tile sizes. This is correct because for [..., x, 1],
789 // packing on dim 1 is equivalent to packing on dim x.
790 SmallVector<int64_t> projectedInnerDimsPos =
791 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
792
793 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
794 innerTileSizes)) {
795 return failure();
796 }
797 // Expand the outer dims permutation with the associated source dims for the
798 // new permutation after bubbling. This is because moving a collapsed dim is
799 // equivalent to moving the associated source dims together.
800 SmallVector<int64_t> newOuterDimsPerm;
801 for (auto outerPos : outerDimsPerm)
802 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
803
804 auto emptyOp = linalg::PackOp::createDestinationTensor(
805 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
806 projectedInnerDimsPos, newOuterDimsPerm);
807 auto newPackOp = linalg::PackOp::create(
808 rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
809 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
810 newOuterDimsPerm);
811
812 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
813 // First apply the permutation on the reassociations of the outer dims.
814 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
815 // -> [[0], [1, 2]]
816 int64_t nextPos =
817 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
818 // Then add direct mapping for the inner tile dims.
819 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
820 newReassocIndices.push_back({nextPos});
821 nextPos += 1;
822 }
823
824 auto newCollapseOp = tensor::CollapseShapeOp::create(
825 rewriter, collapseOp.getLoc(), packOp.getResult().getType(),
826 newPackOp.getResult(), newReassocIndices);
827 rewriter.replaceOp(packOp, newCollapseOp);
828
829 return success();
830}
831
832/// Project dimsPos to their collapsed positions in the reassocIndices.
833///
834/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
835/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
836/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
837/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
839projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
840 ArrayRef<ReassociationIndices> reassocIndices) {
841 SmallVector<int64_t> projectedPos;
842
843 // Map each dimension to the position of corresponding reassociation index.
844 for (auto pos : dimsPos) {
845 for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
846 // If the dimension is present in the current indices group, the group
847 // position within the reassociation map is the desired projected
848 // dimension position.
849 if (llvm::is_contained(indices, pos)) {
850 projectedPos.push_back(idx);
851 break;
852 }
853 }
854 }
855 assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
856
857 return projectedPos;
858}
859
860/// Bubble up pack op through expand shape op.
861///
862/// For example:
863///
864/// %expand = tensor.expand_shape %in [[0], [1, 2]]
865/// : tensor<?x64xf32> into tensor<?x4x16xf32>
866/// %pack = linalg.pack %expand outer_dims_perm = [0, 1]
867/// inner_dims_pos = [2] inner_tiles = [8] into %empty
868/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
869///
870/// can be transformed into:
871///
872/// %pack = linalg.pack %in outer_dims_perm = [1, 2]
873/// inner_dims_pos = [1] inner_tiles = [8] into %empty
874/// : tensor<?x64xf32> -> tensor<?x8x8xf32>
875/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
876/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
877static LogicalResult
878bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
879 linalg::PackOp packOp,
880 PatternRewriter &rewriter) {
881 if (!packOp.hasPureTensorSemantics())
882 return failure();
883
884 // Outer dimensions permutation is not supported currently.
885 // TODO: Handle outer_dims_perm variants.
886 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
887 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
888 return rewriter.notifyMatchFailure(packOp,
889 "non-identity outer dims perm NYI");
890 }
891
892 // Validate dimensions' relations between shape expansion and packing.
894 expandOp.getReassociationIndices();
895 ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
896 llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims);
897
898 for (auto [idx, indices] : llvm::enumerate(reassoc)) {
899 // For each expand_shape reassociation, figure out which dimensions get
900 // packed if any.
901 llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices);
902 llvm::SetVector<int64_t> packedDims =
903 llvm::set_intersection(packDimsPos, expandDimPos);
904
905 // The expanded dimension is not packed so, it does not affect moving pack
906 // before shape expansion - simply continue.
907 if (packedDims.empty())
908 continue;
909 // Shape expansion cannot be propagated when multiple expanded dimension are
910 // packed - in this case operation reordering would affect final element
911 // positions and/or shapes can no longer be projected.
912 if (packedDims.size() != 1)
913 return rewriter.notifyMatchFailure(
914 packOp, "only one of the expanded dimensions can be packed");
915 // Only the inner-most expanded dimension should be packed. Otherwise,
916 // elements order will be affected after operation reordering.
917 if (packedDims.front() != indices.back())
918 return rewriter.notifyMatchFailure(
919 packOp, "can only pack the inner-most expanded dimension");
920 }
921
922 // Project pack.inner_dims_pos to positions before shape expansion.
923 SmallVector<int64_t> projectedInnerDimsPos =
924 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
925
926 // Project the shape expansion to new packed shape.
927 // The pack.outer_dims_perm is restricted to identity so, the permutation can
928 // be omitted for simplicity.
929 // TODO: Account for outer dimensions permutation.
930 //
931 // If reassociation is not possible, then reordering cannot happen.
932 // This can be caused by pack padding affecting previously expanded
933 // dimensions or packing extending dimensions.
934 RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType(
935 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
936 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
937 auto reassocExpand =
938 getReassociationIndicesForReshape(newPackType, packOp.getDestType());
939 if (!reassocExpand)
940 return rewriter.notifyMatchFailure(
941 packOp, "could not reassociate dims after bubbling up");
942
943 Value destTensor = linalg::PackOp::createDestinationTensor(
944 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
945 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
946 PackOp packedVal = linalg::PackOp::create(
947 rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
948 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
949 /*outerDimsPerm=*/SmallVector<int64_t>{});
950
951 Value newExpandOp = tensor::ExpandShapeOp::create(
952 rewriter, packOp.getLoc(), packOp.getDestType(), packedVal.getResult(),
953 *reassocExpand);
954 rewriter.replaceOp(packOp, newExpandOp);
955
956 return success();
957}
958
959class BubbleUpPackOpThroughReshapeOp final
960 : public OpRewritePattern<linalg::PackOp> {
961public:
962 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
963 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
964
965 LogicalResult matchAndRewrite(linalg::PackOp packOp,
966 PatternRewriter &rewriter) const override {
967 if (!packOp.hasPureTensorSemantics())
968 return failure();
969
970 Operation *srcOp = packOp.getSource().getDefiningOp();
971 // Currently only support when the pack op is the only user.
972 if (!srcOp || !(srcOp->getNumResults() == 1) ||
973 !srcOp->getResult(0).hasOneUse()) {
974 return failure();
975 }
976 // Currently only support static inner tile sizes.
977 if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
978 return failure();
979
980 // User controlled propagation function.
981 if (!controlFn(&packOp.getSourceMutable()))
982 return failure();
983
985 .Case([&](tensor::CollapseShapeOp op) {
986 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
987 })
988 .Case([&](tensor::ExpandShapeOp op) {
989 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
990 })
991 .Default(failure());
992 }
993
994private:
995 ControlPropagationFn controlFn;
996};
997
998/// Push down unpack op through expand shape op when the packed dims can be
999/// projected to the dims after expanding. This is possible when the inner tile
1000/// sizes can divide the projected dims.
1001///
1002/// For example:
1003///
1004/// %unpack = linalg.unpack %in outer_dims_perm = [0, 1]
1005/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
1006/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
1007/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
1008/// : tensor<?x256xf32> into tensor<?x256x256xf32>
1009///
1010/// can be transformed into:
1011///
1012/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
1013/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
1014/// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2]
1015/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
1016/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
1017static LogicalResult pushDownUnPackOpThroughExpandShape(
1018 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
1019 PatternRewriter &rewriter, ControlPropagationFn controlFn) {
1020 if (!unPackOp.hasPureTensorSemantics())
1021 return failure();
1022
1023 // User controlled propagation function.
1024 if (!controlFn(&expandOp.getSrcMutable()))
1025 return failure();
1026
1027 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
1028 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
1029 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
1030
1031 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
1032 if (!expandTy)
1033 return failure();
1034 ArrayRef<int64_t> dstShape = expandTy.getShape();
1035 SmallVector<ReassociationIndices> reassocIndices =
1036 expandOp.getReassociationIndices();
1037 // Project inner tile pos to the dim pos after expanding. For example, if dims
1038 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
1039 // on dim y.
1040 //
1041 // Project to inner-most non-unit dims to increase the chance that they can be
1042 // divided by the inner tile sizes. This is correct because for [..., x, 1],
1043 // unpacking on dim 1 is equivalent to unpacking on dim x.
1044 SmallVector<int64_t> projectedInnerDimsPos =
1045 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
1046
1047 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
1048 innerTileSizes)) {
1049 return failure();
1050 }
1051 // Expand the outer dims permutation with the associated expanded dims for the
1052 // new permutation after pushing. This is because moving a source dim is
1053 // equivalent to moving the associated expanded dims together.
1054 SmallVector<int64_t> newOuterDimsPerm;
1055 for (auto outerPos : outerDimsPerm)
1056 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
1057
1058 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
1059 // First apply the permutation on the reassociations of the outer dims.
1060 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
1061 // -> [[0], [1, 2]]
1062 int64_t nextPos =
1063 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
1064 // Then add direct mapping for the inner tile dims.
1065 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
1066 newReassocIndices.push_back({nextPos});
1067 nextPos += 1;
1068 }
1069
1070 RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType(
1071 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
1072 auto newExpandOp =
1073 tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
1074 unPackOp.getSource(), newReassocIndices);
1075
1076 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
1077 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
1078 projectedInnerDimsPos, newOuterDimsPerm);
1079 auto newUnPackOp = linalg::UnPackOp::create(
1080 rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
1081 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
1082 rewriter.replaceOp(expandOp, newUnPackOp);
1083
1084 return success();
1085}
1086
1087class PushDownUnPackOpThroughReshapeOp final
1088 : public OpRewritePattern<linalg::UnPackOp> {
1089public:
1090 PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
1092 : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) {
1093 }
1094
1095 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
1096 PatternRewriter &rewriter) const override {
1097 if (!unPackOp.hasPureTensorSemantics())
1098 return failure();
1099
1100 Value result = unPackOp.getResult();
1101 // Currently only support unpack op with the single user.
1102 if (!result.hasOneUse()) {
1103 return failure();
1104 }
1105 // Currently only support static inner tile sizes.
1106 if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1107 return failure();
1108
1109 Operation *consumerOp = *result.user_begin();
1111 .Case([&](tensor::ExpandShapeOp op) {
1112 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1113 controlFn);
1114 })
1115 .Default(failure());
1116 }
1117
1118private:
1119 ControlPropagationFn controlFn;
1120};
1121
1122// TODO: Relax this restriction. We should unpack a generic op also
1123// in the presence of multiple unpack ops as producers.
1124/// Return the unpacked operand, if present, for the current generic op.
1125static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1126 OpOperand *unPackedOperand = nullptr;
1127 for (OpOperand &operand : genericOp->getOpOperands()) {
1128 auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
1129 if (!unPackOp)
1130 continue;
1131 if (unPackedOperand)
1132 return failure();
1133 unPackedOperand = &operand;
1134 }
1135 if (!unPackedOperand)
1136 return failure();
1137 return unPackedOperand;
1138}
1139
1140/// Push down a linalg.unpack op through a generic op.
1141/// The new generic op works on packed domain; pack ops are created for input
1142/// and output operands. A linalg.unpack op is inserted right after the packed
1143/// generic. E.g.
1144///
1145/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1146///
1147/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1148///
1149/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1150/// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1151/// inner_dims_pos = [3] inner_tiles = [32] into %0
1152/// %2 = linalg.generic {indexing_maps = [#map],
1153/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1154/// outs(%1 : tensor<12x56x56x64xf32>) {
1155/// ^bb0(%out : f32):
1156/// linalg.yield %out : f32
1157/// } -> tensor<12x56x56x64xf32>
1158///
1159/// will be converted to
1160///
1161/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1162///
1163/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1164/// %1 = linalg.generic {indexing_maps = [#map],
1165/// iterator_types = ["parallel", "parallel", "parallel",
1166/// "parallel", "parallel"]}
1167/// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1168/// ^bb0(%out : f32):
1169/// linalg.yield %out : f32
1170/// } -> tensor<12x2x56x56x32xf32>
1171/// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1172/// inner_dims_pos = [3] inner_tiles = [32] into %0
1173///
1174static FailureOr<std::tuple<GenericOp, Value>>
1175pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1176 ControlPropagationFn controlFn,
1177 bool poisonPaddingOk) {
1178 if (genericOp.getNumResults() != 1)
1179 return failure();
1180
1181 if (hasGatherSemantics(genericOp))
1182 return failure();
1183
1184 // Collect the unPacked operand, if present.
1185 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1186 if (failed(maybeUnPackedOperand))
1187 return failure();
1188 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1189
1190 // Extract packing information.
1191 linalg::UnPackOp producerUnPackOp =
1192 unPackedOperand->get().getDefiningOp<linalg::UnPackOp>();
1193 assert(producerUnPackOp && "expect a valid UnPackOp");
1194
1195 if (!controlFn(unPackedOperand))
1196 return failure();
1197
1198 auto packInfo =
1199 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1200 if (failed(packInfo))
1201 return failure();
1202
1203 // Rebuild the indexing map for the corresponding init operand.
1205 bool requiresPadding =
1206 getPackedOperandDetails(rewriter, *packInfo, genericOp,
1207 genericOp.getDpsInitOperand(0), packedOperandMap);
1208 if (requiresPadding && !poisonPaddingOk)
1209 return failure();
1210
1211 auto [packedOutOperand, packedOutIndexingMap] =
1212 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(),
1213 genericOp.getDpsInitOperand(0),
1214 packedOperandMap);
1215 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1216
1217 // Forward the new tensor.empty as a destination if it is one of the following
1218 // situations:
1219 // 1) The dps init operand is a tensor.empty.
1220 // 2) The dps init is a write-only operand, i.e., it is not used in the
1221 // genericOp
1222 Value dest = packedOutOperand;
1223 auto initTensor =
1224 genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
1225 if (initTensor || isGenericOutsNotUsed(genericOp)) {
1226 if (destPack)
1227 dest = destPack.getDest();
1228 }
1229
1230 // Pack the genericOp.
1231 // pack(unpack) is foldable in this case. This is because in pushing down the
1232 // unpack, by default we will populate an additional pack op after the unpack.
1233 // This guarantees them to be foldable.
1234 auto maybeGenericOp =
1235 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1236 /*isFoldableUnpackPack=*/true, poisonPaddingOk);
1237 if (failed(maybeGenericOp))
1238 return failure();
1239 GenericOp newGenericOp = *maybeGenericOp;
1240 Value newResult =
1241 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1242
1243 // If the output is unaffected, no need to unpack.
1244 if (!destPack)
1245 return std::make_tuple(newGenericOp, newResult);
1246
1247 auto mixedTiles = destPack.getMixedTiles();
1248 auto innerDimsPos = destPack.getInnerDimsPos();
1249 auto outerDimsPerm = destPack.getOuterDimsPerm();
1250
1251 // Insert an unPackOp right after the packed generic.
1252 Value unPackOpRes =
1253 linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
1254 destPack.getSource(), innerDimsPos, mixedTiles,
1255 outerDimsPerm)
1256 .getResult();
1257
1258 return std::make_tuple(newGenericOp, unPackOpRes);
1259}
1260
1261// Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1262struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1263public:
1264 PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1266 bool poisonPaddingOk)
1267 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)),
1268 poisonPaddingOk(std::move(poisonPaddingOk)) {}
1269
1270 LogicalResult matchAndRewrite(GenericOp genericOp,
1271 PatternRewriter &rewriter) const override {
1272 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
1273 rewriter, genericOp, controlFn, poisonPaddingOk);
1274 if (failed(genericAndRepl))
1275 return failure();
1276 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1277 return success();
1278 }
1279
1280private:
1281 ControlPropagationFn controlFn;
1282 bool poisonPaddingOk;
1283};
1284
1285/// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
1286/// add as many zero padding dimensions in `high` and `low` based on the number
1287/// of point loops.
1288struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1289 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1290 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1291
1292 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1293 PatternRewriter &rewriter) const override {
1294 linalg::UnPackOp unpackOp =
1295 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1296 if (!unpackOp)
1297 return failure();
1298
1299 if (!unpackOp.hasPureTensorSemantics())
1300 return failure();
1301
1302 if (!controlFn(&padOp.getSourceMutable()))
1303 return failure();
1304
1305 Location loc = padOp.getLoc();
1306 // Bail out if one of the padded dimension is a tiled one.
1307 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1308 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1309 llvm::SmallBitVector innerDims(paddedDims.size());
1310 for (int64_t dim : innerDimsPos)
1311 innerDims.flip(dim);
1312 if (paddedDims.anyCommon(innerDims))
1313 return failure();
1314
1315 Value paddingVal = padOp.getConstantPaddingValue();
1316 if (!paddingVal)
1317 return failure();
1318
1319 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1320 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1321 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1322 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1323 if (!outerDimsPerm.empty()) {
1324 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1325 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1326 }
1327 // Add zero padding for the point loops.
1328 size_t pointLoopsSize = innerDimsPos.size();
1329 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1330 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1331
1332 auto newPadOp = tensor::PadOp::create(rewriter, loc, /*result=*/Type(),
1333 unpackOp.getSource(), lowPad, highPad,
1334 paddingVal, padOp.getNofold());
1335
1336 // Inject the linalg.unpack right after the packed padOp.
1337 Value outputUnPack =
1338 tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
1339 padOp.getResultType().getElementType());
1340
1341 UnPackOp replacement = linalg::UnPackOp::create(
1342 rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1343 unpackOp.getMixedTiles(), outerDimsPerm);
1344 rewriter.replaceOp(padOp, replacement);
1345 return success();
1346 }
1347
1348private:
1349 ControlPropagationFn controlFn;
1350};
1351
1352// This struct contains infomation about extract_slice dims.
1353struct SliceDimInfo {
1354 OpFoldResult offset;
1355 OpFoldResult sliceSize;
1356 OpFoldResult outputSize;
1357};
1358
1359/// Return all extract slice operands, if present, for the current
1360/// generic op.
1361static FailureOr<SmallVector<OpOperand *>>
1362getSliceOperands(GenericOp genericOp) {
1363 SmallVector<OpOperand *> sliceOperands;
1364 for (auto operand : genericOp.getDpsInputOperands()) {
1365 auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1366 if (!extractOp)
1367 continue;
1368 sliceOperands.push_back(operand);
1369 }
1370 if (sliceOperands.empty()) {
1371 return failure();
1372 }
1373 return sliceOperands;
1374}
1375
1376// Return a map of dims that have partial slices on them so that other operands
1377// can use this information. Also return a bool mentioning if a reduction dim
1378// has a non full slice as that can be used to fold the original extract slice.
1379static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
1380getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
1381 tensor::ExtractSliceOp producerSliceOp =
1382 sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1383 assert(producerSliceOp && "expect a valid ExtractSliceOp");
1384 llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
1385 SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1386 SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1387
1389 genericOp.getContext(), producerSliceOp.getSourceType().getShape());
1390
1391 for (auto [idx, expr] : llvm::enumerate(
1392 genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1393 // If we have a full slice in a dimension then we dont need to add it to
1394 // the partial slice map.
1395 if (isConstantIntValue(offsets[idx], 0) &&
1396 isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
1397 continue;
1398 }
1399 // We only support partial slices of AffineDimExprs so bail-out if thats not
1400 // the case.
1401 if (!isa<AffineDimExpr>(expr)) {
1402 return failure();
1403 }
1404 SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1405 int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1406 partialSliceDimMap[dimPos] = sliceDimInfo;
1407 }
1408 // Next check if the dims with partial slice info are used in non
1409 // AffineDimExpr in other operands and if they are then bail-out.
1410 for (OpOperand &operand : genericOp->getOpOperands()) {
1411 if (operand == *sliceOperand) {
1412 continue;
1413 }
1414 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1415 if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
1416 if (isa<AffineDimExpr>(expr)) {
1417 return false;
1418 }
1419 WalkResult status = expr.walk([&](AffineExpr expr) {
1420 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1421 if (partialSliceDimMap.contains(dimExpr.getPosition())) {
1422 return WalkResult::interrupt();
1423 }
1424 }
1425 return WalkResult::advance();
1426 });
1427 if (status.wasInterrupted()) {
1428 return true;
1429 }
1430 return false;
1431 })) {
1432 return failure();
1433 }
1434 }
1435 return partialSliceDimMap;
1436}
1437
1438static FailureOr<std::tuple<GenericOp, Value>>
1439pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
1440 GenericOp genericOp,
1441 ControlPropagationFn controlFn) {
1442 if (genericOp.getNumResults() != 1)
1443 return rewriter.notifyMatchFailure(
1444 genericOp, "propagation through multi-result generic is unsupported.");
1445 if (hasGatherSemantics(genericOp))
1446 return rewriter.notifyMatchFailure(
1447 genericOp,
1448 "propagation through generic with gather semantics is unsupported.");
1449 // Collect the sliced operand, if present.
1450 auto maybeSliceOperands = getSliceOperands(genericOp);
1451 if (failed(maybeSliceOperands))
1452 return failure();
1453 SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
1454 OpOperand *sliceOperand;
1455
1456 bool foundValidOperand = false;
1457 for (auto currSliceOperand : sliceOperands) {
1458 if (controlFn(currSliceOperand)) {
1459 sliceOperand = currSliceOperand;
1460 foundValidOperand = true;
1461 break;
1462 }
1463 }
1464 if (!foundValidOperand) {
1465 return failure();
1466 }
1467 unsigned OperandIndex = sliceOperand->getOperandNumber();
1468
1469 tensor::ExtractSliceOp producerSliceOp =
1470 sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1471 assert(producerSliceOp && "expect a valid ExtractSliceOp");
1472
1473 if (producerSliceOp.getSource().getType().getRank() !=
1474 producerSliceOp.getResult().getType().getRank()) {
1475 return rewriter.notifyMatchFailure(
1476 genericOp,
1477 "propagation of rank-reducing extract slice is unsupported.");
1478 }
1479
1480 SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
1481 if (!areAllConstantIntValue(strides, 1))
1482 return rewriter.notifyMatchFailure(
1483 genericOp, "propagation of strided extract slice is unsupported.");
1484
1485 // check if we can support the propagation of this extractSlice
1486 // through the generic op and if so return the dimensions that
1487
1488 auto maybePartialSliceDimMap =
1489 getPartialSliceDimInfo(genericOp, sliceOperand);
1490
1491 if (failed(maybePartialSliceDimMap)) {
1492 return failure();
1493 }
1494
1495 auto partialSliceDimMap = *maybePartialSliceDimMap;
1496
1498 genericOp.getIteratorTypesArray();
1499 bool hasPartialReductionDimSlice =
1500 llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
1501 int64_t sliceDim = slice.first;
1502 return iterators[sliceDim] == utils::IteratorType::reduction;
1503 });
1504
1505 // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1506 Location loc = genericOp->getLoc();
1507 AffineExpr dim0, dim1;
1508 bindDims(rewriter.getContext(), dim0, dim1);
1509 auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
1510 auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1511 return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
1512 {v1, v2});
1513 };
1514
1515 MLIRContext *ctx = genericOp.getContext();
1516 SmallVector<Value> paddedInputs;
1517 for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1518 if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1519 paddedInputs.push_back(producerSliceOp.getSource());
1520 continue;
1521 }
1522 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1523 if (IndexingMap.getNumResults() == 0) {
1524 paddedInputs.push_back(operand->get());
1525 continue;
1526 }
1527 SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
1528 getAsIndexOpFoldResult(ctx, 0));
1529 SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
1530 getAsIndexOpFoldResult(ctx, 0));
1531 for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
1532 if (!isa<AffineDimExpr>(expr)) {
1533 continue;
1534 }
1535 AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1536 if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1537 continue;
1538 }
1539 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1540 operandLowPads[idx] = sliceDimInfo.offset;
1541 operandHighPads[idx] =
1542 sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1543 sliceDimInfo.sliceSize);
1544 }
1545 auto paddingValue = ub::PoisonOp::create(
1546 rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
1547 auto paddedOperand = tensor::PadOp::create(
1548 rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
1549 paddingValue, /*nofold=*/false);
1550 paddedInputs.push_back(paddedOperand);
1551 }
1552 AffineMap outputIndexingMap =
1553 genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1554
1555 auto outputShapeType =
1556 llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1557 SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
1558 outputShapeType.getShape(),
1559 [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
1560 SmallVector<OpFoldResult> newSizes = OutputShape;
1561 SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
1562 getAsIndexOpFoldResult(ctx, 0));
1563 SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
1564 getAsIndexOpFoldResult(ctx, 0));
1565 SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
1566 getAsIndexOpFoldResult(ctx, 1));
1567 for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
1568 if (!isa<AffineDimExpr>(expr)) {
1569 continue;
1570 }
1571 AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1572 if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1573 continue;
1574 }
1575 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1576 outputLowPads[idx] = sliceDimInfo.offset;
1577 outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1578 sliceDimInfo.sliceSize);
1579 OutputShape[idx] = sliceDimInfo.outputSize;
1580 newSizes[idx] = sliceDimInfo.sliceSize;
1581 }
1582 Value newPadOutput;
1583 auto outputElType =
1584 getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
1585 if (isGenericOutsNotUsed(genericOp)) {
1586 newPadOutput =
1587 tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1588 } else {
1589 auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1590 newPadOutput = tensor::PadOp::create(
1591 rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
1592 outputHighPads, paddingValue, /*nofold=*/false);
1593 }
1594
1595 auto newGenericOp = linalg::GenericOp::create(
1596 rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
1597 genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1598 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
1599 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
1600 newGenericOp.getRegion().begin());
1601
1602 auto extractOp = tensor::ExtractSliceOp::create(
1603 rewriter, loc,
1604 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1605 outputLowPads, newSizes, newStrides);
1606 Value extractRes = extractOp.getResult();
1607
1608 return std::make_tuple(newGenericOp, extractRes);
1609}
1610
1611class PushDownExtractSliceOpThroughGenericOp final
1612 : public OpRewritePattern<GenericOp> {
1613public:
1614 PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
1616 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1617
1618 LogicalResult matchAndRewrite(GenericOp genericOp,
1619 PatternRewriter &rewriter) const override {
1620 auto genericAndRepl =
1621 pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1622 if (failed(genericAndRepl))
1623 return failure();
1624 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1625 return success();
1626 }
1627
1628private:
1629 ControlPropagationFn controlFn;
1630};
1631
1632} // namespace
1633
1636 const ControlPropagationFn &controlPackUnPackPropagation,
1637 bool PoisonPaddingOk) {
1638 patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
1639 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1640 patterns.getContext(), controlPackUnPackPropagation);
1641 patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
1642 PushDownUnPackOpThroughGenericOp>(
1643 patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
1644}
1645
1648 const ControlPropagationFn &controlPackUnPackPropagation) {
1649 patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1650 patterns.getContext(), controlPackUnPackPropagation);
1651}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
Base type for affine expression.
Definition AffineExpr.h:68
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition AffineExpr.h:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:593
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:232
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition Utils.h:402
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...