MLIR  21.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Dominance.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include <optional>
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
35 
36 namespace {
37 
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39  for (Operation &op : genericOp.getBody()->getOperations())
40  if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
41  return true;
42  return false;
43 }
44 
45 // The struct contains the infomation about mapping packing information to
46 // the iteration domain of Linalg ops.
47 struct PackInfo {
48  int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
49  // InnerDimsPos on iteration domain, which follows the order in pack ops.
50  SmallVector<int64_t> tiledDimsPos;
51  // The sizes of tiling data dimensions on iteration domain.
52  llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
53  // The mapping from a dimension of iteration domain to the corresponding inner
54  // tiling dimension on iteration domain.
55  llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
56  // The permutation of outer dims (on domain).
57  SmallVector<int64_t> outerDimsOnDomainPerm;
58 };
59 
60 template <typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
63  OpTy packOrUnPackOp) {
64  static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
65  "applies to only pack or unpack operations");
66  LLVM_DEBUG(
67  { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
68 
69  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
72  genericOp.getIteratorTypesArray();
73 
74  PackInfo packInfo;
75  int64_t origNumDims = indexingMap.getNumDims();
76  SmallVector<AffineExpr> exprs(indexingMap.getResults());
77  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
78  for (auto [index, innerDimPos, tileSize] :
79  llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
80  innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81  auto expr = exprs[innerDimPos];
82  if (!isa<AffineDimExpr>(expr))
83  return failure();
84  int64_t domainDimPos =
85  cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86  if (!isParallelIterator(iterators[domainDimPos]))
87  return failure();
88  packInfo.tiledDimsPos.push_back(domainDimPos);
89  packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90  packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
91  LLVM_DEBUG({
92  llvm::dbgs() << "map innerDimPos=" << innerDimPos
93  << " to iteration dimension (d" << domainDimPos << ", d"
94  << packInfo.tileToPointMapping[domainDimPos]
95  << "), which has size=("
96  << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
97  });
98  }
99 
100  // Bail out if a tiled dimension is present in a map but not as an affine dim
101  // expression.
102  auto areAllAffineDimExpr = [&](int dim) {
103  for (AffineMap map : indexingMaps) {
104  if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
105  return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
106  })) {
107  return false;
108  }
109  }
110  return true;
111  };
112  for (int64_t i : packInfo.tiledDimsPos)
113  if (!areAllAffineDimExpr(i))
114  return failure();
115 
116  // Get the outer dims perm on the iteration domain. Start by identifying the
117  // set of domain dims affected by the outer permutation along with the
118  // permuted ordering for those dims. Then the full outer dims permutation can
119  // be constructed by replacing the affected dims with the permuted result in a
120  // numLoops-rank identity. e.g.
121  // outerDimsPerm = [1, 2, 0]
122  // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
123  //
124  // permutedOuterDims = [4, 3, 1]
125  // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
126  //
127  // Non-affine dim expressions must not be permuted by the outer dims
128  // permutation.
129  SmallVector<int64_t> permutedOuterDims;
130  for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131  auto permutedExpr = indexingMap.getResult(dim);
132  if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133  permutedOuterDims.push_back(dimExpr.getPosition());
134  continue;
135  }
136 
137  // TODO: Allow propagation with transposes on non affine dim expressions,
138  // e.g. d0 + d1 which implies transposing both dims simultaneously while
139  // maintaining the relative position between them.
140  if (static_cast<int64_t>(index) != dim)
141  return failure();
142  }
143  if (!permutedOuterDims.empty()) {
144  int64_t outerDimIndex = 0;
145  llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
146  permutedOuterDims.end());
147  for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148  packInfo.outerDimsOnDomainPerm.push_back(
149  permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150  : i);
151  LLVM_DEBUG({
152  llvm::dbgs() << "map outer dimsDimsPerm to ";
153  for (auto dim : packInfo.outerDimsOnDomainPerm)
154  llvm::dbgs() << dim << " ";
155  llvm::dbgs() << "\n";
156  });
157  }
158 
159  return packInfo;
160 }
161 
162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
163  ArrayRef<AffineExpr> exprs) {
164  // Compute `outer_dims_perm`. See example:
165  // current exprs : (d0, d1, d2, d3) -> (d2, d3)
166  // perm : [0, 3, 1, 2]
167  // First map d2, d3 with their position in the array as:
168  // currentPositionTileLoops: dim | pos
169  // d2 | 0
170  // d3 | 1
171  // then scan `perm` in order and get the `outer_dims_perm`
172  // to be used, here it would be [1, 0].
173  assert(!perm.empty() && "expect perm not to be empty");
174  assert(!exprs.empty() && "expect exprs not to be empty");
175  if (exprs.size() == 1)
176  return {};
178  DenseMap<int64_t, int64_t> currentPositionTileLoops;
179  for (auto [pos, expr] : llvm::enumerate(exprs)) {
180  // Here we rely on the assumption that the outer dims permutation
181  // when propagating currently requires that non-affine dim expressions
182  // are not permuted, thus allowing the identity assignment below.
183  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184  currentPositionTileLoops[dimExpr.getPosition()] = pos;
185  else
186  currentPositionTileLoops[pos] = pos;
187  }
188  for (int64_t loopIdx : perm) {
189  if (currentPositionTileLoops.count(loopIdx))
190  outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
191  }
192  return outerDimsPerm;
193 }
194 
195 /// Returns a tuple for packed operand and indexing_map with the assumptions:
196 /// 1) The generic op is the producer of the pack op.
197 /// 2) The generic op has only one result.
198 /// If the operand is a scalar or packing dimensions are all irrelevant to the
199 /// operand, the operand and the updated indexing map will be returned.
200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
201 ///
202 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
203 /// #map1 = affine_map<(d0, d1) -> (d0)>
204 /// #map2 = affine_map<(d0, d1) -> (d1)>
205 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
206 /// iterator_types = ["parallel", "parallel"]}
207 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
208 /// outs(%init : tensor<?x?xf32>) {
209 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
210 /// %4 = arith.addf %arg3, %arg4 : f32
211 /// linalg.yield %4 : f32
212 /// } -> tensor<?x?xf32>
213 /// %1 = linalg.pack %0
214 /// inner_dims_pos = [0, 1]
215 /// inner_tiles = [8, 2]
216 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
217 ///
218 /// Taking the first input operand as an example, the inner tile size of d1 is
219 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
220 /// affine_map<(d1, d3)>` will be returned.
221 ///
222 /// %pack = linalg.pack %arg0
223 /// inner_dims_pos = [0]
224 /// inner_tiles = [8]
225 /// into %init : tensor<?xf32> -> tensor<?x8xf32>
226 static std::tuple<Value, AffineMap>
227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
228  GenericOp genericOp, OpOperand *opOperand) {
229  int64_t numOrigLoops = genericOp.getNumLoops();
230  int64_t numInnerLoops = packInfo.getNumTiledLoops();
231  int64_t numLoops = numOrigLoops + numInnerLoops;
232  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
233  llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
234  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
235 
236  // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
237  if (genericOp.isScalar(opOperand) || exprs.empty())
238  return std::make_tuple(opOperand->get(),
239  AffineMap::get(numLoops, 0, exprs, b.getContext()));
240 
241  // Step 1. Construct the information of packing data dimensions; append inner
242  // dimensions to the indexing maps for the operand.
243  for (auto [index, expr] : llvm::enumerate(exprs)) {
244  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245  int64_t dimPos = dimExpr.getPosition();
246  domainDimToOperandDim[dimPos] = index;
247  continue;
248  }
249  }
251  SmallVector<OpFoldResult> innerTileSizes;
252  for (auto dimPos : packInfo.tiledDimsPos) {
253  if (!domainDimToOperandDim.count(dimPos))
254  continue;
255  int64_t index = domainDimToOperandDim[dimPos];
256  innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257  innerDimsPos.push_back(index);
258  exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
259  }
260 
261  // Step 2. Handle outer dim permutations.
263  if (!packInfo.outerDimsOnDomainPerm.empty()) {
264  outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
265 
266  // Step 2.1: Fold transpose into the linalg.generic.
267  SmallVector<int64_t> inversedOuterPerm =
268  invertPermutationVector(packInfo.outerDimsOnDomainPerm);
269  for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
270  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271  int64_t dimPos = dimExpr.getPosition();
272  exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
273  continue;
274  }
275  assert(isa<AffineConstantExpr>(exprs[i]) &&
276  "Attempted to permute non-constant and non-affine dim expression");
277  }
278  // Step 2.2: Undo the transposition on `exprs` and propagate the
279  // transposition on the pack using outerDimsPerm.
280  if (!outerDimsPerm.empty()) {
281  SmallVector<AffineExpr> auxVec = exprs;
282  for (const auto &en : enumerate(outerDimsPerm))
283  auxVec[en.index()] = exprs[en.value()];
284  exprs = auxVec;
285  }
286  }
287  auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
288 
289  // The operand does not have dimensions that relates to pack op.
290  if (innerDimsPos.empty() && outerDimsPerm.empty())
291  return std::make_tuple(opOperand->get(), indexingMap);
292 
293  auto empty = linalg::PackOp::createDestinationTensor(
294  b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
295  auto packedOperand = b.create<linalg::PackOp>(
296  loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
297  /*padding=*/std::nullopt, outerDimsPerm);
298  return std::make_tuple(packedOperand, indexingMap);
299 }
300 
301 /// This function is a helper subroutine to pack a genericOp and return it. It
302 /// will create a new generic op with the packed operand and the packed output
303 /// according to packInfo when we attempt to push down unpack or bubble up pack
304 /// around it. Implicitly this will only work when a packInfo can be obtained.
305 /// This make sure that we are only using this function on parallel permuted
306 /// dimensions.
307 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
308  Value dest, AffineMap packedOutIndexingMap,
309  const PackInfo &packInfo,
310  bool isFoldableUnpackPack) {
311  Location loc = genericOp.getLoc();
312  SmallVector<Value> inputOperands;
313  SmallVector<Value> inputOperandsFromUnpackedSource;
314  SmallVector<AffineMap> indexingMaps;
315  auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
316  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
317  packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
318  llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
319  };
320  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
321  auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
322  rewriter, loc, packInfo, genericOp, inputOperand);
323  auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
324  auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
325  if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
326  inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
327  } else {
328  inputOperandsFromUnpackedSource.push_back(packedOperand);
329  }
330  inputOperands.push_back(packedOperand);
331  indexingMaps.push_back(packedIndexingMap);
332  }
333 
334  // If the unpack->pack sequences can be folded, replace use the sources of
335  // the unpack ops in any unpack->pack chains on the generic op operands.
336  if (isFoldableUnpackPack) {
337  inputOperands = inputOperandsFromUnpackedSource;
338  if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
339  auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
340  if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
341  dest = destUnPack.getSource();
342  }
343  }
344  }
345 
346  int64_t numInnerLoops = packInfo.getNumTiledLoops();
348  genericOp.getIteratorTypesArray();
349  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
350 
351  indexingMaps.push_back(packedOutIndexingMap);
352 
353  auto newGenericOp = rewriter.create<linalg::GenericOp>(
354  loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
355  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
356  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
357  newGenericOp.getRegion().begin());
358  return newGenericOp;
359 }
360 
361 /// Bubbles up linalg.pack op through a producer generic op. This
362 /// swap pack(generic) to generic(pack). The new generic op works on packed
363 /// domain; pack ops are created for input and output operands. E.g.,
364 ///
365 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
366 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
367 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
368 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
369 /// %3 = linalg.generic {indexing_maps = [#map0, #map0],
370 /// iterator_types = ["parallel", "parallel"]}
371 /// ins(%arg0 : tensor<?x?xf32>)
372 /// outs(%2 : tensor<?x?xf32>) {
373 /// ^bb0(%arg3: f32, %arg4: f32):
374 /// %4 = arith.addf %arg3, %arg3 : f32
375 /// linalg.yield %4 : f32
376 /// } -> tensor<?x?xf32>
377 /// %4 = linalg.pack %3
378 /// inner_dims_pos = [0, 1]
379 /// inner_tiles = [8, 2]
380 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
381 ///
382 /// will be converted to
383 ///
384 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
385 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
386 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
387 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
388 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
389 /// %0 = affine.apply #map()[%dim]
390 /// %1 = affine.apply #map1()[%dim_0]
391 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
392 /// %pack = linalg.pack %arg0
393 /// inner_dims_pos = [0, 1]
394 /// inner_tiles = [8, 2]
395 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
396 /// %3 = linalg.generic {indexing_maps = [#map2, #map2],
397 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
398 /// ins(%pack : tensor<?x?x8x2xf32>)
399 /// outs(%arg1 : tensor<?x?x8x2xf32>) {
400 /// ^bb0(%in: f32, %out: f32):
401 /// %4 = arith.addf %in, %in : f32
402 /// linalg.yield %4 : f32
403 /// } -> tensor<?x?x8x2xf32>
404 static FailureOr<GenericOp>
405 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
406  const ControlPropagationFn &controlFn) {
407  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
408  if (!genericOp)
409  return failure();
410 
411  // User controlled propagation function.
412  if (!controlFn(&packOp.getSourceMutable()))
413  return failure();
414 
415  // TODO: Enable propagation in the presence of linalg.index and
416  // tensor.extract, likely as a separate pattern as the pack information and
417  // propagation decision needs to be inferred from the region of the generic.
418  if (hasGatherSemantics(genericOp))
419  return failure();
420 
421  // TODO: Relax the restriction. We are able to bubble up the pack op through
422  // multi-result generic op. It just needs more work.
423  if (genericOp.getNumResults() != 1)
424  return failure();
425 
426  // Bail-out if the result of the generic has multiple uses, as bubbling up
427  // creates recomputation if the generic has multiple users.
428  // TODO: Enable the case where every use is an identical pack op as no
429  // recomputation is needed in that case.
430  if (!genericOp->getResult(0).hasOneUse())
431  return failure();
432 
433  // TODO: Add an option for allowing padding values. It could introduce
434  // undefined behavior if we unconditionally propagate pack op through all
435  // the ops. E.g., if the padding value is zero and there are division ops in
436  // a generic op. Some values of padding area could be NaN (0/0).
437  if (packOp.getPaddingValue())
438  return failure();
439 
440  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
441  auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
442  if (failed(packInfo))
443  return failure();
444 
445  // We want to move the pack not the generic.
446  OpBuilder::InsertionGuard guard(rewriter);
447  rewriter.setInsertionPoint(genericOp);
448 
449  // We need to handle two cases:
450  // 1) The linalg.pack destination is a tensor.empty. If this is the case, we
451  // create a new tensor.empty to avoid breaking dominance, as we are moving the
452  // linalg.pack above the linalg.generic.
453  // 2) The destination is not a tensor.empty. In this case we can replace only
454  // if the destination of the linalg.pack dominates the linalg.generic.
455  Value packOpDest = packOp.getDest();
456  if (!packOpDest.hasOneUse())
457  return failure();
458  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
459  packOpDest = rewriter.create<tensor::EmptyOp>(
460  genericOp->getLoc(), emptyOp.getMixedSizes(),
461  emptyOp.getType().getElementType());
462  } else {
463  DominanceInfo dom(genericOp);
464  if (!dom.properlyDominates(packOpDest, genericOp))
465  return failure();
466  }
467 
468  // Rebuild the indexing map for the corresponding init operand.
469  auto [packedOutOperand, packedOutIndexingMap] =
470  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
471  genericOp, opOperand);
472 
473  // If the dps init operand of the generic is a tensor.empty forward the pack
474  // op destination.
475  Value dest = packedOutOperand;
476  if (auto initTensor = genericOp.getDpsInitOperand(0)
477  ->get()
478  .getDefiningOp<tensor::EmptyOp>()) {
479  dest = packOpDest;
480  }
481  // pack(unpack) isn't naively foldable because the unpack op can be from
482  // an arbitrary domain so we need to keep both.
483  return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
484  *packInfo, /*isFoldableUnpackPack=*/false);
485 }
486 
487 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
488 struct BubbleUpPackOpThroughGenericOpPattern
489  : public OpRewritePattern<linalg::PackOp> {
490 public:
491  BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
493  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
494 
495  LogicalResult matchAndRewrite(linalg::PackOp packOp,
496  PatternRewriter &rewriter) const override {
497  auto genericOp =
498  bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
499  if (failed(genericOp))
500  return failure();
501  rewriter.replaceOp(packOp, genericOp->getResults());
502  return success();
503  }
504 
505 private:
506  ControlPropagationFn controlFn;
507 };
508 
509 /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
510 /// add as many zero padding dimensions in `high` and `low` based on the number
511 /// of point loops.
512 class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
513 public:
514  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
515  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
516 
517  LogicalResult matchAndRewrite(linalg::PackOp packOp,
518  PatternRewriter &rewriter) const override {
519  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
520  if (!padOp)
521  return failure();
522 
523  // User controlled propagation function.
524  if (!controlFn(&packOp.getSourceMutable()))
525  return failure();
526 
527  // TODO: Enable padding when the padding values are the same.
528  if (packOp.getPaddingValue())
529  return failure();
530 
531  // Fail for non-constant padding values. The body of the pad could
532  // depend on the padding indices and/or properties of the padded
533  // tensor so for now we fail.
534  // TODO: Support non-constant padding values.
535  Value paddingVal = padOp.getConstantPaddingValue();
536  if (!paddingVal)
537  return failure();
538 
539  if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
540  return failure();
541 
542  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
543 
544  // Bail out if one of the padded dimension is a tiled one.
545  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
546  llvm::SmallBitVector innerDims(paddedDims.size());
547  for (int64_t dim : innerDimsPos)
548  innerDims.flip(dim);
549  if (paddedDims.anyCommon(innerDims))
550  return failure();
551 
552  Location loc = padOp->getLoc();
553  OpBuilder::InsertionGuard guard(rewriter);
554  rewriter.setInsertionPoint(padOp);
555 
556  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
557  SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
558  auto empty = linalg::PackOp::createDestinationTensor(
559  rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
560  outerDimsPerm);
561  auto sourcePack = rewriter.create<linalg::PackOp>(
562  loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
563  /*padding=*/std::nullopt, outerDimsPerm);
564 
565  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
566  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
567  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
568  if (!outerDimsPerm.empty()) {
569  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
570  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
571  }
572  // The tiled dimensions were verified to be unpadded above, so here we
573  // just append 0 for the inner tile dimensions.
574  size_t pointLoopsSize = innerDimsPos.size();
575  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
576  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
577 
578  auto newPadOp = rewriter.create<tensor::PadOp>(
579  loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
580  padOp.getNofold());
581 
582  // If the pad has more than one user, create an unpack on the new pad to
583  // replace the other uses.
584  if (!padOp->hasOneUse()) {
585  auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
586  rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
587  Value unpackedPad = rewriter.create<linalg::UnPackOp>(
588  loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
589  rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
590  }
591 
592  // Replace the pack with the new pad.
593  rewriter.replaceOp(packOp, newPadOp.getResult());
594 
595  return success();
596  }
597 
598 private:
599  ControlPropagationFn controlFn;
600 };
601 
602 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
603 ///
604 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
605 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
606 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
607 /// non-unit projected dims in pos [2, 3] is 2.
608 ///
609 /// If all candidates in a reassociation are unit dims, it chooses the
610 /// inner-most dim pos.
612 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
613  ArrayRef<ReassociationIndices> reassocIndices,
614  ArrayRef<int64_t> targetShape) {
615  SmallVector<int64_t> projectedDimsPos;
616  for (auto pos : dimsPos) {
617  // In the case all dims are unit, this will return the inner-most one.
618  int64_t projectedPos = reassocIndices[pos].back();
619  for (auto i : llvm::reverse(reassocIndices[pos])) {
620  int64_t dim = targetShape[i];
621  if (dim > 1 || ShapedType::isDynamic(dim)) {
622  projectedPos = i;
623  break;
624  }
625  }
626  projectedDimsPos.push_back(projectedPos);
627  }
628  return projectedDimsPos;
629 }
630 
631 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
632 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
633  ArrayRef<int64_t> shape,
634  ArrayRef<int64_t> tileSizes) {
635  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
636  int64_t dim = shape[pos];
637  if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
638  return false;
639  }
640  return true;
641 }
642 
643 /// Permutate the reassociation indices and reindex them in the sequence order.
644 /// Returns the next dim pos in the sequence.
645 ///
646 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
647 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
648 /// [[0], [1, 2]].
649 static int64_t applyPermutationAndReindexReassoc(
650  SmallVector<ReassociationIndices> &reassocIndices,
651  ArrayRef<int64_t> permutation) {
652  if (!permutation.empty())
653  applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
654  int64_t nextPos = 0;
655  for (ReassociationIndices &indices : reassocIndices) {
656  for (auto &index : indices) {
657  index = nextPos;
658  nextPos += 1;
659  }
660  }
661  return nextPos;
662 }
663 
664 /// Bubble up pack op through collapse shape op when the packed dims can be
665 /// projected to the dims before collapsing. This is possible when the inner
666 /// tile sizes can divide the projected dims.
667 ///
668 /// For example:
669 ///
670 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
671 /// : tensor<?x16x4xf32> into tensor<?x4xf32>
672 /// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1]
673 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
674 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
675 ///
676 /// can be transformed into:
677 ///
678 /// %pack = linalg.pack %in outer_dims_perm = [1, 2]
679 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
680 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
681 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
682 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
683 static LogicalResult
684 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
685  linalg::PackOp packOp,
686  PatternRewriter &rewriter) {
687  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
688  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
689  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
690 
691  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
692  SmallVector<ReassociationIndices> reassocIndices =
693  collapseOp.getReassociationIndices();
694  // Project inner tile pos to the dim pos before collapsing. For example, if
695  // dims [x, y] is collapsed into [z], packing on dim z can be projected back
696  // to pack on dim y.
697  //
698  // Project to inner-most non-unit dims to increase the chance that they can be
699  // divided by the inner tile sizes. This is correct because for [..., x, 1],
700  // packing on dim 1 is equivalent to packing on dim x.
701  SmallVector<int64_t> projectedInnerDimsPos =
702  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
703 
704  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
705  innerTileSizes)) {
706  return failure();
707  }
708  // Expand the outer dims permutation with the associated source dims for the
709  // new permutation after bubbling. This is because moving a collapsed dim is
710  // equivalent to moving the associated source dims together.
711  SmallVector<int64_t> newOuterDimsPerm;
712  for (auto outerPos : outerDimsPerm)
713  llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
714 
715  auto emptyOp = linalg::PackOp::createDestinationTensor(
716  rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
717  projectedInnerDimsPos, newOuterDimsPerm);
718  auto newPackOp = rewriter.create<linalg::PackOp>(
719  packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
720  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
721 
722  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
723  // First apply the permutation on the reassociations of the outer dims.
724  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
725  // -> [[0], [1, 2]]
726  int64_t nextPos =
727  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
728  // Then add direct mapping for the inner tile dims.
729  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
730  newReassocIndices.push_back({nextPos});
731  nextPos += 1;
732  }
733 
734  auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
735  collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
736  rewriter.replaceOp(packOp, newCollapseOp);
737 
738  return success();
739 }
740 
741 /// Project dimsPos to their collapsed positions in the reassocIndices.
742 ///
743 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
744 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
745 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
746 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
748 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
749  ArrayRef<ReassociationIndices> reassocIndices) {
750  SmallVector<int64_t> projectedPos;
751 
752  // Map each dimension to the position of corresponding reassociation index.
753  for (auto pos : dimsPos) {
754  for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
755  // If the dimension is present in the current indices group, the group
756  // position within the reassociation map is the desired projected
757  // dimension position.
758  if (llvm::is_contained(indices, pos)) {
759  projectedPos.push_back(idx);
760  break;
761  }
762  }
763  }
764  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
765 
766  return projectedPos;
767 }
768 
769 /// Bubble up pack op through expand shape op.
770 ///
771 /// For example:
772 ///
773 /// %expand = tensor.expand_shape %in [[0], [1, 2]]
774 /// : tensor<?x64xf32> into tensor<?x4x16xf32>
775 /// %pack = linalg.pack %expand outer_dims_perm = [0, 1]
776 /// inner_dims_pos = [2] inner_tiles = [8] into %empty
777 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
778 ///
779 /// can be transformed into:
780 ///
781 /// %pack = linalg.pack %in outer_dims_perm = [1, 2]
782 /// inner_dims_pos = [1] inner_tiles = [8] into %empty
783 /// : tensor<?x64xf32> -> tensor<?x8x8xf32>
784 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
785 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
786 static LogicalResult
787 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
788  linalg::PackOp packOp,
789  PatternRewriter &rewriter) {
790  // Outer dimensions permutation is not supported currently.
791  // TODO: Handle outer_dims_perm variants.
792  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
794  return rewriter.notifyMatchFailure(packOp,
795  "non-identity outer dims perm NYI");
796  }
797 
798  // Validate dimensions' relations between shape expansion and packing.
800  expandOp.getReassociationIndices();
801  ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
802  llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims);
803 
804  for (auto [idx, indices] : llvm::enumerate(reassoc)) {
805  // For each expand_shape reassociation, figure out which dimensions get
806  // packed if any.
807  llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices);
808  llvm::SetVector<int64_t> packedDims =
809  llvm::set_intersection(packDimsPos, expandDimPos);
810 
811  // The expanded dimension is not packed so, it does not affect moving pack
812  // before shape expansion - simply continue.
813  if (packedDims.empty())
814  continue;
815  // Shape expansion cannot be propagated when multiple expanded dimension are
816  // packed - in this case operation reordering would affect final element
817  // positions and/or shapes can no longer be projected.
818  if (packedDims.size() != 1)
819  return rewriter.notifyMatchFailure(
820  packOp, "only one of the expanded dimensions can be packed");
821  // Only the inner-most expanded dimension should be packed. Otherwise,
822  // elements order will be affected after operation reordering.
823  if (packedDims.front() != indices.back())
824  return rewriter.notifyMatchFailure(
825  packOp, "can only pack the inner-most expanded dimension");
826  }
827 
828  // Project pack.inner_dims_pos to positions before shape expansion.
829  SmallVector<int64_t> projectedInnerDimsPos =
830  projectDimsPosIntoReassocPos(packInnerDims, reassoc);
831 
832  // Project the shape expansion to new packed shape.
833  // The pack.outer_dims_perm is restricted to identity so, the permutation can
834  // be omitted for simplicity.
835  // TODO: Account for outer dimensions permutation.
836  //
837  // If reassociation is not possible, then reordering cannot happen.
838  // This can be caused by pack padding affecting previously expanded
839  // dimensions or packing extending dimensions.
840  RankedTensorType newPackType = linalg::PackOp::inferPackedType(
841  expandOp.getSrcType(), packOp.getStaticInnerTiles(),
842  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
843  auto reassocExpand =
844  getReassociationIndicesForReshape(newPackType, packOp.getDestType());
845  if (!reassocExpand)
846  return rewriter.notifyMatchFailure(
847  packOp, "could not reassociate dims after bubbling up");
848 
849  Value destTensor = linalg::PackOp::createDestinationTensor(
850  rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
851  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
852  Value packedVal = rewriter.create<linalg::PackOp>(
853  packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
854  packOp.getMixedTiles(), packOp.getPaddingValue(),
855  /*outerDimsPerm=*/SmallVector<int64_t>{});
856 
857  Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
858  packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
859  rewriter.replaceOp(packOp, newExpandOp);
860 
861  return success();
862 }
863 
864 class BubbleUpPackOpThroughReshapeOp final
865  : public OpRewritePattern<linalg::PackOp> {
866 public:
867  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
868  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
869 
870  LogicalResult matchAndRewrite(linalg::PackOp packOp,
871  PatternRewriter &rewriter) const override {
872  Operation *srcOp = packOp.getSource().getDefiningOp();
873  // Currently only support when the pack op is the only user.
874  if (!srcOp || !(srcOp->getNumResults() == 1) ||
875  !srcOp->getResult(0).hasOneUse()) {
876  return failure();
877  }
878  // Currently only support static inner tile sizes.
879  if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
880  return failure();
881 
882  // User controlled propagation function.
883  if (!controlFn(&packOp.getSourceMutable()))
884  return failure();
885 
887  .Case([&](tensor::CollapseShapeOp op) {
888  return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
889  })
890  .Case([&](tensor::ExpandShapeOp op) {
891  return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
892  })
893  .Default([](Operation *) { return failure(); });
894  }
895 
896 private:
897  ControlPropagationFn controlFn;
898 };
899 
900 /// Push down unpack op through expand shape op when the packed dims can be
901 /// projected to the dims after expanding. This is possible when the inner tile
902 /// sizes can divide the projected dims.
903 ///
904 /// For example:
905 ///
906 /// %unpack = linalg.unpack %in outer_dims_perm = [0, 1]
907 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
908 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
909 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
910 /// : tensor<?x256xf32> into tensor<?x256x256xf32>
911 ///
912 /// can be transformed into:
913 ///
914 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
915 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
916 /// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2]
917 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
918 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
919 static LogicalResult pushDownUnPackOpThroughExpandShape(
920  linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
921  PatternRewriter &rewriter, ControlPropagationFn controlFn) {
922  // User controlled propagation function.
923  if (!controlFn(&expandOp.getSrcMutable()))
924  return failure();
925 
926  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
927  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
928  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
929 
930  auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
931  if (!expandTy)
932  return failure();
933  ArrayRef<int64_t> dstShape = expandTy.getShape();
934  SmallVector<ReassociationIndices> reassocIndices =
935  expandOp.getReassociationIndices();
936  // Project inner tile pos to the dim pos after expanding. For example, if dims
937  // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
938  // on dim y.
939  //
940  // Project to inner-most non-unit dims to increase the chance that they can be
941  // divided by the inner tile sizes. This is correct because for [..., x, 1],
942  // unpacking on dim 1 is equivalent to unpacking on dim x.
943  SmallVector<int64_t> projectedInnerDimsPos =
944  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
945 
946  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
947  innerTileSizes)) {
948  return failure();
949  }
950  // Expand the outer dims permutation with the associated expanded dims for the
951  // new permutation after pushing. This is because moving a source dim is
952  // equivalent to moving the associated expanded dims together.
953  SmallVector<int64_t> newOuterDimsPerm;
954  for (auto outerPos : outerDimsPerm)
955  llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
956 
957  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
958  // First apply the permutation on the reassociations of the outer dims.
959  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
960  // -> [[0], [1, 2]]
961  int64_t nextPos =
962  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
963  // Then add direct mapping for the inner tile dims.
964  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
965  newReassocIndices.push_back({nextPos});
966  nextPos += 1;
967  }
968 
969  RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
970  expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
971  auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
972  expandOp.getLoc(), newExpandType, unPackOp.getSource(),
973  newReassocIndices);
974 
975  auto emptyOp = linalg::UnPackOp::createDestinationTensor(
976  rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
977  projectedInnerDimsPos, newOuterDimsPerm);
978  auto newUnPackOp = rewriter.create<linalg::UnPackOp>(
979  unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
980  projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
981  rewriter.replaceOp(expandOp, newUnPackOp);
982 
983  return success();
984 }
985 
986 class PushDownUnPackOpThroughReshapeOp final
987  : public OpRewritePattern<linalg::UnPackOp> {
988 public:
989  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
991  : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) {
992  }
993 
994  LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
995  PatternRewriter &rewriter) const override {
996  Value result = unPackOp.getResult();
997  // Currently only support unpack op with the single user.
998  if (!result.hasOneUse()) {
999  return failure();
1000  }
1001  // Currently only support static inner tile sizes.
1002  if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1003  return failure();
1004 
1005  Operation *consumerOp = *result.user_begin();
1006  return TypeSwitch<Operation *, LogicalResult>(consumerOp)
1007  .Case([&](tensor::ExpandShapeOp op) {
1008  return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1009  controlFn);
1010  })
1011  .Default([](Operation *) { return failure(); });
1012  }
1013 
1014 private:
1015  ControlPropagationFn controlFn;
1016 };
1017 
1018 // TODO: Relax this restriction. We should unpack a generic op also
1019 // in the presence of multiple unpack ops as producers.
1020 /// Return the unpacked operand, if present, for the current generic op.
1021 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1022  OpOperand *unPackedOperand = nullptr;
1023  for (OpOperand &operand : genericOp->getOpOperands()) {
1024  auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
1025  if (!unPackOp)
1026  continue;
1027  if (unPackedOperand)
1028  return failure();
1029  unPackedOperand = &operand;
1030  }
1031  if (!unPackedOperand)
1032  return failure();
1033  return unPackedOperand;
1034 }
1035 
1036 /// Push down a linalg.unpack op through a generic op.
1037 /// The new generic op works on packed domain; pack ops are created for input
1038 /// and output operands. A linalg.unpack op is inserted right after the packed
1039 /// generic. E.g.
1040 ///
1041 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1042 ///
1043 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1044 ///
1045 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1046 /// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1047 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1048 /// %2 = linalg.generic {indexing_maps = [#map],
1049 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1050 /// outs(%1 : tensor<12x56x56x64xf32>) {
1051 /// ^bb0(%out : f32):
1052 /// linalg.yield %out : f32
1053 /// } -> tensor<12x56x56x64xf32>
1054 ///
1055 /// will be converted to
1056 ///
1057 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1058 ///
1059 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1060 /// %1 = linalg.generic {indexing_maps = [#map],
1061 /// iterator_types = ["parallel", "parallel", "parallel",
1062 /// "parallel", "parallel"]}
1063 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1064 /// ^bb0(%out : f32):
1065 /// linalg.yield %out : f32
1066 /// } -> tensor<12x2x56x56x32xf32>
1067 /// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1068 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1069 ///
1070 static FailureOr<std::tuple<GenericOp, Value>>
1071 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1072  ControlPropagationFn controlFn) {
1073  if (genericOp.getNumResults() != 1)
1074  return failure();
1075 
1076  if (hasGatherSemantics(genericOp))
1077  return failure();
1078 
1079  // Collect the unPacked operand, if present.
1080  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1081  if (failed(maybeUnPackedOperand))
1082  return failure();
1083  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1084 
1085  // Extract packing information.
1086  linalg::UnPackOp producerUnPackOp =
1087  unPackedOperand->get().getDefiningOp<linalg::UnPackOp>();
1088  assert(producerUnPackOp && "expect a valid UnPackOp");
1089 
1090  if (!controlFn(unPackedOperand))
1091  return failure();
1092 
1093  auto packInfo =
1094  getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1095  if (failed(packInfo))
1096  return failure();
1097 
1098  // Rebuild the indexing map for the corresponding init operand.
1099  auto [packedOutOperand, packedOutIndexingMap] =
1100  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1101  genericOp, genericOp.getDpsInitOperand(0));
1102  auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1103 
1104  // If the dps init operand of the generic is a tensor.empty, do not pack it
1105  // and forward the new tensor.empty as a destination.
1106  Value dest = packedOutOperand;
1107  if (auto initTensor = genericOp.getDpsInitOperand(0)
1108  ->get()
1109  .getDefiningOp<tensor::EmptyOp>()) {
1110  if (destPack)
1111  dest = destPack.getDest();
1112  }
1113 
1114  // Pack the genericOp.
1115  // pack(unpack) is foldable in this case. This is because in pushing down the
1116  // unpack, by default we will populate an additional pack op after the unpack.
1117  // This guarantees them to be foldable.
1118  GenericOp newGenericOp =
1119  packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1120  /*isFoldableUnpackPack=*/true);
1121  Value newResult =
1122  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1123 
1124  // If the output is unaffected, no need to unpack.
1125  if (!destPack)
1126  return std::make_tuple(newGenericOp, newResult);
1127 
1128  auto mixedTiles = destPack.getMixedTiles();
1129  auto innerDimsPos = destPack.getInnerDimsPos();
1130  auto outerDimsPerm = destPack.getOuterDimsPerm();
1131 
1132  // Insert an unPackOp right after the packed generic.
1133  Value unPackOpRes =
1134  rewriter
1135  .create<linalg::UnPackOp>(genericOp.getLoc(), newResult,
1136  destPack.getSource(), innerDimsPos,
1137  mixedTiles, outerDimsPerm)
1138  .getResult();
1139 
1140  return std::make_tuple(newGenericOp, unPackOpRes);
1141 }
1142 
1143 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1144 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1145 public:
1146  PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1148  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1149 
1150  LogicalResult matchAndRewrite(GenericOp genericOp,
1151  PatternRewriter &rewriter) const override {
1152  auto genericAndRepl =
1153  pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1154  if (failed(genericAndRepl))
1155  return failure();
1156  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1157  return success();
1158  }
1159 
1160 private:
1161  ControlPropagationFn controlFn;
1162 };
1163 
1164 /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
1165 /// add as many zero padding dimensions in `high` and `low` based on the number
1166 /// of point loops.
1167 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1168  PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1169  : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1170 
1171  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1172  PatternRewriter &rewriter) const override {
1173  linalg::UnPackOp unpackOp =
1174  padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1175  if (!unpackOp)
1176  return failure();
1177 
1178  if (!controlFn(&padOp.getSourceMutable()))
1179  return failure();
1180 
1181  Location loc = padOp.getLoc();
1182  // Bail out if one of the padded dimension is a tiled one.
1183  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1184  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1185  llvm::SmallBitVector innerDims(paddedDims.size());
1186  for (int64_t dim : innerDimsPos)
1187  innerDims.flip(dim);
1188  if (paddedDims.anyCommon(innerDims))
1189  return failure();
1190 
1191  Value paddingVal = padOp.getConstantPaddingValue();
1192  if (!paddingVal)
1193  return failure();
1194 
1195  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1196  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1197  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1198  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1199  if (!outerDimsPerm.empty()) {
1200  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1201  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1202  }
1203  // Add zero padding for the point loops.
1204  size_t pointLoopsSize = innerDimsPos.size();
1205  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1206  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1207 
1208  auto newPadOp = rewriter.create<tensor::PadOp>(
1209  loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1210  paddingVal, padOp.getNofold());
1211 
1212  // Inject the linalg.unpack right after the packed padOp.
1213  Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1214  loc, padOp.getResultType().getShape(),
1215  padOp.getResultType().getElementType());
1216 
1217  Value replacement = rewriter.create<linalg::UnPackOp>(
1218  loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1219  unpackOp.getMixedTiles(), outerDimsPerm);
1220  rewriter.replaceOp(padOp, replacement);
1221  return success();
1222  }
1223 
1224 private:
1225  ControlPropagationFn controlFn;
1226 };
1227 
1228 } // namespace
1229 
1232  const ControlPropagationFn &controlPackUnPackPropagation) {
1233  patterns
1234  .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1235  BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1236  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1237  patterns.getContext(), controlPackUnPackPropagation);
1238 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4585
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4583
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:55
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:324
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:576
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:666
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_iterator user_begin() const
Definition: Value.h:216
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:238
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:382
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
Definition: Transforms.h:1790
Include the generated interface declarations.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314