MLIR  18.0.0git
DropUnitDims.cpp
Go to the documentation of this file.
1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/BuiltinTypes.h"
31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 
35 namespace mlir {
36 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
37 #include "mlir/Dialect/Linalg/Passes.h.inc"
38 } // namespace mlir
39 
40 #define DEBUG_TYPE "linalg-drop-unit-dims"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 namespace {
46 /// Pattern to move init operands to ins when all the loops are parallel and
47 /// blockArgument corresponding to init is used in the region. This is a fix-up
48 /// when unit reduction dimensions are all folded away. In this context, it
49 /// becomes a elementwise generic op. E.g., it converts
50 ///
51 /// %0 = tensor.empty() : tensor<1x1xf32>
52 /// %1 = linalg.fill
53 /// ins(%cst : f32)
54 /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
55 /// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
56 /// affine_map<(d0) -> (0, d0)>],
57 /// iterator_types = ["parallel"]}
58 /// ins(%arg0 : tensor<1x?x1x1xf32>)
59 /// outs(%1 : tensor<1x1xf32>) {
60 /// ^bb0(%in: f32, %out: f32):
61 /// %3 = arith.addf %in, %out : f32
62 /// linalg.yield %3 : f32
63 /// } -> tensor<1x1xf32>
64 ///
65 /// into
66 ///
67 /// %0 = tensor.empty() : tensor<1x1xf32>
68 /// %1 = linalg.fill
69 /// ins(%cst : f32)
70 /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
71 /// %2 = tensor.empty() : tensor<1x1xf32>
72 /// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
73 /// affine_map<(d0) -> (0, d0)>,
74 /// affine_map<(d0) -> (0, d0)>],
75 /// iterator_types = ["parallel"]}
76 /// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
77 /// outs(%2 : tensor<1x1xf32>) {
78 /// ^bb0(%in: f32, %in_0: f32, %out: f32):
79 /// %4 = arith.addf %in, %in_0 : f32
80 /// linalg.yield %4 : f32
81 /// } -> tensor<1x1xf32>
82 struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
84  LogicalResult matchAndRewrite(GenericOp genericOp,
85  PatternRewriter &rewriter) const override {
86  if (!genericOp.hasTensorSemantics())
87  return failure();
88  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
89  return failure();
90 
91  auto outputOperands = genericOp.getDpsInitsMutable();
92  SetVector<OpOperand *> candidates;
93  for (OpOperand &op : outputOperands) {
94  if (genericOp.getMatchingBlockArgument(&op).use_empty())
95  continue;
96  candidates.insert(&op);
97  }
98 
99  if (candidates.empty())
100  return failure();
101 
102  // Compute the modified indexing maps.
103  int64_t origNumInput = genericOp.getNumDpsInputs();
104  SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
105  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
106  SmallVector<AffineMap> newIndexingMaps;
107  newIndexingMaps.append(indexingMaps.begin(),
108  std::next(indexingMaps.begin(), origNumInput));
109  for (OpOperand *op : candidates) {
110  newInputOperands.push_back(op->get());
111  newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
112  }
113  newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
114  indexingMaps.end());
115 
116  Location loc = genericOp.getLoc();
117  SmallVector<Value> newOutputOperands =
118  llvm::to_vector(genericOp.getDpsInits());
119  for (OpOperand *op : candidates) {
120  OpBuilder::InsertionGuard guard(rewriter);
121  rewriter.setInsertionPointAfterValue(op->get());
122  auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
123  auto empty = rewriter.create<tensor::EmptyOp>(
124  loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType);
125 
126  unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
127  newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
128  }
129 
130  auto newOp = rewriter.create<GenericOp>(
131  loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
132  newIndexingMaps, genericOp.getIteratorTypesArray(),
133  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
134 
135  Region &region = newOp.getRegion();
136  Block *block = new Block();
137  region.push_back(block);
138  IRMapping mapper;
139  OpBuilder::InsertionGuard guard(rewriter);
140  rewriter.setInsertionPointToStart(block);
141  for (auto bbarg : genericOp.getRegionInputArgs())
142  mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
143 
144  for (OpOperand *op : candidates) {
145  BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
146  mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
147  }
148 
149  for (OpOperand &op : outputOperands) {
150  BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
151  if (candidates.count(&op))
152  block->addArgument(bbarg.getType(), loc);
153  else
154  mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
155  }
156 
157  for (auto &op : genericOp.getBody()->getOperations()) {
158  rewriter.clone(op, mapper);
159  }
160  rewriter.replaceOp(genericOp, newOp.getResults());
161 
162  return success();
163  }
164 };
165 } // namespace
166 
167 //===---------------------------------------------------------------------===//
168 // Drop loops that are unit-extents within Linalg operations.
169 //===---------------------------------------------------------------------===//
170 
171 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
172 /// broadcasting. For example,
173 ///
174 /// ```mlir
175 /// #accesses = [
176 /// affine_map<(d0, d1) -> (0, d1)>,
177 /// affine_map<(d0, d1) -> (d0, 0)>,
178 /// affine_map<(d0, d1) -> (d0, d1)>
179 /// ]
180 ///
181 /// #trait = {
182 /// args_in = 2,
183 /// args_out = 1,
184 /// indexing_maps = #accesses,
185 /// iterator_types = ["parallel", "parallel"],
186 /// library_call = "some_external_fn"
187 /// }
188 ///
189 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
190 /// tensor<5x5xf32>
191 /// {
192 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
193 /// tensor<5xf32> into tensor<1x5xf32>
194 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
195 /// tensor<5xf32> into tensor<5x1xf32>
196 /// %2 = linalg.generic #trait %0, %1 {
197 /// ^bb0(%arg2: f32, %arg3: f32):
198 /// %3 = arith.addf %arg2, %arg3 : f32
199 /// linalg.yield %3 : f32
200 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
201 /// return %2 : tensor<5x5xf32>
202 /// }
203 ///
204 /// would canonicalize to
205 ///
206 /// ```mlir
207 /// #accesses = [
208 /// affine_map<(d0, d1) -> (d1)>,
209 /// affine_map<(d0, d1) -> (d0)>,
210 /// affine_map<(d0, d1) -> (d0, d1)>
211 /// ]
212 ///
213 /// #trait = {
214 /// args_in = 2,
215 /// args_out = 1,
216 /// indexing_maps = #accesses,
217 /// iterator_types = ["parallel", "parallel"],
218 /// library_call = "some_external_fn"
219 /// }
220 ///
221 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
222 /// tensor<5x5xf32>
223 /// {
224 /// %0 = linalg.generic #trait %arg0, %arg1 {
225 /// ^bb0(%arg2: f32, %arg3: f32):
226 /// %3 = arith.addf %arg2, %arg3 : f32
227 /// linalg.yield %3 : f32
228 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
229 /// return %0 : tensor<5x5xf32>
230 /// }
231 
232 /// Update the index accesses of linalg operations having index semantics.
233 static void
234 replaceUnitDimIndexOps(GenericOp genericOp,
235  const llvm::SmallDenseSet<unsigned> &unitDims,
236  RewriterBase &rewriter) {
237  for (IndexOp indexOp :
238  llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
239  OpBuilder::InsertionGuard guard(rewriter);
240  rewriter.setInsertionPoint(indexOp);
241  if (unitDims.count(indexOp.getDim()) != 0) {
242  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
243  } else {
244  // Update the dimension of the index operation if needed.
245  unsigned droppedDims = llvm::count_if(
246  unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
247  if (droppedDims != 0)
248  rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
249  indexOp.getDim() - droppedDims);
250  }
251  }
252 }
253 
254 /// Expand the given `value` so that the type matches the type of `origDest`.
255 /// The `reassociation` is used when `rankReductionStrategy` is set to
256 /// `RankReductionStrategy::ReassociativeReshape`.
257 static Value
258 expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
259  ArrayRef<ReassociationIndices> reassociation,
260  ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
261  // There are no results for memref outputs.
262  auto origResultType = cast<RankedTensorType>(origDest.getType());
263  if (rankReductionStrategy ==
265  unsigned rank = origResultType.getRank();
266  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
268  tensor::getMixedSizes(rewriter, loc, origDest);
269  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
270  return rewriter.createOrFold<tensor::InsertSliceOp>(
271  loc, result, origDest, offsets, sizes, strides);
272  }
273 
274  assert(rankReductionStrategy ==
276  "unknown rank reduction strategy");
277  return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
278  reassociation);
279 }
280 
281 /// Collapse the given `value` so that the type matches the type of
282 /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
283 /// set to `RankReductionStrategy::ReassociativeReshape`.
285  RewriterBase &rewriter, Location loc, Value operand,
286  ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
287  ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
288  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
289  if (rankReductionStrategy ==
291  FailureOr<Value> rankReducingExtract =
292  memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
293  targetShape);
294  assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
295  return *rankReducingExtract;
296  }
297 
298  assert(
299  rankReductionStrategy ==
301  "unknown rank reduction strategy");
302  MemRefLayoutAttrInterface layout;
303  auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
304  layout, memrefType.getMemorySpace());
305  return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
306  reassociation);
307  }
308  if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
309  if (rankReductionStrategy ==
311  FailureOr<Value> rankReducingExtract =
312  tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
313  targetShape);
314  assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
315  return *rankReducingExtract;
316  }
317 
318  assert(
319  rankReductionStrategy ==
321  "unknown rank reduction strategy");
322  auto targetType =
323  RankedTensorType::get(targetShape, tensorType.getElementType());
324  return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
325  reassociation);
326  }
327  llvm_unreachable("unsupported operand type");
328 }
329 
330 /// Compute the modified metadata for an operands of operation
331 /// whose unit dims are being dropped. Return the new indexing map
332 /// to use, the shape of the operand in the replacement op
333 /// and the `reassocation` to use to go from original operand shape
334 /// to modified operand shape.
339 };
341  MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
342  llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
343  ArrayRef<AffineExpr> dimReplacements) {
345  ReassociationIndices reassociationGroup;
346  SmallVector<AffineExpr> newIndexExprs;
347  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
348  ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
349  ArrayRef<AffineExpr> exprs = indexingMap.getResults();
350 
351  auto isUnitDim = [&](unsigned dim) {
352  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
353  unsigned oldPosition = dimExpr.getPosition();
354  return !oldDimsToNewDimsMap.count(oldPosition);
355  }
356  // Handle the other case where the shape is 1, and is accessed using a
357  // constant 0.
358  if (operandShape[dim] == 1) {
359  auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
360  return constAffineExpr && constAffineExpr.getValue() == 0;
361  }
362  return false;
363  };
364 
365  unsigned dim = 0;
366  while (dim < operandShape.size() && isUnitDim(dim))
367  reassociationGroup.push_back(dim++);
368  while (dim < operandShape.size()) {
369  assert(!isUnitDim(dim) && "expected non unit-extent");
370  reassociationGroup.push_back(dim);
371  AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
372  newIndexExprs.push_back(newExpr);
373  info.targetShape.push_back(operandShape[dim]);
374  ++dim;
375  // Fold all following dimensions that are unit-extent.
376  while (dim < operandShape.size() && isUnitDim(dim)) {
377  reassociationGroup.push_back(dim++);
378  }
379  info.reassociation.push_back(reassociationGroup);
380  reassociationGroup.clear();
381  }
382  info.indexMap =
383  AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
384  newIndexExprs, context);
385  return info;
386 }
387 
388 LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
389  const ControlDropUnitDims &options) {
390  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
391  if (indexingMaps.empty())
392  return failure();
393 
394  // 1. Check if any of the iteration dimensions are unit-trip count. They will
395  // end up being unit-trip count if they are used to index into a unit-dim
396  // tensor/memref.
397  AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
398  if (!invertedMap) {
399  return rewriter.notifyMatchFailure(genericOp,
400  "invalid indexing maps for operation");
401  }
402  SmallVector<int64_t> dims = genericOp.getStaticShape();
403 
404  // 1a. Get the allowed list of dimensions to drop from the `options`.
405  SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
406  if (allowedUnitDims.empty()) {
407  return rewriter.notifyMatchFailure(
408  genericOp, "control function returns no allowed unit dims to prune");
409  }
410  llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
411  allowedUnitDims.end());
412  llvm::SmallDenseSet<unsigned> unitDims;
413  for (const auto &expr : enumerate(invertedMap.getResults())) {
414  if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
415  if (dims[dimExpr.getPosition()] == 1 &&
416  unitDimsFilter.count(expr.index()))
417  unitDims.insert(expr.index());
418  }
419  }
420 
421  // 2. Compute the iterator types of the modified op by dropping the one-trip
422  // count loops.
423  SmallVector<utils::IteratorType> newIteratorTypes;
424  llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
425  SmallVector<AffineExpr> dimReplacements;
426  unsigned newDims = 0;
427  for (auto [index, attr] :
428  llvm::enumerate(genericOp.getIteratorTypesArray())) {
429  if (unitDims.count(index)) {
430  dimReplacements.push_back(
431  getAffineConstantExpr(0, rewriter.getContext()));
432  } else {
433  newIteratorTypes.push_back(attr);
434  oldDimToNewDimMap[index] = newDims;
435  dimReplacements.push_back(
436  getAffineDimExpr(newDims, rewriter.getContext()));
437  newDims++;
438  }
439  }
440 
441  // 3. For each of the operands, find the
442  // - modified affine map to use.
443  // - shape of the operands after the unit-dims are dropped.
444  // - the reassociation indices used to convert from the original
445  // operand type to modified operand (needed only when using reshapes
446  // for rank reduction strategy)
447  // Note that the indexing maps might need changing even if there are no
448  // unit dimensions that are dropped to handle cases where `0` is used to
449  // access a unit-extent tensor. Consider moving this out of this specific
450  // transformation as a stand-alone transformation. Kept here right now due
451  // to legacy.
452  SmallVector<AffineMap> newIndexingMaps;
454  SmallVector<SmallVector<int64_t>> targetShapes;
455  SmallVector<bool> collapsed;
456  auto hasCollapsibleType = [](OpOperand &operand) {
457  Type operandType = operand.get().getType();
458  if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
459  return memrefOperandType.getLayout().isIdentity();
460  } else if (auto tensorOperandType =
461  dyn_cast<RankedTensorType>(operandType)) {
462  return tensorOperandType.getEncoding() == nullptr;
463  }
464  return false;
465  };
466  for (OpOperand &opOperand : genericOp->getOpOperands()) {
467  auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
468  ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
469  if (!hasCollapsibleType(opOperand)) {
470  AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
471  dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
472  newIndexingMaps.push_back(newIndexingMap);
473  targetShapes.push_back(llvm::to_vector(shape));
474  collapsed.push_back(false);
475  reassociations.push_back({});
476  continue;
477  }
478  auto replacementInfo = dropUnitExtentFromOperandMetadata(
479  rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
480  dimReplacements);
481  reassociations.push_back(replacementInfo.reassociation);
482  newIndexingMaps.push_back(replacementInfo.indexMap);
483  targetShapes.push_back(replacementInfo.targetShape);
484  collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
485  indexingMap.getNumResults()));
486  }
487 
488  // Abort if the indexing maps of the result operation are not invertible
489  // (i.e. not legal) or if no dimension was reduced.
490  if (newIndexingMaps == indexingMaps ||
491  !inversePermutation(concatAffineMaps(newIndexingMaps)))
492  return failure();
493 
494  Location loc = genericOp.getLoc();
495  // 4. For each of the operands, collapse the operand to convert
496  // from original shape to shape in the modified operation if needed,
497  // either through use of reshapes or rank-reducing slices as
498  // specified in `options`.
499  SmallVector<Value> newOperands;
500  for (OpOperand &opOperand : genericOp->getOpOperands()) {
501  int64_t idx = opOperand.getOperandNumber();
502  if (!collapsed[idx]) {
503  newOperands.push_back(opOperand.get());
504  continue;
505  }
506  newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
507  targetShapes[idx], reassociations[idx],
508  options.rankReductionStrategy));
509  }
510 
511  // 5. Create the `linalg.generic` operation with the new operands,
512  // indexing maps, iterator types and result types.
513  ArrayRef<Value> newInputs =
514  ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
515  ArrayRef<Value> newOutputs =
516  ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
517  SmallVector<Type> resultTypes;
518  resultTypes.reserve(genericOp.getNumResults());
519  for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
520  resultTypes.push_back(newOutputs[i].getType());
521  GenericOp replacementOp =
522  rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
523  newIndexingMaps, newIteratorTypes);
524  rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
525  replacementOp.getRegion().begin());
526  // 5a. Replace `linalg.index` operations that refer to the dropped unit
527  // dimensions.
528  replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
529 
530  // 6. If any result type changes, insert a reshape/slice to convert from the
531  // original
532  // type to the new type.
533  SmallVector<Value> resultReplacements;
534  for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
535  unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
536  Value origDest = genericOp.getDpsInitOperand(index)->get();
537  if (!collapsed[opOperandIndex]) {
538  resultReplacements.push_back(result);
539  continue;
540  }
541  resultReplacements.push_back(expandValue(rewriter, loc, result, origDest,
542  reassociations[opOperandIndex],
543  options.rankReductionStrategy));
544  }
545 
546  rewriter.replaceOp(genericOp, resultReplacements);
547  return success();
548 }
549 
550 namespace {
551 struct DropUnitDims : public OpRewritePattern<GenericOp> {
552  DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
553  PatternBenefit benefit = 1)
554  : OpRewritePattern(context, benefit), options(std::move(options)) {}
555 
556  LogicalResult matchAndRewrite(GenericOp genericOp,
557  PatternRewriter &rewriter) const override {
558  return dropUnitDims(rewriter, genericOp, options);
559  }
560 
561 private:
563 };
564 } // namespace
565 
566 namespace {
567 /// Convert `extract_slice` operations to rank-reduced versions.
568 struct RankReducedExtractSliceOp
569  : public OpRewritePattern<tensor::ExtractSliceOp> {
571 
572  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
573  PatternRewriter &rewriter) const override {
574  RankedTensorType resultType = sliceOp.getType();
575  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
576  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
577  SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
578  auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
579  if (!reassociation ||
580  reassociation->size() == static_cast<size_t>(resultType.getRank()))
581  return failure();
582  auto rankReducedType = cast<RankedTensorType>(
583  tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
584  reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
585  strides));
586 
587  Location loc = sliceOp.getLoc();
588  Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
589  loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
590  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
591  sliceOp, resultType, newSlice, *reassociation);
592  return success();
593  }
594 };
595 
596 /// Convert `insert_slice` operations to rank-reduced versions.
597 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
598 template <typename InsertOpTy>
599 struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
601 
602  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
603  PatternRewriter &rewriter) const override {
604  RankedTensorType sourceType = insertSliceOp.getSourceType();
605  SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
606  SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
607  SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
608  auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
609  if (!reassociation ||
610  reassociation->size() == static_cast<size_t>(sourceType.getRank()))
611  return failure();
612  Location loc = insertSliceOp.getLoc();
613  tensor::CollapseShapeOp reshapedSource;
614  {
615  OpBuilder::InsertionGuard g(rewriter);
616  // The only difference between InsertSliceOp and ParallelInsertSliceOp
617  // is the insertion point is just before the ParallelCombiningOp in the
618  // parallel case.
619  if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
620  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
621  reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
622  loc, insertSliceOp.getSource(), *reassociation);
623  }
624  rewriter.replaceOpWithNewOp<InsertOpTy>(
625  insertSliceOp, reshapedSource, insertSliceOp.getDest(),
626  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
627  insertSliceOp.getMixedStrides());
628  return success();
629  }
630 };
631 } // namespace
632 
633 /// Patterns that are used to canonicalize the use of unit-extent dims for
634 /// broadcasting.
635 static void
638  auto *context = patterns.getContext();
639  patterns.add<DropUnitDims>(context, options);
640  // TODO: Patterns unrelated to unit dim folding should be factored out.
641  patterns.add<RankReducedExtractSliceOp,
642  RankReducedInsertSliceOp<tensor::InsertSliceOp>,
643  RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
644  context);
645  linalg::FillOp::getCanonicalizationPatterns(patterns, context);
646  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
647  tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
648  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
652 }
653 
654 static void
657  auto *context = patterns.getContext();
658  options.rankReductionStrategy =
660  patterns.add<DropUnitDims>(context, options);
661  // TODO: Patterns unrelated to unit dim folding should be factored out.
662  linalg::FillOp::getCanonicalizationPatterns(patterns, context);
663  tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
667 }
668 
671  if (options.rankReductionStrategy ==
674  } else if (options.rankReductionStrategy ==
676  ReassociativeReshape) {
678  }
679 }
680 
682  RewritePatternSet &patterns) {
683  patterns.add<MoveInitOperandsToInput>(patterns.getContext());
684 }
685 
686 namespace {
687 /// Pass that removes unit-extent dims within generic ops.
688 struct LinalgFoldUnitExtentDimsPass
689  : public impl::LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
690  void runOnOperation() override {
691  Operation *op = getOperation();
692  MLIRContext *context = op->getContext();
693  RewritePatternSet patterns(context);
695  if (useRankReducingSlices) {
696  options.rankReductionStrategy = linalg::ControlDropUnitDims::
698  }
701  (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
702  }
703 };
704 } // namespace
705 
707  return std::make_unique<LinalgFoldUnitExtentDimsPass>();
708 }
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:222
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:378
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:480
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
MLIRContext * getContext() const
Definition: Builders.h:55
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:371
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition: Utils.cpp:886
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options)
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:749
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:794
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createLinalgFoldUnitExtentDimsPass()
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:608
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:584
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:473