MLIR  21.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Dominance.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include <optional>
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
35 
36 namespace {
37 
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39  for (Operation &op : genericOp.getBody()->getOperations())
40  if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
41  return true;
42  return false;
43 }
44 
45 // The struct contains the infomation about mapping packing information to
46 // the iteration domain of Linalg ops.
47 struct PackInfo {
48  int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
49  // InnerDimsPos on iteration domain, which follows the order in pack ops.
50  SmallVector<int64_t> tiledDimsPos;
51  // The sizes of tiling data dimensions on iteration domain.
52  llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
53  // The mapping from a dimension of iteration domain to the corresponding inner
54  // tiling dimension on iteration domain.
55  llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
56  // The permutation of outer dims (on domain).
57  SmallVector<int64_t> outerDimsOnDomainPerm;
58 };
59 
60 template <typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
63  OpTy packOrUnPackOp) {
64  static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
65  "applies to only pack or unpack operations");
66  LLVM_DEBUG(
67  { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
68 
69  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
72  genericOp.getIteratorTypesArray();
73 
74  PackInfo packInfo;
75  int64_t origNumDims = indexingMap.getNumDims();
76  SmallVector<AffineExpr> exprs(indexingMap.getResults());
77  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
78  for (auto [index, innerDimPos, tileSize] :
79  llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
80  innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81  auto expr = exprs[innerDimPos];
82  if (!isa<AffineDimExpr>(expr))
83  return failure();
84  int64_t domainDimPos =
85  cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86  if (!isParallelIterator(iterators[domainDimPos]))
87  return failure();
88  packInfo.tiledDimsPos.push_back(domainDimPos);
89  packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90  packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
91  LLVM_DEBUG({
92  llvm::dbgs() << "map innerDimPos=" << innerDimPos
93  << " to iteration dimension (d" << domainDimPos << ", d"
94  << packInfo.tileToPointMapping[domainDimPos]
95  << "), which has size=("
96  << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
97  });
98  }
99 
100  // Bail out if a tiled dimension is present in a map but not as an affine dim
101  // expression.
102  auto areAllAffineDimExpr = [&](int dim) {
103  for (AffineMap map : indexingMaps) {
104  if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
105  return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
106  })) {
107  return false;
108  }
109  }
110  return true;
111  };
112  for (int64_t i : packInfo.tiledDimsPos)
113  if (!areAllAffineDimExpr(i))
114  return failure();
115 
116  // Get the outer dims perm on the iteration domain. Start by identifying the
117  // set of domain dims affected by the outer permutation along with the
118  // permuted ordering for those dims. Then the full outer dims permutation can
119  // be constructed by replacing the affected dims with the permuted result in a
120  // numLoops-rank identity. e.g.
121  // outerDimsPerm = [1, 2, 0]
122  // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
123  //
124  // permutedOuterDims = [4, 3, 1]
125  // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
126  //
127  // Non-affine dim expressions must not be permuted by the outer dims
128  // permutation.
129  SmallVector<int64_t> permutedOuterDims;
130  for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131  auto permutedExpr = indexingMap.getResult(dim);
132  if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133  permutedOuterDims.push_back(dimExpr.getPosition());
134  continue;
135  }
136 
137  // TODO: Allow propagation with transposes on non affine dim expressions,
138  // e.g. d0 + d1 which implies transposing both dims simultaneously while
139  // maintaining the relative position between them.
140  if (static_cast<int64_t>(index) != dim)
141  return failure();
142  }
143  if (!permutedOuterDims.empty()) {
144  int64_t outerDimIndex = 0;
145  llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
146  permutedOuterDims.end());
147  for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148  packInfo.outerDimsOnDomainPerm.push_back(
149  permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150  : i);
151  LLVM_DEBUG({
152  llvm::dbgs() << "map outer dimsDimsPerm to ";
153  for (auto dim : packInfo.outerDimsOnDomainPerm)
154  llvm::dbgs() << dim << " ";
155  llvm::dbgs() << "\n";
156  });
157  }
158 
159  return packInfo;
160 }
161 
162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
163  ArrayRef<AffineExpr> exprs) {
164  // Compute `outer_dims_perm`. See example:
165  // current exprs : (d0, d1, d2, d3) -> (d2, d3)
166  // perm : [0, 3, 1, 2]
167  // First map d2, d3 with their position in the array as:
168  // currentPositionTileLoops: dim | pos
169  // d2 | 0
170  // d3 | 1
171  // then scan `perm` in order and get the `outer_dims_perm`
172  // to be used, here it would be [1, 0].
173  assert(!perm.empty() && "expect perm not to be empty");
174  assert(!exprs.empty() && "expect exprs not to be empty");
175  if (exprs.size() == 1)
176  return {};
178  DenseMap<int64_t, int64_t> currentPositionTileLoops;
179  for (auto [pos, expr] : llvm::enumerate(exprs)) {
180  // Here we rely on the assumption that the outer dims permutation
181  // when propagating currently requires that non-affine dim expressions
182  // are not permuted, thus allowing the identity assignment below.
183  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184  currentPositionTileLoops[dimExpr.getPosition()] = pos;
185  else
186  currentPositionTileLoops[pos] = pos;
187  }
188  for (int64_t loopIdx : perm) {
189  if (currentPositionTileLoops.count(loopIdx))
190  outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
191  }
192  return outerDimsPerm;
193 }
194 
195 /// Returns a tuple for packed operand and indexing_map with the assumptions:
196 /// 1) The generic op is the producer of the pack op.
197 /// 2) The generic op has only one result.
198 /// If the operand is a scalar or packing dimensions are all irrelevant to the
199 /// operand, the operand and the updated indexing map will be returned.
200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
201 ///
202 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
203 /// #map1 = affine_map<(d0, d1) -> (d0)>
204 /// #map2 = affine_map<(d0, d1) -> (d1)>
205 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
206 /// iterator_types = ["parallel", "parallel"]}
207 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
208 /// outs(%init : tensor<?x?xf32>) {
209 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
210 /// %4 = arith.addf %arg3, %arg4 : f32
211 /// linalg.yield %4 : f32
212 /// } -> tensor<?x?xf32>
213 /// %1 = linalg.pack %0
214 /// inner_dims_pos = [0, 1]
215 /// inner_tiles = [8, 2]
216 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
217 ///
218 /// Taking the first input operand as an example, the inner tile size of d1 is
219 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
220 /// affine_map<(d1, d3)>` will be returned.
221 ///
222 /// %pack = linalg.pack %arg0
223 /// inner_dims_pos = [0]
224 /// inner_tiles = [8]
225 /// into %init : tensor<?xf32> -> tensor<?x8xf32>
226 static std::tuple<Value, AffineMap>
227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
228  GenericOp genericOp, OpOperand *opOperand) {
229  int64_t numOrigLoops = genericOp.getNumLoops();
230  int64_t numInnerLoops = packInfo.getNumTiledLoops();
231  int64_t numLoops = numOrigLoops + numInnerLoops;
232  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
233  llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
234  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
235 
236  // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
237  if (genericOp.isScalar(opOperand) || exprs.empty())
238  return std::make_tuple(opOperand->get(),
239  AffineMap::get(numLoops, 0, exprs, b.getContext()));
240 
241  // Step 1. Construct the information of packing data dimensions; append inner
242  // dimensions to the indexing maps for the operand.
243  for (auto [index, expr] : llvm::enumerate(exprs)) {
244  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245  int64_t dimPos = dimExpr.getPosition();
246  domainDimToOperandDim[dimPos] = index;
247  continue;
248  }
249  }
251  SmallVector<OpFoldResult> innerTileSizes;
252  for (auto dimPos : packInfo.tiledDimsPos) {
253  if (!domainDimToOperandDim.count(dimPos))
254  continue;
255  int64_t index = domainDimToOperandDim[dimPos];
256  innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257  innerDimsPos.push_back(index);
258  exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
259  }
260 
261  // Step 2. Handle outer dim permutations.
263  if (!packInfo.outerDimsOnDomainPerm.empty()) {
264  outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
265 
266  // Step 2.1: Fold transpose into the linalg.generic.
267  SmallVector<int64_t> inversedOuterPerm =
268  invertPermutationVector(packInfo.outerDimsOnDomainPerm);
269  for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
270  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271  int64_t dimPos = dimExpr.getPosition();
272  exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
273  continue;
274  }
275  assert(isa<AffineConstantExpr>(exprs[i]) &&
276  "Attempted to permute non-constant and non-affine dim expression");
277  }
278  // Step 2.2: Undo the transposition on `exprs` and propagate the
279  // transposition on the pack using outerDimsPerm.
280  if (!outerDimsPerm.empty()) {
281  SmallVector<AffineExpr> auxVec = exprs;
282  for (const auto &en : enumerate(outerDimsPerm))
283  auxVec[en.index()] = exprs[en.value()];
284  exprs = auxVec;
285  }
286  }
287  auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
288 
289  // The operand does not have dimensions that relates to pack op.
290  if (innerDimsPos.empty() && outerDimsPerm.empty())
291  return std::make_tuple(opOperand->get(), indexingMap);
292 
293  auto empty = linalg::PackOp::createDestinationTensor(
294  b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
295  auto packedOperand = b.create<linalg::PackOp>(
296  loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
297  /*padding=*/std::nullopt, outerDimsPerm);
298  return std::make_tuple(packedOperand, indexingMap);
299 }
300 
301 /// Pack a genericOp and return it.
302 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
303  Value dest, AffineMap packedOutIndexingMap,
304  const PackInfo &packInfo) {
305  Location loc = genericOp.getLoc();
306  SmallVector<Value> inputOperands;
307  SmallVector<AffineMap> indexingMaps;
308  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309  auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310  rewriter, loc, packInfo, genericOp, inputOperand);
311  inputOperands.push_back(packedOperand);
312  indexingMaps.push_back(packedIndexingMap);
313  }
314 
315  int64_t numInnerLoops = packInfo.getNumTiledLoops();
317  genericOp.getIteratorTypesArray();
318  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
319 
320  indexingMaps.push_back(packedOutIndexingMap);
321 
322  auto newGenericOp = rewriter.create<linalg::GenericOp>(
323  loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
324  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
325  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
326  newGenericOp.getRegion().begin());
327  return newGenericOp;
328 }
329 
330 /// Bubbles up linalg.pack op through a producer generic op. This
331 /// swap pack(generic) to generic(pack). The new generic op works on packed
332 /// domain; pack ops are created for input and output operands. E.g.,
333 ///
334 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
335 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
336 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
337 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
338 /// %3 = linalg.generic {indexing_maps = [#map0, #map0],
339 /// iterator_types = ["parallel", "parallel"]}
340 /// ins(%arg0 : tensor<?x?xf32>)
341 /// outs(%2 : tensor<?x?xf32>) {
342 /// ^bb0(%arg3: f32, %arg4: f32):
343 /// %4 = arith.addf %arg3, %arg3 : f32
344 /// linalg.yield %4 : f32
345 /// } -> tensor<?x?xf32>
346 /// %4 = linalg.pack %3
347 /// inner_dims_pos = [0, 1]
348 /// inner_tiles = [8, 2]
349 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
350 ///
351 /// will be converted to
352 ///
353 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
354 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
355 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
356 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
357 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
358 /// %0 = affine.apply #map()[%dim]
359 /// %1 = affine.apply #map1()[%dim_0]
360 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
361 /// %pack = linalg.pack %arg0
362 /// inner_dims_pos = [0, 1]
363 /// inner_tiles = [8, 2]
364 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
365 /// %3 = linalg.generic {indexing_maps = [#map2, #map2],
366 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
367 /// ins(%pack : tensor<?x?x8x2xf32>)
368 /// outs(%arg1 : tensor<?x?x8x2xf32>) {
369 /// ^bb0(%in: f32, %out: f32):
370 /// %4 = arith.addf %in, %in : f32
371 /// linalg.yield %4 : f32
372 /// } -> tensor<?x?x8x2xf32>
373 static FailureOr<GenericOp>
374 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
375  const ControlPropagationFn &controlFn) {
376  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
377  if (!genericOp)
378  return failure();
379 
380  // User controlled propagation function.
381  if (!controlFn(&packOp.getSourceMutable()))
382  return failure();
383 
384  // TODO: Enable propagation in the presence of linalg.index and
385  // tensor.extract, likely as a separate pattern as the pack information and
386  // propagation decision needs to be inferred from the region of the generic.
387  if (hasGatherSemantics(genericOp))
388  return failure();
389 
390  // TODO: Relax the restriction. We are able to bubble up the pack op through
391  // multi-result generic op. It just needs more work.
392  if (genericOp.getNumResults() != 1)
393  return failure();
394 
395  // Bail-out if the result of the generic has multiple uses, as bubbling up
396  // creates recomputation if the generic has multiple users.
397  // TODO: Enable the case where every use is an identical pack op as no
398  // recomputation is needed in that case.
399  if (!genericOp->getResult(0).hasOneUse())
400  return failure();
401 
402  // TODO: Add an option for allowing padding values. It could introduce
403  // undefined behavior if we unconditionally propagate pack op through all
404  // the ops. E.g., if the padding value is zero and there are division ops in
405  // a generic op. Some values of padding area could be NaN (0/0).
406  if (packOp.getPaddingValue())
407  return failure();
408 
409  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
410  auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
411  if (failed(packInfo))
412  return failure();
413 
414  // We want to move the pack not the generic.
415  OpBuilder::InsertionGuard guard(rewriter);
416  rewriter.setInsertionPoint(genericOp);
417 
418  // We need to handle two cases:
419  // 1) The linalg.pack destination is a tensor.empty. If this is the case, we
420  // create a new tensor.empty to avoid breaking dominance, as we are moving the
421  // linalg.pack above the linalg.generic.
422  // 2) The destination is not a tensor.empty. In this case we can replace only
423  // if the destination of the linalg.pack dominates the linalg.generic.
424  Value packOpDest = packOp.getDest();
425  if (!packOpDest.hasOneUse())
426  return failure();
427  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
428  packOpDest = rewriter.create<tensor::EmptyOp>(
429  genericOp->getLoc(), emptyOp.getMixedSizes(),
430  emptyOp.getType().getElementType());
431  } else {
432  DominanceInfo dom(genericOp);
433  if (!dom.properlyDominates(packOpDest, genericOp))
434  return failure();
435  }
436 
437  // Rebuild the indexing map for the corresponding init operand.
438  auto [packedOutOperand, packedOutIndexingMap] =
439  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
440  genericOp, opOperand);
441 
442  // If the dps init operand of the generic is a tensor.empty forward the pack
443  // op destination.
444  Value dest = packedOutOperand;
445  if (auto initTensor = genericOp.getDpsInitOperand(0)
446  ->get()
447  .getDefiningOp<tensor::EmptyOp>()) {
448  dest = packOpDest;
449  }
450  return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
451  *packInfo);
452 }
453 
454 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
455 struct BubbleUpPackOpThroughGenericOpPattern
456  : public OpRewritePattern<linalg::PackOp> {
457 public:
458  BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
460  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
461 
462  LogicalResult matchAndRewrite(linalg::PackOp packOp,
463  PatternRewriter &rewriter) const override {
464  auto genericOp =
465  bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466  if (failed(genericOp))
467  return failure();
468  rewriter.replaceOp(packOp, genericOp->getResults());
469  return success();
470  }
471 
472 private:
473  ControlPropagationFn controlFn;
474 };
475 
476 /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
477 /// add as many zero padding dimensions in `high` and `low` based on the number
478 /// of point loops.
479 class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
480 public:
481  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
482  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
483 
484  LogicalResult matchAndRewrite(linalg::PackOp packOp,
485  PatternRewriter &rewriter) const override {
486  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
487  if (!padOp)
488  return failure();
489 
490  // User controlled propagation function.
491  if (!controlFn(&packOp.getSourceMutable()))
492  return failure();
493 
494  // TODO: Enable padding when the padding values are the same.
495  if (packOp.getPaddingValue())
496  return failure();
497 
498  // Fail for non-constant padding values. The body of the pad could
499  // depend on the padding indices and/or properties of the padded
500  // tensor so for now we fail.
501  // TODO: Support non-constant padding values.
502  Value paddingVal = padOp.getConstantPaddingValue();
503  if (!paddingVal)
504  return failure();
505 
506  if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
507  return failure();
508 
509  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
510 
511  // Bail out if one of the padded dimension is a tiled one.
512  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
513  llvm::SmallBitVector innerDims(paddedDims.size());
514  for (int64_t dim : innerDimsPos)
515  innerDims.flip(dim);
516  if (paddedDims.anyCommon(innerDims))
517  return failure();
518 
519  Location loc = padOp->getLoc();
520  OpBuilder::InsertionGuard guard(rewriter);
521  rewriter.setInsertionPoint(padOp);
522 
523  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
524  SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
525  auto empty = linalg::PackOp::createDestinationTensor(
526  rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
527  outerDimsPerm);
528  auto sourcePack = rewriter.create<linalg::PackOp>(
529  loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
530  /*padding=*/std::nullopt, outerDimsPerm);
531 
532  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
533  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
534  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
535  if (!outerDimsPerm.empty()) {
536  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
538  }
539  // The tiled dimensions were verified to be unpadded above, so here we
540  // just append 0 for the inner tile dimensions.
541  size_t pointLoopsSize = innerDimsPos.size();
542  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
543  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
544 
545  auto newPadOp = rewriter.create<tensor::PadOp>(
546  loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
547  padOp.getNofold());
548 
549  // If the pad has more than one user, create an unpack on the new pad to
550  // replace the other uses.
551  if (!padOp->hasOneUse()) {
552  auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
553  rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
554  Value unpackedPad = rewriter.create<linalg::UnPackOp>(
555  loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
556  rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
557  }
558 
559  // Replace the pack with the new pad.
560  rewriter.replaceOp(packOp, newPadOp.getResult());
561 
562  return success();
563  }
564 
565 private:
566  ControlPropagationFn controlFn;
567 };
568 
569 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
570 ///
571 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
572 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
573 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
574 /// non-unit projected dims in pos [2, 3] is 2.
575 ///
576 /// If all candidates in a reassociation are unit dims, it chooses the
577 /// inner-most dim pos.
579 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
580  ArrayRef<ReassociationIndices> reassocIndices,
581  ArrayRef<int64_t> targetShape) {
582  SmallVector<int64_t> projectedDimsPos;
583  for (auto pos : dimsPos) {
584  // In the case all dims are unit, this will return the inner-most one.
585  int64_t projectedPos = reassocIndices[pos].back();
586  for (auto i : llvm::reverse(reassocIndices[pos])) {
587  int64_t dim = targetShape[i];
588  if (dim > 1 || ShapedType::isDynamic(dim)) {
589  projectedPos = i;
590  break;
591  }
592  }
593  projectedDimsPos.push_back(projectedPos);
594  }
595  return projectedDimsPos;
596 }
597 
598 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
599 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
600  ArrayRef<int64_t> shape,
601  ArrayRef<int64_t> tileSizes) {
602  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
603  int64_t dim = shape[pos];
604  if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
605  return false;
606  }
607  return true;
608 }
609 
610 /// Permutate the reassociation indices and reindex them in the sequence order.
611 /// Returns the next dim pos in the sequence.
612 ///
613 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
614 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
615 /// [[0], [1, 2]].
616 static int64_t applyPermutationAndReindexReassoc(
617  SmallVector<ReassociationIndices> &reassocIndices,
618  ArrayRef<int64_t> permutation) {
619  if (!permutation.empty())
620  applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
621  int64_t nextPos = 0;
622  for (ReassociationIndices &indices : reassocIndices) {
623  for (auto &index : indices) {
624  index = nextPos;
625  nextPos += 1;
626  }
627  }
628  return nextPos;
629 }
630 
631 /// Bubble up pack op through collapse shape op when the packed dims can be
632 /// projected to the dims before collapsing. This is possible when the inner
633 /// tile sizes can divide the projected dims.
634 ///
635 /// For example:
636 ///
637 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
638 /// : tensor<?x16x4xf32> into tensor<?x4xf32>
639 /// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1]
640 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
641 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
642 ///
643 /// can be transformed into:
644 ///
645 /// %pack = linalg.pack %in outer_dims_perm = [1, 2]
646 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
647 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
648 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
649 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
650 static LogicalResult
651 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
652  linalg::PackOp packOp,
653  PatternRewriter &rewriter) {
654  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
655  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
656  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
657 
658  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
659  SmallVector<ReassociationIndices> reassocIndices =
660  collapseOp.getReassociationIndices();
661  // Project inner tile pos to the dim pos before collapsing. For example, if
662  // dims [x, y] is collapsed into [z], packing on dim z can be projected back
663  // to pack on dim y.
664  //
665  // Project to inner-most non-unit dims to increase the chance that they can be
666  // divided by the inner tile sizes. This is correct because for [..., x, 1],
667  // packing on dim 1 is equivalent to packing on dim x.
668  SmallVector<int64_t> projectedInnerDimsPos =
669  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
670 
671  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
672  innerTileSizes)) {
673  return failure();
674  }
675  // Expand the outer dims permutation with the associated source dims for the
676  // new permutation after bubbling. This is because moving a collapsed dim is
677  // equivalent to moving the associated source dims together.
678  SmallVector<int64_t> newOuterDimsPerm;
679  for (auto outerPos : outerDimsPerm)
680  llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
681 
682  auto emptyOp = linalg::PackOp::createDestinationTensor(
683  rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
684  projectedInnerDimsPos, newOuterDimsPerm);
685  auto newPackOp = rewriter.create<linalg::PackOp>(
686  packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
687  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
688 
689  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
690  // First apply the permutation on the reassociations of the outer dims.
691  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
692  // -> [[0], [1, 2]]
693  int64_t nextPos =
694  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
695  // Then add direct mapping for the inner tile dims.
696  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
697  newReassocIndices.push_back({nextPos});
698  nextPos += 1;
699  }
700 
701  auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
702  collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
703  rewriter.replaceOp(packOp, newCollapseOp);
704 
705  return success();
706 }
707 
708 /// Project dimsPos to their collapsed positions in the reassocIndices.
709 ///
710 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
711 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
712 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
713 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
715 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
716  ArrayRef<ReassociationIndices> reassocIndices) {
717  SmallVector<int64_t> projectedPos;
718 
719  // Map each dimension to the position of corresponding reassociation index.
720  for (auto pos : dimsPos) {
721  for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
722  // If the dimension is present in the current indices group, the group
723  // position within the reassociation map is the desired projected
724  // dimension position.
725  if (llvm::is_contained(indices, pos)) {
726  projectedPos.push_back(idx);
727  break;
728  }
729  }
730  }
731  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
732 
733  return projectedPos;
734 }
735 
736 /// Bubble up pack op through expand shape op.
737 ///
738 /// For example:
739 ///
740 /// %expand = tensor.expand_shape %in [[0], [1, 2]]
741 /// : tensor<?x64xf32> into tensor<?x4x16xf32>
742 /// %pack = linalg.pack %expand outer_dims_perm = [0, 1]
743 /// inner_dims_pos = [2] inner_tiles = [8] into %empty
744 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
745 ///
746 /// can be transformed into:
747 ///
748 /// %pack = linalg.pack %in outer_dims_perm = [1, 2]
749 /// inner_dims_pos = [1] inner_tiles = [8] into %empty
750 /// : tensor<?x64xf32> -> tensor<?x8x8xf32>
751 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
752 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
753 static LogicalResult
754 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
755  linalg::PackOp packOp,
756  PatternRewriter &rewriter) {
757  // Outer dimensions permutation is not supported currently.
758  // TODO: Handle outer_dims_perm variants.
759  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
761  return rewriter.notifyMatchFailure(packOp,
762  "non-identity outer dims perm NYI");
763  }
764 
765  // Validate dimensions' relations between shape expansion and packing.
767  expandOp.getReassociationIndices();
768  ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
769  llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims);
770 
771  for (auto [idx, indices] : llvm::enumerate(reassoc)) {
772  // For each expand_shape reassociation, figure out which dimensions get
773  // packed if any.
774  llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices);
775  llvm::SetVector<int64_t> packedDims =
776  llvm::set_intersection(packDimsPos, expandDimPos);
777 
778  // The expanded dimension is not packed so, it does not affect moving pack
779  // before shape expansion - simply continue.
780  if (packedDims.empty())
781  continue;
782  // Shape expansion cannot be propagated when multiple expanded dimension are
783  // packed - in this case operation reordering would affect final element
784  // positions and/or shapes can no longer be projected.
785  if (packedDims.size() != 1)
786  return rewriter.notifyMatchFailure(
787  packOp, "only one of the expanded dimensions can be packed");
788  // Only the inner-most expanded dimension should be packed. Otherwise,
789  // elements order will be affected after operation reordering.
790  if (packedDims.front() != indices.back())
791  return rewriter.notifyMatchFailure(
792  packOp, "can only pack the inner-most expanded dimension");
793  }
794 
795  // Project pack.inner_dims_pos to positions before shape expansion.
796  SmallVector<int64_t> projectedInnerDimsPos =
797  projectDimsPosIntoReassocPos(packInnerDims, reassoc);
798 
799  // Project the shape expansion to new packed shape.
800  // The pack.outer_dims_perm is restricted to identity so, the permutation can
801  // be omitted for simplicity.
802  // TODO: Account for outer dimensions permutation.
803  //
804  // If reassociation is not possible, then reordering cannot happen.
805  // This can be caused by pack padding affecting previously expanded
806  // dimensions or packing extending dimensions.
807  RankedTensorType newPackType = linalg::PackOp::inferPackedType(
808  expandOp.getSrcType(), packOp.getStaticInnerTiles(),
809  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
810  auto reassocExpand =
811  getReassociationIndicesForReshape(newPackType, packOp.getDestType());
812  if (!reassocExpand)
813  return rewriter.notifyMatchFailure(
814  packOp, "could not reassociate dims after bubbling up");
815 
816  Value destTensor = linalg::PackOp::createDestinationTensor(
817  rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
818  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
819  Value packedVal = rewriter.create<linalg::PackOp>(
820  packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
821  packOp.getMixedTiles(), packOp.getPaddingValue(),
822  /*outerDimsPerm=*/SmallVector<int64_t>{});
823 
824  Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
825  packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
826  rewriter.replaceOp(packOp, newExpandOp);
827 
828  return success();
829 }
830 
831 class BubbleUpPackOpThroughReshapeOp final
832  : public OpRewritePattern<linalg::PackOp> {
833 public:
834  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
835  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
836 
837  LogicalResult matchAndRewrite(linalg::PackOp packOp,
838  PatternRewriter &rewriter) const override {
839  Operation *srcOp = packOp.getSource().getDefiningOp();
840  // Currently only support when the pack op is the only user.
841  if (!srcOp || !(srcOp->getNumResults() == 1) ||
842  !srcOp->getResult(0).hasOneUse()) {
843  return failure();
844  }
845  // Currently only support static inner tile sizes.
846  if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
847  return ShapedType::isDynamic(size);
848  })) {
849  return failure();
850  }
851 
852  // User controlled propagation function.
853  if (!controlFn(&packOp.getSourceMutable()))
854  return failure();
855 
857  .Case([&](tensor::CollapseShapeOp op) {
858  return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
859  })
860  .Case([&](tensor::ExpandShapeOp op) {
861  return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
862  })
863  .Default([](Operation *) { return failure(); });
864  }
865 
866 private:
867  ControlPropagationFn controlFn;
868 };
869 
870 /// Push down unpack op through expand shape op when the packed dims can be
871 /// projected to the dims after expanding. This is possible when the inner tile
872 /// sizes can divide the projected dims.
873 ///
874 /// For example:
875 ///
876 /// %unpack = linalg.unpack %in outer_dims_perm = [0, 1]
877 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
878 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
879 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
880 /// : tensor<?x256xf32> into tensor<?x256x256xf32>
881 ///
882 /// can be transformed into:
883 ///
884 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
885 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
886 /// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2]
887 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
888 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
889 static LogicalResult pushDownUnPackOpThroughExpandShape(
890  linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
891  PatternRewriter &rewriter, ControlPropagationFn controlFn) {
892  // User controlled propagation function.
893  if (!controlFn(&expandOp.getSrcMutable()))
894  return failure();
895 
896  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
897  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
898  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
899 
900  auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
901  if (!expandTy)
902  return failure();
903  ArrayRef<int64_t> dstShape = expandTy.getShape();
904  SmallVector<ReassociationIndices> reassocIndices =
905  expandOp.getReassociationIndices();
906  // Project inner tile pos to the dim pos after expanding. For example, if dims
907  // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
908  // on dim y.
909  //
910  // Project to inner-most non-unit dims to increase the chance that they can be
911  // divided by the inner tile sizes. This is correct because for [..., x, 1],
912  // unpacking on dim 1 is equivalent to unpacking on dim x.
913  SmallVector<int64_t> projectedInnerDimsPos =
914  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
915 
916  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
917  innerTileSizes)) {
918  return failure();
919  }
920  // Expand the outer dims permutation with the associated expanded dims for the
921  // new permutation after pushing. This is because moving a source dim is
922  // equivalent to moving the associated expanded dims together.
923  SmallVector<int64_t> newOuterDimsPerm;
924  for (auto outerPos : outerDimsPerm)
925  llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
926 
927  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
928  // First apply the permutation on the reassociations of the outer dims.
929  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
930  // -> [[0], [1, 2]]
931  int64_t nextPos =
932  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
933  // Then add direct mapping for the inner tile dims.
934  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
935  newReassocIndices.push_back({nextPos});
936  nextPos += 1;
937  }
938 
939  RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
940  expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
941  auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
942  expandOp.getLoc(), newExpandType, unPackOp.getSource(),
943  newReassocIndices);
944 
945  auto emptyOp = linalg::UnPackOp::createDestinationTensor(
946  rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
947  projectedInnerDimsPos, newOuterDimsPerm);
948  auto newUnPackOp = rewriter.create<linalg::UnPackOp>(
949  unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
950  projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
951  rewriter.replaceOp(expandOp, newUnPackOp);
952 
953  return success();
954 }
955 
956 class PushDownUnPackOpThroughReshapeOp final
957  : public OpRewritePattern<linalg::UnPackOp> {
958 public:
959  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
961  : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) {
962  }
963 
964  LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
965  PatternRewriter &rewriter) const override {
966  Value result = unPackOp.getResult();
967  // Currently only support unpack op with the single user.
968  if (!result.hasOneUse()) {
969  return failure();
970  }
971  // Currently only support static inner tile sizes.
972  if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
973  return ShapedType::isDynamic(size);
974  })) {
975  return failure();
976  }
977 
978  Operation *consumerOp = *result.user_begin();
979  return TypeSwitch<Operation *, LogicalResult>(consumerOp)
980  .Case([&](tensor::ExpandShapeOp op) {
981  return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
982  controlFn);
983  })
984  .Default([](Operation *) { return failure(); });
985  }
986 
987 private:
988  ControlPropagationFn controlFn;
989 };
990 
991 // TODO: Relax this restriction. We should unpack a generic op also
992 // in the presence of multiple unpack ops as producers.
993 /// Return the unpacked operand, if present, for the current generic op.
994 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
995  OpOperand *unPackedOperand = nullptr;
996  for (OpOperand &operand : genericOp->getOpOperands()) {
997  auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
998  if (!unPackOp)
999  continue;
1000  if (unPackedOperand)
1001  return failure();
1002  unPackedOperand = &operand;
1003  }
1004  if (!unPackedOperand)
1005  return failure();
1006  return unPackedOperand;
1007 }
1008 
1009 /// Push down a linalg.unpack op through a generic op.
1010 /// The new generic op works on packed domain; pack ops are created for input
1011 /// and output operands. A linalg.unpack op is inserted right after the packed
1012 /// generic. E.g.
1013 ///
1014 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1015 ///
1016 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1017 ///
1018 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1019 /// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1020 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1021 /// %2 = linalg.generic {indexing_maps = [#map],
1022 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1023 /// outs(%1 : tensor<12x56x56x64xf32>) {
1024 /// ^bb0(%out : f32):
1025 /// linalg.yield %out : f32
1026 /// } -> tensor<12x56x56x64xf32>
1027 ///
1028 /// will be converted to
1029 ///
1030 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1031 ///
1032 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1033 /// %1 = linalg.generic {indexing_maps = [#map],
1034 /// iterator_types = ["parallel", "parallel", "parallel",
1035 /// "parallel", "parallel"]}
1036 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1037 /// ^bb0(%out : f32):
1038 /// linalg.yield %out : f32
1039 /// } -> tensor<12x2x56x56x32xf32>
1040 /// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1041 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1042 ///
1043 static FailureOr<std::tuple<GenericOp, Value>>
1044 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1045  ControlPropagationFn controlFn) {
1046  if (genericOp.getNumResults() != 1)
1047  return failure();
1048 
1049  if (hasGatherSemantics(genericOp))
1050  return failure();
1051 
1052  // Collect the unPacked operand, if present.
1053  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1054  if (failed(maybeUnPackedOperand))
1055  return failure();
1056  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1057 
1058  // Extract packing information.
1059  linalg::UnPackOp producerUnPackOp =
1060  unPackedOperand->get().getDefiningOp<linalg::UnPackOp>();
1061  assert(producerUnPackOp && "expect a valid UnPackOp");
1062 
1063  if (!controlFn(unPackedOperand))
1064  return failure();
1065 
1066  auto packInfo =
1067  getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1068  if (failed(packInfo))
1069  return failure();
1070 
1071  // Rebuild the indexing map for the corresponding init operand.
1072  auto [packedOutOperand, packedOutIndexingMap] =
1073  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1074  genericOp, genericOp.getDpsInitOperand(0));
1075  auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1076 
1077  // If the dps init operand of the generic is a tensor.empty, do not pack it
1078  // and forward the new tensor.empty as a destination.
1079  Value dest = packedOutOperand;
1080  if (auto initTensor = genericOp.getDpsInitOperand(0)
1081  ->get()
1082  .getDefiningOp<tensor::EmptyOp>()) {
1083  if (destPack)
1084  dest = destPack.getDest();
1085  }
1086 
1087  // Pack the genericOp.
1088  GenericOp newGenericOp =
1089  packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1090  Value newResult =
1091  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1092 
1093  // If the output is unaffected, no need to unpack.
1094  if (!destPack)
1095  return std::make_tuple(newGenericOp, newResult);
1096 
1097  auto mixedTiles = destPack.getMixedTiles();
1098  auto innerDimsPos = destPack.getInnerDimsPos();
1099  auto outerDimsPerm = destPack.getOuterDimsPerm();
1100 
1101  // Insert an unPackOp right after the packed generic.
1102  Value unPackOpRes =
1103  rewriter
1104  .create<linalg::UnPackOp>(genericOp.getLoc(), newResult,
1105  destPack.getSource(), innerDimsPos,
1106  mixedTiles, outerDimsPerm)
1107  .getResult();
1108 
1109  return std::make_tuple(newGenericOp, unPackOpRes);
1110 }
1111 
1112 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1113 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1114 public:
1115  PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1117  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1118 
1119  LogicalResult matchAndRewrite(GenericOp genericOp,
1120  PatternRewriter &rewriter) const override {
1121  auto genericAndRepl =
1122  pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1123  if (failed(genericAndRepl))
1124  return failure();
1125  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1126  return success();
1127  }
1128 
1129 private:
1130  ControlPropagationFn controlFn;
1131 };
1132 
1133 /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
1134 /// add as many zero padding dimensions in `high` and `low` based on the number
1135 /// of point loops.
1136 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1137  PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1138  : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1139 
1140  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1141  PatternRewriter &rewriter) const override {
1142  linalg::UnPackOp unpackOp =
1143  padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1144  if (!unpackOp)
1145  return failure();
1146 
1147  if (!controlFn(&padOp.getSourceMutable()))
1148  return failure();
1149 
1150  Location loc = padOp.getLoc();
1151  // Bail out if one of the padded dimension is a tiled one.
1152  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1153  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1154  llvm::SmallBitVector innerDims(paddedDims.size());
1155  for (int64_t dim : innerDimsPos)
1156  innerDims.flip(dim);
1157  if (paddedDims.anyCommon(innerDims))
1158  return failure();
1159 
1160  Value paddingVal = padOp.getConstantPaddingValue();
1161  if (!paddingVal)
1162  return failure();
1163 
1164  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1165  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1166  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1167  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1168  if (!outerDimsPerm.empty()) {
1169  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1170  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1171  }
1172  // Add zero padding for the point loops.
1173  size_t pointLoopsSize = innerDimsPos.size();
1174  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1175  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1176 
1177  auto newPadOp = rewriter.create<tensor::PadOp>(
1178  loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1179  paddingVal, padOp.getNofold());
1180 
1181  // Inject the linalg.unpack right after the packed padOp.
1182  Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1183  loc, padOp.getResultType().getShape(),
1184  padOp.getResultType().getElementType());
1185 
1186  Value replacement = rewriter.create<linalg::UnPackOp>(
1187  loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1188  unpackOp.getMixedTiles(), outerDimsPerm);
1189  rewriter.replaceOp(padOp, replacement);
1190  return success();
1191  }
1192 
1193 private:
1194  ControlPropagationFn controlFn;
1195 };
1196 
1197 } // namespace
1198 
1201  const ControlPropagationFn &controlPackUnPackPropagation) {
1202  patterns
1203  .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1204  BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1205  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1206  patterns.getContext(), controlPackUnPackPropagation);
1207 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4515
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4513
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:324
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:576
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents an operand of an operation.
Definition: Value.h:243
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:666
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_iterator user_begin() const
Definition: Value.h:202
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:191
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:238
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:382
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
Definition: Transforms.h:1789
Include the generated interface declarations.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314