MLIR  22.0.0git
DropUnitDims.cpp
Go to the documentation of this file.
1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/BuiltinTypes.h"
31 #include "llvm/Support/Debug.h"
32 
33 namespace mlir {
34 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
35 #include "mlir/Dialect/Linalg/Passes.h.inc"
36 } // namespace mlir
37 
38 #define DEBUG_TYPE "linalg-drop-unit-dims"
39 
40 using namespace mlir;
41 using namespace mlir::linalg;
42 
43 namespace {
44 /// Pattern to move init operands to ins when all the loops are parallel and
45 /// blockArgument corresponding to init is used in the region. This is a fix-up
46 /// when unit reduction dimensions are all folded away. In this context, it
47 /// becomes a elementwise generic op. E.g., it converts
48 ///
49 /// %0 = tensor.empty() : tensor<1x1xf32>
50 /// %1 = linalg.fill
51 /// ins(%cst : f32)
52 /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
53 /// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
54 /// affine_map<(d0) -> (0, d0)>],
55 /// iterator_types = ["parallel"]}
56 /// ins(%arg0 : tensor<1x?x1x1xf32>)
57 /// outs(%1 : tensor<1x1xf32>) {
58 /// ^bb0(%in: f32, %out: f32):
59 /// %3 = arith.addf %in, %out : f32
60 /// linalg.yield %3 : f32
61 /// } -> tensor<1x1xf32>
62 ///
63 /// into
64 ///
65 /// %0 = tensor.empty() : tensor<1x1xf32>
66 /// %1 = linalg.fill
67 /// ins(%cst : f32)
68 /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
69 /// %2 = tensor.empty() : tensor<1x1xf32>
70 /// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
71 /// affine_map<(d0) -> (0, d0)>,
72 /// affine_map<(d0) -> (0, d0)>],
73 /// iterator_types = ["parallel"]}
74 /// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
75 /// outs(%2 : tensor<1x1xf32>) {
76 /// ^bb0(%in: f32, %in_0: f32, %out: f32):
77 /// %4 = arith.addf %in, %in_0 : f32
78 /// linalg.yield %4 : f32
79 /// } -> tensor<1x1xf32>
80 struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
82  LogicalResult matchAndRewrite(GenericOp genericOp,
83  PatternRewriter &rewriter) const override {
84  if (!genericOp.hasPureTensorSemantics())
85  return failure();
86  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
87  return failure();
88 
89  auto outputOperands = genericOp.getDpsInitsMutable();
90  SetVector<OpOperand *> candidates;
91  for (OpOperand &op : outputOperands) {
92  if (genericOp.getMatchingBlockArgument(&op).use_empty())
93  continue;
94  candidates.insert(&op);
95  }
96 
97  if (candidates.empty())
98  return failure();
99 
100  // Compute the modified indexing maps.
101  int64_t origNumInput = genericOp.getNumDpsInputs();
102  SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
103  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
104  SmallVector<AffineMap> newIndexingMaps;
105  newIndexingMaps.append(indexingMaps.begin(),
106  std::next(indexingMaps.begin(), origNumInput));
107  for (OpOperand *op : candidates) {
108  newInputOperands.push_back(op->get());
109  newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
110  }
111  newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
112  indexingMaps.end());
113 
114  Location loc = genericOp.getLoc();
115  SmallVector<Value> newOutputOperands =
116  llvm::to_vector(genericOp.getDpsInits());
117  for (OpOperand *op : candidates) {
118  OpBuilder::InsertionGuard guard(rewriter);
119  rewriter.setInsertionPointAfterValue(op->get());
120  auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
121  auto empty = tensor::EmptyOp::create(
122  rewriter, loc, tensor::getMixedSizes(rewriter, loc, op->get()),
123  elemType);
124 
125  unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
126  newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
127  }
128 
129  auto newOp = GenericOp::create(
130  rewriter, loc, genericOp.getResultTypes(), newInputOperands,
131  newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(),
132  /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
133 
134  OpBuilder::InsertionGuard guard(rewriter);
135  Region &region = newOp.getRegion();
136  Block *block = rewriter.createBlock(&region);
137  IRMapping mapper;
138  for (auto bbarg : genericOp.getRegionInputArgs())
139  mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
140 
141  for (OpOperand *op : candidates) {
142  BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
143  mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
144  }
145 
146  for (OpOperand &op : outputOperands) {
147  BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
148  if (candidates.count(&op))
149  block->addArgument(bbarg.getType(), loc);
150  else
151  mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
152  }
153 
154  for (auto &op : genericOp.getBody()->getOperations()) {
155  rewriter.clone(op, mapper);
156  }
157  rewriter.replaceOp(genericOp, newOp.getResults());
158 
159  return success();
160  }
161 };
162 } // namespace
163 
164 //===---------------------------------------------------------------------===//
165 // Drop loops that are unit-extents within Linalg operations.
166 //===---------------------------------------------------------------------===//
167 
168 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
169 /// broadcasting. For example,
170 ///
171 /// ```mlir
172 /// #accesses = [
173 /// affine_map<(d0, d1) -> (0, d1)>,
174 /// affine_map<(d0, d1) -> (d0, 0)>,
175 /// affine_map<(d0, d1) -> (d0, d1)>
176 /// ]
177 ///
178 /// #trait = {
179 /// indexing_maps = #accesses,
180 /// iterator_types = ["parallel", "parallel"],
181 /// library_call = "some_external_fn"
182 /// }
183 ///
184 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
185 /// tensor<5x5xf32>
186 /// {
187 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
188 /// tensor<5xf32> into tensor<1x5xf32>
189 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
190 /// tensor<5xf32> into tensor<5x1xf32>
191 /// %2 = linalg.generic #trait %0, %1 {
192 /// ^bb0(%arg2: f32, %arg3: f32):
193 /// %3 = arith.addf %arg2, %arg3 : f32
194 /// linalg.yield %3 : f32
195 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
196 /// return %2 : tensor<5x5xf32>
197 /// }
198 ///
199 /// would canonicalize to
200 ///
201 /// ```mlir
202 /// #accesses = [
203 /// affine_map<(d0, d1) -> (d1)>,
204 /// affine_map<(d0, d1) -> (d0)>,
205 /// affine_map<(d0, d1) -> (d0, d1)>
206 /// ]
207 ///
208 /// #trait = {
209 /// indexing_maps = #accesses,
210 /// iterator_types = ["parallel", "parallel"],
211 /// library_call = "some_external_fn"
212 /// }
213 ///
214 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
215 /// tensor<5x5xf32>
216 /// {
217 /// %0 = linalg.generic #trait %arg0, %arg1 {
218 /// ^bb0(%arg2: f32, %arg3: f32):
219 /// %3 = arith.addf %arg2, %arg3 : f32
220 /// linalg.yield %3 : f32
221 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
222 /// return %0 : tensor<5x5xf32>
223 /// }
224 
225 /// Update the index accesses of linalg operations having index semantics.
226 static void
227 replaceUnitDimIndexOps(GenericOp genericOp,
228  const llvm::SmallDenseSet<unsigned> &unitDims,
229  RewriterBase &rewriter) {
230  for (IndexOp indexOp :
231  llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
232  OpBuilder::InsertionGuard guard(rewriter);
233  rewriter.setInsertionPoint(indexOp);
234  if (unitDims.count(indexOp.getDim()) != 0) {
235  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
236  } else {
237  // Update the dimension of the index operation if needed.
238  unsigned droppedDims = llvm::count_if(
239  unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
240  if (droppedDims != 0)
241  rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
242  indexOp.getDim() - droppedDims);
243  }
244  }
245 }
246 
247 /// Expand the given `value` so that the type matches the type of `origDest`.
248 /// The `reassociation` is used when `rankReductionStrategy` is set to
249 /// `RankReductionStrategy::ReassociativeReshape`.
250 static Value
251 expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
252  ArrayRef<ReassociationIndices> reassociation,
253  ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
254  // There are no results for memref outputs.
255  auto origResultType = cast<RankedTensorType>(origDest.getType());
256  if (rankReductionStrategy ==
258  unsigned rank = origResultType.getRank();
259  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
261  tensor::getMixedSizes(rewriter, loc, origDest);
262  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
263  return rewriter.createOrFold<tensor::InsertSliceOp>(
264  loc, result, origDest, offsets, sizes, strides);
265  }
266 
267  assert(rankReductionStrategy ==
269  "unknown rank reduction strategy");
270  return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result,
271  reassociation)
272  .getResult();
273 }
274 
275 /// Collapse the given `value` so that the type matches the type of
276 /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
277 /// set to `RankReductionStrategy::ReassociativeReshape`.
279  RewriterBase &rewriter, Location loc, Value operand,
280  ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
281  ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
282  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
283  if (rankReductionStrategy ==
285  FailureOr<Value> rankReducingExtract =
286  memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
287  targetShape);
288  assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
289  return *rankReducingExtract;
290  }
291 
292  assert(
293  rankReductionStrategy ==
295  "unknown rank reduction strategy");
296  MemRefLayoutAttrInterface layout;
297  auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
298  layout, memrefType.getMemorySpace());
299  return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
300  reassociation);
301  }
302  if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
303  if (rankReductionStrategy ==
305  FailureOr<Value> rankReducingExtract =
306  tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
307  targetShape);
308  assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
309  return *rankReducingExtract;
310  }
311 
312  assert(
313  rankReductionStrategy ==
315  "unknown rank reduction strategy");
316  auto targetType =
317  RankedTensorType::get(targetShape, tensorType.getElementType());
318  return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
319  reassociation);
320  }
321  llvm_unreachable("unsupported operand type");
322 }
323 
324 /// Compute the modified metadata for an operands of operation
325 /// whose unit dims are being dropped. Return the new indexing map
326 /// to use, the shape of the operand in the replacement op
327 /// and the `reassocation` to use to go from original operand shape
328 /// to modified operand shape.
333 };
335  MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
336  llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
337  ArrayRef<AffineExpr> dimReplacements) {
339  ReassociationIndices reassociationGroup;
340  SmallVector<AffineExpr> newIndexExprs;
341  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
342  SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
343  ArrayRef<AffineExpr> exprs = indexingMap.getResults();
344 
345  auto isUnitDim = [&](unsigned dim) {
346  if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
347  unsigned oldPosition = dimExpr.getPosition();
348  return !oldDimsToNewDimsMap.count(oldPosition) &&
349  (operandShape[dim] == 1);
350  }
351  // Handle the other case where the shape is 1, and is accessed using a
352  // constant 0.
353  if (operandShape[dim] == 1) {
354  auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
355  return constAffineExpr && constAffineExpr.getValue() == 0;
356  }
357  return false;
358  };
359 
360  unsigned dim = 0;
361  while (dim < operandShape.size() && isUnitDim(dim))
362  reassociationGroup.push_back(dim++);
363  while (dim < operandShape.size()) {
364  assert(!isUnitDim(dim) && "expected non unit-extent");
365  reassociationGroup.push_back(dim);
366  AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
367  newIndexExprs.push_back(newExpr);
368  info.targetShape.push_back(operandShape[dim]);
369  ++dim;
370  // Fold all following dimensions that are unit-extent.
371  while (dim < operandShape.size() && isUnitDim(dim)) {
372  reassociationGroup.push_back(dim++);
373  }
374  info.reassociation.push_back(reassociationGroup);
375  reassociationGroup.clear();
376  }
377  info.indexMap =
378  AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
379  newIndexExprs, context);
380  return info;
381 }
382 
383 FailureOr<DropUnitDimsResult>
384 linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
385  const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
386  const ControlDropUnitDims &options) {
387  auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
388  if (!dpsOp) {
389  return rewriter.notifyMatchFailure(
390  op, "op should implement DestinationStyleOpInterface");
391  }
392 
393  SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
394  if (indexingMaps.empty())
395  return failure();
396 
397  // 1. Check if any of the iteration dimensions are unit-trip count. They will
398  // end up being unit-trip count if they are used to index into a unit-dim
399  // tensor/memref.
400  AffineMap invertedMap =
401  inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
402  if (!invertedMap) {
403  return rewriter.notifyMatchFailure(op,
404  "invalid indexing maps for operation");
405  }
406 
407  SmallVector<int64_t> allShapesSizes;
408  for (OpOperand &opOperand : op->getOpOperands())
409  llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
410 
411  // 1a. Get the allowed list of dimensions to drop from the `options`.
412  SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
413  if (allowedUnitDims.empty()) {
414  return rewriter.notifyMatchFailure(
415  op, "control function returns no allowed unit dims to prune");
416  }
417  llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
418  allowedUnitDims.end());
419  llvm::SmallDenseSet<unsigned> unitDims;
420  for (const auto &expr : enumerate(invertedMap.getResults())) {
421  if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
422  if (allShapesSizes[dimExpr.getPosition()] == 1 &&
423  unitDimsFilter.count(expr.index()))
424  unitDims.insert(expr.index());
425  }
426  }
427 
428  // 2. Compute the new loops of the modified op by dropping the one-trip
429  // count loops.
430  llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
431  SmallVector<AffineExpr> dimReplacements;
432  unsigned newDims = 0;
433  for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
434  if (unitDims.count(index)) {
435  dimReplacements.push_back(
436  getAffineConstantExpr(0, rewriter.getContext()));
437  } else {
438  oldDimToNewDimMap[index] = newDims;
439  dimReplacements.push_back(
440  getAffineDimExpr(newDims, rewriter.getContext()));
441  newDims++;
442  }
443  }
444 
445  // 3. For each of the operands, find the
446  // - modified affine map to use.
447  // - shape of the operands after the unit-dims are dropped.
448  // - the reassociation indices used to convert from the original
449  // operand type to modified operand (needed only when using reshapes
450  // for rank reduction strategy)
451  // Note that the indexing maps might need changing even if there are no
452  // unit dimensions that are dropped to handle cases where `0` is used to
453  // access a unit-extent tensor. Consider moving this out of this specific
454  // transformation as a stand-alone transformation. Kept here right now due
455  // to legacy.
456  SmallVector<AffineMap> newIndexingMaps;
458  SmallVector<SmallVector<int64_t>> targetShapes;
459  SmallVector<bool> collapsed;
460  auto hasCollapsibleType = [](OpOperand &operand) {
461  Type operandType = operand.get().getType();
462  if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
463  return memrefOperandType.getLayout().isIdentity();
464  }
465  if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
466  return tensorOperandType.getEncoding() == nullptr;
467  }
468  return false;
469  };
470  for (OpOperand &opOperand : op->getOpOperands()) {
471  auto indexingMap = op.getMatchingIndexingMap(&opOperand);
472  SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
473  if (!hasCollapsibleType(opOperand)) {
474  AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
475  dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
476  newIndexingMaps.push_back(newIndexingMap);
477  targetShapes.push_back(llvm::to_vector(shape));
478  collapsed.push_back(false);
479  reassociations.push_back({});
480  continue;
481  }
482  auto replacementInfo =
483  dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
484  oldDimToNewDimMap, dimReplacements);
485  reassociations.push_back(replacementInfo.reassociation);
486  newIndexingMaps.push_back(replacementInfo.indexMap);
487  targetShapes.push_back(replacementInfo.targetShape);
488  collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
489  indexingMap.getNumResults()));
490  }
491 
492  // Abort if the indexing maps of the result operation are not invertible
493  // (i.e. not legal) or if no dimension was reduced.
494  if (newIndexingMaps == indexingMaps ||
496  concatAffineMaps(newIndexingMaps, rewriter.getContext())))
497  return failure();
498 
499  Location loc = op.getLoc();
500  // 4. For each of the operands, collapse the operand to convert
501  // from original shape to shape in the modified operation if needed,
502  // either through use of reshapes or rank-reducing slices as
503  // specified in `options`.
504  SmallVector<Value> newOperands;
505  for (OpOperand &opOperand : op->getOpOperands()) {
506  int64_t idx = opOperand.getOperandNumber();
507  if (!collapsed[idx]) {
508  newOperands.push_back(opOperand.get());
509  continue;
510  }
511  newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
512  targetShapes[idx], reassociations[idx],
513  options.rankReductionStrategy));
514  }
515 
516  IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
517  loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
518 
519  // 6. If any result type changes, insert a reshape/slice to convert from the
520  // original type to the new type.
521  SmallVector<Value> resultReplacements;
522  for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
523  unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
524  Value origDest = dpsOp.getDpsInitOperand(index)->get();
525  if (!collapsed[opOperandIndex]) {
526  resultReplacements.push_back(result);
527  continue;
528  }
529  Value expandedValue = expandValue(rewriter, loc, result, origDest,
530  reassociations[opOperandIndex],
531  options.rankReductionStrategy);
532  resultReplacements.push_back(expandedValue);
533  }
534 
535  return DropUnitDimsResult{replacementOp, resultReplacements};
536 }
537 
538 FailureOr<DropUnitDimsResult>
539 linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
540  const ControlDropUnitDims &options) {
541 
542  DroppedUnitDimsBuilder build =
543  [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
544  ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
545  const llvm::SmallDenseSet<unsigned> &droppedDims)
546  -> IndexingMapOpInterface {
547  auto genericOp = cast<GenericOp>(op);
548  // Compute the iterator types of the modified op by dropping the one-trip
549  // count loops.
550  SmallVector<utils::IteratorType> newIteratorTypes;
551  for (auto [index, attr] :
552  llvm::enumerate(genericOp.getIteratorTypesArray())) {
553  if (!droppedDims.count(index))
554  newIteratorTypes.push_back(attr);
555  }
556 
557  // Create the `linalg.generic` operation with the new operands,
558  // indexing maps, iterator types and result types.
559  ArrayRef<Value> newInputs =
560  ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
561  ArrayRef<Value> newOutputs =
562  ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
563  SmallVector<Type> resultTypes;
564  resultTypes.reserve(genericOp.getNumResults());
565  for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
566  resultTypes.push_back(newOutputs[i].getType());
567  GenericOp replacementOp =
568  GenericOp::create(b, loc, resultTypes, newInputs, newOutputs,
569  newIndexingMaps, newIteratorTypes);
570  b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
571  replacementOp.getRegion().begin());
572  // 5a. Replace `linalg.index` operations that refer to the dropped unit
573  // dimensions.
574  IRRewriter rewriter(b);
575  replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
576 
577  return replacementOp;
578  };
579 
580  return dropUnitDims(rewriter, genericOp, build, options);
581 }
582 
583 namespace {
584 struct DropUnitDims : public OpRewritePattern<GenericOp> {
585  DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
586  PatternBenefit benefit = 1)
587  : OpRewritePattern(context, benefit), options(std::move(options)) {}
588 
589  LogicalResult matchAndRewrite(GenericOp genericOp,
590  PatternRewriter &rewriter) const override {
591  FailureOr<DropUnitDimsResult> result =
592  dropUnitDims(rewriter, genericOp, options);
593  if (failed(result)) {
594  return failure();
595  }
596  rewriter.replaceOp(genericOp, result->replacements);
597  return success();
598  }
599 
600 private:
602 };
603 } // namespace
604 
605 //===---------------------------------------------------------------------===//
606 // Drop dimensions that are unit-extents within tensor operations.
607 //===---------------------------------------------------------------------===//
608 
609 namespace {
610 struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
611  DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
612  PatternBenefit benefit = 1)
613  : OpRewritePattern(context, benefit), options(std::move(options)) {}
614 
615  LogicalResult matchAndRewrite(tensor::PadOp padOp,
616  PatternRewriter &rewriter) const override {
617  // 1a. Get the allowed list of dimensions to drop from the `options`.
618  SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
619  if (allowedUnitDims.empty()) {
620  return rewriter.notifyMatchFailure(
621  padOp, "control function returns no allowed unit dims to prune");
622  }
623 
624  if (padOp.getSourceType().getEncoding()) {
625  return rewriter.notifyMatchFailure(
626  padOp, "cannot collapse dims of tensor with encoding");
627  }
628 
629  // Fail for non-constant padding values. The body of the pad could
630  // depend on the padding indices and/or properties of the padded
631  // tensor so for now we fail.
632  // TODO: Support non-constant padding values.
633  Value paddingVal = padOp.getConstantPaddingValue();
634  if (!paddingVal) {
635  return rewriter.notifyMatchFailure(
636  padOp, "unimplemented: non-constant padding value");
637  }
638 
639  ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
640  ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
641  int64_t padRank = sourceShape.size();
642 
643  auto isStaticZero = [](OpFoldResult f) {
644  return getConstantIntValue(f) == 0;
645  };
646 
647  llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
648  allowedUnitDims.end());
649  llvm::SmallDenseSet<unsigned> unitDims;
650  SmallVector<int64_t> newShape;
651  SmallVector<int64_t> newResultShape;
652  SmallVector<OpFoldResult> newLowPad;
653  SmallVector<OpFoldResult> newHighPad;
654  for (const auto [dim, size, outSize, low, high] : zip_equal(
655  llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
656  resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
657  if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
658  isStaticZero(high)) {
659  unitDims.insert(dim);
660  } else {
661  newShape.push_back(size);
662  newResultShape.push_back(outSize);
663  newLowPad.push_back(low);
664  newHighPad.push_back(high);
665  }
666  }
667 
668  if (unitDims.empty()) {
669  return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
670  }
671 
672  ReassociationIndices reassociationGroup;
673  SmallVector<ReassociationIndices> reassociationMap;
674  int64_t dim = 0;
675  while (dim < padRank && unitDims.contains(dim))
676  reassociationGroup.push_back(dim++);
677  while (dim < padRank) {
678  assert(!unitDims.contains(dim) && "expected non unit-extent");
679  reassociationGroup.push_back(dim);
680  dim++;
681  // Fold all following dimensions that are unit-extent.
682  while (dim < padRank && unitDims.contains(dim))
683  reassociationGroup.push_back(dim++);
684  reassociationMap.push_back(reassociationGroup);
685  reassociationGroup.clear();
686  }
687 
688  Value collapsedSource =
689  collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
690  reassociationMap, options.rankReductionStrategy);
691 
692  auto newResultType = RankedTensorType::get(
693  newResultShape, padOp.getResultType().getElementType());
694  auto newPadOp = tensor::PadOp::create(
695  rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource,
696  newLowPad, newHighPad, paddingVal, padOp.getNofold());
697 
698  Value dest = padOp.getResult();
699  if (options.rankReductionStrategy ==
701  SmallVector<OpFoldResult> expandedSizes;
702  int64_t numUnitDims = 0;
703  for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
704  if (unitDims.contains(dim)) {
705  expandedSizes.push_back(rewriter.getIndexAttr(1));
706  numUnitDims++;
707  continue;
708  }
709  expandedSizes.push_back(tensor::getMixedSize(
710  rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
711  }
712  dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes,
713  padOp.getResultType().getElementType());
714  }
715 
716  Value expandedValue =
717  expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
718  reassociationMap, options.rankReductionStrategy);
719  rewriter.replaceOp(padOp, expandedValue);
720  return success();
721  }
722 
723 private:
725 };
726 } // namespace
727 
728 namespace {
729 /// Convert `extract_slice` operations to rank-reduced versions.
730 struct RankReducedExtractSliceOp
731  : public OpRewritePattern<tensor::ExtractSliceOp> {
733 
734  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
735  PatternRewriter &rewriter) const override {
736  RankedTensorType resultType = sliceOp.getType();
737  SmallVector<OpFoldResult> targetShape;
738  for (auto size : resultType.getShape())
739  targetShape.push_back(rewriter.getIndexAttr(size));
740  auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
741  if (!reassociation ||
742  reassociation->size() == static_cast<size_t>(resultType.getRank()))
743  return failure();
744 
745  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
746  SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
747  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
748  auto rankReducedType = cast<RankedTensorType>(
749  tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
750  reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
751  strides));
752 
753  Location loc = sliceOp.getLoc();
754  Value newSlice = tensor::ExtractSliceOp::create(
755  rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes,
756  strides);
757  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
758  sliceOp, resultType, newSlice, *reassociation);
759  return success();
760  }
761 };
762 
763 /// Convert `insert_slice` operations to rank-reduced versions.
764 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
765 template <typename InsertOpTy>
766 struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
768 
769  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
770  PatternRewriter &rewriter) const override {
771  RankedTensorType sourceType = insertSliceOp.getSourceType();
772  SmallVector<OpFoldResult> targetShape;
773  for (auto size : sourceType.getShape())
774  targetShape.push_back(rewriter.getIndexAttr(size));
775  auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
776  if (!reassociation ||
777  reassociation->size() == static_cast<size_t>(sourceType.getRank()))
778  return failure();
779 
780  Location loc = insertSliceOp.getLoc();
781  tensor::CollapseShapeOp reshapedSource;
782  {
783  OpBuilder::InsertionGuard g(rewriter);
784  // The only difference between InsertSliceOp and ParallelInsertSliceOp
785  // is the insertion point is just before the ParallelCombiningOp in the
786  // parallel case.
787  if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
788  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
789  reshapedSource = tensor::CollapseShapeOp::create(
790  rewriter, loc, insertSliceOp.getSource(), *reassociation);
791  }
792  rewriter.replaceOpWithNewOp<InsertOpTy>(
793  insertSliceOp, reshapedSource, insertSliceOp.getDest(),
794  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
795  insertSliceOp.getMixedStrides());
796  return success();
797  }
798 };
799 } // namespace
800 
801 /// Patterns that are used to canonicalize the use of unit-extent dims for
802 /// broadcasting.
803 static void
806  auto *context = patterns.getContext();
807  patterns.add<DropUnitDims>(context, options);
808  patterns.add<DropPadUnitDims>(context, options);
809  // TODO: Patterns unrelated to unit dim folding should be factored out.
810  patterns.add<RankReducedExtractSliceOp,
811  RankReducedInsertSliceOp<tensor::InsertSliceOp>,
812  RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
813  context);
814  linalg::FillOp::getCanonicalizationPatterns(patterns, context);
815  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
816  tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
817  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
821 }
822 
823 static void
826  auto *context = patterns.getContext();
827  patterns.add<DropUnitDims>(context, options);
828  patterns.add<DropPadUnitDims>(context, options);
829  // TODO: Patterns unrelated to unit dim folding should be factored out.
830  linalg::FillOp::getCanonicalizationPatterns(patterns, context);
831  tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
835 }
836 
839  if (options.rankReductionStrategy ==
842  } else if (options.rankReductionStrategy ==
844  ReassociativeReshape) {
846  }
847 }
848 
851  patterns.add<MoveInitOperandsToInput>(patterns.getContext());
852 }
853 
854 namespace {
855 /// Pass that removes unit-extent dims within generic ops.
856 struct LinalgFoldUnitExtentDimsPass
857  : public impl::LinalgFoldUnitExtentDimsPassBase<
858  LinalgFoldUnitExtentDimsPass> {
859  using impl::LinalgFoldUnitExtentDimsPassBase<
860  LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
861  void runOnOperation() override {
862  Operation *op = getOperation();
863  MLIRContext *context = op->getContext();
864  RewritePatternSet patterns(context);
866  if (useRankReducingSlices) {
867  options.rankReductionStrategy = linalg::ControlDropUnitDims::
869  }
872  (void)applyPatternsGreedily(op, std::move(patterns));
873  }
874 };
875 
876 } // namespace
877 
878 namespace {
879 
880 /// Returns reassociation indices for collapsing/expanding a
881 /// tensor of rank `rank` at position `pos`.
883 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
884  SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
885  bool lastDim = pos == rank - 1;
886  if (rank > 2) {
887  for (int64_t i = 0; i < rank - 1; i++) {
888  if (i == pos || (lastDim && i == pos - 1))
889  reassociation[i] = ReassociationIndices{i, i + 1};
890  else if (i < pos)
891  reassociation[i] = ReassociationIndices{i};
892  else
893  reassociation[i] = ReassociationIndices{i + 1};
894  }
895  }
896  return reassociation;
897 }
898 
899 /// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
900 /// If `pos < 0`, then don't collapse.
901 static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
902  int64_t pos) {
903  if (pos < 0)
904  return val;
905  auto valType = cast<ShapedType>(val.getType());
906  SmallVector<int64_t> collapsedShape(valType.getShape());
907  collapsedShape.erase(collapsedShape.begin() + pos);
908  return collapseValue(
909  rewriter, val.getLoc(), val, collapsedShape,
910  getReassociationForReshapeAtDim(valType.getRank(), pos),
912 }
913 
914 /// Base class for all rank reduction patterns for contraction ops
915 /// with unit dimensions. All patterns should convert one named op
916 /// to another named op. Intended to reduce only one iteration space dim
917 /// at a time.
918 /// Reducing multiple dims will happen with recusive application of
919 /// pattern rewrites.
920 template <typename FromOpTy, typename ToOpTy>
921 struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
923 
924  /// Collapse all collapsable operands.
926  collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
927  ArrayRef<int64_t> operandCollapseDims) const {
928  assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
929  "expected 3 operands and dims");
930  return llvm::map_to_vector(
931  llvm::zip(operands, operandCollapseDims), [&](auto pair) {
932  return collapseSingletonDimAt(rewriter, std::get<0>(pair),
933  std::get<1>(pair));
934  });
935  }
936 
937  /// Expand result tensor.
938  Value expandResult(PatternRewriter &rewriter, Value result,
939  RankedTensorType expandedType, int64_t dim) const {
940  return tensor::ExpandShapeOp::create(
941  rewriter, result.getLoc(), expandedType, result,
942  getReassociationForReshapeAtDim(expandedType.getRank(), dim));
943  }
944 
945  LogicalResult matchAndRewrite(FromOpTy contractionOp,
946  PatternRewriter &rewriter) const override {
947  if (contractionOp.hasUserDefinedMaps()) {
948  return rewriter.notifyMatchFailure(
949  contractionOp, "ops with user-defined maps are not supported");
950  }
951 
952  auto loc = contractionOp.getLoc();
953  auto inputs = contractionOp.getDpsInputs();
954  auto inits = contractionOp.getDpsInits();
955  if (inputs.size() != 2 || inits.size() != 1)
956  return rewriter.notifyMatchFailure(contractionOp,
957  "expected 2 inputs and 1 init");
958  auto lhs = inputs[0];
959  auto rhs = inputs[1];
960  auto init = inits[0];
961  SmallVector<Value> operands{lhs, rhs, init};
962 
963  SmallVector<int64_t> operandUnitDims;
964  if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
965  return rewriter.notifyMatchFailure(contractionOp,
966  "no reducable dims found");
967 
968  SmallVector<Value> collapsedOperands =
969  collapseOperands(rewriter, operands, operandUnitDims);
970  Value collapsedLhs = collapsedOperands[0];
971  Value collapsedRhs = collapsedOperands[1];
972  Value collapsedInit = collapsedOperands[2];
973  SmallVector<Type, 1> collapsedResultTy;
974  if (isa<RankedTensorType>(collapsedInit.getType()))
975  collapsedResultTy.push_back(collapsedInit.getType());
976  auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy,
977  ValueRange{collapsedLhs, collapsedRhs},
978  ValueRange{collapsedInit});
979  for (auto attr : contractionOp->getAttrs()) {
980  if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
981  attr.getName() == "indexing_maps")
982  continue;
983  collapsedOp->setAttr(attr.getName(), attr.getValue());
984  }
985 
986  auto results = contractionOp.getResults();
987  assert(results.size() < 2 && "expected at most one result");
988  if (results.empty()) {
989  rewriter.replaceOp(contractionOp, collapsedOp);
990  } else {
991  rewriter.replaceOp(
992  contractionOp,
993  expandResult(rewriter, collapsedOp.getResultTensors()[0],
994  cast<RankedTensorType>(results[0].getType()),
995  operandUnitDims[2]));
996  }
997 
998  return success();
999  }
1000 
1001  /// Populate `operandUnitDims` with 3 indices indicating the unit dim
1002  /// for each operand that should be collapsed in this pattern. If an
1003  /// operand shouldn't be collapsed, the index should be negative.
1004  virtual LogicalResult
1005  getOperandUnitDims(LinalgOp op,
1006  SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
1007 };
1008 
1009 /// Patterns for unbatching batched contraction ops
1010 template <typename FromOpTy, typename ToOpTy>
1011 struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
1012  using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1013 
1014  /// Look for unit batch dims to collapse.
1015  LogicalResult
1016  getOperandUnitDims(LinalgOp op,
1017  SmallVectorImpl<int64_t> &operandUnitDims) const override {
1018  FailureOr<ContractionDimensions> maybeContractionDims =
1020  if (failed(maybeContractionDims)) {
1021  LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1022  return failure();
1023  }
1024  ContractionDimensions contractionDims = maybeContractionDims.value();
1025 
1026  if (contractionDims.batch.size() != 1)
1027  return failure();
1028  auto batchDim = contractionDims.batch[0];
1030  op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
1031  if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
1032  return cast<ShapedType>(std::get<0>(pair).getType())
1033  .getShape()[std::get<1>(pair)] != 1;
1034  })) {
1035  LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1036  return failure();
1037  }
1038 
1039  operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
1040  std::get<1>(bOperands[1]),
1041  std::get<1>(bOperands[2])};
1042  return success();
1043  }
1044 };
1045 
1046 /// Patterns for reducing non-batch dimensions
1047 template <typename FromOpTy, typename ToOpTy>
1048 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1049  using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1050 
1051  /// Helper for determining whether the lhs/init or rhs/init are reduced.
1052  static bool constexpr reduceLeft =
1053  (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1054  std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1055  (std::is_same_v<FromOpTy, MatmulOp> &&
1056  std::is_same_v<ToOpTy, VecmatOp>) ||
1057  (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1058 
1059  /// Look for non-batch spatial dims to collapse.
1060  LogicalResult
1061  getOperandUnitDims(LinalgOp op,
1062  SmallVectorImpl<int64_t> &operandUnitDims) const override {
1063  FailureOr<ContractionDimensions> maybeContractionDims =
1065  if (failed(maybeContractionDims)) {
1066  LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1067  return failure();
1068  }
1069  ContractionDimensions contractionDims = maybeContractionDims.value();
1070 
1071  if constexpr (reduceLeft) {
1072  auto m = contractionDims.m[0];
1074  op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1075  if (mOperands.size() != 2)
1076  return failure();
1077  if (llvm::all_of(mOperands, [](auto pair) {
1078  return cast<ShapedType>(std::get<0>(pair).getType())
1079  .getShape()[std::get<1>(pair)] == 1;
1080  })) {
1081  operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1082  std::get<1>(mOperands[1])};
1083  return success();
1084  }
1085  } else {
1086  auto n = contractionDims.n[0];
1088  op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1089  if (nOperands.size() != 2)
1090  return failure();
1091  if (llvm::all_of(nOperands, [](auto pair) {
1092  return cast<ShapedType>(std::get<0>(pair).getType())
1093  .getShape()[std::get<1>(pair)] == 1;
1094  })) {
1095  operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1096  std::get<1>(nOperands[1])};
1097  return success();
1098  }
1099  }
1100  LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1101  return failure();
1102  }
1103 };
1104 
1105 } // namespace
1106 
1109  MLIRContext *context = patterns.getContext();
1110  // Unbatching patterns for unit batch size
1111  patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1112  patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1113  patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1114 
1115  // Non-batch rank 1 reducing patterns
1116  patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1117  patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1118  // Batch rank 1 reducing patterns
1119  patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1120  patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1121 
1122  // Non-batch rank 0 reducing patterns
1123  patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1124  patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
1125 }
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:496
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
MLIRContext * getContext() const
Definition: Builders.h:55
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:419
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
std::function< IndexingMapOpInterface(Location loc, OpBuilder &, IndexingMapOpInterface, ArrayRef< Value > newOperands, ArrayRef< AffineMap > newIndexingMaps, const llvm::SmallDenseSet< unsigned > &droppedDims)> DroppedUnitDimsBuilder
Definition: Transforms.h:548
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:385
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition: Utils.cpp:920
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options)
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:829
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:522