MLIR  19.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Dominance.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 #include <optional>
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26 #include "mlir/Dialect/Linalg/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 #define DEBUG_TYPE "linalg-data-layout-propagation"
33 
34 namespace {
35 
36 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37  for (Operation &op : genericOp.getBody()->getOperations())
38  if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
39  return true;
40  return false;
41 }
42 
43 // The struct contains the infomation about mapping packing information to
44 // the iteration domain of Linalg ops.
45 struct PackInfo {
46  int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
47  // InnerDimsPos on iteration domain, which follows the order in pack ops.
48  SmallVector<int64_t> tiledDimsPos;
49  // The sizes of tiling data dimensions on iteration domain.
50  llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
51  // The mapping from a dimension of iteration domain to the corresponding inner
52  // tiling dimension on iteration domain.
53  llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
54  // The permutation of outer dims (on domain).
55  SmallVector<int64_t> outerDimsOnDomainPerm;
56 };
57 
58 template <typename OpTy>
60 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
61  OpTy packOrUnPackOp) {
62  static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
63  "applies to only pack or unpack operations");
64  LLVM_DEBUG(
65  { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
66 
67  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
68  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
70  genericOp.getIteratorTypesArray();
71 
72  PackInfo packInfo;
73  int64_t origNumDims = indexingMap.getNumDims();
74  SmallVector<AffineExpr> exprs(indexingMap.getResults());
75  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
76  for (auto [index, innerDimPos, tileSize] :
77  llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
78  innerDimsPos, packOrUnPackOp.getMixedTiles())) {
79  auto expr = exprs[innerDimPos];
80  if (!isa<AffineDimExpr>(expr))
81  return failure();
82  int64_t domainDimPos =
83  cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
84  if (!isParallelIterator(iterators[domainDimPos]))
85  return failure();
86  packInfo.tiledDimsPos.push_back(domainDimPos);
87  packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88  packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
89  LLVM_DEBUG({
90  llvm::dbgs() << "map innerDimPos=" << innerDimPos
91  << " to iteration dimension (d" << domainDimPos << ", d"
92  << packInfo.tileToPointMapping[domainDimPos]
93  << "), which has size=("
94  << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
95  });
96  }
97 
98  // Bail out if a tiled dimension is present in a map but not as an affine dim
99  // expression.
100  auto areAllAffineDimExpr = [&](int dim) {
101  for (AffineMap map : indexingMaps) {
102  if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
103  return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
104  })) {
105  return false;
106  }
107  }
108  return true;
109  };
110  for (int64_t i : packInfo.tiledDimsPos)
111  if (!areAllAffineDimExpr(i))
112  return failure();
113 
114  // Get the outer dims perm on the iteration domain. Start by identifying the
115  // set of domain dims affected by the outer permutation along with the
116  // permuted ordering for those dims. Then the full outer dims permutation can
117  // be constructed by replacing the affected dims with the permuted result in a
118  // numLoops-rank identity. e.g.
119  // outerDimsPerm = [1, 2, 0]
120  // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
121  //
122  // permutedOuterDims = [4, 3, 1]
123  // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
124  //
125  // Non-affine dim expressions must not be permuted by the outer dims
126  // permutation.
127  SmallVector<int64_t> permutedOuterDims;
128  for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129  auto permutedExpr = indexingMap.getResult(dim);
130  if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131  permutedOuterDims.push_back(dimExpr.getPosition());
132  continue;
133  }
134 
135  // TODO: Allow propagation with transposes on non affine dim expressions,
136  // e.g. d0 + d1 which implies transposing both dims simultaneously while
137  // maintaining the relative position between them.
138  if (static_cast<int64_t>(index) != dim)
139  return failure();
140  }
141  if (!permutedOuterDims.empty()) {
142  int64_t outerDimIndex = 0;
143  llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
144  permutedOuterDims.end());
145  for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146  packInfo.outerDimsOnDomainPerm.push_back(
147  permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
148  : i);
149  LLVM_DEBUG({
150  llvm::dbgs() << "map outer dimsDimsPerm to ";
151  for (auto dim : packInfo.outerDimsOnDomainPerm)
152  llvm::dbgs() << dim << " ";
153  llvm::dbgs() << "\n";
154  });
155  }
156 
157  return packInfo;
158 }
159 
160 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
161  ArrayRef<AffineExpr> exprs) {
162  // Compute `outer_dims_perm`. See example:
163  // current exprs : (d0, d1, d2, d3) -> (d2, d3)
164  // perm : [0, 3, 1, 2]
165  // First map d2, d3 with their position in the array as:
166  // currentPositionTileLoops: dim | pos
167  // d2 | 0
168  // d3 | 1
169  // then scan `perm` in order and get the `outer_dims_perm`
170  // to be used, here it would be [1, 0].
171  assert(!perm.empty() && "expect perm not to be empty");
172  assert(!exprs.empty() && "expect exprs not to be empty");
173  if (exprs.size() == 1)
174  return {};
175  SmallVector<int64_t> outerDimsPerm;
176  DenseMap<int64_t, int64_t> currentPositionTileLoops;
177  for (auto [pos, expr] : llvm::enumerate(exprs)) {
178  // Here we rely on the assumption that the outer dims permutation
179  // when propagating currently requires that non-affine dim expressions
180  // are not permuted, thus allowing the identity assignment below.
181  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182  currentPositionTileLoops[dimExpr.getPosition()] = pos;
183  else
184  currentPositionTileLoops[pos] = pos;
185  }
186  for (int64_t loopIdx : perm) {
187  if (currentPositionTileLoops.count(loopIdx))
188  outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
189  }
190  return outerDimsPerm;
191 }
192 
193 /// Returns a tuple for packed operand and indexing_map with the assumptions:
194 /// 1) The generic op is the producer of the pack op.
195 /// 2) The generic op has only one result.
196 /// If the operand is a scalar or packing dimensions are all irrelevant to the
197 /// operand, the operand and the updated indexing map will be returned.
198 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
199 ///
200 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
201 /// #map1 = affine_map<(d0, d1) -> (d0)>
202 /// #map2 = affine_map<(d0, d1) -> (d1)>
203 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
204 /// iterator_types = ["parallel", "parallel"]}
205 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
206 /// outs(%init : tensor<?x?xf32>) {
207 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
208 /// %4 = arith.addf %arg3, %arg4 : f32
209 /// linalg.yield %4 : f32
210 /// } -> tensor<?x?xf32>
211 /// %1 = tensor.pack %0
212 /// inner_dims_pos = [0, 1]
213 /// inner_tiles = [8, 2]
214 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
215 ///
216 /// Taking the first input operand as an example, the inner tile size of d1 is
217 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
218 /// affine_map<(d1, d3)>` will be returned.
219 ///
220 /// %pack = tensor.pack %arg0
221 /// inner_dims_pos = [0]
222 /// inner_tiles = [8]
223 /// into %init : tensor<?xf32> -> tensor<?x8xf32>
224 static std::tuple<Value, AffineMap>
225 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
226  GenericOp genericOp, OpOperand *opOperand) {
227  int64_t numOrigLoops = genericOp.getNumLoops();
228  int64_t numInnerLoops = packInfo.getNumTiledLoops();
229  int64_t numLoops = numOrigLoops + numInnerLoops;
230  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
231  llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
232  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
233 
234  // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
235  if (genericOp.isScalar(opOperand) || exprs.empty())
236  return std::make_tuple(opOperand->get(),
237  AffineMap::get(numLoops, 0, exprs, b.getContext()));
238 
239  // Step 1. Construct the information of packing data dimensions; append inner
240  // dimensions to the indexing maps for the operand.
241  for (auto [index, expr] : llvm::enumerate(exprs)) {
242  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
243  int64_t dimPos = dimExpr.getPosition();
244  domainDimToOperandDim[dimPos] = index;
245  continue;
246  }
247  }
248  SmallVector<int64_t> innerDimsPos;
249  SmallVector<OpFoldResult> innerTileSizes;
250  for (auto dimPos : packInfo.tiledDimsPos) {
251  if (!domainDimToOperandDim.count(dimPos))
252  continue;
253  int64_t index = domainDimToOperandDim[dimPos];
254  innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
255  innerDimsPos.push_back(index);
256  exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
257  }
258 
259  // Step 2. Handle outer dim permutations.
260  SmallVector<int64_t> outerDimsPerm;
261  if (!packInfo.outerDimsOnDomainPerm.empty()) {
262  outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
263 
264  // Step 2.1: Fold transpose into the linalg.generic.
265  SmallVector<int64_t> inversedOuterPerm =
266  invertPermutationVector(packInfo.outerDimsOnDomainPerm);
267  for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
268  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
269  int64_t dimPos = dimExpr.getPosition();
270  exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
271  continue;
272  }
273  assert(isa<AffineConstantExpr>(exprs[i]) &&
274  "Attempted to permute non-constant and non-affine dim expression");
275  }
276  // Step 2.2: Undo the transposition on `exprs` and propagate the
277  // transposition on the pack using outerDimsPerm.
278  if (!outerDimsPerm.empty()) {
279  SmallVector<AffineExpr> auxVec = exprs;
280  for (const auto &en : enumerate(outerDimsPerm))
281  auxVec[en.index()] = exprs[en.value()];
282  exprs = auxVec;
283  }
284  }
285  auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
286 
287  // The operand does not have dimensions that relates to pack op.
288  if (innerDimsPos.empty() && outerDimsPerm.empty())
289  return std::make_tuple(opOperand->get(), indexingMap);
290 
291  auto empty = tensor::PackOp::createDestinationTensor(
292  b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
293  auto packedOperand = b.create<tensor::PackOp>(
294  loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
295  /*padding=*/std::nullopt, outerDimsPerm);
296  return std::make_tuple(packedOperand, indexingMap);
297 }
298 
299 /// Pack a genericOp and return it.
300 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
301  Value dest, AffineMap packedOutIndexingMap,
302  const PackInfo &packInfo) {
303  Location loc = genericOp.getLoc();
304  SmallVector<Value> inputOperands;
305  SmallVector<AffineMap> indexingMaps;
306  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
307  auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
308  rewriter, loc, packInfo, genericOp, inputOperand);
309  inputOperands.push_back(packedOperand);
310  indexingMaps.push_back(packedIndexingMap);
311  }
312 
313  int64_t numInnerLoops = packInfo.getNumTiledLoops();
315  genericOp.getIteratorTypesArray();
316  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
317 
318  indexingMaps.push_back(packedOutIndexingMap);
319 
320  auto newGenericOp = rewriter.create<linalg::GenericOp>(
321  loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
322  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
323  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
324  newGenericOp.getRegion().begin());
325  return newGenericOp;
326 }
327 
328 /// Bubbles up tensor.pack op through a producer generic op. This
329 /// swap pack(generic) to generic(pack). The new generic op works on packed
330 /// domain; pack ops are created for input and output operands. E.g.,
331 ///
332 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
333 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
334 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
335 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
336 /// %3 = linalg.generic {indexing_maps = [#map0, #map0],
337 /// iterator_types = ["parallel", "parallel"]}
338 /// ins(%arg0 : tensor<?x?xf32>)
339 /// outs(%2 : tensor<?x?xf32>) {
340 /// ^bb0(%arg3: f32, %arg4: f32):
341 /// %4 = arith.addf %arg3, %arg3 : f32
342 /// linalg.yield %4 : f32
343 /// } -> tensor<?x?xf32>
344 /// %4 = tensor.pack %3
345 /// inner_dims_pos = [0, 1]
346 /// inner_tiles = [8, 2]
347 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
348 ///
349 /// will be converted to
350 ///
351 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
352 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
353 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
354 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
355 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
356 /// %0 = affine.apply #map()[%dim]
357 /// %1 = affine.apply #map1()[%dim_0]
358 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
359 /// %pack = tensor.pack %arg0
360 /// inner_dims_pos = [0, 1]
361 /// inner_tiles = [8, 2]
362 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
363 /// %3 = linalg.generic {indexing_maps = [#map2, #map2],
364 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
365 /// ins(%pack : tensor<?x?x8x2xf32>)
366 /// outs(%arg1 : tensor<?x?x8x2xf32>) {
367 /// ^bb0(%in: f32, %out: f32):
368 /// %4 = arith.addf %in, %in : f32
369 /// linalg.yield %4 : f32
370 /// } -> tensor<?x?x8x2xf32>
372 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
373  const ControlPropagationFn &controlFn) {
374  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
375  if (!genericOp)
376  return failure();
377 
378  // User controlled propagation function.
379  if (!controlFn(genericOp))
380  return failure();
381 
382  // TODO: Enable propagation in the presence of linalg.index and
383  // tensor.extract, likely as a separate pattern as the pack information and
384  // propagation decision needs to be inferred from the region of the generic.
385  if (hasGatherSemantics(genericOp))
386  return failure();
387 
388  // TODO: Relax the restriction. We are able to bubble up the pack op through
389  // multi-result generic op. It just needs more work.
390  if (genericOp.getNumResults() != 1)
391  return failure();
392 
393  // Bail-out if the result of the generic has multiple uses, as bubbling up
394  // creates recomputation if the generic has multiple users.
395  // TODO: Enable the case where every use is an identical pack op as no
396  // recomputation is needed in that case.
397  if (!genericOp->getResult(0).hasOneUse())
398  return failure();
399 
400  // We want to move the pack not the generic.
401  OpBuilder::InsertionGuard guard(rewriter);
402  rewriter.setInsertionPoint(genericOp);
403 
404  // We need to handle two cases:
405  // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
406  // create a new tensor.empty to avoid breaking dominance, as we are moving the
407  // tensor.pack above the linalg.generic.
408  // 2) The destination is not a tensor.empty. In this case we can replace only
409  // if the destination of the tensor.pack dominates the linalg.generic.
410  Value packOpDest = packOp.getDest();
411  if (!packOpDest.hasOneUse())
412  return failure();
413  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
414  packOpDest = rewriter.create<tensor::EmptyOp>(
415  genericOp->getLoc(), emptyOp.getMixedSizes(),
416  emptyOp.getType().getElementType());
417  } else {
418  DominanceInfo dom(genericOp);
419  if (!dom.properlyDominates(packOpDest, genericOp))
420  return failure();
421  }
422 
423  // TODO: Add an option for allowing padding values. It could introduce
424  // undefined behavior if we unconditionally propagate pack op through all
425  // the ops. E.g., if the padding value is zero and there are division ops in
426  // a generic op. Some values of padding area could be NaN (0/0).
427  if (packOp.getPaddingValue())
428  return failure();
429 
430  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
431  auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
432  if (failed(packInfo))
433  return failure();
434 
435  // Rebuild the indexing map for the corresponding init operand.
436  auto [packedOutOperand, packedOutIndexingMap] =
437  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
438  genericOp, opOperand);
439 
440  // If the dps init operand of the generic is a tensor.empty forward the pack
441  // op destination.
442  Value dest = packedOutOperand;
443  if (auto initTensor = genericOp.getDpsInitOperand(0)
444  ->get()
445  .getDefiningOp<tensor::EmptyOp>()) {
446  dest = packOpDest;
447  }
448  return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
449  *packInfo);
450 }
451 
452 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
453 struct BubbleUpPackOpThroughGenericOpPattern
454  : public OpRewritePattern<tensor::PackOp> {
455 public:
456  BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
458  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
459 
460  LogicalResult matchAndRewrite(tensor::PackOp packOp,
461  PatternRewriter &rewriter) const override {
462  auto genericOp =
463  bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
464  if (failed(genericOp))
465  return failure();
466  rewriter.replaceOp(packOp, genericOp->getResults());
467  return success();
468  }
469 
470 private:
471  ControlPropagationFn controlFn;
472 };
473 
474 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
475 /// add as many zero padding dimensions in `high` and `low` based on the number
476 /// of point loops.
477 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
478 public:
479  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
480  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
481 
482  LogicalResult matchAndRewrite(tensor::PackOp packOp,
483  PatternRewriter &rewriter) const override {
484  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
485  if (!padOp)
486  return failure();
487 
488  // User controlled propagation function.
489  if (!controlFn(padOp))
490  return failure();
491 
492  if (!padOp.getResult().hasOneUse())
493  return failure();
494 
495  // TODO: Enable padding when the padding values are the same.
496  if (packOp.getPaddingValue())
497  return failure();
498 
499  // Fail for non-constant padding values. The body of the pad could
500  // depend on the padding indices and/or properties of the padded
501  // tensor so for now we fail.
502  // TODO: Support non-constant padding values.
503  Value paddingVal = padOp.getConstantPaddingValue();
504  if (!paddingVal)
505  return failure();
506 
507  if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
508  return failure();
509 
510  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
511  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
512 
513  // Bail out if one of the padded dimension is a tiled one.
514  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
515  llvm::SmallBitVector innerDims(paddedDims.size());
516  for (int64_t dim : innerDimsPos)
517  innerDims.flip(dim);
518  if (paddedDims.anyCommon(innerDims))
519  return failure();
520 
521  Location loc = padOp->getLoc();
522  OpBuilder::InsertionGuard guard(rewriter);
523  rewriter.setInsertionPoint(padOp);
524 
525  auto empty = tensor::PackOp::createDestinationTensor(
526  rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
527  outerDimsPerm);
528  Value packedSource = rewriter.create<tensor::PackOp>(
529  loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
530  /*padding=*/std::nullopt, outerDimsPerm);
531 
532  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
533  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
534  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
535  if (!outerDimsPerm.empty()) {
536  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
538  }
539  // The tiled dimensions were verified to be unpadded above, so here we
540  // just append 0 for the inner tile dimensions.
541  size_t pointLoopsSize = innerDimsPos.size();
542  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
543  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
544 
545  auto newPadOp = rewriter.create<tensor::PadOp>(
546  loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
547  padOp.getNofold());
548  rewriter.replaceOp(packOp, newPadOp.getResult());
549  return success();
550  }
551 
552 private:
553  ControlPropagationFn controlFn;
554 };
555 
556 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
557 ///
558 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
559 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
560 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
561 /// non-unit projected dims in pos [2, 3] is 2.
562 ///
563 /// If all candidates in a reassociation are unit dims, it chooses the
564 /// inner-most dim pos.
566 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
567  ArrayRef<ReassociationIndices> reassocIndices,
568  ArrayRef<int64_t> targetShape) {
569  SmallVector<int64_t> projectedDimsPos;
570  for (auto pos : dimsPos) {
571  // In the case all dims are unit, this will return the inner-most one.
572  int64_t projectedPos = reassocIndices[pos].back();
573  for (auto i : llvm::reverse(reassocIndices[pos])) {
574  int64_t dim = targetShape[i];
575  if (dim > 1 || ShapedType::isDynamic(dim)) {
576  projectedPos = i;
577  break;
578  }
579  }
580  projectedDimsPos.push_back(projectedPos);
581  }
582  return projectedDimsPos;
583 }
584 
585 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
586 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
587  ArrayRef<int64_t> shape,
588  ArrayRef<int64_t> tileSizes) {
589  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
590  int64_t dim = shape[pos];
591  if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
592  return false;
593  }
594  return true;
595 }
596 
597 /// Permutate the reassociation indices and reindex them in the sequence order.
598 /// Returns the next dim pos in the sequence.
599 ///
600 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
601 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
602 /// [[0], [1, 2]].
603 static int64_t applyPermutationAndReindexReassoc(
604  SmallVector<ReassociationIndices> &reassocIndices,
605  ArrayRef<int64_t> permutation) {
606  applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
607  int64_t nextPos = 0;
608  for (ReassociationIndices &indices : reassocIndices) {
609  for (auto &index : indices) {
610  index = nextPos;
611  nextPos += 1;
612  }
613  }
614  return nextPos;
615 }
616 
617 /// Bubble up pack op through collapse shape op when the packed dims can be
618 /// projected to the dims before collapsing. This is possible when the inner
619 /// tile sizes can divide the projected dims.
620 ///
621 /// For example:
622 ///
623 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
624 /// : tensor<?x16x4xf32> into tensor<?x4xf32>
625 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
626 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
627 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
628 ///
629 /// can be transformed into:
630 ///
631 /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
632 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
633 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
634 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
635 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
636 static LogicalResult
637 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
638  tensor::PackOp packOp,
639  PatternRewriter &rewriter) {
640  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
641  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
642  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
643 
644  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
645  SmallVector<ReassociationIndices> reassocIndices =
646  collapseOp.getReassociationIndices();
647  // Project inner tile pos to the dim pos before collapsing. For example, if
648  // dims [x, y] is collapsed into [z], packing on dim z can be projected back
649  // to pack on dim y.
650  //
651  // Project to inner-most non-unit dims to increase the chance that they can be
652  // divided by the inner tile sizes. This is correct because for [..., x, 1],
653  // packing on dim 1 is equivalent to packing on dim x.
654  SmallVector<int64_t> projectedInnerDimsPos =
655  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
656 
657  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
658  innerTileSizes)) {
659  return failure();
660  }
661  // Expand the outer dims permutation with the associated source dims for the
662  // new permutation after bubbling. This is because moving a collapsed dim is
663  // equivalent to moving the associated source dims together.
664  SmallVector<int64_t> newOuterDimsPerm;
665  for (auto outerPos : outerDimsPerm) {
666  newOuterDimsPerm.insert(newOuterDimsPerm.end(),
667  reassocIndices[outerPos].begin(),
668  reassocIndices[outerPos].end());
669  }
670 
671  auto emptyOp = tensor::PackOp::createDestinationTensor(
672  rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
673  projectedInnerDimsPos, newOuterDimsPerm);
674  auto newPackOp = rewriter.create<tensor::PackOp>(
675  packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
676  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
677 
678  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
679  // First apply the permutation on the reassociations of the outer dims.
680  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
681  // -> [[0], [1, 2]]
682  int64_t nextPos =
683  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
684  // Then add direct mapping for the inner tile dims.
685  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
686  newReassocIndices.push_back({nextPos});
687  nextPos += 1;
688  }
689 
690  auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
691  collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
692  rewriter.replaceOp(packOp, newCollapseOp);
693 
694  return success();
695 }
696 
697 class BubbleUpPackOpThroughReshapeOp final
698  : public OpRewritePattern<tensor::PackOp> {
699 public:
700  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
701  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
702 
703  LogicalResult matchAndRewrite(tensor::PackOp packOp,
704  PatternRewriter &rewriter) const override {
705  Operation *srcOp = packOp.getSource().getDefiningOp();
706  // Currently only support when the pack op is the only user.
707  if (!srcOp || !(srcOp->getNumResults() == 1) ||
708  !srcOp->getResult(0).hasOneUse()) {
709  return failure();
710  }
711  // Currently only support static inner tile sizes.
712  if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
713  return ShapedType::isDynamic(size);
714  })) {
715  return failure();
716  }
717 
718  // User controlled propagation function.
719  if (!controlFn(srcOp))
720  return failure();
721 
723  .Case([&](tensor::CollapseShapeOp op) {
724  return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
725  })
726  .Default([](Operation *) { return failure(); });
727  }
728 
729 private:
730  ControlPropagationFn controlFn;
731 };
732 
733 /// Push down unpack op through expand shape op when the packed dims can be
734 /// projected to the dims after expanding. This is possible when the inner tile
735 /// sizes can divide the projected dims.
736 ///
737 /// For example:
738 ///
739 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
740 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
741 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
742 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
743 /// : tensor<?x256xf32> into tensor<?x256x256xf32>
744 ///
745 /// can be transformed into:
746 ///
747 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
748 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
749 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
750 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
751 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
752 static LogicalResult
753 pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
754  tensor::ExpandShapeOp expandOp,
755  PatternRewriter &rewriter) {
756  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
757  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
758  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
759 
760  ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
761  SmallVector<ReassociationIndices> reassocIndices =
762  expandOp.getReassociationIndices();
763  // Project inner tile pos to the dim pos after expanding. For example, if dims
764  // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
765  // on dim y.
766  //
767  // Project to inner-most non-unit dims to increase the chance that they can be
768  // divided by the inner tile sizes. This is correct because for [..., x, 1],
769  // unpacking on dim 1 is equivalent to unpacking on dim x.
770  SmallVector<int64_t> projectedInnerDimsPos =
771  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
772 
773  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
774  innerTileSizes)) {
775  return failure();
776  }
777  // Expand the outer dims permutation with the associated expanded dims for the
778  // new permutation after pushing. This is because moving a source dim is
779  // equivalent to moving the associated expanded dims together.
780  SmallVector<int64_t> newOuterDimsPerm;
781  for (auto outerPos : outerDimsPerm) {
782  newOuterDimsPerm.insert(newOuterDimsPerm.end(),
783  reassocIndices[outerPos].begin(),
784  reassocIndices[outerPos].end());
785  }
786 
787  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
788  // First apply the permutation on the reassociations of the outer dims.
789  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
790  // -> [[0], [1, 2]]
791  int64_t nextPos =
792  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
793  // Then add direct mapping for the inner tile dims.
794  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
795  newReassocIndices.push_back({nextPos});
796  nextPos += 1;
797  }
798 
799  RankedTensorType newExpandType =
800  tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
801  projectedInnerDimsPos, newOuterDimsPerm);
802  auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
803  expandOp.getLoc(), newExpandType, unPackOp.getSource(),
804  newReassocIndices);
805 
806  auto emptyOp = tensor::UnPackOp::createDestinationTensor(
807  rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
808  projectedInnerDimsPos, newOuterDimsPerm);
809  auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
810  unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
811  projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
812  rewriter.replaceOp(expandOp, newUnPackOp);
813 
814  return success();
815 }
816 
817 class PushDownUnPackOpThroughReshapeOp final
818  : public OpRewritePattern<tensor::UnPackOp> {
819 public:
820  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
822  : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
823  }
824 
825  LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
826  PatternRewriter &rewriter) const override {
827  Value result = unPackOp.getResult();
828  // Currently only support unpack op with the single user.
829  if (!result.hasOneUse()) {
830  return failure();
831  }
832  // Currently only support static inner tile sizes.
833  if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
834  return ShapedType::isDynamic(size);
835  })) {
836  return failure();
837  }
838 
839  Operation *consumerOp = *result.user_begin();
840  // User controlled propagation function.
841  if (!controlFn(consumerOp))
842  return failure();
843 
844  return TypeSwitch<Operation *, LogicalResult>(consumerOp)
845  .Case([&](tensor::ExpandShapeOp op) {
846  return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
847  })
848  .Default([](Operation *) { return failure(); });
849  }
850 
851 private:
852  ControlPropagationFn controlFn;
853 };
854 
855 // TODO: Relax this restriction. We should unpack a generic op also
856 // in the presence of multiple unpack ops as producers.
857 /// Return the unpacked operand, if present, for the current generic op.
858 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
859  OpOperand *unPackedOperand = nullptr;
860  for (OpOperand &operand : genericOp->getOpOperands()) {
861  auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
862  if (!unPackOp)
863  continue;
864  if (unPackedOperand)
865  return failure();
866  unPackedOperand = &operand;
867  }
868  if (!unPackedOperand)
869  return failure();
870  return unPackedOperand;
871 }
872 
873 /// Push down a tensor.unpack op through a generic op.
874 /// The new generic op works on packed domain; pack ops are created for input
875 /// and output operands. A tensor.unpack op is inserted right after the packed
876 /// generic. E.g.
877 ///
878 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
879 ///
880 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
881 ///
882 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
883 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
884 /// inner_dims_pos = [3] inner_tiles = [32] into %0
885 /// %2 = linalg.generic {indexing_maps = [#map],
886 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
887 /// outs(%1 : tensor<12x56x56x64xf32>) {
888 /// ^bb0(%out : f32):
889 /// linalg.yield %out : f32
890 /// } -> tensor<12x56x56x64xf32>
891 ///
892 /// will be converted to
893 ///
894 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
895 ///
896 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
897 /// %1 = linalg.generic {indexing_maps = [#map],
898 /// iterator_types = ["parallel", "parallel", "parallel",
899 /// "parallel", "parallel"]}
900 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
901 /// ^bb0(%out : f32):
902 /// linalg.yield %out : f32
903 /// } -> tensor<12x2x56x56x32xf32>
904 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
905 /// inner_dims_pos = [3] inner_tiles = [32] into %0
906 ///
908 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
909  if (genericOp.getNumResults() != 1)
910  return failure();
911 
912  if (hasGatherSemantics(genericOp))
913  return failure();
914 
915  // Collect the unPacked operand, if present.
916  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
917  if (failed(maybeUnPackedOperand))
918  return failure();
919  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
920 
921  // Extract packing information.
922  tensor::UnPackOp producerUnPackOp =
923  unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
924  assert(producerUnPackOp && "expect a valid UnPackOp");
925  auto packInfo =
926  getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
927  if (failed(packInfo))
928  return failure();
929 
930  // Rebuild the indexing map for the corresponding init operand.
931  auto [packedOutOperand, packedOutIndexingMap] =
932  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
933  genericOp, genericOp.getDpsInitOperand(0));
934  auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
935 
936  // If the dps init operand of the generic is a tensor.empty, do not pack it
937  // and forward the new tensor.empty as a destination.
938  Value dest = packedOutOperand;
939  if (auto initTensor = genericOp.getDpsInitOperand(0)
940  ->get()
941  .getDefiningOp<tensor::EmptyOp>()) {
942  if (destPack)
943  dest = destPack.getDest();
944  }
945 
946  // Pack the genericOp.
947  GenericOp newGenericOp =
948  packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
949  Value newResult =
950  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
951 
952  // If the output is unaffected, no need to unpack.
953  if (!destPack)
954  return std::make_tuple(newGenericOp, newResult);
955 
956  auto mixedTiles = destPack.getMixedTiles();
957  auto innerDimsPos = destPack.getInnerDimsPos();
958  auto outerDimsPerm = destPack.getOuterDimsPerm();
959 
960  // If the output type for the generic differs from the source
961  // unpack op, we need to create a new destination tensor. In the
962  // dynamic case we always need a new destination.
963  auto loc = genericOp.getLoc();
964  Value unPackDest = producerUnPackOp.getDest();
965  auto genericOutType =
966  cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
967  if (producerUnPackOp.getDestType() != genericOutType ||
968  !genericOutType.hasStaticShape()) {
969  unPackDest = tensor::UnPackOp::createDestinationTensor(
970  rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
971  }
972 
973  // Insert an unPackOp right after the packed generic.
974  Value unPackOpRes =
975  rewriter
976  .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
977  mixedTiles, outerDimsPerm)
978  .getResult();
979 
980  return std::make_tuple(newGenericOp, unPackOpRes);
981 }
982 
983 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
984 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
985 public:
986  PushDownUnPackOpThroughGenericOp(MLIRContext *context,
988  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
989 
990  LogicalResult matchAndRewrite(GenericOp genericOp,
991  PatternRewriter &rewriter) const override {
992  if (!controlFn(genericOp))
993  return failure();
994 
995  auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
996  if (failed(genericAndRepl))
997  return failure();
998  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
999  return success();
1000  }
1001 
1002 private:
1003  ControlPropagationFn controlFn;
1004 };
1005 
1006 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
1007 /// add as many zero padding dimensions in `high` and `low` based on the number
1008 /// of point loops.
1009 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1010  PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1011  : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1012 
1013  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1014  PatternRewriter &rewriter) const override {
1015  tensor::UnPackOp unpackOp =
1016  padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1017  if (!unpackOp)
1018  return failure();
1019 
1020  if (!controlFn(padOp))
1021  return failure();
1022 
1023  Location loc = padOp.getLoc();
1024  // Bail out if one of the padded dimension is a tiled one.
1025  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1026  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1027  llvm::SmallBitVector innerDims(paddedDims.size());
1028  for (int64_t dim : innerDimsPos)
1029  innerDims.flip(dim);
1030  if (paddedDims.anyCommon(innerDims))
1031  return failure();
1032 
1033  Value paddingVal = padOp.getConstantPaddingValue();
1034  if (!paddingVal)
1035  return failure();
1036 
1037  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1038  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1039  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1040  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1041  if (!outerDimsPerm.empty()) {
1042  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1043  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1044  }
1045  // Add zero padding for the point loops.
1046  size_t pointLoopsSize = innerDimsPos.size();
1047  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1048  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1049 
1050  auto newPadOp = rewriter.create<tensor::PadOp>(
1051  loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1052  paddingVal, padOp.getNofold());
1053 
1054  // Inject the tensor.unpack right after the packed padOp.
1055  Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1056  loc, padOp.getResultType().getShape(),
1057  padOp.getResultType().getElementType());
1058 
1059  Value replacement = rewriter.create<tensor::UnPackOp>(
1060  loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1061  unpackOp.getMixedTiles(), outerDimsPerm);
1062  rewriter.replaceOp(padOp, replacement);
1063  return success();
1064  }
1065 
1066 private:
1067  ControlPropagationFn controlFn;
1068 };
1069 
1070 } // namespace
1071 
1073  RewritePatternSet &patterns,
1074  const ControlPropagationFn &controlPackUnPackPropagation) {
1075  patterns
1076  .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1077  BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1078  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1079  patterns.getContext(), controlPackUnPackPropagation);
1080 }
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
unsigned getNumResults() const
Definition: AffineMap.cpp:386
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
A class for computing basic dominance information.
Definition: Dominance.h:136
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:149
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:580
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
MLIRContext * getContext() const
Definition: PatternMatch.h:822
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_iterator user_begin() const
Definition: Value.h:222
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
std::function< bool(Operation *op)> ControlPropagationFn
Function type which is used to control propagation of tensor.pack/unpack ops.
Definition: Transforms.h:1544
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:371
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358