MLIR  18.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1 //===- DataLayoutPropagation.cpp -----------------------------------------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Dominance.h"
20 #include "llvm/Support/Debug.h"
21 #include <optional>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 #define DEBUG_TYPE "linalg-data-layout-propagation"
32 
33 namespace {
34 
35 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
36  for (Operation &op : genericOp.getBody()->getOperations())
37  if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
38  return true;
39  return false;
40 }
41 
42 // The struct contains the infomation about mapping packing information to
43 // the iteration domain of Linalg ops.
44 struct PackInfo {
45  int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
46  // InnerDimsPos on iteration domain, which follows the order in pack ops.
47  SmallVector<int64_t> tiledDimsPos;
48  // The sizes of tiling data dimensions on iteration domain.
49  llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
50  // The mapping from a dimension of iteration domain to the corresponding inner
51  // tiling dimension on iteration domain.
52  llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
53  // The permutation of outer dims (on domain).
54  SmallVector<int64_t> outerDimsOnDomainPerm;
55 };
56 
57 template <typename OpTy>
59 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
60  OpTy packOrUnPackOp) {
61  static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
62  "applies to only pack or unpack operations");
63  LLVM_DEBUG(
64  { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
65 
66  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
67  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
69  genericOp.getIteratorTypesArray();
70 
71  PackInfo packInfo;
72  int64_t origNumDims = indexingMap.getNumDims();
73  SmallVector<AffineExpr> exprs(indexingMap.getResults());
74  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
75  for (auto [index, innerDimPos, tileSize] :
76  llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
77  innerDimsPos, packOrUnPackOp.getMixedTiles())) {
78  auto expr = exprs[innerDimPos];
79  if (!isa<AffineDimExpr>(expr))
80  return failure();
81  int64_t domainDimPos =
82  cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
83  if (!isParallelIterator(iterators[domainDimPos]))
84  return failure();
85  packInfo.tiledDimsPos.push_back(domainDimPos);
86  packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
87  packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
88  LLVM_DEBUG({
89  llvm::dbgs() << "map innerDimPos=" << innerDimPos
90  << " to iteration dimension (d" << domainDimPos << ", d"
91  << packInfo.tileToPointMapping[domainDimPos]
92  << "), which has size=("
93  << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
94  });
95  }
96 
97  // Bail out if a tiled dimension is present in a map but not as an affine dim
98  // expression.
99  auto areAllAffineDimExpr = [&](int dim) {
100  for (AffineMap map : indexingMaps) {
101  if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
102  return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
103  })) {
104  return false;
105  }
106  }
107  return true;
108  };
109  for (int64_t i : packInfo.tiledDimsPos)
110  if (!areAllAffineDimExpr(i))
111  return failure();
112 
113  // Get the outer dims perm on the iteration domain. Start by identifying the
114  // set of domain dims affected by the outer permutation along with the
115  // permuted ordering for those dims. Then the full outer dims permutation can
116  // be constructed by replacing the affected dims with the permuted result in a
117  // numLoops-rank identity. e.g.
118  // outerDimsPerm = [1, 2, 0]
119  // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
120  //
121  // permutedOuterDims = [4, 3, 1]
122  // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
123  //
124  // Non-affine dim expressions must not be permuted by the outer dims
125  // permutation.
126  SmallVector<int64_t> permutedOuterDims;
127  for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
128  auto permutedExpr = indexingMap.getResult(dim);
129  if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
130  permutedOuterDims.push_back(dimExpr.getPosition());
131  continue;
132  }
133 
134  // TODO: Allow propagation with transposes on non affine dim expressions,
135  // e.g. d0 + d1 which implies transposing both dims simultaneously while
136  // maintaining the relative position between them.
137  if (static_cast<int64_t>(index) != dim)
138  return failure();
139  }
140  if (!permutedOuterDims.empty()) {
141  int64_t outerDimIndex = 0;
142  llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
143  permutedOuterDims.end());
144  for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
145  packInfo.outerDimsOnDomainPerm.push_back(
146  permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
147  : i);
148  LLVM_DEBUG({
149  llvm::dbgs() << "map outer dimsDimsPerm to ";
150  for (auto dim : packInfo.outerDimsOnDomainPerm)
151  llvm::dbgs() << dim << " ";
152  llvm::dbgs() << "\n";
153  });
154  }
155 
156  return packInfo;
157 }
158 
159 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
160  ArrayRef<AffineExpr> exprs) {
161  // Compute `outer_dims_perm`. See example:
162  // current exprs : (d0, d1, d2, d3) -> (d2, d3)
163  // perm : [0, 3, 1, 2]
164  // First map d2, d3 with their position in the array as:
165  // currentPositionTileLoops: dim | pos
166  // d2 | 0
167  // d3 | 1
168  // then scan `perm` in order and get the `outer_dims_perm`
169  // to be used, here it would be [1, 0].
170  assert(!perm.empty() && "expect perm not to be empty");
171  assert(!exprs.empty() && "expect exprs not to be empty");
172  if (exprs.size() == 1)
173  return {};
174  SmallVector<int64_t> outerDimsPerm;
175  DenseMap<int64_t, int64_t> currentPositionTileLoops;
176  for (auto [pos, expr] : llvm::enumerate(exprs)) {
177  // Here we rely on the assumption that the outer dims permutation
178  // when propagating currently requires that non-affine dim expressions
179  // are not permuted, thus allowing the identity assignment below.
180  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
181  currentPositionTileLoops[dimExpr.getPosition()] = pos;
182  else
183  currentPositionTileLoops[pos] = pos;
184  }
185  for (int64_t loopIdx : perm) {
186  if (currentPositionTileLoops.count(loopIdx))
187  outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
188  }
189  return outerDimsPerm;
190 }
191 
192 /// Returns a tuple for packed operand and indexing_map with the assumptions:
193 /// 1) The generic op is the producer of the pack op.
194 /// 2) The generic op has only one result.
195 /// If the operand is a scalar or packing dimensions are all irrelevant to the
196 /// operand, the operand and the updated indexing map will be returned.
197 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
198 ///
199 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
200 /// #map1 = affine_map<(d0, d1) -> (d0)>
201 /// #map2 = affine_map<(d0, d1) -> (d1)>
202 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
203 /// iterator_types = ["parallel", "parallel"]}
204 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
205 /// outs(%init : tensor<?x?xf32>) {
206 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
207 /// %4 = arith.addf %arg3, %arg4 : f32
208 /// linalg.yield %4 : f32
209 /// } -> tensor<?x?xf32>
210 /// %1 = tensor.pack %0
211 /// inner_dims_pos = [0, 1]
212 /// inner_tiles = [8, 2]
213 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
214 ///
215 /// Taking the first input operand as an example, the inner tile size of d1 is
216 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
217 /// affine_map<(d1, d3)>` will be returned.
218 ///
219 /// %pack = tensor.pack %arg0
220 /// inner_dims_pos = [0]
221 /// inner_tiles = [8]
222 /// into %init : tensor<?xf32> -> tensor<?x8xf32>
223 static std::tuple<Value, AffineMap>
224 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
225  GenericOp genericOp, OpOperand *opOperand) {
226  int64_t numOrigLoops = genericOp.getNumLoops();
227  int64_t numInnerLoops = packInfo.getNumTiledLoops();
228  int64_t numLoops = numOrigLoops + numInnerLoops;
229  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
230  llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
231  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
232 
233  // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
234  if (genericOp.isScalar(opOperand) || exprs.empty())
235  return std::make_tuple(opOperand->get(),
236  AffineMap::get(numLoops, 0, exprs, b.getContext()));
237 
238  // Step 1. Construct the information of packing data dimensions; append inner
239  // dimensions to the indexing maps for the operand.
240  for (auto [index, expr] : llvm::enumerate(exprs)) {
241  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
242  int64_t dimPos = dimExpr.getPosition();
243  domainDimToOperandDim[dimPos] = index;
244  continue;
245  }
246  }
247  SmallVector<int64_t> innerDimsPos;
248  SmallVector<OpFoldResult> innerTileSizes;
249  for (auto dimPos : packInfo.tiledDimsPos) {
250  if (!domainDimToOperandDim.count(dimPos))
251  continue;
252  int64_t index = domainDimToOperandDim[dimPos];
253  innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
254  innerDimsPos.push_back(index);
255  exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
256  }
257 
258  // Step 2. Handle outer dim permutations.
259  SmallVector<int64_t> outerDimsPerm;
260  if (!packInfo.outerDimsOnDomainPerm.empty()) {
261  outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
262 
263  // Step 2.1: Fold transpose into the linalg.generic.
264  SmallVector<int64_t> inversedOuterPerm =
265  invertPermutationVector(packInfo.outerDimsOnDomainPerm);
266  for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
267  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
268  int64_t dimPos = dimExpr.getPosition();
269  exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
270  continue;
271  }
272  assert(isa<AffineConstantExpr>(exprs[i]) &&
273  "Attempted to permute non-constant and non-affine dim expression");
274  }
275  // Step 2.2: Undo the transposition on `exprs` and propagate the
276  // transposition on the pack using outerDimsPerm.
277  if (!outerDimsPerm.empty()) {
278  SmallVector<AffineExpr> auxVec = exprs;
279  for (const auto &en : enumerate(outerDimsPerm))
280  auxVec[en.index()] = exprs[en.value()];
281  exprs = auxVec;
282  }
283  }
284  auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
285 
286  // The operand does not have dimensions that relates to pack op.
287  if (innerDimsPos.empty() && outerDimsPerm.empty())
288  return std::make_tuple(opOperand->get(), indexingMap);
289 
290  auto empty = tensor::PackOp::createDestinationTensor(
291  b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
292  auto packedOperand = b.create<tensor::PackOp>(
293  loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
294  /*padding=*/std::nullopt, outerDimsPerm);
295  return std::make_tuple(packedOperand, indexingMap);
296 }
297 
298 /// Pack a genericOp and return it.
299 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
300  Value dest, AffineMap packedOutIndexingMap,
301  const PackInfo &packInfo) {
302  Location loc = genericOp.getLoc();
303  SmallVector<Value> inputOperands;
304  SmallVector<AffineMap> indexingMaps;
305  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
306  auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
307  rewriter, loc, packInfo, genericOp, inputOperand);
308  inputOperands.push_back(packedOperand);
309  indexingMaps.push_back(packedIndexingMap);
310  }
311 
312  int64_t numInnerLoops = packInfo.getNumTiledLoops();
314  genericOp.getIteratorTypesArray();
315  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
316 
317  indexingMaps.push_back(packedOutIndexingMap);
318 
319  auto newGenericOp = rewriter.create<linalg::GenericOp>(
320  loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
321  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
322  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
323  newGenericOp.getRegion().begin());
324  return newGenericOp;
325 }
326 
327 /// Bubbles up tensor.pack op through a producer generic op. This
328 /// swap pack(generic) to generic(pack). The new generic op works on packed
329 /// domain; pack ops are created for input and output operands. E.g.,
330 ///
331 /// #map0 = affine_map<(d0, d1) -> (d0, d1)>
332 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
333 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
334 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
335 /// %3 = linalg.generic {indexing_maps = [#map0, #map0],
336 /// iterator_types = ["parallel", "parallel"]}
337 /// ins(%arg0 : tensor<?x?xf32>)
338 /// outs(%2 : tensor<?x?xf32>) {
339 /// ^bb0(%arg3: f32, %arg4: f32):
340 /// %4 = arith.addf %arg3, %arg3 : f32
341 /// linalg.yield %4 : f32
342 /// } -> tensor<?x?xf32>
343 /// %4 = tensor.pack %3
344 /// inner_dims_pos = [0, 1]
345 /// inner_tiles = [8, 2]
346 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
347 ///
348 /// will be converted to
349 ///
350 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
351 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
352 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
353 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
354 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
355 /// %0 = affine.apply #map()[%dim]
356 /// %1 = affine.apply #map1()[%dim_0]
357 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
358 /// %pack = tensor.pack %arg0
359 /// inner_dims_pos = [0, 1]
360 /// inner_tiles = [8, 2]
361 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
362 /// %3 = linalg.generic {indexing_maps = [#map2, #map2],
363 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
364 /// ins(%pack : tensor<?x?x8x2xf32>)
365 /// outs(%arg1 : tensor<?x?x8x2xf32>) {
366 /// ^bb0(%in: f32, %out: f32):
367 /// %4 = arith.addf %in, %in : f32
368 /// linalg.yield %4 : f32
369 /// } -> tensor<?x?x8x2xf32>
371 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
372  ControlPropagationFn controlFn) {
373  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
374  if (!genericOp)
375  return failure();
376 
377  // User controlled propagation function.
378  if (!controlFn(genericOp))
379  return failure();
380 
381  // TODO: Enable propagation in the presence of linalg.index and
382  // tensor.extract, likely as a separate pattern as the pack information and
383  // propagation decision needs to be inferred from the region of the generic.
384  if (hasGatherSemantics(genericOp))
385  return failure();
386 
387  // TODO: Relax the restriction. We are able to bubble up the pack op through
388  // multi-result generic op. It just needs more work.
389  if (genericOp.getNumResults() != 1)
390  return failure();
391 
392  // Bail-out if the result of the generic has multiple uses, as bubbling up
393  // creates recomputation if the generic has multiple users.
394  // TODO: Enable the case where every use is an identical pack op as no
395  // recomputation is needed in that case.
396  if (!genericOp->getResult(0).hasOneUse())
397  return failure();
398 
399  // We want to move the pack not the generic.
400  OpBuilder::InsertionGuard guard(rewriter);
401  rewriter.setInsertionPoint(genericOp);
402 
403  // We need to handle two cases:
404  // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
405  // create a new tensor.empty to avoid breaking dominance, as we are moving the
406  // tensor.pack above the linalg.generic.
407  // 2) The destination is not a tensor.empty. In this case we can replace only
408  // if the destination of the tensor.pack dominates the linalg.generic.
409  Value packOpDest = packOp.getDest();
410  if (!packOpDest.hasOneUse())
411  return failure();
412  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
413  packOpDest = rewriter.create<tensor::EmptyOp>(
414  genericOp->getLoc(), emptyOp.getMixedSizes(),
415  emptyOp.getType().getElementType());
416  } else {
417  DominanceInfo dom(genericOp);
418  if (!dom.properlyDominates(packOpDest, genericOp))
419  return failure();
420  }
421 
422  // TODO: Add an option for allowing padding values. It could introduce
423  // undefined behavior if we unconditionally propagate pack op through all
424  // the ops. E.g., if the padding value is zero and there are division ops in
425  // a generic op. Some values of padding area could be NaN (0/0).
426  if (packOp.getPaddingValue())
427  return failure();
428 
429  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
430  auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
431  if (failed(packInfo))
432  return failure();
433 
434  // Rebuild the indexing map for the corresponding init operand.
435  auto [packedOutOperand, packedOutIndexingMap] =
436  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
437  genericOp, opOperand);
438 
439  // If the dps init operand of the generic is a tensor.empty forward the pack
440  // op destination.
441  Value dest = packedOutOperand;
442  if (auto initTensor = genericOp.getDpsInitOperand(0)
443  ->get()
444  .getDefiningOp<tensor::EmptyOp>()) {
445  dest = packOpDest;
446  }
447  return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
448  *packInfo);
449 }
450 
451 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
452 struct BubbleUpPackOpThroughGenericOpPattern
453  : public OpRewritePattern<tensor::PackOp> {
454 public:
455  BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
457  : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
458 
459  LogicalResult matchAndRewrite(tensor::PackOp packOp,
460  PatternRewriter &rewriter) const override {
461  auto genericOp =
462  bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
463  if (failed(genericOp))
464  return failure();
465  rewriter.replaceOp(packOp, genericOp->getResults());
466  return success();
467  }
468 
469 private:
470  ControlPropagationFn controlFn;
471 };
472 
473 // TODO: Relax this restriction. We should unpack a generic op also
474 // in the presence of multiple unpack ops as producers.
475 /// Return the unpacked operand, if present, for the current generic op.
476 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
477  OpOperand *unPackedOperand = nullptr;
478  for (OpOperand &operand : genericOp->getOpOperands()) {
479  auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
480  if (!unPackOp)
481  continue;
482  if (unPackedOperand)
483  return failure();
484  unPackedOperand = &operand;
485  }
486  if (!unPackedOperand)
487  return failure();
488  return unPackedOperand;
489 }
490 
491 /// Push down a tensor.unpack op through a generic op.
492 /// The new generic op works on packed domain; pack ops are created for input
493 /// and output operands. A tensor.unpack op is inserted right after the packed
494 /// generic. E.g.
495 ///
496 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
497 ///
498 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
499 ///
500 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
501 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
502 /// inner_dims_pos = [3] inner_tiles = [32] into %0
503 /// %2 = linalg.generic {indexing_maps = [#map],
504 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
505 /// outs(%1 : tensor<12x56x56x64xf32>) {
506 /// ^bb0(%out : f32):
507 /// linalg.yield %out : f32
508 /// } -> tensor<12x56x56x64xf32>
509 ///
510 /// will be converted to
511 ///
512 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
513 ///
514 /// %0 = tensor.empty() : tensor<12x56x56x64xf32>
515 /// %1 = linalg.generic {indexing_maps = [#map],
516 /// iterator_types = ["parallel", "parallel", "parallel",
517 /// "parallel", "parallel"]}
518 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
519 /// ^bb0(%out : f32):
520 /// linalg.yield %out : f32
521 /// } -> tensor<12x2x56x56x32xf32>
522 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
523 /// inner_dims_pos = [3] inner_tiles = [32] into %0
524 ///
526 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
527  if (genericOp.getNumResults() != 1)
528  return failure();
529 
530  if (hasGatherSemantics(genericOp))
531  return failure();
532 
533  // Collect the unPacked operand, if present.
534  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
535  if (failed(maybeUnPackedOperand))
536  return failure();
537  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
538 
539  // Extract packing information.
540  tensor::UnPackOp producerUnPackOp =
541  unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
542  assert(producerUnPackOp && "expect a valid UnPackOp");
543  auto packInfo =
544  getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
545  if (failed(packInfo))
546  return failure();
547 
548  // Rebuild the indexing map for the corresponding init operand.
549  auto [packedOutOperand, packedOutIndexingMap] =
550  getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
551  genericOp, genericOp.getDpsInitOperand(0));
552  auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
553 
554  // If the dps init operand of the generic is a tensor.empty, do not pack it
555  // and forward the new tensor.empty as a destination.
556  Value dest = packedOutOperand;
557  if (auto initTensor = genericOp.getDpsInitOperand(0)
558  ->get()
559  .getDefiningOp<tensor::EmptyOp>()) {
560  if (destPack)
561  dest = destPack.getDest();
562  }
563 
564  // Pack the genericOp.
565  GenericOp newGenericOp =
566  packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
567  Value newResult =
568  newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
569 
570  // If the output is unaffected, no need to unpack.
571  if (!destPack)
572  return std::make_tuple(newGenericOp, newResult);
573 
574  auto mixedTiles = destPack.getMixedTiles();
575  auto innerDimsPos = destPack.getInnerDimsPos();
576  auto outerDimsPerm = destPack.getOuterDimsPerm();
577 
578  // If the output type for the generic differs from the source
579  // unpack op, we need to create a new destination tensor. In the
580  // dynamic case we always need a new destination.
581  auto loc = genericOp.getLoc();
582  Value unPackDest = producerUnPackOp.getDest();
583  auto genericOutType =
584  cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
585  if (producerUnPackOp.getDestType() != genericOutType ||
586  !genericOutType.hasStaticShape()) {
587  unPackDest = tensor::UnPackOp::createDestinationTensor(
588  rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
589  }
590 
591  // Insert an unPackOp right after the packed generic.
592  Value unPackOpRes =
593  rewriter
594  .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
595  mixedTiles, outerDimsPerm)
596  .getResult();
597 
598  return std::make_tuple(newGenericOp, unPackOpRes);
599 }
600 
601 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
602 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
603 public:
604  PushDownUnPackOpThroughGenericOp(MLIRContext *context,
606  : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
607 
608  LogicalResult matchAndRewrite(GenericOp genericOp,
609  PatternRewriter &rewriter) const override {
610  if (!controlFn(genericOp))
611  return failure();
612 
613  auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
614  if (failed(genericAndRepl))
615  return failure();
616  rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
617  return success();
618  }
619 
620 private:
621  ControlPropagationFn controlFn;
622 };
623 
624 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
625 /// add as many zero padding dimensions in `high` and `low` based on the number
626 /// of point loops.
627 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
628  PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
629  : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
630 
631  LogicalResult matchAndRewrite(tensor::PadOp padOp,
632  PatternRewriter &rewriter) const override {
633  tensor::UnPackOp unpackOp =
634  padOp.getSource().getDefiningOp<tensor::UnPackOp>();
635  if (!unpackOp)
636  return failure();
637 
638  if (!controlFn(padOp))
639  return failure();
640 
641  Location loc = padOp.getLoc();
642  // Bail out if one of the padded dimension is a tiled one.
643  llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
644  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
645  llvm::SmallBitVector innerDims(paddedDims.size());
646  for (int64_t dim : innerDimsPos)
647  innerDims.flip(dim);
648  if (paddedDims.anyCommon(innerDims))
649  return failure();
650 
651  Value paddingVal = padOp.getConstantPaddingValue();
652  if (!paddingVal)
653  return failure();
654 
655  // If we have `outer_dims_perms` we need to adjust the padded dimensions.
656  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
657  SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
658  SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
659  if (!outerDimsPerm.empty()) {
660  applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
661  applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
662  }
663  // Add zero padding for the point loops.
664  size_t pointLoopsSize = innerDimsPos.size();
665  lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
666  highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
667 
668  auto newPadOp = rewriter.create<tensor::PadOp>(
669  loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
670  paddingVal, padOp.getNofold());
671 
672  // Inject the tensor.unpack right after the packed padOp.
673  Value outputUnPack = rewriter.create<tensor::EmptyOp>(
674  loc, padOp.getResultType().getShape(),
675  padOp.getResultType().getElementType());
676 
677  Value replacement = rewriter.create<tensor::UnPackOp>(
678  loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
679  unpackOp.getMixedTiles(), outerDimsPerm);
680  rewriter.replaceOp(padOp, replacement);
681  return success();
682  }
683 
684 private:
685  ControlPropagationFn controlFn;
686 };
687 
688 } // namespace
689 
691  RewritePatternSet &patterns,
692  const ControlPropagationFn &controlPackUnPackPropagation) {
693  patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
694  PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
695  patterns.getContext(), controlPackUnPackPropagation);
696 }
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
unsigned getNumResults() const
Definition: AffineMap.cpp:382
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
A class for computing basic dominance information.
Definition: Dominance.h:121
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:134
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
std::function< bool(Operation *op)> ControlPropagationFn
Function type which is used to control propagation of tensor.pack/unpack ops.
Definition: Transforms.h:1535
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:371
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357