MLIR  20.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Dominance.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include <optional>
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
35 
36 namespace {
37 
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39  for (Operation &op : genericOp.getBody()->getOperations())
40  if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
41  return true;
42  return false;
43 }
44 
45 // The struct contains the infomation about mapping packing information to
46 // the iteration domain of Linalg ops.
47 struct PackInfo {
48  int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
49  // InnerDimsPos on iteration domain, which follows the order in pack ops.
50  SmallVector<int64_t> tiledDimsPos;
51  // The sizes of tiling data dimensions on iteration domain.
52  llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
53  // The mapping from a dimension of iteration domain to the corresponding inner
54  // tiling dimension on iteration domain.
55  llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
56  // The permutation of outer dims (on domain).
57  SmallVector<int64_t> outerDimsOnDomainPerm;
58 };
59 
60 template <typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
63  OpTy packOrUnPackOp) {
64  static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
65  "applies to only pack or unpack operations");
66  LLVM_DEBUG(
67  { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
68 
69  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
72  genericOp.getIteratorTypesArray();
73 
74  PackInfo packInfo;
75  int64_t origNumDims = indexingMap.getNumDims();
76  SmallVector<AffineExpr> exprs(indexingMap.getResults());
77  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
78  for (auto [index, innerDimPos, tileSize] :
79  llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
80  innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81  auto expr = exprs[innerDimPos];
82  if (!isa<AffineDimExpr>(expr))
83  return failure();
84  int64_t domainDimPos =
85  cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86  if (!isParallelIterator(iterators[domainDimPos]))
87  return failure();
88  packInfo.tiledDimsPos.push_back(domainDimPos);
89  packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90  packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
91  LLVM_DEBUG({
92  llvm::dbgs() << "map innerDimPos=" << innerDimPos
93  << " to iteration dimension (d" << domainDimPos << ", d"
94  << packInfo.tileToPointMapping[domainDimPos]
95  << "), which has size=("
96  << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
97  });
98  }
99 
100  // Bail out if a tiled dimension is present in a map but not as an affine dim
101  // expression.
102  auto areAllAffineDimExpr = [&](int dim) {
103  for (AffineMap map : indexingMaps) {
104  if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
105  return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
106  })) {
107  return false;
108  }
109  }
110  return true;
111  };
112  for (int64_t i : packInfo.tiledDimsPos)
113  if (!areAllAffineDimExpr(i))
114  return failure();
115 
116  // Get the outer dims perm on the iteration domain. Start by identifying the
117  // set of domain dims affected by the outer permutation along with the
118  // permuted ordering for those dims. Then the full outer dims permutation can
119  // be constructed by replacing the affected dims with the permuted result in a
120  // numLoops-rank identity. e.g.
121  // outerDimsPerm = [1, 2, 0]
122  // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
123  //
124  // permutedOuterDims = [4, 3, 1]
125  // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
126  //
127  // Non-affine dim expressions must not be permuted by the outer dims
128  // permutation.
129  SmallVector<int64_t> permutedOuterDims;
130  for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131  auto permutedExpr = indexingMap.getResult(dim);
132  if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133  permutedOuterDims.push_back(dimExpr.getPosition());
134  continue;
135  }
136 
137  // TODO: Allow propagation with transposes on non affine dim expressions,
138  // e.g. d0 + d1 which implies transposing both dims simultaneously while
139  // maintaining the relative position between them.
140  if (static_cast<int64_t>(index) != dim)
141  return failure();
142  }
143  if (!permutedOuterDims.empty()) {
144  int64_t outerDimIndex = 0;
145  llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
146  permutedOuterDims.end());
147  for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148  packInfo.outerDimsOnDomainPerm.push_back(
149  permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150  : i);
151  LLVM_DEBUG({
152  llvm::dbgs() << "map outer dimsDimsPerm to ";
153  for (auto dim : packInfo.outerDimsOnDomainPerm)
154  llvm::dbgs() << dim << " ";
155  llvm::dbgs() << "\n";
156  });
157  }
158 
159  return packInfo;
160 }
161 
162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
163  ArrayRef<AffineExpr> exprs) {
164  // Compute `outer_dims_perm`. See example:
165  // current exprs : (d0, d1, d2, d3) -> (d2, d3)
166  // perm : [0, 3, 1, 2]
167  // First map d2, d3 with their position in the array as:
168  // currentPositionTileLoops: dim | pos
169  // d2 | 0
170  // d3 | 1
171  // then scan `perm` in order and get the `outer_dims_perm`
172  // to be used, here it would be [1, 0].
173  assert(!perm.empty() && "expect perm not to be empty");
174  assert(!exprs.empty() && "expect exprs not to be empty");
175  if (exprs.size() == 1)
176  return {};
177  SmallVector<int64_t> outerDimsPerm;
178  DenseMap<int64_t, int64_t> currentPositionTileLoops;
179  for (auto [pos, expr] : llvm::enumerate(exprs)) {
180  // Here we rely on the assumption that the outer dims permutation
181  // when propagating currently requires that non-affine dim expressions
182  // are not permuted, thus allowing the identity assignment below.
183  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184  currentPositionTileLoops[dimExpr.getPosition()] = pos;
185  else
186  currentPositionTileLoops[pos] = pos;
187  }
188  for (int64_t loopIdx : perm) {
189  if (currentPositionTileLoops.count(loopIdx))
190  outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
191  }
192  return outerDimsPerm;
193 }
194 
195 /// Returns a tuple for packed operand and indexing_map with the assumptions:
196 /// 1) The generic op is the producer of the pack op.
197 /// 2) The generic op has only one result.
198 /// If the operand is a scalar or packing dimensions are all irrelevant to the
199 /// operand, the operand and the updated indexing map will be returned.
200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
201 ///
202 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
203 /// #map1 = affine_map<(d0, d1) -> (d0)>
204 /// #map2 = affine_map<(d0, d1) -> (d1)>
205 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
206 /// iterator_types = ["parallel", "parallel"]}
207 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
208 /// outs(%init : tensor<?x?xf32>) {
209 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
210 /// %4 = arith.addf %arg3, %arg4 : f32
211 /// linalg.yield %4 : f32
212 /// } -> tensor<?x?xf32>
213 /// %1 = tensor.pack %0
214 /// inner_dims_pos = [0, 1]
215 /// inner_tiles = [8, 2]
216 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
217 ///
218 /// Taking the first input operand as an example, the inner tile size of d1 is
219 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
220 /// affine_map<(d1, d3)>` will be returned.
221 ///
222 /// %pack = tensor.pack %arg0
223 /// inner_dims_pos = [0]
224 /// inner_tiles = [8]
225 /// into %init : tensor<?xf32> -> tensor<?x8xf32>
226 static std::tuple<Value, AffineMap>
227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
228  GenericOp genericOp, OpOperand *opOperand) {
229  int64_t numOrigLoops = genericOp.getNumLoops();
230  int64_t numInnerLoops = packInfo.getNumTiledLoops();
231  int64_t numLoops = numOrigLoops + numInnerLoops;
232  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
233  llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
234  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
235 
236  // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
237  if (genericOp.isScalar(opOperand) || exprs.empty())
238  return std::make_tuple(opOperand->get(),
239  AffineMap::get(numLoops, 0, exprs, b.getContext()));
240 
241  // Step 1. Construct the information of packing data dimensions; append inner
242  // dimensions to the indexing maps for the operand.
243  for (auto [index, expr] : llvm::enumerate(exprs)) {
244  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245  int64_t dimPos = dimExpr.getPosition();
246  domainDimToOperandDim[dimPos] = index;
247  continue;
248  }
249  }
250  SmallVector<int64_t> innerDimsPos;
251  SmallVector<OpFoldResult> innerTileSizes;
252  for (auto dimPos : packInfo.tiledDimsPos) {
253  if (!domainDimToOperandDim.count(dimPos))
254  continue;
255  int64_t index = domainDimToOperandDim[dimPos];
256  innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257  innerDimsPos.push_back(index);
258  exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
259  }
260 
261  // Step 2. Handle outer dim permutations.
262  SmallVector<int64_t> outerDimsPerm;
263  if (!packInfo.outerDimsOnDomainPerm.empty()) {
264  outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
265 
266  // Step 2.1: Fold transpose into the linalg.generic.
267  SmallVector<int64_t> inversedOuterPerm =
268  invertPermutationVector(packInfo.outerDimsOnDomainPerm);
269  for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
270  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271  int64_t dimPos = dimExpr.getPosition();
272  exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
273  continue;
274  }
275  assert(isa<AffineConstantExpr>(exprs[i]) &&
276  "Attempted to permute non-constant and non-affine dim expression");
277  }
278  // Step 2.2: Undo the transposition on `exprs` and propagate the
279  // transposition on the pack using outerDimsPerm.
280  if (!outerDimsPerm.empty()) {
281  SmallVector<AffineExpr> auxVec = exprs;
282  for (const auto &en : enumerate(outerDimsPerm))
283  auxVec[en.index()] = exprs[en.value()];
284  exprs = auxVec;
285  }
286  }
287  auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
288 
289  // The operand does not have dimensions that relates to pack op.
290  if (innerDimsPos.empty() && outerDimsPerm.empty())
291  return std::make_tuple(opOperand->get(), indexingMap);
292 
293  auto empty = tensor::PackOp::createDestinationTensor(
294  b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
295  auto packedOperand = b.create<tensor::PackOp>(
296  loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
297  /*padding=*/std::nullopt, outerDimsPerm);
298  return std::make_tuple(packedOperand, indexingMap);
299 }
300 
301 /// Pack a genericOp and return it.
302 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
303  Value dest, AffineMap packedOutIndexingMap,
304  const PackInfo &packInfo) {
305  Location loc = genericOp.getLoc();
306  SmallVector<Value> inputOperands;
307  SmallVector<AffineMap> indexingMaps;
308  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309  auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310  rewriter, loc, packInfo, genericOp, inputOperand);
311  inputOperands.push_back(packedOperand);
312  indexingMaps.push_back(packedIndexingMap);
313  }
314 
315  int64_t numInnerLoops = packInfo.getNumTiledLoops();
317  genericOp.getIteratorTypesArray();
318  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
319 
320  indexingMaps.push_back(packedOutIndexingMap);
321 
322  auto newGenericOp = rewriter.create<linalg::GenericOp>(
323  loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
324  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
325  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
326  newGenericOp.getRegion().begin());
327  return newGenericOp;
328 }
329 
330 /// Bubbles up tensor.pack op through a producer generic op. This
331 /// swap pack(generic) to generic(pack). The new generic op works on packed
332 /// domain; pack ops are created for input and output operands. E.g.,
333 ///
334 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
335 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
336 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
337 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
338 /// %3 = linalg.generic {indexing_maps = [#map0, #map0],
339 /// iterator_types = ["parallel", "parallel"]}
340 /// ins(%arg0 : tensor<?x?xf32>)
341 /// outs(%2 : tensor<?x?xf32>) {
342 /// ^bb0(%arg3: f32, %arg4: f32):
343 /// %4 = arith.addf %arg3, %arg3 : f32
344 /// linalg.yield %4 : f32
345 /// } -> tensor<?x?xf32>
346 /// %4 = tensor.pack %3
347 /// inner_dims_pos = [0, 1]
348 /// inner_tiles = [8, 2]
349 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
350 ///
351 /// will be converted to
352 ///
353 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
354 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
355 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
356 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
357 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
358 /// %0 = affine.apply #map()[%dim]
359 /// %1 = affine.apply #map1()[%dim_0]
360 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
361 /// %pack = tensor.pack %arg0
362 /// inner_dims_pos = [0, 1]
363 /// inner_tiles = [8, 2]
364 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
365 /// %3 = linalg.generic {indexing_maps = [#map2, #map2],
366 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
367 /// ins(%pack : tensor<?x?x8x2xf32>)
368 /// outs(%arg1 : tensor<?x?x8x2xf32>) {
369 /// ^bb0(%in: f32, %out: f32):
370 /// %4 = arith.addf %in, %in : f32
371 /// linalg.yield %4 : f32
372 /// } -> tensor<?x?x8x2xf32>
373 static FailureOr<GenericOp>
374 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
375  const ControlPropagationFn &controlFn) {
376  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
377  if (!genericOp)
378  return failure();
379 
380  // User controlled propagation function.
381  if (!controlFn(&packOp.getSourceMutable()))
382  return failure();
383 
384  // TODO: Enable propagation in the presence of linalg.index and
385  // tensor.extract, likely as a separate pattern as the pack information and
386  // propagation decision needs to be inferred from the region of the generic.
387  if (hasGatherSemantics(genericOp))
388  return failure();
389 
390  // TODO: Relax the restriction. We are able to bubble up the pack op through
391  // multi-result generic op. It just needs more work.
392  if (genericOp.getNumResults() != 1)
393  return failure();
394 
395  // Bail-out if the result of the generic has multiple uses, as bubbling up
396  // creates recomputation if the generic has multiple users.
397  // TODO: Enable the case where every use is an identical pack op as no
398  // recomputation is needed in that case.
399  if (!genericOp->getResult(0).hasOneUse())
400  return failure();
401 
402  // We want to move the pack not the generic.
403  OpBuilder::InsertionGuard guard(rewriter);
404  rewriter.setInsertionPoint(genericOp);
405 
406  // We need to handle two cases:
407  // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
408  // create a new tensor.empty to avoid breaking dominance, as we are moving the
409  // tensor.pack above the linalg.generic.
410  // 2) The destination is not a tensor.empty. In this case we can replace only
411  // if the destination of the tensor.pack dominates the linalg.generic.
412  Value packOpDest = packOp.getDest();
413  if (!packOpDest.hasOneUse())
414  return failure();
415  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
416  packOpDest = rewriter.create<tensor::EmptyOp>(
417  genericOp->getLoc(), emptyOp.getMixedSizes(),
418  emptyOp.getType().getElementType());
419  } else {
420  DominanceInfo dom(genericOp);
421  if (!dom.properlyDominates(packOpDest, genericOp))
422  return failure();
423  }
424 
425  // TODO: Add an option for allowing padding values. It could introduce
426  // undefined behavior if we unconditionally propagate pack op through all
427  // the ops. E.g., if the padding value is zero and there are division ops in
428  // a generic op. Some values of padding area could be NaN (0/0).
429  if (packOp.getPaddingValue())
430  return failure();
431 
432  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
433  auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
434  if (failed(packInfo))
435  return failure();
436 
437  // Rebuild the indexing map for the corresponding init operand.
438  auto [packedOutOperand, packedOutIndexingMap] =
439  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
440  genericOp, opOperand);
441 
442  // If the dps init operand of the generic is a tensor.empty forward the pack
443  // op destination.
444  Value dest = packedOutOperand;
445  if (auto initTensor = genericOp.getDpsInitOperand(0)
446  ->get()
447  .getDefiningOp<tensor::EmptyOp>()) {
448  dest = packOpDest;
449  }
450  return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
451  *packInfo);
452 }
453 
454 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
455 struct BubbleUpPackOpThroughGenericOpPattern
456  : public OpRewritePattern<tensor::PackOp> {
457 public:
458  BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
460  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
461 
462  LogicalResult matchAndRewrite(tensor::PackOp packOp,
463  PatternRewriter &rewriter) const override {
464  auto genericOp =
465  bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466  if (failed(genericOp))
467  return failure();
468  rewriter.replaceOp(packOp, genericOp->getResults());
469  return success();
470  }
471 
472 private:
473  ControlPropagationFn controlFn;
474 };
475 
476 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
477 /// add as many zero padding dimensions in `high` and `low` based on the number
478 /// of point loops.
479 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
480 public:
481  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
482  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
483 
484  LogicalResult matchAndRewrite(tensor::PackOp packOp,
485  PatternRewriter &rewriter) const override {
486  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
487  if (!padOp)
488  return failure();
489 
490  // User controlled propagation function.
491  if (!controlFn(&packOp.getSourceMutable()))
492  return failure();
493 
494  // TODO: Enable padding when the padding values are the same.
495  if (packOp.getPaddingValue())
496  return failure();
497 
498  // Fail for non-constant padding values. The body of the pad could
499  // depend on the padding indices and/or properties of the padded
500  // tensor so for now we fail.
501  // TODO: Support non-constant padding values.
502  Value paddingVal = padOp.getConstantPaddingValue();
503  if (!paddingVal)
504  return failure();
505 
506  if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
507  return failure();
508 
509  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
510 
511  // Bail out if one of the padded dimension is a tiled one.
512  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
513  llvm::SmallBitVector innerDims(paddedDims.size());
514  for (int64_t dim : innerDimsPos)
515  innerDims.flip(dim);
516  if (paddedDims.anyCommon(innerDims))
517  return failure();
518 
519  Location loc = padOp->getLoc();
520  OpBuilder::InsertionGuard guard(rewriter);
521  rewriter.setInsertionPoint(padOp);
522 
523  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
524  SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
525  auto empty = tensor::PackOp::createDestinationTensor(
526  rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
527  outerDimsPerm);
528  auto sourcePack = rewriter.create<tensor::PackOp>(
529  loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
530  /*padding=*/std::nullopt, outerDimsPerm);
531 
532  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
533  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
534  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
535  if (!outerDimsPerm.empty()) {
536  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
538  }
539  // The tiled dimensions were verified to be unpadded above, so here we
540  // just append 0 for the inner tile dimensions.
541  size_t pointLoopsSize = innerDimsPos.size();
542  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
543  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
544 
545  auto newPadOp = rewriter.create<tensor::PadOp>(
546  loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
547  padOp.getNofold());
548 
549  // If the pad has more than one user, create an unpack on the new pad to
550  // replace the other uses.
551  if (!padOp->hasOneUse()) {
552  auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
553  rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
554  Value unpackedPad = rewriter.create<tensor::UnPackOp>(
555  loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
556  rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
557  }
558 
559  // Replace the pack with the new pad.
560  rewriter.replaceOp(packOp, newPadOp.getResult());
561 
562  return success();
563  }
564 
565 private:
566  ControlPropagationFn controlFn;
567 };
568 
569 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
570 ///
571 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
572 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
573 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
574 /// non-unit projected dims in pos [2, 3] is 2.
575 ///
576 /// If all candidates in a reassociation are unit dims, it chooses the
577 /// inner-most dim pos.
579 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
580  ArrayRef<ReassociationIndices> reassocIndices,
581  ArrayRef<int64_t> targetShape) {
582  SmallVector<int64_t> projectedDimsPos;
583  for (auto pos : dimsPos) {
584  // In the case all dims are unit, this will return the inner-most one.
585  int64_t projectedPos = reassocIndices[pos].back();
586  for (auto i : llvm::reverse(reassocIndices[pos])) {
587  int64_t dim = targetShape[i];
588  if (dim > 1 || ShapedType::isDynamic(dim)) {
589  projectedPos = i;
590  break;
591  }
592  }
593  projectedDimsPos.push_back(projectedPos);
594  }
595  return projectedDimsPos;
596 }
597 
598 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
599 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
600  ArrayRef<int64_t> shape,
601  ArrayRef<int64_t> tileSizes) {
602  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
603  int64_t dim = shape[pos];
604  if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
605  return false;
606  }
607  return true;
608 }
609 
610 /// Permutate the reassociation indices and reindex them in the sequence order.
611 /// Returns the next dim pos in the sequence.
612 ///
613 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
614 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
615 /// [[0], [1, 2]].
616 static int64_t applyPermutationAndReindexReassoc(
617  SmallVector<ReassociationIndices> &reassocIndices,
618  ArrayRef<int64_t> permutation) {
619  if (!permutation.empty())
620  applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
621  int64_t nextPos = 0;
622  for (ReassociationIndices &indices : reassocIndices) {
623  for (auto &index : indices) {
624  index = nextPos;
625  nextPos += 1;
626  }
627  }
628  return nextPos;
629 }
630 
631 /// Bubble up pack op through collapse shape op when the packed dims can be
632 /// projected to the dims before collapsing. This is possible when the inner
633 /// tile sizes can divide the projected dims.
634 ///
635 /// For example:
636 ///
637 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
638 /// : tensor<?x16x4xf32> into tensor<?x4xf32>
639 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
640 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
641 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
642 ///
643 /// can be transformed into:
644 ///
645 /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
646 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
647 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
648 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
649 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
650 static LogicalResult
651 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
652  tensor::PackOp packOp,
653  PatternRewriter &rewriter) {
654  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
655  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
656  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
657 
658  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
659  SmallVector<ReassociationIndices> reassocIndices =
660  collapseOp.getReassociationIndices();
661  // Project inner tile pos to the dim pos before collapsing. For example, if
662  // dims [x, y] is collapsed into [z], packing on dim z can be projected back
663  // to pack on dim y.
664  //
665  // Project to inner-most non-unit dims to increase the chance that they can be
666  // divided by the inner tile sizes. This is correct because for [..., x, 1],
667  // packing on dim 1 is equivalent to packing on dim x.
668  SmallVector<int64_t> projectedInnerDimsPos =
669  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
670 
671  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
672  innerTileSizes)) {
673  return failure();
674  }
675  // Expand the outer dims permutation with the associated source dims for the
676  // new permutation after bubbling. This is because moving a collapsed dim is
677  // equivalent to moving the associated source dims together.
678  SmallVector<int64_t> newOuterDimsPerm;
679  for (auto outerPos : outerDimsPerm) {
680  newOuterDimsPerm.insert(newOuterDimsPerm.end(),
681  reassocIndices[outerPos].begin(),
682  reassocIndices[outerPos].end());
683  }
684 
685  auto emptyOp = tensor::PackOp::createDestinationTensor(
686  rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
687  projectedInnerDimsPos, newOuterDimsPerm);
688  auto newPackOp = rewriter.create<tensor::PackOp>(
689  packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
690  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
691 
692  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
693  // First apply the permutation on the reassociations of the outer dims.
694  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
695  // -> [[0], [1, 2]]
696  int64_t nextPos =
697  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
698  // Then add direct mapping for the inner tile dims.
699  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
700  newReassocIndices.push_back({nextPos});
701  nextPos += 1;
702  }
703 
704  auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
705  collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
706  rewriter.replaceOp(packOp, newCollapseOp);
707 
708  return success();
709 }
710 
711 /// Project dimsPos to their collapsed positions in the reassocIndices.
712 ///
713 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
714 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
715 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
716 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
718 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
719  ArrayRef<ReassociationIndices> reassocIndices) {
720  SmallVector<int64_t> projectedPos;
721 
722  // Map each dimension to the position of corresponding reassociation index.
723  for (auto pos : dimsPos) {
724  for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
725  // If the dimension is present in the current indices group, the group
726  // position within the reassociation map is the desired projected
727  // dimension position.
728  if (llvm::any_of(indices,
729  [&](int64_t expandDim) { return expandDim == pos; })) {
730  projectedPos.push_back(idx);
731  break;
732  }
733  }
734  }
735  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
736 
737  return projectedPos;
738 }
739 
740 /// Bubble up pack op through expand shape op.
741 ///
742 /// For example:
743 ///
744 /// %expand = tensor.expand_shape %in [[0], [1, 2]]
745 /// : tensor<?x64xf32> into tensor<?x4x16xf32>
746 /// %pack = tensor.pack %expand outer_dims_perm = [0, 1]
747 /// inner_dims_pos = [2] inner_tiles = [8] into %empty
748 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
749 ///
750 /// can be transformed into:
751 ///
752 /// %pack = tensor.pack %in outer_dims_perm = [1, 2]
753 /// inner_dims_pos = [1] inner_tiles = [8] into %empty
754 /// : tensor<?x64xf32> -> tensor<?x8x8xf32>
755 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
756 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
757 static LogicalResult
758 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
759  tensor::PackOp packOp,
760  PatternRewriter &rewriter) {
761  // Outer dimensions permutation is not supported currently.
762  // TODO: Handle outer_dims_perm variants.
763  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
764  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
765  return rewriter.notifyMatchFailure(packOp,
766  "non-identity outer dims perm NYI");
767  }
768 
769  // Validate dimensions' relations between shape expansion and packing.
771  expandOp.getReassociationIndices();
772  ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
773  llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
774  packInnerDims.end());
775 
776  for (auto [idx, indices] : llvm::enumerate(reassoc)) {
777  // For each expand_shape reassociation, figure out which dimensions get
778  // packed if any.
779  llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
780  llvm::SetVector<int64_t> packedDims =
781  llvm::set_intersection(packDimsPos, expandDimPos);
782 
783  // The expanded dimension is not packed so, it does not affect moving pack
784  // before shape expansion - simply continue.
785  if (packedDims.empty())
786  continue;
787  // Shape expansion cannot be propagated when multiple expanded dimension are
788  // packed - in this case operation reordering would affect final element
789  // positions and/or shapes can no longer be projected.
790  if (packedDims.size() != 1)
791  return rewriter.notifyMatchFailure(
792  packOp, "only one of the expanded dimensions can be packed");
793  // Only the inner-most expanded dimension should be packed. Otherwise,
794  // elements order will be affected after operation reordering.
795  if (packedDims.front() != indices.back())
796  return rewriter.notifyMatchFailure(
797  packOp, "can only pack the inner-most expanded dimension");
798  }
799 
800  // Project pack.inner_dims_pos to positions before shape expansion.
801  SmallVector<int64_t> projectedInnerDimsPos =
802  projectDimsPosIntoReassocPos(packInnerDims, reassoc);
803 
804  // Project the shape expansion to new packed shape.
805  // The pack.outer_dims_perm is restricted to identity so, the permutation can
806  // be omitted for simplicity.
807  // TODO: Account for outer dimensions permutation.
808  //
809  // If reassociation is not possible, then reordering cannot happen.
810  // This can be caused by pack padding affecting previously expanded
811  // dimensions or packing extending dimensions.
812  RankedTensorType newPackType = tensor::PackOp::inferPackedType(
813  expandOp.getSrcType(), packOp.getStaticInnerTiles(),
814  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
815  auto reassocExpand =
816  getReassociationIndicesForReshape(newPackType, packOp.getDestType());
817  if (!reassocExpand)
818  return rewriter.notifyMatchFailure(
819  packOp, "could not reassociate dims after bubbling up");
820 
821  Value destTensor = tensor::PackOp::createDestinationTensor(
822  rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
823  projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
824  Value packedVal = rewriter.create<tensor::PackOp>(
825  packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
826  packOp.getMixedTiles(), packOp.getPaddingValue(),
827  /*outerDimsPerm=*/SmallVector<int64_t>{});
828 
829  Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
830  packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
831  rewriter.replaceOp(packOp, newExpandOp);
832 
833  return success();
834 }
835 
836 class BubbleUpPackOpThroughReshapeOp final
837  : public OpRewritePattern<tensor::PackOp> {
838 public:
839  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
840  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
841 
842  LogicalResult matchAndRewrite(tensor::PackOp packOp,
843  PatternRewriter &rewriter) const override {
844  Operation *srcOp = packOp.getSource().getDefiningOp();
845  // Currently only support when the pack op is the only user.
846  if (!srcOp || !(srcOp->getNumResults() == 1) ||
847  !srcOp->getResult(0).hasOneUse()) {
848  return failure();
849  }
850  // Currently only support static inner tile sizes.
851  if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
852  return ShapedType::isDynamic(size);
853  })) {
854  return failure();
855  }
856 
857  // User controlled propagation function.
858  if (!controlFn(&packOp.getSourceMutable()))
859  return failure();
860 
862  .Case([&](tensor::CollapseShapeOp op) {
863  return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
864  })
865  .Case([&](tensor::ExpandShapeOp op) {
866  return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
867  })
868  .Default([](Operation *) { return failure(); });
869  }
870 
871 private:
872  ControlPropagationFn controlFn;
873 };
874 
875 /// Push down unpack op through expand shape op when the packed dims can be
876 /// projected to the dims after expanding. This is possible when the inner tile
877 /// sizes can divide the projected dims.
878 ///
879 /// For example:
880 ///
881 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
882 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
883 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
884 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
885 /// : tensor<?x256xf32> into tensor<?x256x256xf32>
886 ///
887 /// can be transformed into:
888 ///
889 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
890 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
891 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
892 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
893 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
894 static LogicalResult pushDownUnPackOpThroughExpandShape(
895  tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
896  PatternRewriter &rewriter, ControlPropagationFn controlFn) {
897  // User controlled propagation function.
898  if (!controlFn(&expandOp.getSrcMutable()))
899  return failure();
900 
901  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
902  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
903  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
904 
905  auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
906  if (!expandTy)
907  return failure();
908  ArrayRef<int64_t> dstShape = expandTy.getShape();
909  SmallVector<ReassociationIndices> reassocIndices =
910  expandOp.getReassociationIndices();
911  // Project inner tile pos to the dim pos after expanding. For example, if dims
912  // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
913  // on dim y.
914  //
915  // Project to inner-most non-unit dims to increase the chance that they can be
916  // divided by the inner tile sizes. This is correct because for [..., x, 1],
917  // unpacking on dim 1 is equivalent to unpacking on dim x.
918  SmallVector<int64_t> projectedInnerDimsPos =
919  projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
920 
921  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
922  innerTileSizes)) {
923  return failure();
924  }
925  // Expand the outer dims permutation with the associated expanded dims for the
926  // new permutation after pushing. This is because moving a source dim is
927  // equivalent to moving the associated expanded dims together.
928  SmallVector<int64_t> newOuterDimsPerm;
929  for (auto outerPos : outerDimsPerm) {
930  newOuterDimsPerm.insert(newOuterDimsPerm.end(),
931  reassocIndices[outerPos].begin(),
932  reassocIndices[outerPos].end());
933  }
934 
935  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
936  // First apply the permutation on the reassociations of the outer dims.
937  // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
938  // -> [[0], [1, 2]]
939  int64_t nextPos =
940  applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
941  // Then add direct mapping for the inner tile dims.
942  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
943  newReassocIndices.push_back({nextPos});
944  nextPos += 1;
945  }
946 
947  RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
948  expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
949  auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
950  expandOp.getLoc(), newExpandType, unPackOp.getSource(),
951  newReassocIndices);
952 
953  auto emptyOp = tensor::UnPackOp::createDestinationTensor(
954  rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
955  projectedInnerDimsPos, newOuterDimsPerm);
956  auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
957  unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
958  projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
959  rewriter.replaceOp(expandOp, newUnPackOp);
960 
961  return success();
962 }
963 
964 class PushDownUnPackOpThroughReshapeOp final
965  : public OpRewritePattern<tensor::UnPackOp> {
966 public:
967  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
969  : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
970  }
971 
972  LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
973  PatternRewriter &rewriter) const override {
974  Value result = unPackOp.getResult();
975  // Currently only support unpack op with the single user.
976  if (!result.hasOneUse()) {
977  return failure();
978  }
979  // Currently only support static inner tile sizes.
980  if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
981  return ShapedType::isDynamic(size);
982  })) {
983  return failure();
984  }
985 
986  Operation *consumerOp = *result.user_begin();
987  return TypeSwitch<Operation *, LogicalResult>(consumerOp)
988  .Case([&](tensor::ExpandShapeOp op) {
989  return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
990  controlFn);
991  })
992  .Default([](Operation *) { return failure(); });
993  }
994 
995 private:
996  ControlPropagationFn controlFn;
997 };
998 
999 // TODO: Relax this restriction. We should unpack a generic op also
1000 // in the presence of multiple unpack ops as producers.
1001 /// Return the unpacked operand, if present, for the current generic op.
1002 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1003  OpOperand *unPackedOperand = nullptr;
1004  for (OpOperand &operand : genericOp->getOpOperands()) {
1005  auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
1006  if (!unPackOp)
1007  continue;
1008  if (unPackedOperand)
1009  return failure();
1010  unPackedOperand = &operand;
1011  }
1012  if (!unPackedOperand)
1013  return failure();
1014  return unPackedOperand;
1015 }
1016 
1017 /// Push down a tensor.unpack op through a generic op.
1018 /// The new generic op works on packed domain; pack ops are created for input
1019 /// and output operands. A tensor.unpack op is inserted right after the packed
1020 /// generic. E.g.
1021 ///
1022 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1023 ///
1024 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1025 ///
1026 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1027 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1028 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1029 /// %2 = linalg.generic {indexing_maps = [#map],
1030 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1031 /// outs(%1 : tensor<12x56x56x64xf32>) {
1032 /// ^bb0(%out : f32):
1033 /// linalg.yield %out : f32
1034 /// } -> tensor<12x56x56x64xf32>
1035 ///
1036 /// will be converted to
1037 ///
1038 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1039 ///
1040 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1041 /// %1 = linalg.generic {indexing_maps = [#map],
1042 /// iterator_types = ["parallel", "parallel", "parallel",
1043 /// "parallel", "parallel"]}
1044 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1045 /// ^bb0(%out : f32):
1046 /// linalg.yield %out : f32
1047 /// } -> tensor<12x2x56x56x32xf32>
1048 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1049 /// inner_dims_pos = [3] inner_tiles = [32] into %0
1050 ///
1051 static FailureOr<std::tuple<GenericOp, Value>>
1052 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1053  ControlPropagationFn controlFn) {
1054  if (genericOp.getNumResults() != 1)
1055  return failure();
1056 
1057  if (hasGatherSemantics(genericOp))
1058  return failure();
1059 
1060  // Collect the unPacked operand, if present.
1061  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1062  if (failed(maybeUnPackedOperand))
1063  return failure();
1064  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1065 
1066  // Extract packing information.
1067  tensor::UnPackOp producerUnPackOp =
1068  unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
1069  assert(producerUnPackOp && "expect a valid UnPackOp");
1070 
1071  if (!controlFn(unPackedOperand))
1072  return failure();
1073 
1074  auto packInfo =
1075  getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1076  if (failed(packInfo))
1077  return failure();
1078 
1079  // Rebuild the indexing map for the corresponding init operand.
1080  auto [packedOutOperand, packedOutIndexingMap] =
1081  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1082  genericOp, genericOp.getDpsInitOperand(0));
1083  auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
1084 
1085  // If the dps init operand of the generic is a tensor.empty, do not pack it
1086  // and forward the new tensor.empty as a destination.
1087  Value dest = packedOutOperand;
1088  if (auto initTensor = genericOp.getDpsInitOperand(0)
1089  ->get()
1090  .getDefiningOp<tensor::EmptyOp>()) {
1091  if (destPack)
1092  dest = destPack.getDest();
1093  }
1094 
1095  // Pack the genericOp.
1096  GenericOp newGenericOp =
1097  packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1098  Value newResult =
1099  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1100 
1101  // If the output is unaffected, no need to unpack.
1102  if (!destPack)
1103  return std::make_tuple(newGenericOp, newResult);
1104 
1105  auto mixedTiles = destPack.getMixedTiles();
1106  auto innerDimsPos = destPack.getInnerDimsPos();
1107  auto outerDimsPerm = destPack.getOuterDimsPerm();
1108 
1109  // If the output type for the generic differs from the source
1110  // unpack op, we need to create a new destination tensor. In the
1111  // dynamic case we always need a new destination.
1112  auto loc = genericOp.getLoc();
1113  Value unPackDest = producerUnPackOp.getDest();
1114  auto genericOutType =
1115  cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
1116  if (producerUnPackOp.getDestType() != genericOutType ||
1117  !genericOutType.hasStaticShape()) {
1118  unPackDest = tensor::UnPackOp::createDestinationTensor(
1119  rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
1120  }
1121 
1122  // Insert an unPackOp right after the packed generic.
1123  Value unPackOpRes =
1124  rewriter
1125  .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
1126  mixedTiles, outerDimsPerm)
1127  .getResult();
1128 
1129  return std::make_tuple(newGenericOp, unPackOpRes);
1130 }
1131 
1132 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1133 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1134 public:
1135  PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1137  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1138 
1139  LogicalResult matchAndRewrite(GenericOp genericOp,
1140  PatternRewriter &rewriter) const override {
1141  auto genericAndRepl =
1142  pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1143  if (failed(genericAndRepl))
1144  return failure();
1145  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1146  return success();
1147  }
1148 
1149 private:
1150  ControlPropagationFn controlFn;
1151 };
1152 
1153 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
1154 /// add as many zero padding dimensions in `high` and `low` based on the number
1155 /// of point loops.
1156 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1157  PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1158  : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1159 
1160  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1161  PatternRewriter &rewriter) const override {
1162  tensor::UnPackOp unpackOp =
1163  padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1164  if (!unpackOp)
1165  return failure();
1166 
1167  if (!controlFn(&padOp.getSourceMutable()))
1168  return failure();
1169 
1170  Location loc = padOp.getLoc();
1171  // Bail out if one of the padded dimension is a tiled one.
1172  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1173  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1174  llvm::SmallBitVector innerDims(paddedDims.size());
1175  for (int64_t dim : innerDimsPos)
1176  innerDims.flip(dim);
1177  if (paddedDims.anyCommon(innerDims))
1178  return failure();
1179 
1180  Value paddingVal = padOp.getConstantPaddingValue();
1181  if (!paddingVal)
1182  return failure();
1183 
1184  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1185  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1186  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1187  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1188  if (!outerDimsPerm.empty()) {
1189  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1190  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1191  }
1192  // Add zero padding for the point loops.
1193  size_t pointLoopsSize = innerDimsPos.size();
1194  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1195  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1196 
1197  auto newPadOp = rewriter.create<tensor::PadOp>(
1198  loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1199  paddingVal, padOp.getNofold());
1200 
1201  // Inject the tensor.unpack right after the packed padOp.
1202  Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1203  loc, padOp.getResultType().getShape(),
1204  padOp.getResultType().getElementType());
1205 
1206  Value replacement = rewriter.create<tensor::UnPackOp>(
1207  loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1208  unpackOp.getMixedTiles(), outerDimsPerm);
1209  rewriter.replaceOp(padOp, replacement);
1210  return success();
1211  }
1212 
1213 private:
1214  ControlPropagationFn controlFn;
1215 };
1216 
1217 } // namespace
1218 
1220  RewritePatternSet &patterns,
1221  const ControlPropagationFn &controlPackUnPackPropagation) {
1222  patterns
1223  .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1224  BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1225  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1226  patterns.getContext(), controlPackUnPackPropagation);
1227 }
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:375
MLIRContext * getContext() const
Definition: Builders.h:55
A class for computing basic dominance information.
Definition: Dominance.h:136
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:149
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:586
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
MLIRContext * getContext() const
Definition: PatternMatch.h:823
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_iterator user_begin() const
Definition: Value.h:226
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:369
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of tensor.pack/unpack ops.
Definition: Transforms.h:1662
Include the generated interface declarations.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358