MLIR  22.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Dominance.h"
17 #include "llvm/ADT/SetOperations.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/Debug.h"
21 #include <optional>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 #define DEBUG_TYPE "linalg-data-layout-propagation"
32 
33 namespace {
34 
35 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
36  for (Operation &op : genericOp.getBody()->getOperations())
37  if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
38  return true;
39  return false;
40 }
41 
42 // The struct contains the infomation about mapping packing information to
43 // the iteration domain of Linalg ops.
44 struct PackInfo {
45  int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
46  // InnerDimsPos on iteration domain, which follows the order in pack ops.
47  SmallVector<int64_t> tiledDimsPos;
48  // The sizes of tiling data dimensions on iteration domain.
49  llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
50  // The mapping from a dimension of iteration domain to the corresponding inner
51  // tiling dimension on iteration domain.
52  llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
53  // The permutation of outer dims (on domain).
54  SmallVector<int64_t> outerDimsOnDomainPerm;
55 };
56 
57 template <typename OpTy>
58 static FailureOr<PackInfo>
59 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
60  OpTy packOrUnPackOp) {
61  static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
62  "applies to only pack or unpack operations");
63  LLVM_DEBUG(
64  { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
65 
66  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
67  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
69  genericOp.getIteratorTypesArray();
70 
71  PackInfo packInfo;
72  int64_t origNumDims = indexingMap.getNumDims();
73  SmallVector<AffineExpr> exprs(indexingMap.getResults());
74  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
75  for (auto [index, innerDimPos, tileSize] :
76  llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
77  innerDimsPos, packOrUnPackOp.getMixedTiles())) {
78  auto expr = exprs[innerDimPos];
79  if (!isa<AffineDimExpr>(expr))
80  return failure();
81  int64_t domainDimPos =
82  cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
83  if (!isParallelIterator(iterators[domainDimPos]))
84  return failure();
85  packInfo.tiledDimsPos.push_back(domainDimPos);
86  packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
87  packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
88  LLVM_DEBUG({
89  llvm::dbgs() << "map innerDimPos=" << innerDimPos
90  << " to iteration dimension (d" << domainDimPos << ", d"
91  << packInfo.tileToPointMapping[domainDimPos]
92  << "), which has size=("
93  << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
94  });
95  }
96 
97  // Bail out if a tiled dimension is present in a map but not as an affine dim
98  // expression.
99  auto areAllAffineDimExpr = [&](int dim) {
100  for (AffineMap map : indexingMaps) {
101  if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
102  return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
103  })) {
104  return false;
105  }
106  }
107  return true;
108  };
109  for (int64_t i : packInfo.tiledDimsPos)
110  if (!areAllAffineDimExpr(i))
111  return failure();
112 
113  // Get the outer dims perm on the iteration domain. Start by identifying the
114  // set of domain dims affected by the outer permutation along with the
115  // permuted ordering for those dims. Then the full outer dims permutation can
116  // be constructed by replacing the affected dims with the permuted result in a
117  // numLoops-rank identity. e.g.
118  // outerDimsPerm = [1, 2, 0]
119  // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
120  //
121  // permutedOuterDims = [4, 3, 1]
122  // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
123  //
124  // Non-affine dim expressions must not be permuted by the outer dims
125  // permutation.
126  SmallVector<int64_t> permutedOuterDims;
127  for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
128  auto permutedExpr = indexingMap.getResult(dim);
129  if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
130  permutedOuterDims.push_back(dimExpr.getPosition());
131  continue;
132  }
133 
134  // TODO: Allow propagation with transposes on non affine dim expressions,
135  // e.g. d0 + d1 which implies transposing both dims simultaneously while
136  // maintaining the relative position between them.
137  if (static_cast<int64_t>(index) != dim)
138  return failure();
139  }
140  if (!permutedOuterDims.empty()) {
141  int64_t outerDimIndex = 0;
142  llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
143  permutedOuterDims.end());
144  for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
145  packInfo.outerDimsOnDomainPerm.push_back(
146  permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
147  : i);
148  LLVM_DEBUG({
149  llvm::dbgs() << "map outer dimsDimsPerm to ";
150  for (auto dim : packInfo.outerDimsOnDomainPerm)
151  llvm::dbgs() << dim << " ";
152  llvm::dbgs() << "\n";
153  });
154  }
155 
156  return packInfo;
157 }
158 
159 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
160  ArrayRef<AffineExpr> exprs) {
161  // Compute `outer_dims_perm`. See example:
162  // current exprs : (d0, d1, d2, d3) -> (d2, d3)
163  // perm : [0, 3, 1, 2]
164  // First map d2, d3 with their position in the array as:
165  // currentPositionTileLoops: dim | pos
166  // d2 | 0
167  // d3 | 1
168  // then scan `perm` in order and get the `outer_dims_perm`
169  // to be used, here it would be [1, 0].
170  assert(!perm.empty() && "expect perm not to be empty");
171  assert(!exprs.empty() && "expect exprs not to be empty");
172  if (exprs.size() == 1)
173  return {};
175  DenseMap<int64_t, int64_t> currentPositionTileLoops;
176  for (auto [pos, expr] : llvm::enumerate(exprs)) {
177  // Here we rely on the assumption that the outer dims permutation
178  // when propagating currently requires that non-affine dim expressions
179  // are not permuted, thus allowing the identity assignment below.
180  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
181  currentPositionTileLoops[dimExpr.getPosition()] = pos;
182  else
183  currentPositionTileLoops[pos] = pos;
184  }
185  for (int64_t loopIdx : perm) {
186  if (currentPositionTileLoops.count(loopIdx))
187  outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
188  }
189  return outerDimsPerm;
190 }
191 
192 /// Returns a tuple for packed operand and indexing_map with the assumptions:
193 /// 1) The generic op is the producer of the pack op.
194 /// 2) The generic op has only one result.
195 /// If the operand is a scalar or packing dimensions are all irrelevant to the
196 /// operand, the operand and the updated indexing map will be returned.
197 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
198 ///
199 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
200 /// #map1 = affine_map<(d0, d1) -> (d0)>
201 /// #map2 = affine_map<(d0, d1) -> (d1)>
202 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
203 /// iterator_types = ["parallel", "parallel"]}
204 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
205 /// outs(%init : tensor<?x?xf32>) {
206 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
207 /// %4 = arith.addf %arg3, %arg4 : f32
208 /// linalg.yield %4 : f32
209 /// } -> tensor<?x?xf32>
210 /// %1 = linalg.pack %0
211 /// inner_dims_pos = [0, 1]
212 /// inner_tiles = [8, 2]
213 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
214 ///
215 /// Taking the first input operand as an example, the inner tile size of d1 is
216 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
217 /// affine_map<(d1, d3)>` will be returned.
218 ///
219 /// %pack = linalg.pack %arg0
220 /// inner_dims_pos = [0]
221 /// inner_tiles = [8]
222 /// into %init : tensor<?xf32> -> tensor<?x8xf32>
223 static std::tuple<Value, AffineMap>
224 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
225  GenericOp genericOp, OpOperand *opOperand) {
226  int64_t numOrigLoops = genericOp.getNumLoops();
227  int64_t numInnerLoops = packInfo.getNumTiledLoops();
228  int64_t numLoops = numOrigLoops + numInnerLoops;
229  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
230  llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
231  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
232 
233  // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
234  if (genericOp.isScalar(opOperand) || exprs.empty())
235  return std::make_tuple(opOperand->get(),
236  AffineMap::get(numLoops, 0, exprs, b.getContext()));
237 
238  // Step 1. Construct the information of packing data dimensions; append inner
239  // dimensions to the indexing maps for the operand.
240  for (auto [index, expr] : llvm::enumerate(exprs)) {
241  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
242  int64_t dimPos = dimExpr.getPosition();
243  domainDimToOperandDim[dimPos] = index;
244  continue;
245  }
246  }
248  SmallVector<OpFoldResult> innerTileSizes;
249  for (auto dimPos : packInfo.tiledDimsPos) {
250  if (!domainDimToOperandDim.count(dimPos))
251  continue;
252  int64_t index = domainDimToOperandDim[dimPos];
253  innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
254  innerDimsPos.push_back(index);
255  exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
256  }
257 
258  // Step 2. Handle outer dim permutations.
260  if (!packInfo.outerDimsOnDomainPerm.empty()) {
261  outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
262 
263  // Step 2.1: Fold transpose into the linalg.generic.
264  SmallVector<int64_t> inversedOuterPerm =
265  invertPermutationVector(packInfo.outerDimsOnDomainPerm);
266  for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
267  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
268  int64_t dimPos = dimExpr.getPosition();
269  exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
270  continue;
271  }
272  assert(isa<AffineConstantExpr>(exprs[i]) &&
273  "Attempted to permute non-constant and non-affine dim expression");
274  }
275  // Step 2.2: Undo the transposition on `exprs` and propagate the
276  // transposition on the pack using outerDimsPerm.
277  if (!outerDimsPerm.empty()) {
278  SmallVector<AffineExpr> auxVec = exprs;
279  for (const auto &en : enumerate(outerDimsPerm))
280  auxVec[en.index()] = exprs[en.value()];
281  exprs = auxVec;
282  }
283  }
284  auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
285 
286  // The operand does not have dimensions that relates to pack op.
287  if (innerDimsPos.empty() && outerDimsPerm.empty())
288  return std::make_tuple(opOperand->get(), indexingMap);
289 
290  auto empty = linalg::PackOp::createDestinationTensor(
291  b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
292  auto packedOperand = linalg::PackOp::create(
293  b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
294  /*padding=*/std::nullopt, outerDimsPerm);
295  return std::make_tuple(packedOperand, indexingMap);
296 }
297 
298 /// This function is a helper subroutine to pack a genericOp and return it. It
299 /// will create a new generic op with the packed operand and the packed output
300 /// according to packInfo when we attempt to push down unpack or bubble up pack
301 /// around it. Implicitly this will only work when a packInfo can be obtained.
302 /// This make sure that we are only using this function on parallel permuted
303 /// dimensions.
304 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
305  Value dest, AffineMap packedOutIndexingMap,
306  const PackInfo &packInfo,
307  bool isFoldableUnpackPack) {
308  Location loc = genericOp.getLoc();
309  SmallVector<Value> inputOperands;
310  SmallVector<Value> inputOperandsFromUnpackedSource;
311  SmallVector<AffineMap> indexingMaps;
312  auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
313  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
314  packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
315  llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
316  };
317  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
318  auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
319  rewriter, loc, packInfo, genericOp, inputOperand);
320  auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
321  auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
322  if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
323  inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
324  } else {
325  inputOperandsFromUnpackedSource.push_back(packedOperand);
326  }
327  inputOperands.push_back(packedOperand);
328  indexingMaps.push_back(packedIndexingMap);
329  }
330 
331  // If the unpack->pack sequences can be folded, replace use the sources of
332  // the unpack ops in any unpack->pack chains on the generic op operands.
333  if (isFoldableUnpackPack) {
334  inputOperands = inputOperandsFromUnpackedSource;
335  if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
336  auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
337  if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
338  dest = destUnPack.getSource();
339  }
340  }
341  }
342 
343  int64_t numInnerLoops = packInfo.getNumTiledLoops();
345  genericOp.getIteratorTypesArray();
346  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
347 
348  indexingMaps.push_back(packedOutIndexingMap);
349 
350  auto newGenericOp = linalg::GenericOp::create(
351  rewriter, loc, dest.getType(), inputOperands, dest, indexingMaps,
352  iterTypes,
353  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
354  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
355  newGenericOp.getRegion().begin());
356  return newGenericOp;
357 }
358 
359 static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
360  return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
361  return genericOp.getMatchingBlockArgument(&operand).use_empty();
362  });
363 }
364 
365 /// Bubbles up linalg.pack op through a producer generic op. This
366 /// swap pack(generic) to generic(pack). The new generic op works on packed
367 /// domain; pack ops are created for input and output operands. E.g.,
368 ///
369 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
370 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
371 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
372 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
373 /// %3 = linalg.generic {indexing_maps = [#map0, #map0],
374 /// iterator_types = ["parallel", "parallel"]}
375 /// ins(%arg0 : tensor<?x?xf32>)
376 /// outs(%2 : tensor<?x?xf32>) {
377 /// ^bb0(%arg3: f32, %arg4: f32):
378 /// %4 = arith.addf %arg3, %arg3 : f32
379 /// linalg.yield %4 : f32
380 /// } -> tensor<?x?xf32>
381 /// %4 = linalg.pack %3
382 /// inner_dims_pos = [0, 1]
383 /// inner_tiles = [8, 2]
384 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
385 ///
386 /// will be converted to
387 ///
388 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
389 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
390 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
391 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
392 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
393 /// %0 = affine.apply #map()[%dim]
394 /// %1 = affine.apply #map1()[%dim_0]
395 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
396 /// %pack = linalg.pack %arg0
397 /// inner_dims_pos = [0, 1]
398 /// inner_tiles = [8, 2]
399 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
400 /// %3 = linalg.generic {indexing_maps = [#map2, #map2],
401 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
402 /// ins(%pack : tensor<?x?x8x2xf32>)
403 /// outs(%arg1 : tensor<?x?x8x2xf32>) {
404 /// ^bb0(%in: f32, %out: f32):
405 /// %4 = arith.addf %in, %in : f32
406 /// linalg.yield %4 : f32
407 /// } -> tensor<?x?x8x2xf32>
408 static FailureOr<GenericOp>
409 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
410  const ControlPropagationFn &controlFn) {
411  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
412  if (!genericOp)
413  return failure();
414 
415  // User controlled propagation function.
416  if (!controlFn(&packOp.getSourceMutable()))
417  return failure();
418 
419  // TODO: Enable propagation in the presence of linalg.index and
420  // tensor.extract, likely as a separate pattern as the pack information and
421  // propagation decision needs to be inferred from the region of the generic.
422  if (hasGatherSemantics(genericOp))
423  return failure();
424 
425  // TODO: Relax the restriction. We are able to bubble up the pack op through
426  // multi-result generic op. It just needs more work.
427  if (genericOp.getNumResults() != 1)
428  return failure();
429 
430  // Bail-out if the result of the generic has multiple uses, as bubbling up
431  // creates recomputation if the generic has multiple users.
432  // TODO: Enable the case where every use is an identical pack op as no
433  // recomputation is needed in that case.
434  if (!genericOp->getResult(0).hasOneUse())
435  return failure();
436 
437  // TODO: Add an option for allowing padding values. It could introduce
438  // undefined behavior if we unconditionally propagate pack op through all
439  // the ops. E.g., if the padding value is zero and there are division ops in
440  // a generic op. Some values of padding area could be NaN (0/0).
441  if (packOp.getPaddingValue())
442  return failure();
443 
444  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
445  auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
446  if (failed(packInfo))
447  return failure();
448 
449  // We want to move the pack not the generic.
450  OpBuilder::InsertionGuard guard(rewriter);
451  rewriter.setInsertionPoint(genericOp);
452 
453  // We need to handle two cases:
454  // 1) The linalg.pack destination is a tensor.empty. If this is the case, we
455  // create a new tensor.empty to avoid breaking dominance, as we are moving the
456  // linalg.pack above the linalg.generic.
457  // 2) The destination is not a tensor.empty. In this case we can replace only
458  // if the destination of the linalg.pack dominates the linalg.generic.
459  Value packOpDest = packOp.getDest();
460  if (!packOpDest.hasOneUse())
461  return failure();
462  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
463  packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(),
464  emptyOp.getMixedSizes(),
465  emptyOp.getType().getElementType());
466  } else {
467  DominanceInfo dom(genericOp);
468  if (!dom.properlyDominates(packOpDest, genericOp))
469  return failure();
470  }
471 
472  // Rebuild the indexing map for the corresponding init operand.
473  auto [packedOutOperand, packedOutIndexingMap] =
474  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
475  genericOp, opOperand);
476 
477  // Forward the new tensor.empty as a destination if it is one of the following
478  // situations:
479  // 1) The dps init operand is a tensor.empty.
480  // 2) The dps init is a write-only operand, i.e., it is not used in the
481  // genericOp
482  Value dest = packedOutOperand;
483  auto initTensor =
484  genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
485  if (initTensor || isGenericOutsNotUsed(genericOp)) {
486  dest = packOpDest;
487  }
488  // pack(unpack) isn't naively foldable because the unpack op can be from
489  // an arbitrary domain so we need to keep both.
490  return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
491  *packInfo, /*isFoldableUnpackPack=*/false);
492 }
493 
494 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
495 struct BubbleUpPackOpThroughGenericOpPattern
496  : public OpRewritePattern<linalg::PackOp> {
497 public:
498  BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
500  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
501 
502  LogicalResult matchAndRewrite(linalg::PackOp packOp,
503  PatternRewriter &rewriter) const override {
504  auto genericOp =
505  bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
506  if (failed(genericOp))
507  return failure();
508  rewriter.replaceOp(packOp, genericOp->getResults());
509  return success();
510  }
511 
512 private:
513  ControlPropagationFn controlFn;
514 };
515 
516 /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
517 /// add as many zero padding dimensions in `high` and `low` based on the number
518 /// of point loops.
519 class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
520 public:
521  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
522  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
523 
524  LogicalResult matchAndRewrite(linalg::PackOp packOp,
525  PatternRewriter &rewriter) const override {
526  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
527  if (!padOp)
528  return failure();
529 
530  // User controlled propagation function.
531  if (!controlFn(&packOp.getSourceMutable()))
532  return failure();
533 
534  // TODO: Enable padding when the padding values are the same.
535  if (packOp.getPaddingValue())
536  return failure();
537 
538  // Fail for non-constant padding values. The body of the pad could
539  // depend on the padding indices and/or properties of the padded
540  // tensor so for now we fail.
541  // TODO: Support non-constant padding values.
542  Value paddingVal = padOp.getConstantPaddingValue();
543  if (!paddingVal)
544  return failure();
545 
546  if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
547  return failure();
548 
549  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
550 
551  // Bail out if one of the padded dimension is a tiled one.
552  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
553  llvm::SmallBitVector innerDims(paddedDims.size());
554  for (int64_t dim : innerDimsPos)
555  innerDims.flip(dim);
556  if (paddedDims.anyCommon(innerDims))
557  return failure();
558 
559  Location loc = padOp->getLoc();
560  OpBuilder::InsertionGuard guard(rewriter);
561  rewriter.setInsertionPoint(padOp);
562 
563  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
564  SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
565  auto empty = linalg::PackOp::createDestinationTensor(
566  rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
567  outerDimsPerm);
568  auto sourcePack = linalg::PackOp::create(
569  rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
570  /*padding=*/std::nullopt, outerDimsPerm);
571 
572  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
573  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
574  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
575  if (!outerDimsPerm.empty()) {
576  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
577  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
578  }
579  // The tiled dimensions were verified to be unpadded above, so here we
580  // just append 0 for the inner tile dimensions.
581  size_t pointLoopsSize = innerDimsPos.size();
582  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
583  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
584 
585  auto newPadOp =
586  tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack,
587  lowPad, highPad, paddingVal, padOp.getNofold());
588 
589  // If the pad has more than one user, create an unpack on the new pad to
590  // replace the other uses.
591  if (!padOp->hasOneUse()) {
592  auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
593  rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
594  Value unpackedPad =
595  linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
596  innerDimsPos, mixedTiles, outerDimsPerm);
597  rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
598  }
599 
600  // Replace the pack with the new pad.
601  rewriter.replaceOp(packOp, newPadOp.getResult());
602 
603  return success();
604  }
605 
606 private:
607  ControlPropagationFn controlFn;
608 };
609 
610 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
611 ///
612 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
613 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
614 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
615 /// non-unit projected dims in pos [2, 3] is 2.
616 ///
617 /// If all candidates in a reassociation are unit dims, it chooses the
618 /// inner-most dim pos.
620 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
621  ArrayRef<ReassociationIndices> reassocIndices,
622  ArrayRef<int64_t> targetShape) {
623  SmallVector<int64_t> projectedDimsPos;
624  for (auto pos : dimsPos) {
625  // In the case all dims are unit, this will return the inner-most one.
626  int64_t projectedPos = reassocIndices[pos].back();
627  for (auto i : llvm::reverse(reassocIndices[pos])) {
628  int64_t dim = targetShape[i];
629  if (dim > 1 || ShapedType::isDynamic(dim)) {
630  projectedPos = i;
631  break;
632  }
633  }
634  projectedDimsPos.push_back(projectedPos);
635  }
636  return projectedDimsPos;
637 }
638 
639 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
640 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
641  ArrayRef<int64_t> shape,
642  ArrayRef<int64_t> tileSizes) {
643  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
644  int64_t dim = shape[pos];
645  if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
646  return false;
647  }
648  return true;
649 }
650 
651 /// Permutate the reassociation indices and reindex them in the sequence order.
652 /// Returns the next dim pos in the sequence.
653 ///
654 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
655 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
656 /// [[0], [1, 2]].
657 static int64_t applyPermutationAndReindexReassoc(
658  SmallVector<ReassociationIndices> &reassocIndices,
659  ArrayRef<int64_t> permutation) {
660  if (!permutation.empty())
661  applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
662  int64_t nextPos = 0;
663  for (ReassociationIndices &indices : reassocIndices) {
664  for (auto &index : indices) {
665  index = nextPos;
666  nextPos += 1;
667  }
668  }
669  return nextPos;
670 }
671 
672 /// Bubble up pack op through collapse shape op when the packed dims can be
673 /// projected to the dims before collapsing. This is possible when the inner
674 /// tile sizes can divide the projected dims.
675 ///
676 /// For example:
677 ///
678 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
679 /// : tensor<?x16x4xf32> into tensor<?x4xf32>
680 /// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1]
681 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
682 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
683 ///
684 /// can be transformed into:
685 ///
686 /// %pack = linalg.pack %in outer_dims_perm = [1, 2]
687 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
688 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
689 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
690 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
691 static LogicalResult
692 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
693  linalg::PackOp packOp,
694  PatternRewriter &rewriter) {
695  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
696  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
697  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
698 
699  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
700  SmallVector<ReassociationIndices> reassocIndices =
701  collapseOp.getReassociationIndices();
702  // Project inner tile pos to the dim pos before collapsing. For example, if
703  // dims [x, y] is collapsed into [z], packing on dim z can be projected back
704  // to pack on dim y.
705  //
706  // Project to inner-most non-unit dims to increase the chance that they can be
707  // divided by the inner tile sizes. This is correct because for [..., x, 1],
708  // packing on dim 1 is equivalent to packing on dim x.
709  SmallVector<int64_t> projectedInnerDimsPos =
710  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
711 
712  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
713  innerTileSizes)) {
714  return failure();
715  }
716  // Expand the outer dims permutation with the associated source dims for the
717  // new permutation after bubbling. This is because moving a collapsed dim is
718  // equivalent to moving the associated source dims together.
719  SmallVector<int64_t> newOuterDimsPerm;
720  for (auto outerPos : outerDimsPerm)
721  llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
722 
723  auto emptyOp = linalg::PackOp::createDestinationTensor(
724  rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
725  projectedInnerDimsPos, newOuterDimsPerm);
726  auto newPackOp = linalg::PackOp::create(
727  rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
728  projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
729  newOuterDimsPerm);
730 
731  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
732  // First apply the permutation on the reassociations of the outer dims.
733  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
734  // -> [[0], [1, 2]]
735  int64_t nextPos =
736  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
737  // Then add direct mapping for the inner tile dims.
738  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
739  newReassocIndices.push_back({nextPos});
740  nextPos += 1;
741  }
742 
743  auto newCollapseOp = tensor::CollapseShapeOp::create(
744  rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp,
745  newReassocIndices);
746  rewriter.replaceOp(packOp, newCollapseOp);
747 
748  return success();
749 }
750 
751 /// Project dimsPos to their collapsed positions in the reassocIndices.
752 ///
753 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
754 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
755 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
756 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
758 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
759  ArrayRef<ReassociationIndices> reassocIndices) {
760  SmallVector<int64_t> projectedPos;
761 
762  // Map each dimension to the position of corresponding reassociation index.
763  for (auto pos : dimsPos) {
764  for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
765  // If the dimension is present in the current indices group, the group
766  // position within the reassociation map is the desired projected
767  // dimension position.
768  if (llvm::is_contained(indices, pos)) {
769  projectedPos.push_back(idx);
770  break;
771  }
772  }
773  }
774  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
775 
776  return projectedPos;
777 }
778 
779 /// Bubble up pack op through expand shape op.
780 ///
781 /// For example:
782 ///
783 /// %expand = tensor.expand_shape %in [[0], [1, 2]]
784 /// : tensor<?x64xf32> into tensor<?x4x16xf32>
785 /// %pack = linalg.pack %expand outer_dims_perm = [0, 1]
786 /// inner_dims_pos = [2] inner_tiles = [8] into %empty
787 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
788 ///
789 /// can be transformed into:
790 ///
791 /// %pack = linalg.pack %in outer_dims_perm = [1, 2]
792 /// inner_dims_pos = [1] inner_tiles = [8] into %empty
793 /// : tensor<?x64xf32> -> tensor<?x8x8xf32>
794 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
795 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
796 static LogicalResult
797 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
798  linalg::PackOp packOp,
799  PatternRewriter &rewriter) {
800  // Outer dimensions permutation is not supported currently.
801  // TODO: Handle outer_dims_perm variants.
802  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
804  return rewriter.notifyMatchFailure(packOp,
805  "non-identity outer dims perm NYI");
806  }
807 
808  // Validate dimensions' relations between shape expansion and packing.
810  expandOp.getReassociationIndices();
811  ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
812  llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims);
813 
814  for (auto [idx, indices] : llvm::enumerate(reassoc)) {
815  // For each expand_shape reassociation, figure out which dimensions get
816  // packed if any.
817  llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices);
818  llvm::SetVector<int64_t> packedDims =
819  llvm::set_intersection(packDimsPos, expandDimPos);
820 
821  // The expanded dimension is not packed so, it does not affect moving pack
822  // before shape expansion - simply continue.
823  if (packedDims.empty())
824  continue;
825  // Shape expansion cannot be propagated when multiple expanded dimension are
826  // packed - in this case operation reordering would affect final element
827  // positions and/or shapes can no longer be projected.
828  if (packedDims.size() != 1)
829  return rewriter.notifyMatchFailure(
830  packOp, "only one of the expanded dimensions can be packed");
831  // Only the inner-most expanded dimension should be packed. Otherwise,
832  // elements order will be affected after operation reordering.
833  if (packedDims.front() != indices.back())
834  return rewriter.notifyMatchFailure(
835  packOp, "can only pack the inner-most expanded dimension");
836  }
837 
838  // Project pack.inner_dims_pos to positions before shape expansion.
839  SmallVector<int64_t> projectedInnerDimsPos =
840  projectDimsPosIntoReassocPos(packInnerDims, reassoc);
841 
842  // Project the shape expansion to new packed shape.
843  // The pack.outer_dims_perm is restricted to identity so, the permutation can
844  // be omitted for simplicity.
845  // TODO: Account for outer dimensions permutation.
846  //
847  // If reassociation is not possible, then reordering cannot happen.
848  // This can be caused by pack padding affecting previously expanded
849  // dimensions or packing extending dimensions.
850  RankedTensorType newPackType = linalg::PackOp::inferPackedType(
851  expandOp.getSrcType(), packOp.getStaticInnerTiles(),
852  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
853  auto reassocExpand =
854  getReassociationIndicesForReshape(newPackType, packOp.getDestType());
855  if (!reassocExpand)
856  return rewriter.notifyMatchFailure(
857  packOp, "could not reassociate dims after bubbling up");
858 
859  Value destTensor = linalg::PackOp::createDestinationTensor(
860  rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
861  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
862  Value packedVal = linalg::PackOp::create(
863  rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
864  projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
865  /*outerDimsPerm=*/SmallVector<int64_t>{});
866 
867  Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(),
868  packOp.getDestType(),
869  packedVal, *reassocExpand);
870  rewriter.replaceOp(packOp, newExpandOp);
871 
872  return success();
873 }
874 
875 class BubbleUpPackOpThroughReshapeOp final
876  : public OpRewritePattern<linalg::PackOp> {
877 public:
878  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
879  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
880 
881  LogicalResult matchAndRewrite(linalg::PackOp packOp,
882  PatternRewriter &rewriter) const override {
883  Operation *srcOp = packOp.getSource().getDefiningOp();
884  // Currently only support when the pack op is the only user.
885  if (!srcOp || !(srcOp->getNumResults() == 1) ||
886  !srcOp->getResult(0).hasOneUse()) {
887  return failure();
888  }
889  // Currently only support static inner tile sizes.
890  if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
891  return failure();
892 
893  // User controlled propagation function.
894  if (!controlFn(&packOp.getSourceMutable()))
895  return failure();
896 
898  .Case([&](tensor::CollapseShapeOp op) {
899  return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
900  })
901  .Case([&](tensor::ExpandShapeOp op) {
902  return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
903  })
904  .Default([](Operation *) { return failure(); });
905  }
906 
907 private:
908  ControlPropagationFn controlFn;
909 };
910 
911 /// Push down unpack op through expand shape op when the packed dims can be
912 /// projected to the dims after expanding. This is possible when the inner tile
913 /// sizes can divide the projected dims.
914 ///
915 /// For example:
916 ///
917 /// %unpack = linalg.unpack %in outer_dims_perm = [0, 1]
918 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
919 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
920 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
921 /// : tensor<?x256xf32> into tensor<?x256x256xf32>
922 ///
923 /// can be transformed into:
924 ///
925 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
926 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
927 /// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2]
928 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
929 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
930 static LogicalResult pushDownUnPackOpThroughExpandShape(
931  linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
932  PatternRewriter &rewriter, ControlPropagationFn controlFn) {
933  // User controlled propagation function.
934  if (!controlFn(&expandOp.getSrcMutable()))
935  return failure();
936 
937  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
938  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
939  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
940 
941  auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
942  if (!expandTy)
943  return failure();
944  ArrayRef<int64_t> dstShape = expandTy.getShape();
945  SmallVector<ReassociationIndices> reassocIndices =
946  expandOp.getReassociationIndices();
947  // Project inner tile pos to the dim pos after expanding. For example, if dims
948  // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
949  // on dim y.
950  //
951  // Project to inner-most non-unit dims to increase the chance that they can be
952  // divided by the inner tile sizes. This is correct because for [..., x, 1],
953  // unpacking on dim 1 is equivalent to unpacking on dim x.
954  SmallVector<int64_t> projectedInnerDimsPos =
955  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
956 
957  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
958  innerTileSizes)) {
959  return failure();
960  }
961  // Expand the outer dims permutation with the associated expanded dims for the
962  // new permutation after pushing. This is because moving a source dim is
963  // equivalent to moving the associated expanded dims together.
964  SmallVector<int64_t> newOuterDimsPerm;
965  for (auto outerPos : outerDimsPerm)
966  llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
967 
968  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
969  // First apply the permutation on the reassociations of the outer dims.
970  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
971  // -> [[0], [1, 2]]
972  int64_t nextPos =
973  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
974  // Then add direct mapping for the inner tile dims.
975  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
976  newReassocIndices.push_back({nextPos});
977  nextPos += 1;
978  }
979 
980  RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
981  expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
982  auto newExpandOp =
983  tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
984  unPackOp.getSource(), newReassocIndices);
985 
986  auto emptyOp = linalg::UnPackOp::createDestinationTensor(
987  rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
988  projectedInnerDimsPos, newOuterDimsPerm);
989  auto newUnPackOp = linalg::UnPackOp::create(
990  rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
991  projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
992  rewriter.replaceOp(expandOp, newUnPackOp);
993 
994  return success();
995 }
996 
997 class PushDownUnPackOpThroughReshapeOp final
998  : public OpRewritePattern<linalg::UnPackOp> {
999 public:
1000  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
1002  : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) {
1003  }
1004 
1005  LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
1006  PatternRewriter &rewriter) const override {
1007  Value result = unPackOp.getResult();
1008  // Currently only support unpack op with the single user.
1009  if (!result.hasOneUse()) {
1010  return failure();
1011  }
1012  // Currently only support static inner tile sizes.
1013  if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1014  return failure();
1015 
1016  Operation *consumerOp = *result.user_begin();
1017  return TypeSwitch<Operation *, LogicalResult>(consumerOp)
1018  .Case([&](tensor::ExpandShapeOp op) {
1019  return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1020  controlFn);
1021  })
1022  .Default([](Operation *) { return failure(); });
1023  }
1024 
1025 private:
1026  ControlPropagationFn controlFn;
1027 };
1028 
1029 // TODO: Relax this restriction. We should unpack a generic op also
1030 // in the presence of multiple unpack ops as producers.
1031 /// Return the unpacked operand, if present, for the current generic op.
1032 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1033  OpOperand *unPackedOperand = nullptr;
1034  for (OpOperand &operand : genericOp->getOpOperands()) {
1035  auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
1036  if (!unPackOp)
1037  continue;
1038  if (unPackedOperand)
1039  return failure();
1040  unPackedOperand = &operand;
1041  }
1042  if (!unPackedOperand)
1043  return failure();
1044  return unPackedOperand;
1045 }
1046 
1047 /// Push down a linalg.unpack op through a generic op.
1048 /// The new generic op works on packed domain; pack ops are created for input
1049 /// and output operands. A linalg.unpack op is inserted right after the packed
1050 /// generic. E.g.
1051 ///
1052 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1053 ///
1054 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1055 ///
1056 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1057 /// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1058 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1059 /// %2 = linalg.generic {indexing_maps = [#map],
1060 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1061 /// outs(%1 : tensor<12x56x56x64xf32>) {
1062 /// ^bb0(%out : f32):
1063 /// linalg.yield %out : f32
1064 /// } -> tensor<12x56x56x64xf32>
1065 ///
1066 /// will be converted to
1067 ///
1068 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1069 ///
1070 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1071 /// %1 = linalg.generic {indexing_maps = [#map],
1072 /// iterator_types = ["parallel", "parallel", "parallel",
1073 /// "parallel", "parallel"]}
1074 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1075 /// ^bb0(%out : f32):
1076 /// linalg.yield %out : f32
1077 /// } -> tensor<12x2x56x56x32xf32>
1078 /// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1079 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1080 ///
1081 static FailureOr<std::tuple<GenericOp, Value>>
1082 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1083  ControlPropagationFn controlFn) {
1084  if (genericOp.getNumResults() != 1)
1085  return failure();
1086 
1087  if (hasGatherSemantics(genericOp))
1088  return failure();
1089 
1090  // Collect the unPacked operand, if present.
1091  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1092  if (failed(maybeUnPackedOperand))
1093  return failure();
1094  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1095 
1096  // Extract packing information.
1097  linalg::UnPackOp producerUnPackOp =
1098  unPackedOperand->get().getDefiningOp<linalg::UnPackOp>();
1099  assert(producerUnPackOp && "expect a valid UnPackOp");
1100 
1101  if (!controlFn(unPackedOperand))
1102  return failure();
1103 
1104  auto packInfo =
1105  getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1106  if (failed(packInfo))
1107  return failure();
1108 
1109  // Rebuild the indexing map for the corresponding init operand.
1110  auto [packedOutOperand, packedOutIndexingMap] =
1111  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1112  genericOp, genericOp.getDpsInitOperand(0));
1113  auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1114 
1115  // Forward the new tensor.empty as a destination if it is one of the following
1116  // situations:
1117  // 1) The dps init operand is a tensor.empty.
1118  // 2) The dps init is a write-only operand, i.e., it is not used in the
1119  // genericOp
1120  Value dest = packedOutOperand;
1121  auto initTensor =
1122  genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
1123  if (initTensor || isGenericOutsNotUsed(genericOp)) {
1124  if (destPack)
1125  dest = destPack.getDest();
1126  }
1127 
1128  // Pack the genericOp.
1129  // pack(unpack) is foldable in this case. This is because in pushing down the
1130  // unpack, by default we will populate an additional pack op after the unpack.
1131  // This guarantees them to be foldable.
1132  GenericOp newGenericOp =
1133  packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1134  /*isFoldableUnpackPack=*/true);
1135  Value newResult =
1136  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1137 
1138  // If the output is unaffected, no need to unpack.
1139  if (!destPack)
1140  return std::make_tuple(newGenericOp, newResult);
1141 
1142  auto mixedTiles = destPack.getMixedTiles();
1143  auto innerDimsPos = destPack.getInnerDimsPos();
1144  auto outerDimsPerm = destPack.getOuterDimsPerm();
1145 
1146  // Insert an unPackOp right after the packed generic.
1147  Value unPackOpRes =
1148  linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
1149  destPack.getSource(), innerDimsPos, mixedTiles,
1150  outerDimsPerm)
1151  .getResult();
1152 
1153  return std::make_tuple(newGenericOp, unPackOpRes);
1154 }
1155 
1156 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1157 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1158 public:
1159  PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1161  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1162 
1163  LogicalResult matchAndRewrite(GenericOp genericOp,
1164  PatternRewriter &rewriter) const override {
1165  auto genericAndRepl =
1166  pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1167  if (failed(genericAndRepl))
1168  return failure();
1169  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1170  return success();
1171  }
1172 
1173 private:
1174  ControlPropagationFn controlFn;
1175 };
1176 
1177 /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
1178 /// add as many zero padding dimensions in `high` and `low` based on the number
1179 /// of point loops.
1180 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1181  PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1182  : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1183 
1184  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1185  PatternRewriter &rewriter) const override {
1186  linalg::UnPackOp unpackOp =
1187  padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1188  if (!unpackOp)
1189  return failure();
1190 
1191  if (!controlFn(&padOp.getSourceMutable()))
1192  return failure();
1193 
1194  Location loc = padOp.getLoc();
1195  // Bail out if one of the padded dimension is a tiled one.
1196  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1197  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1198  llvm::SmallBitVector innerDims(paddedDims.size());
1199  for (int64_t dim : innerDimsPos)
1200  innerDims.flip(dim);
1201  if (paddedDims.anyCommon(innerDims))
1202  return failure();
1203 
1204  Value paddingVal = padOp.getConstantPaddingValue();
1205  if (!paddingVal)
1206  return failure();
1207 
1208  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1209  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1210  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1211  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1212  if (!outerDimsPerm.empty()) {
1213  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1214  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1215  }
1216  // Add zero padding for the point loops.
1217  size_t pointLoopsSize = innerDimsPos.size();
1218  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1219  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1220 
1221  auto newPadOp = tensor::PadOp::create(rewriter, loc, /*result=*/Type(),
1222  unpackOp.getSource(), lowPad, highPad,
1223  paddingVal, padOp.getNofold());
1224 
1225  // Inject the linalg.unpack right after the packed padOp.
1226  Value outputUnPack =
1227  tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
1228  padOp.getResultType().getElementType());
1229 
1230  Value replacement = linalg::UnPackOp::create(
1231  rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1232  unpackOp.getMixedTiles(), outerDimsPerm);
1233  rewriter.replaceOp(padOp, replacement);
1234  return success();
1235  }
1236 
1237 private:
1238  ControlPropagationFn controlFn;
1239 };
1240 
1241 // This struct contains infomation about extract_slice dims.
1242 struct SliceDimInfo {
1243  OpFoldResult offset;
1244  OpFoldResult sliceSize;
1245  OpFoldResult outputSize;
1246 };
1247 
1248 /// Return the first input extract slice operand, if present, for the current
1249 /// generic op.
1250 static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
1251  OpOperand *sliceOperand = nullptr;
1252  for (auto operand : genericOp.getDpsInputOperands()) {
1253  auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1254  if (!extractOp)
1255  continue;
1256  sliceOperand = operand;
1257  break;
1258  }
1259  if (!sliceOperand) {
1260  return failure();
1261  }
1262  return sliceOperand;
1263 }
1264 
1265 // Return a map of dims that have partial slices on them so that other operands
1266 // can use this information. Also return a bool mentioning if a reduction dim
1267 // has a non full slice as that can be used to fold the original extract slice.
1268 static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
1269 getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
1270  tensor::ExtractSliceOp producerSliceOp =
1271  sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1272  assert(producerSliceOp && "expect a valid ExtractSliceOp");
1273  llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
1274  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1275  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1276 
1278  genericOp.getContext(), producerSliceOp.getSourceType().getShape());
1279 
1280  for (auto [idx, expr] : llvm::enumerate(
1281  genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1282  // If we have a full slice in a dimension then we dont need to add it to
1283  // the partial slice map.
1284  if (isConstantIntValue(offsets[idx], 0) &&
1285  isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
1286  continue;
1287  }
1288  // We only support partial slices of AffineDimExprs so bail-out if thats not
1289  // the case.
1290  if (!isa<AffineDimExpr>(expr)) {
1291  return failure();
1292  }
1293  SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1294  int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1295  partialSliceDimMap[dimPos] = sliceDimInfo;
1296  }
1297  // Next check if the dims with partial slice info are used in non
1298  // AffineDimExpr in other operands and if they are then bail-out.
1299  for (OpOperand &operand : genericOp->getOpOperands()) {
1300  if (operand == *sliceOperand) {
1301  continue;
1302  }
1303  AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1304  if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
1305  if (isa<AffineDimExpr>(expr)) {
1306  return false;
1307  }
1308  WalkResult status = expr.walk([&](AffineExpr expr) {
1309  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1310  if (partialSliceDimMap.contains(dimExpr.getPosition())) {
1311  return WalkResult::interrupt();
1312  }
1313  }
1314  return WalkResult::advance();
1315  });
1316  if (status.wasInterrupted()) {
1317  return true;
1318  }
1319  return false;
1320  })) {
1321  return failure();
1322  }
1323  }
1324  return partialSliceDimMap;
1325 }
1326 
1327 static FailureOr<std::tuple<GenericOp, Value>>
1328 pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
1329  GenericOp genericOp,
1330  ControlPropagationFn controlFn) {
1331  if (genericOp.getNumResults() != 1)
1332  return rewriter.notifyMatchFailure(
1333  genericOp, "propagation through multi-result generic is unsupported.");
1334  if (hasGatherSemantics(genericOp))
1335  return rewriter.notifyMatchFailure(
1336  genericOp,
1337  "propagation through generic with gather semantics is unsupported.");
1338  // Collect the sliced operand, if present.
1339  auto maybeSliceOperand = getSliceOperand(genericOp);
1340  if (failed(maybeSliceOperand))
1341  return failure();
1342  OpOperand *sliceOperand = *maybeSliceOperand;
1343  unsigned OperandIndex = sliceOperand->getOperandNumber();
1344 
1345  if (!controlFn(sliceOperand))
1346  return failure();
1347 
1348  tensor::ExtractSliceOp producerSliceOp =
1349  sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1350  assert(producerSliceOp && "expect a valid ExtractSliceOp");
1351 
1352  if (producerSliceOp.getSource().getType().getRank() !=
1353  producerSliceOp.getResult().getType().getRank()) {
1354  return rewriter.notifyMatchFailure(
1355  genericOp,
1356  "propagation of rank-reducing extract slice is unsupported.");
1357  }
1358 
1359  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
1360  if (!areAllConstantIntValue(strides, 1))
1361  return rewriter.notifyMatchFailure(
1362  genericOp, "propagation of strided extract slice is unsupported.");
1363 
1364  // check if we can support the propagation of this extractSlice
1365  // through the generic op and if so return the dimensions that
1366 
1367  auto maybePartialSliceDimMap =
1368  getPartialSliceDimInfo(genericOp, sliceOperand);
1369 
1370  if (failed(maybePartialSliceDimMap)) {
1371  return failure();
1372  }
1373 
1374  auto partialSliceDimMap = *maybePartialSliceDimMap;
1375 
1377  genericOp.getIteratorTypesArray();
1378  bool hasPartialReductionDimSlice =
1379  llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
1380  int64_t sliceDim = slice.first;
1381  return iterators[sliceDim] == utils::IteratorType::reduction;
1382  });
1383 
1384  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1385  Location loc = genericOp->getLoc();
1386  AffineExpr dim0, dim1;
1387  bindDims(rewriter.getContext(), dim0, dim1);
1388  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
1389  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1390  return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
1391  {v1, v2});
1392  };
1393 
1394  MLIRContext *ctx = genericOp.getContext();
1395  SmallVector<Value> paddedInputs;
1396  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1397  if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1398  paddedInputs.push_back(producerSliceOp.getSource());
1399  continue;
1400  }
1401  AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1402  SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
1403  getAsIndexOpFoldResult(ctx, 0));
1404  SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
1405  getAsIndexOpFoldResult(ctx, 0));
1406  for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
1407  if (!isa<AffineDimExpr>(expr)) {
1408  continue;
1409  }
1410  AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1411  if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1412  continue;
1413  }
1414  SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1415  operandLowPads[idx] = sliceDimInfo.offset;
1416  operandHighPads[idx] =
1417  sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1418  sliceDimInfo.sliceSize);
1419  }
1420  auto paddingValue = ub::PoisonOp::create(
1421  rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
1422  auto paddedOperand = tensor::PadOp::create(
1423  rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
1424  paddingValue, /*nofold=*/false);
1425  paddedInputs.push_back(paddedOperand);
1426  }
1427  AffineMap outputIndexingMap =
1428  genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1429 
1430  auto outputShapeType =
1431  llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1432  SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
1433  outputShapeType.getShape(),
1434  [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
1435  SmallVector<OpFoldResult> newSizes = OutputShape;
1436  SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
1437  getAsIndexOpFoldResult(ctx, 0));
1438  SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
1439  getAsIndexOpFoldResult(ctx, 0));
1440  SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
1441  getAsIndexOpFoldResult(ctx, 1));
1442  for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
1443  if (!isa<AffineDimExpr>(expr)) {
1444  continue;
1445  }
1446  AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1447  if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1448  continue;
1449  }
1450  SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1451  outputLowPads[idx] = sliceDimInfo.offset;
1452  outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1453  sliceDimInfo.sliceSize);
1454  OutputShape[idx] = sliceDimInfo.outputSize;
1455  newSizes[idx] = sliceDimInfo.sliceSize;
1456  }
1457  Value newPadOutput;
1458  auto outputElType =
1459  getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
1460  if (isGenericOutsNotUsed(genericOp)) {
1461  newPadOutput =
1462  tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1463  } else {
1464  auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1465  newPadOutput = tensor::PadOp::create(
1466  rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
1467  outputHighPads, paddingValue, /*nofold=*/false);
1468  }
1469 
1470  auto newGenericOp = linalg::GenericOp::create(
1471  rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
1472  genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1473  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
1474  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
1475  newGenericOp.getRegion().begin());
1476 
1477  auto extractOp = tensor::ExtractSliceOp::create(
1478  rewriter, loc,
1479  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1480  outputLowPads, newSizes, newStrides);
1481  Value extractRes = extractOp.getResult();
1482 
1483  return std::make_tuple(newGenericOp, extractRes);
1484 }
1485 
1486 class PushDownExtractSliceOpThroughGenericOp final
1487  : public OpRewritePattern<GenericOp> {
1488 public:
1489  PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
1491  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1492 
1493  LogicalResult matchAndRewrite(GenericOp genericOp,
1494  PatternRewriter &rewriter) const override {
1495  auto genericAndRepl =
1496  pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1497  if (failed(genericAndRepl))
1498  return failure();
1499  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1500  return success();
1501  }
1502 
1503 private:
1504  ControlPropagationFn controlFn;
1505 };
1506 
1507 } // namespace
1508 
1511  const ControlPropagationFn &controlPackUnPackPropagation) {
1512  patterns
1513  .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1514  BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1515  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1516  patterns.getContext(), controlPackUnPackPropagation);
1517 }
1518 
1521  const ControlPropagationFn &controlPackUnPackPropagation) {
1522  patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1523  patterns.getContext(), controlPackUnPackPropagation);
1524 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5181
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5179
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:346
Base type for affine expression.
Definition: AffineExpr.h:68
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition: AffineExpr.h:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:56
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:579
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_iterator user_begin() const
Definition: Value.h:216
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:235
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:385
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
Definition: Transforms.h:1914
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314