MLIR  20.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Dominance.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include <optional>
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
35 
36 namespace {
37 
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39  for (Operation &op : genericOp.getBody()->getOperations())
40  if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
41  return true;
42  return false;
43 }
44 
45 // The struct contains the infomation about mapping packing information to
46 // the iteration domain of Linalg ops.
47 struct PackInfo {
48  int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
49  // InnerDimsPos on iteration domain, which follows the order in pack ops.
50  SmallVector<int64_t> tiledDimsPos;
51  // The sizes of tiling data dimensions on iteration domain.
52  llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
53  // The mapping from a dimension of iteration domain to the corresponding inner
54  // tiling dimension on iteration domain.
55  llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
56  // The permutation of outer dims (on domain).
57  SmallVector<int64_t> outerDimsOnDomainPerm;
58 };
59 
60 template <typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
63  OpTy packOrUnPackOp) {
64  static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
65  "applies to only pack or unpack operations");
66  LLVM_DEBUG(
67  { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
68 
69  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
72  genericOp.getIteratorTypesArray();
73 
74  PackInfo packInfo;
75  int64_t origNumDims = indexingMap.getNumDims();
76  SmallVector<AffineExpr> exprs(indexingMap.getResults());
77  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
78  for (auto [index, innerDimPos, tileSize] :
79  llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
80  innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81  auto expr = exprs[innerDimPos];
82  if (!isa<AffineDimExpr>(expr))
83  return failure();
84  int64_t domainDimPos =
85  cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86  if (!isParallelIterator(iterators[domainDimPos]))
87  return failure();
88  packInfo.tiledDimsPos.push_back(domainDimPos);
89  packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90  packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
91  LLVM_DEBUG({
92  llvm::dbgs() << "map innerDimPos=" << innerDimPos
93  << " to iteration dimension (d" << domainDimPos << ", d"
94  << packInfo.tileToPointMapping[domainDimPos]
95  << "), which has size=("
96  << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
97  });
98  }
99 
100  // Bail out if a tiled dimension is present in a map but not as an affine dim
101  // expression.
102  auto areAllAffineDimExpr = [&](int dim) {
103  for (AffineMap map : indexingMaps) {
104  if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
105  return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
106  })) {
107  return false;
108  }
109  }
110  return true;
111  };
112  for (int64_t i : packInfo.tiledDimsPos)
113  if (!areAllAffineDimExpr(i))
114  return failure();
115 
116  // Get the outer dims perm on the iteration domain. Start by identifying the
117  // set of domain dims affected by the outer permutation along with the
118  // permuted ordering for those dims. Then the full outer dims permutation can
119  // be constructed by replacing the affected dims with the permuted result in a
120  // numLoops-rank identity. e.g.
121  // outerDimsPerm = [1, 2, 0]
122  // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
123  //
124  // permutedOuterDims = [4, 3, 1]
125  // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
126  //
127  // Non-affine dim expressions must not be permuted by the outer dims
128  // permutation.
129  SmallVector<int64_t> permutedOuterDims;
130  for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131  auto permutedExpr = indexingMap.getResult(dim);
132  if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133  permutedOuterDims.push_back(dimExpr.getPosition());
134  continue;
135  }
136 
137  // TODO: Allow propagation with transposes on non affine dim expressions,
138  // e.g. d0 + d1 which implies transposing both dims simultaneously while
139  // maintaining the relative position between them.
140  if (static_cast<int64_t>(index) != dim)
141  return failure();
142  }
143  if (!permutedOuterDims.empty()) {
144  int64_t outerDimIndex = 0;
145  llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
146  permutedOuterDims.end());
147  for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148  packInfo.outerDimsOnDomainPerm.push_back(
149  permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150  : i);
151  LLVM_DEBUG({
152  llvm::dbgs() << "map outer dimsDimsPerm to ";
153  for (auto dim : packInfo.outerDimsOnDomainPerm)
154  llvm::dbgs() << dim << " ";
155  llvm::dbgs() << "\n";
156  });
157  }
158 
159  return packInfo;
160 }
161 
162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
163  ArrayRef<AffineExpr> exprs) {
164  // Compute `outer_dims_perm`. See example:
165  // current exprs : (d0, d1, d2, d3) -> (d2, d3)
166  // perm : [0, 3, 1, 2]
167  // First map d2, d3 with their position in the array as:
168  // currentPositionTileLoops: dim | pos
169  // d2 | 0
170  // d3 | 1
171  // then scan `perm` in order and get the `outer_dims_perm`
172  // to be used, here it would be [1, 0].
173  assert(!perm.empty() && "expect perm not to be empty");
174  assert(!exprs.empty() && "expect exprs not to be empty");
175  if (exprs.size() == 1)
176  return {};
177  SmallVector<int64_t> outerDimsPerm;
178  DenseMap<int64_t, int64_t> currentPositionTileLoops;
179  for (auto [pos, expr] : llvm::enumerate(exprs)) {
180  // Here we rely on the assumption that the outer dims permutation
181  // when propagating currently requires that non-affine dim expressions
182  // are not permuted, thus allowing the identity assignment below.
183  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184  currentPositionTileLoops[dimExpr.getPosition()] = pos;
185  else
186  currentPositionTileLoops[pos] = pos;
187  }
188  for (int64_t loopIdx : perm) {
189  if (currentPositionTileLoops.count(loopIdx))
190  outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
191  }
192  return outerDimsPerm;
193 }
194 
195 /// Returns a tuple for packed operand and indexing_map with the assumptions:
196 /// 1) The generic op is the producer of the pack op.
197 /// 2) The generic op has only one result.
198 /// If the operand is a scalar or packing dimensions are all irrelevant to the
199 /// operand, the operand and the updated indexing map will be returned.
200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
201 ///
202 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
203 /// #map1 = affine_map<(d0, d1) -> (d0)>
204 /// #map2 = affine_map<(d0, d1) -> (d1)>
205 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
206 /// iterator_types = ["parallel", "parallel"]}
207 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
208 /// outs(%init : tensor<?x?xf32>) {
209 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
210 /// %4 = arith.addf %arg3, %arg4 : f32
211 /// linalg.yield %4 : f32
212 /// } -> tensor<?x?xf32>
213 /// %1 = tensor.pack %0
214 /// inner_dims_pos = [0, 1]
215 /// inner_tiles = [8, 2]
216 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
217 ///
218 /// Taking the first input operand as an example, the inner tile size of d1 is
219 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
220 /// affine_map<(d1, d3)>` will be returned.
221 ///
222 /// %pack = tensor.pack %arg0
223 /// inner_dims_pos = [0]
224 /// inner_tiles = [8]
225 /// into %init : tensor<?xf32> -> tensor<?x8xf32>
226 static std::tuple<Value, AffineMap>
227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
228  GenericOp genericOp, OpOperand *opOperand) {
229  int64_t numOrigLoops = genericOp.getNumLoops();
230  int64_t numInnerLoops = packInfo.getNumTiledLoops();
231  int64_t numLoops = numOrigLoops + numInnerLoops;
232  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
233  llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
234  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
235 
236  // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
237  if (genericOp.isScalar(opOperand) || exprs.empty())
238  return std::make_tuple(opOperand->get(),
239  AffineMap::get(numLoops, 0, exprs, b.getContext()));
240 
241  // Step 1. Construct the information of packing data dimensions; append inner
242  // dimensions to the indexing maps for the operand.
243  for (auto [index, expr] : llvm::enumerate(exprs)) {
244  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245  int64_t dimPos = dimExpr.getPosition();
246  domainDimToOperandDim[dimPos] = index;
247  continue;
248  }
249  }
250  SmallVector<int64_t> innerDimsPos;
251  SmallVector<OpFoldResult> innerTileSizes;
252  for (auto dimPos : packInfo.tiledDimsPos) {
253  if (!domainDimToOperandDim.count(dimPos))
254  continue;
255  int64_t index = domainDimToOperandDim[dimPos];
256  innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257  innerDimsPos.push_back(index);
258  exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
259  }
260 
261  // Step 2. Handle outer dim permutations.
262  SmallVector<int64_t> outerDimsPerm;
263  if (!packInfo.outerDimsOnDomainPerm.empty()) {
264  outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
265 
266  // Step 2.1: Fold transpose into the linalg.generic.
267  SmallVector<int64_t> inversedOuterPerm =
268  invertPermutationVector(packInfo.outerDimsOnDomainPerm);
269  for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
270  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271  int64_t dimPos = dimExpr.getPosition();
272  exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
273  continue;
274  }
275  assert(isa<AffineConstantExpr>(exprs[i]) &&
276  "Attempted to permute non-constant and non-affine dim expression");
277  }
278  // Step 2.2: Undo the transposition on `exprs` and propagate the
279  // transposition on the pack using outerDimsPerm.
280  if (!outerDimsPerm.empty()) {
281  SmallVector<AffineExpr> auxVec = exprs;
282  for (const auto &en : enumerate(outerDimsPerm))
283  auxVec[en.index()] = exprs[en.value()];
284  exprs = auxVec;
285  }
286  }
287  auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
288 
289  // The operand does not have dimensions that relates to pack op.
290  if (innerDimsPos.empty() && outerDimsPerm.empty())
291  return std::make_tuple(opOperand->get(), indexingMap);
292 
293  auto empty = tensor::PackOp::createDestinationTensor(
294  b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
295  auto packedOperand = b.create<tensor::PackOp>(
296  loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
297  /*padding=*/std::nullopt, outerDimsPerm);
298  return std::make_tuple(packedOperand, indexingMap);
299 }
300 
301 /// Pack a genericOp and return it.
302 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
303  Value dest, AffineMap packedOutIndexingMap,
304  const PackInfo &packInfo) {
305  Location loc = genericOp.getLoc();
306  SmallVector<Value> inputOperands;
307  SmallVector<AffineMap> indexingMaps;
308  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309  auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310  rewriter, loc, packInfo, genericOp, inputOperand);
311  inputOperands.push_back(packedOperand);
312  indexingMaps.push_back(packedIndexingMap);
313  }
314 
315  int64_t numInnerLoops = packInfo.getNumTiledLoops();
317  genericOp.getIteratorTypesArray();
318  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
319 
320  indexingMaps.push_back(packedOutIndexingMap);
321 
322  auto newGenericOp = rewriter.create<linalg::GenericOp>(
323  loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
324  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
325  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
326  newGenericOp.getRegion().begin());
327  return newGenericOp;
328 }
329 
330 /// Bubbles up tensor.pack op through a producer generic op. This
331 /// swap pack(generic) to generic(pack). The new generic op works on packed
332 /// domain; pack ops are created for input and output operands. E.g.,
333 ///
334 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
335 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
336 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
337 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
338 /// %3 = linalg.generic {indexing_maps = [#map0, #map0],
339 /// iterator_types = ["parallel", "parallel"]}
340 /// ins(%arg0 : tensor<?x?xf32>)
341 /// outs(%2 : tensor<?x?xf32>) {
342 /// ^bb0(%arg3: f32, %arg4: f32):
343 /// %4 = arith.addf %arg3, %arg3 : f32
344 /// linalg.yield %4 : f32
345 /// } -> tensor<?x?xf32>
346 /// %4 = tensor.pack %3
347 /// inner_dims_pos = [0, 1]
348 /// inner_tiles = [8, 2]
349 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
350 ///
351 /// will be converted to
352 ///
353 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
354 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
355 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
356 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
357 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
358 /// %0 = affine.apply #map()[%dim]
359 /// %1 = affine.apply #map1()[%dim_0]
360 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
361 /// %pack = tensor.pack %arg0
362 /// inner_dims_pos = [0, 1]
363 /// inner_tiles = [8, 2]
364 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
365 /// %3 = linalg.generic {indexing_maps = [#map2, #map2],
366 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
367 /// ins(%pack : tensor<?x?x8x2xf32>)
368 /// outs(%arg1 : tensor<?x?x8x2xf32>) {
369 /// ^bb0(%in: f32, %out: f32):
370 /// %4 = arith.addf %in, %in : f32
371 /// linalg.yield %4 : f32
372 /// } -> tensor<?x?x8x2xf32>
373 static FailureOr<GenericOp>
374 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
375  const ControlPropagationFn &controlFn) {
376  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
377  if (!genericOp)
378  return failure();
379 
380  // User controlled propagation function.
381  if (!controlFn(&packOp.getSourceMutable()))
382  return failure();
383 
384  // TODO: Enable propagation in the presence of linalg.index and
385  // tensor.extract, likely as a separate pattern as the pack information and
386  // propagation decision needs to be inferred from the region of the generic.
387  if (hasGatherSemantics(genericOp))
388  return failure();
389 
390  // TODO: Relax the restriction. We are able to bubble up the pack op through
391  // multi-result generic op. It just needs more work.
392  if (genericOp.getNumResults() != 1)
393  return failure();
394 
395  // Bail-out if the result of the generic has multiple uses, as bubbling up
396  // creates recomputation if the generic has multiple users.
397  // TODO: Enable the case where every use is an identical pack op as no
398  // recomputation is needed in that case.
399  if (!genericOp->getResult(0).hasOneUse())
400  return failure();
401 
402  // We want to move the pack not the generic.
403  OpBuilder::InsertionGuard guard(rewriter);
404  rewriter.setInsertionPoint(genericOp);
405 
406  // We need to handle two cases:
407  // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
408  // create a new tensor.empty to avoid breaking dominance, as we are moving the
409  // tensor.pack above the linalg.generic.
410  // 2) The destination is not a tensor.empty. In this case we can replace only
411  // if the destination of the tensor.pack dominates the linalg.generic.
412  Value packOpDest = packOp.getDest();
413  if (!packOpDest.hasOneUse())
414  return failure();
415  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
416  packOpDest = rewriter.create<tensor::EmptyOp>(
417  genericOp->getLoc(), emptyOp.getMixedSizes(),
418  emptyOp.getType().getElementType());
419  } else {
420  DominanceInfo dom(genericOp);
421  if (!dom.properlyDominates(packOpDest, genericOp))
422  return failure();
423  }
424 
425  // TODO: Add an option for allowing padding values. It could introduce
426  // undefined behavior if we unconditionally propagate pack op through all
427  // the ops. E.g., if the padding value is zero and there are division ops in
428  // a generic op. Some values of padding area could be NaN (0/0).
429  if (packOp.getPaddingValue())
430  return failure();
431 
432  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
433  auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
434  if (failed(packInfo))
435  return failure();
436 
437  // Rebuild the indexing map for the corresponding init operand.
438  auto [packedOutOperand, packedOutIndexingMap] =
439  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
440  genericOp, opOperand);
441 
442  // If the dps init operand of the generic is a tensor.empty forward the pack
443  // op destination.
444  Value dest = packedOutOperand;
445  if (auto initTensor = genericOp.getDpsInitOperand(0)
446  ->get()
447  .getDefiningOp<tensor::EmptyOp>()) {
448  dest = packOpDest;
449  }
450  return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
451  *packInfo);
452 }
453 
454 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
455 struct BubbleUpPackOpThroughGenericOpPattern
456  : public OpRewritePattern<tensor::PackOp> {
457 public:
458  BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
460  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
461 
462  LogicalResult matchAndRewrite(tensor::PackOp packOp,
463  PatternRewriter &rewriter) const override {
464  auto genericOp =
465  bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466  if (failed(genericOp))
467  return failure();
468  rewriter.replaceOp(packOp, genericOp->getResults());
469  return success();
470  }
471 
472 private:
473  ControlPropagationFn controlFn;
474 };
475 
476 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
477 /// add as many zero padding dimensions in `high` and `low` based on the number
478 /// of point loops.
479 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
480 public:
481  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
482  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
483 
484  LogicalResult matchAndRewrite(tensor::PackOp packOp,
485  PatternRewriter &rewriter) const override {
486  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
487  if (!padOp)
488  return failure();
489 
490  // User controlled propagation function.
491  if (!controlFn(&packOp.getSourceMutable()))
492  return failure();
493 
494  // TODO: Enable padding when the padding values are the same.
495  if (packOp.getPaddingValue())
496  return failure();
497 
498  // Fail for non-constant padding values. The body of the pad could
499  // depend on the padding indices and/or properties of the padded
500  // tensor so for now we fail.
501  // TODO: Support non-constant padding values.
502  Value paddingVal = padOp.getConstantPaddingValue();
503  if (!paddingVal)
504  return failure();
505 
506  if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
507  return failure();
508 
509  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
510 
511  // Bail out if one of the padded dimension is a tiled one.
512  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
513  llvm::SmallBitVector innerDims(paddedDims.size());
514  for (int64_t dim : innerDimsPos)
515  innerDims.flip(dim);
516  if (paddedDims.anyCommon(innerDims))
517  return failure();
518 
519  Location loc = padOp->getLoc();
520  OpBuilder::InsertionGuard guard(rewriter);
521  rewriter.setInsertionPoint(padOp);
522 
523  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
524  SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
525  auto empty = tensor::PackOp::createDestinationTensor(
526  rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
527  outerDimsPerm);
528  auto sourcePack = rewriter.create<tensor::PackOp>(
529  loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
530  /*padding=*/std::nullopt, outerDimsPerm);
531 
532  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
533  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
534  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
535  if (!outerDimsPerm.empty()) {
536  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
538  }
539  // The tiled dimensions were verified to be unpadded above, so here we
540  // just append 0 for the inner tile dimensions.
541  size_t pointLoopsSize = innerDimsPos.size();
542  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
543  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
544 
545  auto newPadOp = rewriter.create<tensor::PadOp>(
546  loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
547  padOp.getNofold());
548 
549  // If the pad has more than one user, create an unpack on the new pad to
550  // replace the other uses.
551  if (!padOp->hasOneUse()) {
552  auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
553  rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
554  Value unpackedPad = rewriter.create<tensor::UnPackOp>(
555  loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
556  rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
557  }
558 
559  // Replace the pack with the new pad.
560  rewriter.replaceOp(packOp, newPadOp.getResult());
561 
562  return success();
563  }
564 
565 private:
566  ControlPropagationFn controlFn;
567 };
568 
569 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
570 ///
571 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
572 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
573 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
574 /// non-unit projected dims in pos [2, 3] is 2.
575 ///
576 /// If all candidates in a reassociation are unit dims, it chooses the
577 /// inner-most dim pos.
579 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
580  ArrayRef<ReassociationIndices> reassocIndices,
581  ArrayRef<int64_t> targetShape) {
582  SmallVector<int64_t> projectedDimsPos;
583  for (auto pos : dimsPos) {
584  // In the case all dims are unit, this will return the inner-most one.
585  int64_t projectedPos = reassocIndices[pos].back();
586  for (auto i : llvm::reverse(reassocIndices[pos])) {
587  int64_t dim = targetShape[i];
588  if (dim > 1 || ShapedType::isDynamic(dim)) {
589  projectedPos = i;
590  break;
591  }
592  }
593  projectedDimsPos.push_back(projectedPos);
594  }
595  return projectedDimsPos;
596 }
597 
598 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
599 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
600  ArrayRef<int64_t> shape,
601  ArrayRef<int64_t> tileSizes) {
602  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
603  int64_t dim = shape[pos];
604  if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
605  return false;
606  }
607  return true;
608 }
609 
610 /// Permutate the reassociation indices and reindex them in the sequence order.
611 /// Returns the next dim pos in the sequence.
612 ///
613 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
614 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
615 /// [[0], [1, 2]].
616 static int64_t applyPermutationAndReindexReassoc(
617  SmallVector<ReassociationIndices> &reassocIndices,
618  ArrayRef<int64_t> permutation) {
619  if (!permutation.empty())
620  applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
621  int64_t nextPos = 0;
622  for (ReassociationIndices &indices : reassocIndices) {
623  for (auto &index : indices) {
624  index = nextPos;
625  nextPos += 1;
626  }
627  }
628  return nextPos;
629 }
630 
631 /// Bubble up pack op through collapse shape op when the packed dims can be
632 /// projected to the dims before collapsing. This is possible when the inner
633 /// tile sizes can divide the projected dims.
634 ///
635 /// For example:
636 ///
637 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
638 /// : tensor<?x16x4xf32> into tensor<?x4xf32>
639 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
640 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
641 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
642 ///
643 /// can be transformed into:
644 ///
645 /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
646 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
647 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
648 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
649 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
650 static LogicalResult
651 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
652  tensor::PackOp packOp,
653  PatternRewriter &rewriter) {
654  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
655  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
656  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
657 
658  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
659  SmallVector<ReassociationIndices> reassocIndices =
660  collapseOp.getReassociationIndices();
661  // Project inner tile pos to the dim pos before collapsing. For example, if
662  // dims [x, y] is collapsed into [z], packing on dim z can be projected back
663  // to pack on dim y.
664  //
665  // Project to inner-most non-unit dims to increase the chance that they can be
666  // divided by the inner tile sizes. This is correct because for [..., x, 1],
667  // packing on dim 1 is equivalent to packing on dim x.
668  SmallVector<int64_t> projectedInnerDimsPos =
669  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
670 
671  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
672  innerTileSizes)) {
673  return failure();
674  }
675  // Expand the outer dims permutation with the associated source dims for the
676  // new permutation after bubbling. This is because moving a collapsed dim is
677  // equivalent to moving the associated source dims together.
678  SmallVector<int64_t> newOuterDimsPerm;
679  for (auto outerPos : outerDimsPerm) {
680  newOuterDimsPerm.insert(newOuterDimsPerm.end(),
681  reassocIndices[outerPos].begin(),
682  reassocIndices[outerPos].end());
683  }
684 
685  auto emptyOp = tensor::PackOp::createDestinationTensor(
686  rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
687  projectedInnerDimsPos, newOuterDimsPerm);
688  auto newPackOp = rewriter.create<tensor::PackOp>(
689  packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
690  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
691 
692  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
693  // First apply the permutation on the reassociations of the outer dims.
694  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
695  // -> [[0], [1, 2]]
696  int64_t nextPos =
697  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
698  // Then add direct mapping for the inner tile dims.
699  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
700  newReassocIndices.push_back({nextPos});
701  nextPos += 1;
702  }
703 
704  auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
705  collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
706  rewriter.replaceOp(packOp, newCollapseOp);
707 
708  return success();
709 }
710 
711 /// Project dimsPos to their collapsed positions in the reassocIndices.
712 ///
713 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
714 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
715 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
716 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
718 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
719  ArrayRef<ReassociationIndices> reassocIndices) {
720  SmallVector<int64_t> projectedPos;
721 
722  // Map each dimension to the position of corresponding reassociation index.
723  for (auto pos : dimsPos) {
724  for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
725  // If the dimension is present in the current indices group, the group
726  // position within the reassociation map is the desired projected
727  // dimension position.
728  if (llvm::is_contained(indices, pos)) {
729  projectedPos.push_back(idx);
730  break;
731  }
732  }
733  }
734  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
735 
736  return projectedPos;
737 }
738 
739 /// Bubble up pack op through expand shape op.
740 ///
741 /// For example:
742 ///
743 /// %expand = tensor.expand_shape %in [[0], [1, 2]]
744 /// : tensor<?x64xf32> into tensor<?x4x16xf32>
745 /// %pack = tensor.pack %expand outer_dims_perm = [0, 1]
746 /// inner_dims_pos = [2] inner_tiles = [8] into %empty
747 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
748 ///
749 /// can be transformed into:
750 ///
751 /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
752 /// inner_dims_pos = [1] inner_tiles = [8] into %empty
753 /// : tensor<?x64xf32> -> tensor<?x8x8xf32>
754 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
755 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
756 static LogicalResult
757 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
758  tensor::PackOp packOp,
759  PatternRewriter &rewriter) {
760  // Outer dimensions permutation is not supported currently.
761  // TODO: Handle outer_dims_perm variants.
762  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
763  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
764  return rewriter.notifyMatchFailure(packOp,
765  "non-identity outer dims perm NYI");
766  }
767 
768  // Validate dimensions' relations between shape expansion and packing.
770  expandOp.getReassociationIndices();
771  ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
772  llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
773  packInnerDims.end());
774 
775  for (auto [idx, indices] : llvm::enumerate(reassoc)) {
776  // For each expand_shape reassociation, figure out which dimensions get
777  // packed if any.
778  llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
779  llvm::SetVector<int64_t> packedDims =
780  llvm::set_intersection(packDimsPos, expandDimPos);
781 
782  // The expanded dimension is not packed so, it does not affect moving pack
783  // before shape expansion - simply continue.
784  if (packedDims.empty())
785  continue;
786  // Shape expansion cannot be propagated when multiple expanded dimension are
787  // packed - in this case operation reordering would affect final element
788  // positions and/or shapes can no longer be projected.
789  if (packedDims.size() != 1)
790  return rewriter.notifyMatchFailure(
791  packOp, "only one of the expanded dimensions can be packed");
792  // Only the inner-most expanded dimension should be packed. Otherwise,
793  // elements order will be affected after operation reordering.
794  if (packedDims.front() != indices.back())
795  return rewriter.notifyMatchFailure(
796  packOp, "can only pack the inner-most expanded dimension");
797  }
798 
799  // Project pack.inner_dims_pos to positions before shape expansion.
800  SmallVector<int64_t> projectedInnerDimsPos =
801  projectDimsPosIntoReassocPos(packInnerDims, reassoc);
802 
803  // Project the shape expansion to new packed shape.
804  // The pack.outer_dims_perm is restricted to identity so, the permutation can
805  // be omitted for simplicity.
806  // TODO: Account for outer dimensions permutation.
807  //
808  // If reassociation is not possible, then reordering cannot happen.
809  // This can be caused by pack padding affecting previously expanded
810  // dimensions or packing extending dimensions.
811  RankedTensorType newPackType = tensor::PackOp::inferPackedType(
812  expandOp.getSrcType(), packOp.getStaticInnerTiles(),
813  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
814  auto reassocExpand =
815  getReassociationIndicesForReshape(newPackType, packOp.getDestType());
816  if (!reassocExpand)
817  return rewriter.notifyMatchFailure(
818  packOp, "could not reassociate dims after bubbling up");
819 
820  Value destTensor = tensor::PackOp::createDestinationTensor(
821  rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
822  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
823  Value packedVal = rewriter.create<tensor::PackOp>(
824  packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
825  packOp.getMixedTiles(), packOp.getPaddingValue(),
826  /*outerDimsPerm=*/SmallVector<int64_t>{});
827 
828  Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
829  packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
830  rewriter.replaceOp(packOp, newExpandOp);
831 
832  return success();
833 }
834 
835 class BubbleUpPackOpThroughReshapeOp final
836  : public OpRewritePattern<tensor::PackOp> {
837 public:
838  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
839  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
840 
841  LogicalResult matchAndRewrite(tensor::PackOp packOp,
842  PatternRewriter &rewriter) const override {
843  Operation *srcOp = packOp.getSource().getDefiningOp();
844  // Currently only support when the pack op is the only user.
845  if (!srcOp || !(srcOp->getNumResults() == 1) ||
846  !srcOp->getResult(0).hasOneUse()) {
847  return failure();
848  }
849  // Currently only support static inner tile sizes.
850  if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
851  return ShapedType::isDynamic(size);
852  })) {
853  return failure();
854  }
855 
856  // User controlled propagation function.
857  if (!controlFn(&packOp.getSourceMutable()))
858  return failure();
859 
861  .Case([&](tensor::CollapseShapeOp op) {
862  return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
863  })
864  .Case([&](tensor::ExpandShapeOp op) {
865  return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
866  })
867  .Default([](Operation *) { return failure(); });
868  }
869 
870 private:
871  ControlPropagationFn controlFn;
872 };
873 
874 /// Push down unpack op through expand shape op when the packed dims can be
875 /// projected to the dims after expanding. This is possible when the inner tile
876 /// sizes can divide the projected dims.
877 ///
878 /// For example:
879 ///
880 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
881 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
882 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
883 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
884 /// : tensor<?x256xf32> into tensor<?x256x256xf32>
885 ///
886 /// can be transformed into:
887 ///
888 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
889 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
890 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
891 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
892 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
893 static LogicalResult pushDownUnPackOpThroughExpandShape(
894  tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
895  PatternRewriter &rewriter, ControlPropagationFn controlFn) {
896  // User controlled propagation function.
897  if (!controlFn(&expandOp.getSrcMutable()))
898  return failure();
899 
900  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
901  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
902  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
903 
904  auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
905  if (!expandTy)
906  return failure();
907  ArrayRef<int64_t> dstShape = expandTy.getShape();
908  SmallVector<ReassociationIndices> reassocIndices =
909  expandOp.getReassociationIndices();
910  // Project inner tile pos to the dim pos after expanding. For example, if dims
911  // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
912  // on dim y.
913  //
914  // Project to inner-most non-unit dims to increase the chance that they can be
915  // divided by the inner tile sizes. This is correct because for [..., x, 1],
916  // unpacking on dim 1 is equivalent to unpacking on dim x.
917  SmallVector<int64_t> projectedInnerDimsPos =
918  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
919 
920  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
921  innerTileSizes)) {
922  return failure();
923  }
924  // Expand the outer dims permutation with the associated expanded dims for the
925  // new permutation after pushing. This is because moving a source dim is
926  // equivalent to moving the associated expanded dims together.
927  SmallVector<int64_t> newOuterDimsPerm;
928  for (auto outerPos : outerDimsPerm) {
929  newOuterDimsPerm.insert(newOuterDimsPerm.end(),
930  reassocIndices[outerPos].begin(),
931  reassocIndices[outerPos].end());
932  }
933 
934  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
935  // First apply the permutation on the reassociations of the outer dims.
936  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
937  // -> [[0], [1, 2]]
938  int64_t nextPos =
939  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
940  // Then add direct mapping for the inner tile dims.
941  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
942  newReassocIndices.push_back({nextPos});
943  nextPos += 1;
944  }
945 
946  RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
947  expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
948  auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
949  expandOp.getLoc(), newExpandType, unPackOp.getSource(),
950  newReassocIndices);
951 
952  auto emptyOp = tensor::UnPackOp::createDestinationTensor(
953  rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
954  projectedInnerDimsPos, newOuterDimsPerm);
955  auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
956  unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
957  projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
958  rewriter.replaceOp(expandOp, newUnPackOp);
959 
960  return success();
961 }
962 
963 class PushDownUnPackOpThroughReshapeOp final
964  : public OpRewritePattern<tensor::UnPackOp> {
965 public:
966  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
968  : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
969  }
970 
971  LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
972  PatternRewriter &rewriter) const override {
973  Value result = unPackOp.getResult();
974  // Currently only support unpack op with the single user.
975  if (!result.hasOneUse()) {
976  return failure();
977  }
978  // Currently only support static inner tile sizes.
979  if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
980  return ShapedType::isDynamic(size);
981  })) {
982  return failure();
983  }
984 
985  Operation *consumerOp = *result.user_begin();
986  return TypeSwitch<Operation *, LogicalResult>(consumerOp)
987  .Case([&](tensor::ExpandShapeOp op) {
988  return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
989  controlFn);
990  })
991  .Default([](Operation *) { return failure(); });
992  }
993 
994 private:
995  ControlPropagationFn controlFn;
996 };
997 
998 // TODO: Relax this restriction. We should unpack a generic op also
999 // in the presence of multiple unpack ops as producers.
1000 /// Return the unpacked operand, if present, for the current generic op.
1001 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1002  OpOperand *unPackedOperand = nullptr;
1003  for (OpOperand &operand : genericOp->getOpOperands()) {
1004  auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
1005  if (!unPackOp)
1006  continue;
1007  if (unPackedOperand)
1008  return failure();
1009  unPackedOperand = &operand;
1010  }
1011  if (!unPackedOperand)
1012  return failure();
1013  return unPackedOperand;
1014 }
1015 
1016 /// Push down a tensor.unpack op through a generic op.
1017 /// The new generic op works on packed domain; pack ops are created for input
1018 /// and output operands. A tensor.unpack op is inserted right after the packed
1019 /// generic. E.g.
1020 ///
1021 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1022 ///
1023 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1024 ///
1025 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1026 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1027 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1028 /// %2 = linalg.generic {indexing_maps = [#map],
1029 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1030 /// outs(%1 : tensor<12x56x56x64xf32>) {
1031 /// ^bb0(%out : f32):
1032 /// linalg.yield %out : f32
1033 /// } -> tensor<12x56x56x64xf32>
1034 ///
1035 /// will be converted to
1036 ///
1037 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1038 ///
1039 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1040 /// %1 = linalg.generic {indexing_maps = [#map],
1041 /// iterator_types = ["parallel", "parallel", "parallel",
1042 /// "parallel", "parallel"]}
1043 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1044 /// ^bb0(%out : f32):
1045 /// linalg.yield %out : f32
1046 /// } -> tensor<12x2x56x56x32xf32>
1047 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1048 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1049 ///
1050 static FailureOr<std::tuple<GenericOp, Value>>
1051 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1052  ControlPropagationFn controlFn) {
1053  if (genericOp.getNumResults() != 1)
1054  return failure();
1055 
1056  if (hasGatherSemantics(genericOp))
1057  return failure();
1058 
1059  // Collect the unPacked operand, if present.
1060  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1061  if (failed(maybeUnPackedOperand))
1062  return failure();
1063  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1064 
1065  // Extract packing information.
1066  tensor::UnPackOp producerUnPackOp =
1067  unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
1068  assert(producerUnPackOp && "expect a valid UnPackOp");
1069 
1070  if (!controlFn(unPackedOperand))
1071  return failure();
1072 
1073  auto packInfo =
1074  getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1075  if (failed(packInfo))
1076  return failure();
1077 
1078  // Rebuild the indexing map for the corresponding init operand.
1079  auto [packedOutOperand, packedOutIndexingMap] =
1080  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1081  genericOp, genericOp.getDpsInitOperand(0));
1082  auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
1083 
1084  // If the dps init operand of the generic is a tensor.empty, do not pack it
1085  // and forward the new tensor.empty as a destination.
1086  Value dest = packedOutOperand;
1087  if (auto initTensor = genericOp.getDpsInitOperand(0)
1088  ->get()
1089  .getDefiningOp<tensor::EmptyOp>()) {
1090  if (destPack)
1091  dest = destPack.getDest();
1092  }
1093 
1094  // Pack the genericOp.
1095  GenericOp newGenericOp =
1096  packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1097  Value newResult =
1098  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1099 
1100  // If the output is unaffected, no need to unpack.
1101  if (!destPack)
1102  return std::make_tuple(newGenericOp, newResult);
1103 
1104  auto mixedTiles = destPack.getMixedTiles();
1105  auto innerDimsPos = destPack.getInnerDimsPos();
1106  auto outerDimsPerm = destPack.getOuterDimsPerm();
1107 
1108  // Insert an unPackOp right after the packed generic.
1109  Value unPackOpRes =
1110  rewriter
1111  .create<tensor::UnPackOp>(genericOp.getLoc(), newResult,
1112  destPack.getSource(), innerDimsPos,
1113  mixedTiles, outerDimsPerm)
1114  .getResult();
1115 
1116  return std::make_tuple(newGenericOp, unPackOpRes);
1117 }
1118 
1119 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1120 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1121 public:
1122  PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1124  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1125 
1126  LogicalResult matchAndRewrite(GenericOp genericOp,
1127  PatternRewriter &rewriter) const override {
1128  auto genericAndRepl =
1129  pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1130  if (failed(genericAndRepl))
1131  return failure();
1132  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1133  return success();
1134  }
1135 
1136 private:
1137  ControlPropagationFn controlFn;
1138 };
1139 
1140 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
1141 /// add as many zero padding dimensions in `high` and `low` based on the number
1142 /// of point loops.
1143 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1144  PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1145  : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1146 
1147  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1148  PatternRewriter &rewriter) const override {
1149  tensor::UnPackOp unpackOp =
1150  padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1151  if (!unpackOp)
1152  return failure();
1153 
1154  if (!controlFn(&padOp.getSourceMutable()))
1155  return failure();
1156 
1157  Location loc = padOp.getLoc();
1158  // Bail out if one of the padded dimension is a tiled one.
1159  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1160  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1161  llvm::SmallBitVector innerDims(paddedDims.size());
1162  for (int64_t dim : innerDimsPos)
1163  innerDims.flip(dim);
1164  if (paddedDims.anyCommon(innerDims))
1165  return failure();
1166 
1167  Value paddingVal = padOp.getConstantPaddingValue();
1168  if (!paddingVal)
1169  return failure();
1170 
1171  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1172  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1173  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1174  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1175  if (!outerDimsPerm.empty()) {
1176  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1177  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1178  }
1179  // Add zero padding for the point loops.
1180  size_t pointLoopsSize = innerDimsPos.size();
1181  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1182  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1183 
1184  auto newPadOp = rewriter.create<tensor::PadOp>(
1185  loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1186  paddingVal, padOp.getNofold());
1187 
1188  // Inject the tensor.unpack right after the packed padOp.
1189  Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1190  loc, padOp.getResultType().getShape(),
1191  padOp.getResultType().getElementType());
1192 
1193  Value replacement = rewriter.create<tensor::UnPackOp>(
1194  loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1195  unpackOp.getMixedTiles(), outerDimsPerm);
1196  rewriter.replaceOp(padOp, replacement);
1197  return success();
1198  }
1199 
1200 private:
1201  ControlPropagationFn controlFn;
1202 };
1203 
1204 } // namespace
1205 
1208  const ControlPropagationFn &controlPackUnPackPropagation) {
1209  patterns
1210  .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1211  BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1212  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1213  patterns.getContext(), controlPackUnPackPropagation);
1214 }
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:56
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:153
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:615
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_iterator user_begin() const
Definition: Value.h:226
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:364
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of tensor.pack/unpack ops.
Definition: Transforms.h:1781
Include the generated interface declarations.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358