MLIR  22.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SetOperations.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 #include <optional>
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26 #include "mlir/Dialect/Linalg/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 #define DEBUG_TYPE "linalg-data-layout-propagation"
33 
34 namespace {
35 
36 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37  for (Operation &op : genericOp.getBody()->getOperations())
38  if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
39  return true;
40  return false;
41 }
42 
43 // The struct contains the infomation about mapping packing information to
44 // the iteration domain of Linalg ops.
45 struct PackInfo {
46  int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
47  // InnerDimsPos on iteration domain, which follows the order in pack ops.
48  SmallVector<int64_t> tiledDimsPos;
49  // The sizes of tiling data dimensions on iteration domain.
50  llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
51  // The mapping from a dimension of iteration domain to the corresponding inner
52  // tiling dimension on iteration domain.
53  llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
54  // The permutation of outer dims (on domain).
55  SmallVector<int64_t> outerDimsOnDomainPerm;
56 };
57 
58 template <typename OpTy>
59 static FailureOr<PackInfo>
60 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
61  OpTy packOrUnPackOp) {
62  static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
63  "applies to only pack or unpack operations");
64  LLVM_DEBUG(
65  { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
66 
67  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
68  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
70  genericOp.getIteratorTypesArray();
71 
72  PackInfo packInfo;
73  int64_t origNumDims = indexingMap.getNumDims();
74  SmallVector<AffineExpr> exprs(indexingMap.getResults());
75  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
76  for (auto [index, innerDimPos, tileSize] :
77  llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
78  innerDimsPos, packOrUnPackOp.getMixedTiles())) {
79  auto expr = exprs[innerDimPos];
80  if (!isa<AffineDimExpr>(expr))
81  return failure();
82  int64_t domainDimPos =
83  cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
84  if (!isParallelIterator(iterators[domainDimPos]))
85  return failure();
86  packInfo.tiledDimsPos.push_back(domainDimPos);
87  packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88  packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
89  LLVM_DEBUG({
90  llvm::dbgs() << "map innerDimPos=" << innerDimPos
91  << " to iteration dimension (d" << domainDimPos << ", d"
92  << packInfo.tileToPointMapping[domainDimPos]
93  << "), which has size=("
94  << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
95  });
96  }
97 
98  // Bail out if a tiled dimension is present in a map but not as an affine dim
99  // expression.
100  auto areAllAffineDimExpr = [&](int dim) {
101  for (AffineMap map : indexingMaps) {
102  if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
103  return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
104  })) {
105  return false;
106  }
107  }
108  return true;
109  };
110  for (int64_t i : packInfo.tiledDimsPos)
111  if (!areAllAffineDimExpr(i))
112  return failure();
113 
114  // Get the outer dims perm on the iteration domain. Start by identifying the
115  // set of domain dims affected by the outer permutation along with the
116  // permuted ordering for those dims. Then the full outer dims permutation can
117  // be constructed by replacing the affected dims with the permuted result in a
118  // numLoops-rank identity. e.g.
119  // outerDimsPerm = [1, 2, 0]
120  // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
121  //
122  // permutedOuterDims = [4, 3, 1]
123  // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
124  //
125  // Non-affine dim expressions must not be permuted by the outer dims
126  // permutation.
127  SmallVector<int64_t> permutedOuterDims;
128  for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129  auto permutedExpr = indexingMap.getResult(dim);
130  if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131  permutedOuterDims.push_back(dimExpr.getPosition());
132  continue;
133  }
134 
135  // TODO: Allow propagation with transposes on non affine dim expressions,
136  // e.g. d0 + d1 which implies transposing both dims simultaneously while
137  // maintaining the relative position between them.
138  if (static_cast<int64_t>(index) != dim)
139  return failure();
140  }
141  if (!permutedOuterDims.empty()) {
142  int64_t outerDimIndex = 0;
143  llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
144  permutedOuterDims.end());
145  for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146  packInfo.outerDimsOnDomainPerm.push_back(
147  permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
148  : i);
149  LLVM_DEBUG({
150  llvm::dbgs() << "map outer dimsDimsPerm to ";
151  for (auto dim : packInfo.outerDimsOnDomainPerm)
152  llvm::dbgs() << dim << " ";
153  llvm::dbgs() << "\n";
154  });
155  }
156 
157  return packInfo;
158 }
159 
160 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
161  ArrayRef<AffineExpr> exprs) {
162  // Compute `outer_dims_perm`. See example:
163  // current exprs : (d0, d1, d2, d3) -> (d2, d3)
164  // perm : [0, 3, 1, 2]
165  // First map d2, d3 with their position in the array as:
166  // currentPositionTileLoops: dim | pos
167  // d2 | 0
168  // d3 | 1
169  // then scan `perm` in order and get the `outer_dims_perm`
170  // to be used, here it would be [1, 0].
171  assert(!perm.empty() && "expect perm not to be empty");
172  assert(!exprs.empty() && "expect exprs not to be empty");
173  if (exprs.size() == 1)
174  return {};
176  DenseMap<int64_t, int64_t> currentPositionTileLoops;
177  for (auto [pos, expr] : llvm::enumerate(exprs)) {
178  // Here we rely on the assumption that the outer dims permutation
179  // when propagating currently requires that non-affine dim expressions
180  // are not permuted, thus allowing the identity assignment below.
181  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182  currentPositionTileLoops[dimExpr.getPosition()] = pos;
183  else
184  currentPositionTileLoops[pos] = pos;
185  }
186  for (int64_t loopIdx : perm) {
187  if (currentPositionTileLoops.count(loopIdx))
188  outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
189  }
190  return outerDimsPerm;
191 }
192 
193 struct PackedOperandDetails {
194  SmallVector<OpFoldResult> innerTileSizes;
197  AffineMap indexingMap;
198 };
199 
200 /// Helper function for getOrCreatePackedViewOfOperand that populates
201 /// the details of the packedOperand that needs to be formed and also
202 /// returns if the packing would require padding.
203 static bool getPackedOperandDetails(
204  OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
206  PackedOperandDetails currOperandDetails;
207  int64_t numOrigLoops = genericOp.getNumLoops();
208  int64_t numInnerLoops = packInfo.getNumTiledLoops();
209  int64_t numLoops = numOrigLoops + numInnerLoops;
210  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
211  llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
212  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
213 
214  // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
215  if (genericOp.isScalar(opOperand) || exprs.empty()) {
216  currOperandDetails.indexingMap =
217  AffineMap::get(numLoops, 0, exprs, b.getContext());
218  packedOperandMap[opOperand] = currOperandDetails;
219  return false;
220  }
221 
222  // Step 1. Construct the information of packing data dimensions; append inner
223  // dimensions to the indexing maps for the operand.
224  for (auto [index, expr] : llvm::enumerate(exprs)) {
225  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
226  int64_t dimPos = dimExpr.getPosition();
227  domainDimToOperandDim[dimPos] = index;
228  continue;
229  }
230  }
232  SmallVector<OpFoldResult> innerTileSizes;
233  for (auto dimPos : packInfo.tiledDimsPos) {
234  if (!domainDimToOperandDim.count(dimPos))
235  continue;
236  int64_t index = domainDimToOperandDim[dimPos];
237  innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
238  innerDimsPos.push_back(index);
239  exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
240  }
241 
242  // Step 2. Handle outer dim permutations.
244  if (!packInfo.outerDimsOnDomainPerm.empty()) {
245  outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
246 
247  // Step 2.1: Fold transpose into the linalg.generic.
248  SmallVector<int64_t> inversedOuterPerm =
249  invertPermutationVector(packInfo.outerDimsOnDomainPerm);
250  for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
251  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
252  int64_t dimPos = dimExpr.getPosition();
253  exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
254  continue;
255  }
256  assert(isa<AffineConstantExpr>(exprs[i]) &&
257  "Attempted to permute non-constant and non-affine dim expression");
258  }
259  // Step 2.2: Undo the transposition on `exprs` and propagate the
260  // transposition on the pack using outerDimsPerm.
261  if (!outerDimsPerm.empty()) {
262  SmallVector<AffineExpr> auxVec = exprs;
263  for (const auto &en : enumerate(outerDimsPerm))
264  auxVec[en.index()] = exprs[en.value()];
265  exprs = auxVec;
266  }
267  }
268  currOperandDetails.indexingMap =
269  AffineMap::get(numLoops, 0, exprs, b.getContext());
270 
271  // The operand does not have dimensions that relates to pack op.
272  if (innerDimsPos.empty() && outerDimsPerm.empty()) {
273  packedOperandMap[opOperand] = currOperandDetails;
274  return false;
275  }
276  auto inputType = cast<RankedTensorType>(opOperand->get().getType());
277 
278  auto maybeIntInnerTileSizes =
279  llvm::map_to_vector(innerTileSizes, [](OpFoldResult ofr) -> int64_t {
280  std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
281  return maybeCst.value_or(ShapedType::kDynamic);
282  });
283  bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
284  inputType.getShape(), innerDimsPos,
285  linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes,
287  .getShape(),
288  outerDimsPerm, innerTileSizes);
289  currOperandDetails.innerDimsPos = innerDimsPos;
290  currOperandDetails.innerTileSizes = innerTileSizes;
291  currOperandDetails.outerDimsPerm = outerDimsPerm;
292  packedOperandMap[opOperand] = currOperandDetails;
293 
294  return requirePadding;
295 }
296 
297 /// Returns a tuple for packed operand and indexing_map with the assumptions:
298 /// 1) The generic op is the producer of the pack op.
299 /// 2) The generic op has only one result.
300 /// If the operand is a scalar or packing dimensions are all irrelevant to the
301 /// operand, the operand and the updated indexing map will be returned.
302 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
303 ///
304 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
305 /// #map1 = affine_map<(d0, d1) -> (d0)>
306 /// #map2 = affine_map<(d0, d1) -> (d1)>
307 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
308 /// iterator_types = ["parallel", "parallel"]}
309 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
310 /// outs(%init : tensor<?x?xf32>) {
311 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
312 /// %4 = arith.addf %arg3, %arg4 : f32
313 /// linalg.yield %4 : f32
314 /// } -> tensor<?x?xf32>
315 /// %1 = linalg.pack %0
316 /// inner_dims_pos = [0, 1]
317 /// inner_tiles = [8, 2]
318 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
319 ///
320 /// Taking the first input operand as an example, the inner tile size of d1 is
321 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
322 /// affine_map<(d1, d3)>` will be returned.
323 ///
324 /// %pack = linalg.pack %arg0
325 /// inner_dims_pos = [0]
326 /// inner_tiles = [8]
327 /// into %init : tensor<?xf32> -> tensor<?x8xf32>
328 static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
329  OpBuilder &b, Location loc, OpOperand *opOperand,
330  const DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
331  assert(packedOperandMap.contains(opOperand) &&
332  "packed operand details expected to be populated");
333  auto currOperandDetails = packedOperandMap.at(opOperand);
334  auto innerDimsPos = currOperandDetails.innerDimsPos;
335  auto outerDimsPerm = currOperandDetails.outerDimsPerm;
336  auto innerTileSizes = currOperandDetails.innerTileSizes;
337  if (innerDimsPos.empty() && outerDimsPerm.empty())
338  return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap);
339 
340  auto empty = linalg::PackOp::createDestinationTensor(
341  b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
342  auto poison = ub::PoisonOp::create(
343  b, loc, getElementTypeOrSelf(opOperand->get().getType()));
344  Value packedOperand =
345  linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
346  innerTileSizes, poison, outerDimsPerm);
347  return std::make_tuple(packedOperand, currOperandDetails.indexingMap);
348 }
349 
350 /// This function is a helper subroutine to pack a genericOp and return it. It
351 /// will create a new generic op with the packed operand and the packed output
352 /// according to packInfo when we attempt to push down unpack or bubble up pack
353 /// around it. Implicitly this will only work when a packInfo can be obtained.
354 /// This make sure that we are only using this function on parallel permuted
355 /// dimensions.
356 static FailureOr<GenericOp>
357 packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
358  AffineMap packedOutIndexingMap, const PackInfo &packInfo,
359  bool isFoldableUnpackPack, bool poisonPaddingOk) {
360  Location loc = genericOp.getLoc();
361  SmallVector<Value> inputOperands;
362  SmallVector<Value> inputOperandsFromUnpackedSource;
363  SmallVector<AffineMap> indexingMaps;
364  auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
365  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
366  packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
367  llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
368  };
370  bool requiresPadding = false;
371  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
372  requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp,
373  inputOperand, packedOperandMap);
374  }
375  if (requiresPadding && !poisonPaddingOk)
376  return failure();
377 
378  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
379  auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
380  rewriter, loc, inputOperand, packedOperandMap);
381  auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
382  auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
383  if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
384  inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
385  } else {
386  inputOperandsFromUnpackedSource.push_back(packedOperand);
387  }
388  inputOperands.push_back(packedOperand);
389  indexingMaps.push_back(packedIndexingMap);
390  }
391 
392  // If the unpack->pack sequences can be folded, replace use the sources of
393  // the unpack ops in any unpack->pack chains on the generic op operands.
394  if (isFoldableUnpackPack) {
395  inputOperands = inputOperandsFromUnpackedSource;
396  if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
397  auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
398  if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
399  dest = destUnPack.getSource();
400  }
401  }
402  }
403 
404  int64_t numInnerLoops = packInfo.getNumTiledLoops();
406  genericOp.getIteratorTypesArray();
407  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
408 
409  indexingMaps.push_back(packedOutIndexingMap);
410 
411  auto newGenericOp = linalg::GenericOp::create(
412  rewriter, loc, dest.getType(), inputOperands, dest, indexingMaps,
413  iterTypes,
414  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
415  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
416  newGenericOp.getRegion().begin());
417  return newGenericOp;
418 }
419 
420 static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
421  return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
422  return genericOp.getMatchingBlockArgument(&operand).use_empty();
423  });
424 }
425 
426 /// Bubbles up linalg.pack op through a producer generic op. This
427 /// swap pack(generic) to generic(pack). The new generic op works on packed
428 /// domain; pack ops are created for input and output operands. E.g.,
429 ///
430 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
431 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
432 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
433 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
434 /// %3 = linalg.generic {indexing_maps = [#map0, #map0],
435 /// iterator_types = ["parallel", "parallel"]}
436 /// ins(%arg0 : tensor<?x?xf32>)
437 /// outs(%2 : tensor<?x?xf32>) {
438 /// ^bb0(%arg3: f32, %arg4: f32):
439 /// %4 = arith.addf %arg3, %arg3 : f32
440 /// linalg.yield %4 : f32
441 /// } -> tensor<?x?xf32>
442 /// %4 = linalg.pack %3
443 /// inner_dims_pos = [0, 1]
444 /// inner_tiles = [8, 2]
445 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
446 ///
447 /// will be converted to
448 ///
449 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
450 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
451 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
452 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
453 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
454 /// %0 = affine.apply #map()[%dim]
455 /// %1 = affine.apply #map1()[%dim_0]
456 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
457 /// %pack = linalg.pack %arg0
458 /// inner_dims_pos = [0, 1]
459 /// inner_tiles = [8, 2]
460 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
461 /// %3 = linalg.generic {indexing_maps = [#map2, #map2],
462 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
463 /// ins(%pack : tensor<?x?x8x2xf32>)
464 /// outs(%arg1 : tensor<?x?x8x2xf32>) {
465 /// ^bb0(%in: f32, %out: f32):
466 /// %4 = arith.addf %in, %in : f32
467 /// linalg.yield %4 : f32
468 /// } -> tensor<?x?x8x2xf32>
469 static FailureOr<GenericOp>
470 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
471  const ControlPropagationFn &controlFn,
472  bool poisonPaddingOk) {
473  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
474  if (!genericOp)
475  return failure();
476 
477  // User controlled propagation function.
478  if (!controlFn(&packOp.getSourceMutable()))
479  return failure();
480 
481  // TODO: Enable propagation in the presence of linalg.index and
482  // tensor.extract, likely as a separate pattern as the pack information and
483  // propagation decision needs to be inferred from the region of the generic.
484  if (hasGatherSemantics(genericOp))
485  return failure();
486 
487  // TODO: Relax the restriction. We are able to bubble up the pack op through
488  // multi-result generic op. It just needs more work.
489  if (genericOp.getNumResults() != 1)
490  return failure();
491 
492  // Bail-out if the result of the generic has multiple uses, as bubbling up
493  // creates recomputation if the generic has multiple users.
494  // TODO: Enable the case where every use is an identical pack op as no
495  // recomputation is needed in that case.
496  if (!genericOp->getResult(0).hasOneUse())
497  return failure();
498 
499  // TODO: Add an option for allowing padding values. It could introduce
500  // undefined behavior if we unconditionally propagate pack op through all
501  // the ops. E.g., if the padding value is zero and there are division ops in
502  // a generic op. Some values of padding area could be NaN (0/0).
503  if (packOp.getPaddingValue())
504  return failure();
505 
506  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
507  auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
508  if (failed(packInfo))
509  return failure();
510 
511  // We want to move the pack not the generic.
512  OpBuilder::InsertionGuard guard(rewriter);
513  rewriter.setInsertionPoint(genericOp);
514 
515  // We need to handle two cases:
516  // 1) The linalg.pack destination is a tensor.empty. If this is the case, we
517  // create a new tensor.empty to avoid breaking dominance, as we are moving the
518  // linalg.pack above the linalg.generic.
519  // 2) The destination is not a tensor.empty. In this case we can replace only
520  // if the destination of the linalg.pack dominates the linalg.generic.
521  Value packOpDest = packOp.getDest();
522  if (!packOpDest.hasOneUse())
523  return failure();
524  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
525  packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(),
526  emptyOp.getMixedSizes(),
527  emptyOp.getType().getElementType());
528  } else {
529  DominanceInfo dom(genericOp);
530  if (!dom.properlyDominates(packOpDest, genericOp))
531  return failure();
532  }
533 
534  // Rebuild the indexing map for the corresponding init operand.
536  bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp,
537  opOperand, packedOperandMap);
538  if (requiresPadding && !poisonPaddingOk)
539  return failure();
540 
541  auto [packedOutOperand, packedOutIndexingMap] =
542  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand,
543  packedOperandMap);
544  // Forward the new tensor.empty as a destination if it is one of the following
545  // situations:
546  // 1) The dps init operand is a tensor.empty.
547  // 2) The dps init is a write-only operand, i.e., it is not used in the
548  // genericOp
549  Value dest = packedOutOperand;
550  auto initTensor =
551  genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
552  if (initTensor || isGenericOutsNotUsed(genericOp)) {
553  dest = packOpDest;
554  }
555  // pack(unpack) isn't naively foldable because the unpack op can be from
556  // an arbitrary domain so we need to keep both.
557  return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
558  *packInfo, /*isFoldableUnpackPack=*/false,
559  poisonPaddingOk);
560 }
561 
562 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
563 struct BubbleUpPackOpThroughGenericOpPattern
564  : public OpRewritePattern<linalg::PackOp> {
565 public:
566  BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
568  bool poisonPaddingOk)
569  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)),
570  poisonPaddingOk(std::move(poisonPaddingOk)) {}
571 
572  LogicalResult matchAndRewrite(linalg::PackOp packOp,
573  PatternRewriter &rewriter) const override {
574  auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
575  poisonPaddingOk);
576  if (failed(genericOp))
577  return failure();
578  rewriter.replaceOp(packOp, genericOp->getResults());
579  return success();
580  }
581 
582 private:
583  ControlPropagationFn controlFn;
584  bool poisonPaddingOk;
585 };
586 
587 /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
588 /// add as many zero padding dimensions in `high` and `low` based on the number
589 /// of point loops.
590 class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
591 public:
592  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
593  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
594 
595  LogicalResult matchAndRewrite(linalg::PackOp packOp,
596  PatternRewriter &rewriter) const override {
597  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
598  if (!padOp)
599  return failure();
600 
601  // User controlled propagation function.
602  if (!controlFn(&packOp.getSourceMutable()))
603  return failure();
604 
605  // TODO: Enable padding when the padding values are the same.
606  if (packOp.getPaddingValue())
607  return failure();
608 
609  // Fail for non-constant padding values. The body of the pad could
610  // depend on the padding indices and/or properties of the padded
611  // tensor so for now we fail.
612  // TODO: Support non-constant padding values.
613  Value paddingVal = padOp.getConstantPaddingValue();
614  if (!paddingVal)
615  return failure();
616 
617  if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
618  return failure();
619 
620  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
621 
622  // Bail out if one of the padded dimension is a tiled one.
623  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
624  llvm::SmallBitVector innerDims(paddedDims.size());
625  for (int64_t dim : innerDimsPos)
626  innerDims.flip(dim);
627  if (paddedDims.anyCommon(innerDims))
628  return failure();
629 
630  Location loc = padOp->getLoc();
631  OpBuilder::InsertionGuard guard(rewriter);
632  rewriter.setInsertionPoint(padOp);
633 
634  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
635  SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
636  auto empty = linalg::PackOp::createDestinationTensor(
637  rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
638  outerDimsPerm);
639  auto sourcePack = linalg::PackOp::create(
640  rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
641  /*padding=*/std::nullopt, outerDimsPerm);
642 
643  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
644  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
645  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
646  if (!outerDimsPerm.empty()) {
647  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
648  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
649  }
650  // The tiled dimensions were verified to be unpadded above, so here we
651  // just append 0 for the inner tile dimensions.
652  size_t pointLoopsSize = innerDimsPos.size();
653  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
654  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
655 
656  auto newPadOp =
657  tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack,
658  lowPad, highPad, paddingVal, padOp.getNofold());
659 
660  // If the pad has more than one user, create an unpack on the new pad to
661  // replace the other uses.
662  if (!padOp->hasOneUse()) {
663  auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
664  rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
665  Value unpackedPad =
666  linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
667  innerDimsPos, mixedTiles, outerDimsPerm);
668  rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
669  }
670 
671  // Replace the pack with the new pad.
672  rewriter.replaceOp(packOp, newPadOp.getResult());
673 
674  return success();
675  }
676 
677 private:
678  ControlPropagationFn controlFn;
679 };
680 
681 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
682 ///
683 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
684 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
685 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
686 /// non-unit projected dims in pos [2, 3] is 2.
687 ///
688 /// If all candidates in a reassociation are unit dims, it chooses the
689 /// inner-most dim pos.
691 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
692  ArrayRef<ReassociationIndices> reassocIndices,
693  ArrayRef<int64_t> targetShape) {
694  SmallVector<int64_t> projectedDimsPos;
695  for (auto pos : dimsPos) {
696  // In the case all dims are unit, this will return the inner-most one.
697  int64_t projectedPos = reassocIndices[pos].back();
698  for (auto i : llvm::reverse(reassocIndices[pos])) {
699  int64_t dim = targetShape[i];
700  if (dim > 1 || ShapedType::isDynamic(dim)) {
701  projectedPos = i;
702  break;
703  }
704  }
705  projectedDimsPos.push_back(projectedPos);
706  }
707  return projectedDimsPos;
708 }
709 
710 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
711 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
712  ArrayRef<int64_t> shape,
713  ArrayRef<int64_t> tileSizes) {
714  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
715  int64_t dim = shape[pos];
716  if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
717  return false;
718  }
719  return true;
720 }
721 
722 /// Permutate the reassociation indices and reindex them in the sequence order.
723 /// Returns the next dim pos in the sequence.
724 ///
725 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
726 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
727 /// [[0], [1, 2]].
728 static int64_t applyPermutationAndReindexReassoc(
729  SmallVector<ReassociationIndices> &reassocIndices,
730  ArrayRef<int64_t> permutation) {
731  if (!permutation.empty())
732  applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
733  int64_t nextPos = 0;
734  for (ReassociationIndices &indices : reassocIndices) {
735  for (auto &index : indices) {
736  index = nextPos;
737  nextPos += 1;
738  }
739  }
740  return nextPos;
741 }
742 
743 /// Bubble up pack op through collapse shape op when the packed dims can be
744 /// projected to the dims before collapsing. This is possible when the inner
745 /// tile sizes can divide the projected dims.
746 ///
747 /// For example:
748 ///
749 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
750 /// : tensor<?x16x4xf32> into tensor<?x4xf32>
751 /// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1]
752 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
753 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
754 ///
755 /// can be transformed into:
756 ///
757 /// %pack = linalg.pack %in outer_dims_perm = [1, 2]
758 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
759 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
760 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
761 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
762 static LogicalResult
763 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
764  linalg::PackOp packOp,
765  PatternRewriter &rewriter) {
766  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
767  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
768  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
769 
770  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
771  SmallVector<ReassociationIndices> reassocIndices =
772  collapseOp.getReassociationIndices();
773  // Project inner tile pos to the dim pos before collapsing. For example, if
774  // dims [x, y] is collapsed into [z], packing on dim z can be projected back
775  // to pack on dim y.
776  //
777  // Project to inner-most non-unit dims to increase the chance that they can be
778  // divided by the inner tile sizes. This is correct because for [..., x, 1],
779  // packing on dim 1 is equivalent to packing on dim x.
780  SmallVector<int64_t> projectedInnerDimsPos =
781  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
782 
783  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
784  innerTileSizes)) {
785  return failure();
786  }
787  // Expand the outer dims permutation with the associated source dims for the
788  // new permutation after bubbling. This is because moving a collapsed dim is
789  // equivalent to moving the associated source dims together.
790  SmallVector<int64_t> newOuterDimsPerm;
791  for (auto outerPos : outerDimsPerm)
792  llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
793 
794  auto emptyOp = linalg::PackOp::createDestinationTensor(
795  rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
796  projectedInnerDimsPos, newOuterDimsPerm);
797  auto newPackOp = linalg::PackOp::create(
798  rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
799  projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
800  newOuterDimsPerm);
801 
802  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
803  // First apply the permutation on the reassociations of the outer dims.
804  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
805  // -> [[0], [1, 2]]
806  int64_t nextPos =
807  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
808  // Then add direct mapping for the inner tile dims.
809  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
810  newReassocIndices.push_back({nextPos});
811  nextPos += 1;
812  }
813 
814  auto newCollapseOp = tensor::CollapseShapeOp::create(
815  rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp,
816  newReassocIndices);
817  rewriter.replaceOp(packOp, newCollapseOp);
818 
819  return success();
820 }
821 
822 /// Project dimsPos to their collapsed positions in the reassocIndices.
823 ///
824 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
825 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
826 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
827 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
829 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
830  ArrayRef<ReassociationIndices> reassocIndices) {
831  SmallVector<int64_t> projectedPos;
832 
833  // Map each dimension to the position of corresponding reassociation index.
834  for (auto pos : dimsPos) {
835  for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
836  // If the dimension is present in the current indices group, the group
837  // position within the reassociation map is the desired projected
838  // dimension position.
839  if (llvm::is_contained(indices, pos)) {
840  projectedPos.push_back(idx);
841  break;
842  }
843  }
844  }
845  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
846 
847  return projectedPos;
848 }
849 
850 /// Bubble up pack op through expand shape op.
851 ///
852 /// For example:
853 ///
854 /// %expand = tensor.expand_shape %in [[0], [1, 2]]
855 /// : tensor<?x64xf32> into tensor<?x4x16xf32>
856 /// %pack = linalg.pack %expand outer_dims_perm = [0, 1]
857 /// inner_dims_pos = [2] inner_tiles = [8] into %empty
858 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
859 ///
860 /// can be transformed into:
861 ///
862 /// %pack = linalg.pack %in outer_dims_perm = [1, 2]
863 /// inner_dims_pos = [1] inner_tiles = [8] into %empty
864 /// : tensor<?x64xf32> -> tensor<?x8x8xf32>
865 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
866 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
867 static LogicalResult
868 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
869  linalg::PackOp packOp,
870  PatternRewriter &rewriter) {
871  // Outer dimensions permutation is not supported currently.
872  // TODO: Handle outer_dims_perm variants.
873  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
875  return rewriter.notifyMatchFailure(packOp,
876  "non-identity outer dims perm NYI");
877  }
878 
879  // Validate dimensions' relations between shape expansion and packing.
881  expandOp.getReassociationIndices();
882  ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
883  llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims);
884 
885  for (auto [idx, indices] : llvm::enumerate(reassoc)) {
886  // For each expand_shape reassociation, figure out which dimensions get
887  // packed if any.
888  llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices);
889  llvm::SetVector<int64_t> packedDims =
890  llvm::set_intersection(packDimsPos, expandDimPos);
891 
892  // The expanded dimension is not packed so, it does not affect moving pack
893  // before shape expansion - simply continue.
894  if (packedDims.empty())
895  continue;
896  // Shape expansion cannot be propagated when multiple expanded dimension are
897  // packed - in this case operation reordering would affect final element
898  // positions and/or shapes can no longer be projected.
899  if (packedDims.size() != 1)
900  return rewriter.notifyMatchFailure(
901  packOp, "only one of the expanded dimensions can be packed");
902  // Only the inner-most expanded dimension should be packed. Otherwise,
903  // elements order will be affected after operation reordering.
904  if (packedDims.front() != indices.back())
905  return rewriter.notifyMatchFailure(
906  packOp, "can only pack the inner-most expanded dimension");
907  }
908 
909  // Project pack.inner_dims_pos to positions before shape expansion.
910  SmallVector<int64_t> projectedInnerDimsPos =
911  projectDimsPosIntoReassocPos(packInnerDims, reassoc);
912 
913  // Project the shape expansion to new packed shape.
914  // The pack.outer_dims_perm is restricted to identity so, the permutation can
915  // be omitted for simplicity.
916  // TODO: Account for outer dimensions permutation.
917  //
918  // If reassociation is not possible, then reordering cannot happen.
919  // This can be caused by pack padding affecting previously expanded
920  // dimensions or packing extending dimensions.
921  RankedTensorType newPackType = linalg::PackOp::inferPackedType(
922  expandOp.getSrcType(), packOp.getStaticInnerTiles(),
923  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
924  auto reassocExpand =
925  getReassociationIndicesForReshape(newPackType, packOp.getDestType());
926  if (!reassocExpand)
927  return rewriter.notifyMatchFailure(
928  packOp, "could not reassociate dims after bubbling up");
929 
930  Value destTensor = linalg::PackOp::createDestinationTensor(
931  rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
932  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
933  Value packedVal = linalg::PackOp::create(
934  rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
935  projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
936  /*outerDimsPerm=*/SmallVector<int64_t>{});
937 
938  Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(),
939  packOp.getDestType(),
940  packedVal, *reassocExpand);
941  rewriter.replaceOp(packOp, newExpandOp);
942 
943  return success();
944 }
945 
946 class BubbleUpPackOpThroughReshapeOp final
947  : public OpRewritePattern<linalg::PackOp> {
948 public:
949  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
950  : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
951 
952  LogicalResult matchAndRewrite(linalg::PackOp packOp,
953  PatternRewriter &rewriter) const override {
954  Operation *srcOp = packOp.getSource().getDefiningOp();
955  // Currently only support when the pack op is the only user.
956  if (!srcOp || !(srcOp->getNumResults() == 1) ||
957  !srcOp->getResult(0).hasOneUse()) {
958  return failure();
959  }
960  // Currently only support static inner tile sizes.
961  if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
962  return failure();
963 
964  // User controlled propagation function.
965  if (!controlFn(&packOp.getSourceMutable()))
966  return failure();
967 
969  .Case([&](tensor::CollapseShapeOp op) {
970  return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
971  })
972  .Case([&](tensor::ExpandShapeOp op) {
973  return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
974  })
975  .Default([](Operation *) { return failure(); });
976  }
977 
978 private:
979  ControlPropagationFn controlFn;
980 };
981 
982 /// Push down unpack op through expand shape op when the packed dims can be
983 /// projected to the dims after expanding. This is possible when the inner tile
984 /// sizes can divide the projected dims.
985 ///
986 /// For example:
987 ///
988 /// %unpack = linalg.unpack %in outer_dims_perm = [0, 1]
989 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
990 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
991 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
992 /// : tensor<?x256xf32> into tensor<?x256x256xf32>
993 ///
994 /// can be transformed into:
995 ///
996 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
997 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
998 /// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2]
999 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
1000 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
1001 static LogicalResult pushDownUnPackOpThroughExpandShape(
1002  linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
1003  PatternRewriter &rewriter, ControlPropagationFn controlFn) {
1004  // User controlled propagation function.
1005  if (!controlFn(&expandOp.getSrcMutable()))
1006  return failure();
1007 
1008  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
1009  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
1010  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
1011 
1012  auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
1013  if (!expandTy)
1014  return failure();
1015  ArrayRef<int64_t> dstShape = expandTy.getShape();
1016  SmallVector<ReassociationIndices> reassocIndices =
1017  expandOp.getReassociationIndices();
1018  // Project inner tile pos to the dim pos after expanding. For example, if dims
1019  // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
1020  // on dim y.
1021  //
1022  // Project to inner-most non-unit dims to increase the chance that they can be
1023  // divided by the inner tile sizes. This is correct because for [..., x, 1],
1024  // unpacking on dim 1 is equivalent to unpacking on dim x.
1025  SmallVector<int64_t> projectedInnerDimsPos =
1026  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
1027 
1028  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
1029  innerTileSizes)) {
1030  return failure();
1031  }
1032  // Expand the outer dims permutation with the associated expanded dims for the
1033  // new permutation after pushing. This is because moving a source dim is
1034  // equivalent to moving the associated expanded dims together.
1035  SmallVector<int64_t> newOuterDimsPerm;
1036  for (auto outerPos : outerDimsPerm)
1037  llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
1038 
1039  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
1040  // First apply the permutation on the reassociations of the outer dims.
1041  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
1042  // -> [[0], [1, 2]]
1043  int64_t nextPos =
1044  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
1045  // Then add direct mapping for the inner tile dims.
1046  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
1047  newReassocIndices.push_back({nextPos});
1048  nextPos += 1;
1049  }
1050 
1051  RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
1052  expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
1053  auto newExpandOp =
1054  tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
1055  unPackOp.getSource(), newReassocIndices);
1056 
1057  auto emptyOp = linalg::UnPackOp::createDestinationTensor(
1058  rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
1059  projectedInnerDimsPos, newOuterDimsPerm);
1060  auto newUnPackOp = linalg::UnPackOp::create(
1061  rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
1062  projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
1063  rewriter.replaceOp(expandOp, newUnPackOp);
1064 
1065  return success();
1066 }
1067 
1068 class PushDownUnPackOpThroughReshapeOp final
1069  : public OpRewritePattern<linalg::UnPackOp> {
1070 public:
1071  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
1073  : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) {
1074  }
1075 
1076  LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
1077  PatternRewriter &rewriter) const override {
1078  Value result = unPackOp.getResult();
1079  // Currently only support unpack op with the single user.
1080  if (!result.hasOneUse()) {
1081  return failure();
1082  }
1083  // Currently only support static inner tile sizes.
1084  if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1085  return failure();
1086 
1087  Operation *consumerOp = *result.user_begin();
1088  return TypeSwitch<Operation *, LogicalResult>(consumerOp)
1089  .Case([&](tensor::ExpandShapeOp op) {
1090  return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1091  controlFn);
1092  })
1093  .Default([](Operation *) { return failure(); });
1094  }
1095 
1096 private:
1097  ControlPropagationFn controlFn;
1098 };
1099 
1100 // TODO: Relax this restriction. We should unpack a generic op also
1101 // in the presence of multiple unpack ops as producers.
1102 /// Return the unpacked operand, if present, for the current generic op.
1103 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1104  OpOperand *unPackedOperand = nullptr;
1105  for (OpOperand &operand : genericOp->getOpOperands()) {
1106  auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
1107  if (!unPackOp)
1108  continue;
1109  if (unPackedOperand)
1110  return failure();
1111  unPackedOperand = &operand;
1112  }
1113  if (!unPackedOperand)
1114  return failure();
1115  return unPackedOperand;
1116 }
1117 
1118 /// Push down a linalg.unpack op through a generic op.
1119 /// The new generic op works on packed domain; pack ops are created for input
1120 /// and output operands. A linalg.unpack op is inserted right after the packed
1121 /// generic. E.g.
1122 ///
1123 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1124 ///
1125 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1126 ///
1127 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1128 /// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1129 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1130 /// %2 = linalg.generic {indexing_maps = [#map],
1131 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1132 /// outs(%1 : tensor<12x56x56x64xf32>) {
1133 /// ^bb0(%out : f32):
1134 /// linalg.yield %out : f32
1135 /// } -> tensor<12x56x56x64xf32>
1136 ///
1137 /// will be converted to
1138 ///
1139 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1140 ///
1141 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1142 /// %1 = linalg.generic {indexing_maps = [#map],
1143 /// iterator_types = ["parallel", "parallel", "parallel",
1144 /// "parallel", "parallel"]}
1145 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1146 /// ^bb0(%out : f32):
1147 /// linalg.yield %out : f32
1148 /// } -> tensor<12x2x56x56x32xf32>
1149 /// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1150 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1151 ///
1152 static FailureOr<std::tuple<GenericOp, Value>>
1153 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1154  ControlPropagationFn controlFn,
1155  bool poisonPaddingOk) {
1156  if (genericOp.getNumResults() != 1)
1157  return failure();
1158 
1159  if (hasGatherSemantics(genericOp))
1160  return failure();
1161 
1162  // Collect the unPacked operand, if present.
1163  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1164  if (failed(maybeUnPackedOperand))
1165  return failure();
1166  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1167 
1168  // Extract packing information.
1169  linalg::UnPackOp producerUnPackOp =
1170  unPackedOperand->get().getDefiningOp<linalg::UnPackOp>();
1171  assert(producerUnPackOp && "expect a valid UnPackOp");
1172 
1173  if (!controlFn(unPackedOperand))
1174  return failure();
1175 
1176  auto packInfo =
1177  getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1178  if (failed(packInfo))
1179  return failure();
1180 
1181  // Rebuild the indexing map for the corresponding init operand.
1183  bool requiresPadding =
1184  getPackedOperandDetails(rewriter, *packInfo, genericOp,
1185  genericOp.getDpsInitOperand(0), packedOperandMap);
1186  if (requiresPadding && !poisonPaddingOk)
1187  return failure();
1188 
1189  auto [packedOutOperand, packedOutIndexingMap] =
1190  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(),
1191  genericOp.getDpsInitOperand(0),
1192  packedOperandMap);
1193  auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1194 
1195  // Forward the new tensor.empty as a destination if it is one of the following
1196  // situations:
1197  // 1) The dps init operand is a tensor.empty.
1198  // 2) The dps init is a write-only operand, i.e., it is not used in the
1199  // genericOp
1200  Value dest = packedOutOperand;
1201  auto initTensor =
1202  genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
1203  if (initTensor || isGenericOutsNotUsed(genericOp)) {
1204  if (destPack)
1205  dest = destPack.getDest();
1206  }
1207 
1208  // Pack the genericOp.
1209  // pack(unpack) is foldable in this case. This is because in pushing down the
1210  // unpack, by default we will populate an additional pack op after the unpack.
1211  // This guarantees them to be foldable.
1212  auto maybeGenericOp =
1213  packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1214  /*isFoldableUnpackPack=*/true, poisonPaddingOk);
1215  if (failed(maybeGenericOp))
1216  return failure();
1217  GenericOp newGenericOp = *maybeGenericOp;
1218  Value newResult =
1219  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1220 
1221  // If the output is unaffected, no need to unpack.
1222  if (!destPack)
1223  return std::make_tuple(newGenericOp, newResult);
1224 
1225  auto mixedTiles = destPack.getMixedTiles();
1226  auto innerDimsPos = destPack.getInnerDimsPos();
1227  auto outerDimsPerm = destPack.getOuterDimsPerm();
1228 
1229  // Insert an unPackOp right after the packed generic.
1230  Value unPackOpRes =
1231  linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
1232  destPack.getSource(), innerDimsPos, mixedTiles,
1233  outerDimsPerm)
1234  .getResult();
1235 
1236  return std::make_tuple(newGenericOp, unPackOpRes);
1237 }
1238 
1239 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1240 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1241 public:
1242  PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1244  bool poisonPaddingOk)
1245  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)),
1246  poisonPaddingOk(std::move(poisonPaddingOk)) {}
1247 
1248  LogicalResult matchAndRewrite(GenericOp genericOp,
1249  PatternRewriter &rewriter) const override {
1250  auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
1251  rewriter, genericOp, controlFn, poisonPaddingOk);
1252  if (failed(genericAndRepl))
1253  return failure();
1254  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1255  return success();
1256  }
1257 
1258 private:
1259  ControlPropagationFn controlFn;
1260  bool poisonPaddingOk;
1261 };
1262 
1263 /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
1264 /// add as many zero padding dimensions in `high` and `low` based on the number
1265 /// of point loops.
1266 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1267  PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1268  : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1269 
1270  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1271  PatternRewriter &rewriter) const override {
1272  linalg::UnPackOp unpackOp =
1273  padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1274  if (!unpackOp)
1275  return failure();
1276 
1277  if (!controlFn(&padOp.getSourceMutable()))
1278  return failure();
1279 
1280  Location loc = padOp.getLoc();
1281  // Bail out if one of the padded dimension is a tiled one.
1282  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1283  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1284  llvm::SmallBitVector innerDims(paddedDims.size());
1285  for (int64_t dim : innerDimsPos)
1286  innerDims.flip(dim);
1287  if (paddedDims.anyCommon(innerDims))
1288  return failure();
1289 
1290  Value paddingVal = padOp.getConstantPaddingValue();
1291  if (!paddingVal)
1292  return failure();
1293 
1294  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1295  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1296  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1297  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1298  if (!outerDimsPerm.empty()) {
1299  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1300  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1301  }
1302  // Add zero padding for the point loops.
1303  size_t pointLoopsSize = innerDimsPos.size();
1304  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1305  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1306 
1307  auto newPadOp = tensor::PadOp::create(rewriter, loc, /*result=*/Type(),
1308  unpackOp.getSource(), lowPad, highPad,
1309  paddingVal, padOp.getNofold());
1310 
1311  // Inject the linalg.unpack right after the packed padOp.
1312  Value outputUnPack =
1313  tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
1314  padOp.getResultType().getElementType());
1315 
1316  Value replacement = linalg::UnPackOp::create(
1317  rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1318  unpackOp.getMixedTiles(), outerDimsPerm);
1319  rewriter.replaceOp(padOp, replacement);
1320  return success();
1321  }
1322 
1323 private:
1324  ControlPropagationFn controlFn;
1325 };
1326 
1327 // This struct contains infomation about extract_slice dims.
1328 struct SliceDimInfo {
1329  OpFoldResult offset;
1330  OpFoldResult sliceSize;
1331  OpFoldResult outputSize;
1332 };
1333 
1334 /// Return all extract slice operands, if present, for the current
1335 /// generic op.
1336 static FailureOr<SmallVector<OpOperand *>>
1337 getSliceOperands(GenericOp genericOp) {
1338  SmallVector<OpOperand *> sliceOperands;
1339  for (auto operand : genericOp.getDpsInputOperands()) {
1340  auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1341  if (!extractOp)
1342  continue;
1343  sliceOperands.push_back(operand);
1344  }
1345  if (sliceOperands.empty()) {
1346  return failure();
1347  }
1348  return sliceOperands;
1349 }
1350 
1351 // Return a map of dims that have partial slices on them so that other operands
1352 // can use this information. Also return a bool mentioning if a reduction dim
1353 // has a non full slice as that can be used to fold the original extract slice.
1354 static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
1355 getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
1356  tensor::ExtractSliceOp producerSliceOp =
1357  sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1358  assert(producerSliceOp && "expect a valid ExtractSliceOp");
1359  llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
1360  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1361  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1362 
1364  genericOp.getContext(), producerSliceOp.getSourceType().getShape());
1365 
1366  for (auto [idx, expr] : llvm::enumerate(
1367  genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1368  // If we have a full slice in a dimension then we dont need to add it to
1369  // the partial slice map.
1370  if (isConstantIntValue(offsets[idx], 0) &&
1371  isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
1372  continue;
1373  }
1374  // We only support partial slices of AffineDimExprs so bail-out if thats not
1375  // the case.
1376  if (!isa<AffineDimExpr>(expr)) {
1377  return failure();
1378  }
1379  SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1380  int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1381  partialSliceDimMap[dimPos] = sliceDimInfo;
1382  }
1383  // Next check if the dims with partial slice info are used in non
1384  // AffineDimExpr in other operands and if they are then bail-out.
1385  for (OpOperand &operand : genericOp->getOpOperands()) {
1386  if (operand == *sliceOperand) {
1387  continue;
1388  }
1389  AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1390  if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
1391  if (isa<AffineDimExpr>(expr)) {
1392  return false;
1393  }
1394  WalkResult status = expr.walk([&](AffineExpr expr) {
1395  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1396  if (partialSliceDimMap.contains(dimExpr.getPosition())) {
1397  return WalkResult::interrupt();
1398  }
1399  }
1400  return WalkResult::advance();
1401  });
1402  if (status.wasInterrupted()) {
1403  return true;
1404  }
1405  return false;
1406  })) {
1407  return failure();
1408  }
1409  }
1410  return partialSliceDimMap;
1411 }
1412 
1413 static FailureOr<std::tuple<GenericOp, Value>>
1414 pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
1415  GenericOp genericOp,
1416  ControlPropagationFn controlFn) {
1417  if (genericOp.getNumResults() != 1)
1418  return rewriter.notifyMatchFailure(
1419  genericOp, "propagation through multi-result generic is unsupported.");
1420  if (hasGatherSemantics(genericOp))
1421  return rewriter.notifyMatchFailure(
1422  genericOp,
1423  "propagation through generic with gather semantics is unsupported.");
1424  // Collect the sliced operand, if present.
1425  auto maybeSliceOperands = getSliceOperands(genericOp);
1426  if (failed(maybeSliceOperands))
1427  return failure();
1428  SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
1429  OpOperand *sliceOperand;
1430 
1431  bool foundValidOperand = false;
1432  for (auto currSliceOperand : sliceOperands) {
1433  if (controlFn(currSliceOperand)) {
1434  sliceOperand = currSliceOperand;
1435  foundValidOperand = true;
1436  break;
1437  }
1438  }
1439  if (!foundValidOperand) {
1440  return failure();
1441  }
1442  unsigned OperandIndex = sliceOperand->getOperandNumber();
1443 
1444  tensor::ExtractSliceOp producerSliceOp =
1445  sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1446  assert(producerSliceOp && "expect a valid ExtractSliceOp");
1447 
1448  if (producerSliceOp.getSource().getType().getRank() !=
1449  producerSliceOp.getResult().getType().getRank()) {
1450  return rewriter.notifyMatchFailure(
1451  genericOp,
1452  "propagation of rank-reducing extract slice is unsupported.");
1453  }
1454 
1455  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
1456  if (!areAllConstantIntValue(strides, 1))
1457  return rewriter.notifyMatchFailure(
1458  genericOp, "propagation of strided extract slice is unsupported.");
1459 
1460  // check if we can support the propagation of this extractSlice
1461  // through the generic op and if so return the dimensions that
1462 
1463  auto maybePartialSliceDimMap =
1464  getPartialSliceDimInfo(genericOp, sliceOperand);
1465 
1466  if (failed(maybePartialSliceDimMap)) {
1467  return failure();
1468  }
1469 
1470  auto partialSliceDimMap = *maybePartialSliceDimMap;
1471 
1473  genericOp.getIteratorTypesArray();
1474  bool hasPartialReductionDimSlice =
1475  llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
1476  int64_t sliceDim = slice.first;
1477  return iterators[sliceDim] == utils::IteratorType::reduction;
1478  });
1479 
1480  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1481  Location loc = genericOp->getLoc();
1482  AffineExpr dim0, dim1;
1483  bindDims(rewriter.getContext(), dim0, dim1);
1484  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
1485  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1486  return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
1487  {v1, v2});
1488  };
1489 
1490  MLIRContext *ctx = genericOp.getContext();
1491  SmallVector<Value> paddedInputs;
1492  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1493  if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1494  paddedInputs.push_back(producerSliceOp.getSource());
1495  continue;
1496  }
1497  AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1498  if (IndexingMap.getNumResults() == 0) {
1499  paddedInputs.push_back(operand->get());
1500  continue;
1501  }
1502  SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
1503  getAsIndexOpFoldResult(ctx, 0));
1504  SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
1505  getAsIndexOpFoldResult(ctx, 0));
1506  for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
1507  if (!isa<AffineDimExpr>(expr)) {
1508  continue;
1509  }
1510  AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1511  if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1512  continue;
1513  }
1514  SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1515  operandLowPads[idx] = sliceDimInfo.offset;
1516  operandHighPads[idx] =
1517  sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1518  sliceDimInfo.sliceSize);
1519  }
1520  auto paddingValue = ub::PoisonOp::create(
1521  rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
1522  auto paddedOperand = tensor::PadOp::create(
1523  rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
1524  paddingValue, /*nofold=*/false);
1525  paddedInputs.push_back(paddedOperand);
1526  }
1527  AffineMap outputIndexingMap =
1528  genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1529 
1530  auto outputShapeType =
1531  llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1532  SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
1533  outputShapeType.getShape(),
1534  [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
1535  SmallVector<OpFoldResult> newSizes = OutputShape;
1536  SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
1537  getAsIndexOpFoldResult(ctx, 0));
1538  SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
1539  getAsIndexOpFoldResult(ctx, 0));
1540  SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
1541  getAsIndexOpFoldResult(ctx, 1));
1542  for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
1543  if (!isa<AffineDimExpr>(expr)) {
1544  continue;
1545  }
1546  AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1547  if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1548  continue;
1549  }
1550  SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1551  outputLowPads[idx] = sliceDimInfo.offset;
1552  outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1553  sliceDimInfo.sliceSize);
1554  OutputShape[idx] = sliceDimInfo.outputSize;
1555  newSizes[idx] = sliceDimInfo.sliceSize;
1556  }
1557  Value newPadOutput;
1558  auto outputElType =
1559  getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
1560  if (isGenericOutsNotUsed(genericOp)) {
1561  newPadOutput =
1562  tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1563  } else {
1564  auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1565  newPadOutput = tensor::PadOp::create(
1566  rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
1567  outputHighPads, paddingValue, /*nofold=*/false);
1568  }
1569 
1570  auto newGenericOp = linalg::GenericOp::create(
1571  rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
1572  genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1573  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
1574  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
1575  newGenericOp.getRegion().begin());
1576 
1577  auto extractOp = tensor::ExtractSliceOp::create(
1578  rewriter, loc,
1579  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1580  outputLowPads, newSizes, newStrides);
1581  Value extractRes = extractOp.getResult();
1582 
1583  return std::make_tuple(newGenericOp, extractRes);
1584 }
1585 
1586 class PushDownExtractSliceOpThroughGenericOp final
1587  : public OpRewritePattern<GenericOp> {
1588 public:
1589  PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
1591  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1592 
1593  LogicalResult matchAndRewrite(GenericOp genericOp,
1594  PatternRewriter &rewriter) const override {
1595  auto genericAndRepl =
1596  pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1597  if (failed(genericAndRepl))
1598  return failure();
1599  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1600  return success();
1601  }
1602 
1603 private:
1604  ControlPropagationFn controlFn;
1605 };
1606 
1607 } // namespace
1608 
1611  const ControlPropagationFn &controlPackUnPackPropagation,
1612  bool PoisonPaddingOk) {
1613  patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
1614  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1615  patterns.getContext(), controlPackUnPackPropagation);
1616  patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
1617  PushDownUnPackOpThroughGenericOp>(
1618  patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
1619 }
1620 
1623  const ControlPropagationFn &controlPackUnPackPropagation) {
1624  patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1625  patterns.getContext(), controlPackUnPackPropagation);
1626 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5181
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5179
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:346
Base type for affine expression.
Definition: AffineExpr.h:68
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition: AffineExpr.h:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:56
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:579
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_iterator user_begin() const
Definition: Value.h:216
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:235
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:385
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
Definition: Transforms.h:1914
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:567
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314