MLIR  16.0.0git
DropUnitDims.cpp
Go to the documentation of this file.
1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "linalg-drop-unit-dims"
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
36 /// broadcasting. For example,
37 ///
38 /// ```mlir
39 /// #accesses = [
40 /// affine_map<(d0, d1) -> (0, d1)>,
41 /// affine_map<(d0, d1) -> (d0, 0)>,
42 /// affine_map<(d0, d1) -> (d0, d1)>
43 /// ]
44 ///
45 /// #trait = {
46 /// args_in = 2,
47 /// args_out = 1,
48 /// indexing_maps = #accesses,
49 /// iterator_types = ["parallel", "parallel"],
50 /// library_call = "some_external_fn"
51 /// }
52 ///
53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
54 /// tensor<5x5xf32>
55 /// {
56 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
57 /// tensor<5xf32> into tensor<1x5xf32>
58 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
59 /// tensor<5xf32> into tensor<5x1xf32>
60 /// %2 = linalg.generic #trait %0, %1 {
61 /// ^bb0(%arg2: f32, %arg3: f32):
62 /// %3 = arith.addf %arg2, %arg3 : f32
63 /// linalg.yield %3 : f32
64 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
65 /// return %2 : tensor<5x5xf32>
66 /// }
67 ///
68 /// would canonicalize to
69 ///
70 /// ```mlir
71 /// #accesses = [
72 /// affine_map<(d0, d1) -> (d1)>,
73 /// affine_map<(d0, d1) -> (d0)>,
74 /// affine_map<(d0, d1) -> (d0, d1)>
75 /// ]
76 ///
77 /// #trait = {
78 /// args_in = 2,
79 /// args_out = 1,
80 /// indexing_maps = #accesses,
81 /// iterator_types = ["parallel", "parallel"],
82 /// library_call = "some_external_fn"
83 /// }
84 ///
85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
86 /// tensor<5x5xf32>
87 /// {
88 /// %0 = linalg.generic #trait %arg0, %arg1 {
89 /// ^bb0(%arg2: f32, %arg3: f32):
90 /// %3 = arith.addf %arg2, %arg3 : f32
91 /// linalg.yield %3 : f32
92 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
93 /// return %0 : tensor<5x5xf32>
94 /// }
95 
96 /// Given dims of the iteration space of a structured op that are known to be
97 /// single trip count (`unitDims`), return the indexing maps to use in the
98 /// canonicalized op with these dims removed, given the original `indexingMaps`.
99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
100  ArrayRef<AffineMap> indexingMaps,
101  MLIRContext *context) {
102  if (indexingMaps.empty())
103  return nullptr;
104  unsigned numIterationDims = indexingMaps.front().getNumDims();
105  unsigned numSymbols = indexingMaps.front().getNumSymbols();
106 
107  // Compute the replacement for each dim expr.
108  SmallVector<AffineExpr, 4> dimReplacements;
109  dimReplacements.reserve(numIterationDims);
110  unsigned numKeptDims = 0;
111  for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
112  if (unitDims.count(dim))
113  dimReplacements.push_back(getAffineConstantExpr(0, context));
114  else
115  dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
116  }
117 
118  // Symbols remain the same.
119  SmallVector<AffineExpr, 4> symReplacements;
120  symReplacements.reserve(numSymbols);
121  for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
122  symReplacements.push_back(getAffineSymbolExpr(symbol, context));
123 
124  SmallVector<AffineMap, 4> newIndexingMaps;
125  newIndexingMaps.reserve(indexingMaps.size());
126  for (AffineMap operandMap : indexingMaps) {
127  // Expected indexing maps to have no symbols.
128  if (operandMap.getNumSymbols())
129  return nullptr;
130  newIndexingMaps.push_back(simplifyAffineMap(
131  operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
132  numIterationDims - unitDims.size(),
133  numSymbols)));
134  }
135 
136  // Check that the new index maps are invertible. If not, something went
137  // wrong, so abort.
138  if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
139  return nullptr;
140  return ArrayAttr::get(context,
141  llvm::to_vector<4>(llvm::map_range(
142  newIndexingMaps, [](AffineMap map) -> Attribute {
143  return AffineMapAttr::get(map);
144  })));
145 }
146 
147 /// Update the index accesses of linalg operations having index semantics.
148 static void replaceUnitDimIndexOps(GenericOp genericOp,
149  const DenseSet<unsigned> &unitDims,
150  PatternRewriter &rewriter) {
151  for (IndexOp indexOp :
152  llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
153  OpBuilder::InsertionGuard guard(rewriter);
154  rewriter.setInsertionPoint(indexOp);
155  if (unitDims.count(indexOp.getDim()) != 0) {
156  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
157  } else {
158  // Update the dimension of the index operation if needed.
159  unsigned droppedDims = llvm::count_if(
160  unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
161  if (droppedDims != 0)
162  rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
163  indexOp.getDim() - droppedDims);
164  }
165  }
166 }
167 
168 namespace {
169 /// Pattern to fold unit-trip count loops in GenericOps.
170 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
172  LogicalResult matchAndRewrite(GenericOp genericOp,
173  PatternRewriter &rewriter) const override {
174  SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMapsArray();
175  if (indexingMaps.empty())
176  return failure();
177 
178  // Check if any of the iteration dimensions are unit-trip count. They will
179  // end up being unit-trip count if they are used to index into a unit-dim
180  // tensor/memref.
181  AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
182  if (!invertedMap)
183  return failure();
184  SmallVector<int64_t> dims = genericOp.getStaticShape();
185 
186  DenseSet<unsigned> unitDims;
187  SmallVector<unsigned, 4> unitDimsReductionLoops;
188  ArrayAttr iteratorTypes = genericOp.getIteratorTypes();
189  for (const auto &expr : enumerate(invertedMap.getResults())) {
190  if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
191  if (dims[dimExpr.getPosition()] == 1)
192  unitDims.insert(expr.index());
193  }
194 
195  if (unitDims.empty())
196  return failure();
197 
198  // Compute the modified indexing maps.
199  MLIRContext *context = rewriter.getContext();
200  ArrayAttr newIndexingMapAttr =
201  replaceUnitDims(unitDims, indexingMaps, context);
202  if (!newIndexingMapAttr)
203  return genericOp.emitError("unable to compute modified indexing_maps");
204 
205  // Compute the iterator types of the modified op by dropping the one-trip
206  // count loops.
207  SmallVector<Attribute, 4> newIteratorTypes;
208  for (const auto &attr : llvm::enumerate(iteratorTypes)) {
209  if (!unitDims.count(attr.index()))
210  newIteratorTypes.push_back(attr.value());
211  }
212 
213  rewriter.startRootUpdate(genericOp);
214  genericOp.setIndexingMapsAttr(newIndexingMapAttr);
215  genericOp.setIteratorTypesAttr(ArrayAttr::get(context, newIteratorTypes));
216  replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
217  rewriter.finalizeRootUpdate(genericOp);
218  return success();
219  }
220 };
221 
222 struct UnitExtentReplacementInfo {
223  Type type;
224  AffineMap indexMap;
225  ArrayAttr reassociation;
226 };
227 } // namespace
228 
229 /// Utility function for replacing operands/results to a linalg generic
230 /// operation with unit-extent dimensions. These can be replaced with
231 /// an operand/result with the unit-extent dimension removed. This is only done
232 /// if the indexing map used to access that didimensionmension has a
233 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
234 /// Linalg op, and its `indexMap` the utility function returns:
235 /// - the new type with dimensions of size 1 removed.
236 /// - modified index map that can be used to access the replaced result/operand
237 /// - the reassociation that converts from the original tensor type to the
238 /// modified tensor type.
240 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
241  MLIRContext *context) {
242  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
243  ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
244  ArrayRef<AffineExpr> exprs = indexingMap.getResults();
245  SmallVector<AffineExpr> reassociations;
246  SmallVector<Attribute> reassociationMaps;
247  SmallVector<AffineExpr> newIndexExprs;
248  SmallVector<int64_t> newShape;
249 
250  int64_t origRank = genericOp.getRank(opOperand);
251  AffineExpr zeroExpr = getAffineConstantExpr(0, context);
252  auto isUnitExtent = [&](int64_t dim) -> bool {
253  return shape[dim] == 1 && exprs[dim] == zeroExpr;
254  };
255 
256  // Early return for memrefs with affine maps to represent that we will always
257  // leave them unchanged.
258  Type actualType = opOperand->get().getType();
259  if (auto memref = actualType.dyn_cast<MemRefType>()) {
260  if (!memref.getLayout().isIdentity())
261  return llvm::None;
262  }
263 
264  int64_t dim = 0;
265  // Fold dimensions that are unit-extent at the beginning of the tensor.
266  while (dim < origRank && isUnitExtent(dim))
267  reassociations.push_back(getAffineDimExpr(dim++, context));
268  while (dim < origRank) {
269  reassociations.push_back(getAffineDimExpr(dim, context));
270  newIndexExprs.push_back(exprs[dim]);
271  newShape.push_back(shape[dim]);
272  // Fold all following dimensions that are unit-extent.
273  while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
274  ++dim;
275  reassociations.push_back(getAffineDimExpr(dim, context));
276  }
277  reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
278  origRank, /*symbolCount = */ 0, reassociations, context)));
279  reassociations.clear();
280  ++dim;
281  }
282 
283  // Compute the tensor or scalar replacement type.
284  Type elementType = getElementTypeOrSelf(opOperand->get());
285  Type replacementType;
286  if (elementType == opOperand->get().getType()) {
287  replacementType = elementType;
288  } else if (actualType.isa<RankedTensorType>()) {
289  replacementType = RankedTensorType::get(newShape, elementType);
290  } else if (actualType.isa<MemRefType>()) {
291  replacementType = MemRefType::get(newShape, elementType);
292  }
293  assert(replacementType && "unsupported shaped type");
294  UnitExtentReplacementInfo info = {replacementType,
295  AffineMap::get(indexingMap.getNumDims(),
296  indexingMap.getNumSymbols(),
297  newIndexExprs, context),
298  ArrayAttr::get(context, reassociationMaps)};
299  return info;
300 }
301 
302 namespace {
303 
305 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
306  SmallVector<ReassociationExprs, 2> reassociationExprs;
307  for (auto attr : affineMapArrayAttr)
308  reassociationExprs.push_back(
309  llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
310  return reassociationExprs;
311 }
312 
313 /// Pattern to replace tensor/buffer operands/results that are unit extents.
314 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
316 
317  // Return the original value if the type is unchanged, or reshape it. Return a
318  // nullptr if this is an unsupported type.
319  Value maybeExpand(Value result, Type origResultType,
320  ArrayAttr reassociationMap, Location loc,
321  PatternRewriter &rewriter) const {
322  if (origResultType == result.getType())
323  return result;
324  if (origResultType.isa<RankedTensorType>()) {
325  return rewriter.create<tensor::ExpandShapeOp>(
326  loc, origResultType, result,
327  convertAffineMapArrayToExprs(reassociationMap));
328  }
329  if (origResultType.isa<MemRefType>()) {
330  return rewriter.create<memref::ExpandShapeOp>(
331  loc, origResultType, result,
332  convertAffineMapArrayToExprs(reassociationMap));
333  }
334  return nullptr;
335  };
336 
337  // Return the original value if the type is unchanged, or reshape it. Return a
338  // nullptr if this is an unsupported type.
339  Value maybeCollapse(Value operand, Type newInputOutputType,
340  ArrayAttr reassociationMap, Location loc,
341  PatternRewriter &rewriter) const {
342  auto operandType = operand.getType();
343  if (operandType == newInputOutputType)
344  return operand;
345  if (operandType.isa<MemRefType>()) {
346  return rewriter.create<memref::CollapseShapeOp>(
347  loc, newInputOutputType, operand,
348  convertAffineMapArrayToExprs(reassociationMap));
349  }
350  if (operandType.isa<RankedTensorType>()) {
351  return rewriter.create<tensor::CollapseShapeOp>(
352  loc, newInputOutputType, operand,
353  convertAffineMapArrayToExprs(reassociationMap));
354  }
355  return nullptr;
356  };
357 
358  LogicalResult matchAndRewrite(GenericOp genericOp,
359  PatternRewriter &rewriter) const override {
360  // Skip the pattern if the op has any tensor with special encoding.
361  if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
362  auto tensorType = type.dyn_cast<RankedTensorType>();
363  return tensorType && tensorType.getEncoding() != nullptr;
364  }))
365  return failure();
366  MLIRContext *context = rewriter.getContext();
367  Location loc = genericOp.getLoc();
368 
369  SmallVector<AffineMap> newIndexingMaps;
370  SmallVector<ArrayAttr> reassociationMaps;
371  SmallVector<Type> newInputOutputTypes;
372  bool doCanonicalization = false;
373  for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
374  auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
375  if (replacementInfo) {
376  reassociationMaps.push_back(replacementInfo->reassociation);
377  newIndexingMaps.push_back(replacementInfo->indexMap);
378  newInputOutputTypes.push_back(replacementInfo->type);
379  doCanonicalization |=
380  replacementInfo->type != opOperand->get().getType();
381  } else {
382  // If replaceUnitExtents cannot handle this case, maintain the same
383  // type, indexing map, and create a set of mappings representing an
384  // identity matrix.
385  newInputOutputTypes.push_back(opOperand->get().getType());
386  newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
387  int64_t origRank = genericOp.getRank(opOperand);
388  auto maps = llvm::to_vector<8>(llvm::map_range(
389  llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
390  return AffineMapAttr::get(
391  AffineMap::get(origRank, /*symbolCount = */ 0,
392  getAffineDimExpr(dim, context), context));
393  }));
394  reassociationMaps.push_back(ArrayAttr::get(context, maps));
395  }
396  }
397 
398  // If the indexing maps of the result operation are not invertible (i.e. not
399  // legal), abort.
400  if (!doCanonicalization ||
401  !inversePermutation(concatAffineMaps(newIndexingMaps)))
402  return failure();
403 
404  // If any operand type change, insert a reshape to convert from the original
405  // type to the new type.
406  // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
407  unsigned flattenedIdx = 0;
408  auto insertReshapes = [&](ValueRange values) {
410  res.reserve(values.size());
411  for (auto operand : values) {
412  auto reshapedValue =
413  maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
414  reassociationMaps[flattenedIdx], loc, rewriter);
415  assert(reshapedValue &&
416  "expected ranked MemRef or Tensor operand type");
417  res.push_back(reshapedValue);
418  ++flattenedIdx;
419  }
420  return res;
421  };
422 
423  SmallVector<Value, 4> newInputs = insertReshapes(genericOp.getInputs());
424  SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.getOutputs());
425 
426  // If any result type changes, insert a reshape to convert from the original
427  // type to the new type.
428  SmallVector<Type, 4> resultTypes;
429  resultTypes.reserve(genericOp.getNumResults());
430  for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
431  resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
432  GenericOp replacementOp = rewriter.create<GenericOp>(
433  loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
434  llvm::to_vector<4>(genericOp.getIteratorTypes()
435  .template getAsValueRange<StringAttr>()));
436  rewriter.inlineRegionBefore(genericOp.getRegion(),
437  replacementOp.getRegion(),
438  replacementOp.getRegion().begin());
439 
440  // If any result tensor has a modified shape, then add reshape to recover
441  // the original shape.
442  SmallVector<Value, 4> resultReplacements;
443  for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
444  unsigned index = result.index() + replacementOp.getNumInputs();
445  auto origResultType = genericOp.getResult(result.index()).getType();
446 
447  auto newResult = maybeExpand(result.value(), origResultType,
448  reassociationMaps[index], loc, rewriter);
449  assert(newResult &&
450  "unexpected output type other than ranked MemRef or Tensor");
451  resultReplacements.push_back(newResult);
452  }
453  rewriter.replaceOp(genericOp, resultReplacements);
454  return success();
455  }
456 };
457 } // namespace
458 
459 namespace {
460 /// Convert `extract_slice` operations to rank-reduced versions.
461 struct RankReducedExtractSliceOp
462  : public OpRewritePattern<tensor::ExtractSliceOp> {
464 
465  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
466  PatternRewriter &rewriter) const override {
467  RankedTensorType resultType = sliceOp.getType();
468  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
469  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
470  SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
471  auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
472  if (!reassociation ||
473  reassociation->size() == static_cast<size_t>(resultType.getRank()))
474  return failure();
475  auto rankReducedType =
476  tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
477  reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
478  strides)
479  .cast<RankedTensorType>();
480 
481  Location loc = sliceOp.getLoc();
482  Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
483  loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
484  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
485  sliceOp, resultType, newSlice, *reassociation);
486  return success();
487  }
488 };
489 
490 /// Convert `insert_slice` operations to rank-reduced versions.
491 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
492 template <typename InsertOpTy>
493 struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
495 
496  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
497  PatternRewriter &rewriter) const override {
498  RankedTensorType sourceType = insertSliceOp.getSourceType();
499  SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
500  SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
501  SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
502  auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
503  if (!reassociation ||
504  reassociation->size() == static_cast<size_t>(sourceType.getRank()))
505  return failure();
506  Location loc = insertSliceOp.getLoc();
507  tensor::CollapseShapeOp reshapedSource;
508  {
509  OpBuilder::InsertionGuard g(rewriter);
510  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
511  // the the insertion point is just before the ParallelCombiningOp in the
512  // parallel case.
514  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
515  reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
516  loc, insertSliceOp.getSource(), *reassociation);
517  }
518  rewriter.replaceOpWithNewOp<InsertOpTy>(
519  insertSliceOp, reshapedSource, insertSliceOp.getDest(),
520  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
521  insertSliceOp.getMixedStrides());
522  return success();
523  }
524 };
525 } // namespace
526 
527 /// Patterns that are used to canonicalize the use of unit-extent dims for
528 /// broadcasting.
530  RewritePatternSet &patterns) {
531  auto *context = patterns.getContext();
532  patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
533  RankReducedInsertSliceOp<tensor::InsertSliceOp>,
534  RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
535  context);
536  linalg::FillOp::getCanonicalizationPatterns(patterns, context);
537  linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context);
538  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
539  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
540 }
541 
542 namespace {
543 /// Pass that removes unit-extent dims within generic ops.
544 struct LinalgFoldUnitExtentDimsPass
545  : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
546  void runOnOperation() override {
547  Operation *op = getOperation();
548  MLIRContext *context = op->getContext();
549  RewritePatternSet patterns(context);
550  if (foldOneTripLoopsOnly)
551  patterns.add<FoldUnitDimLoops>(context);
552  else
554  (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
555  }
556 };
557 } // namespace
558 
560  return std::make_unique<LinalgFoldUnitExtentDimsPass>();
561 }
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:653
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:514
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
Optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition: Utils.cpp:1131
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:484
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
static void replaceUnitDimIndexOps(GenericOp genericOp, const DenseSet< unsigned > &unitDims, PatternRewriter &rewriter)
Update the index accesses of linalg operations having index semantics.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
U dyn_cast() const
Definition: AffineExpr.h:281
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
U dyn_cast() const
Definition: Types.h:270
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:499
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps...
Definition: AffineMap.cpp:698
Base type for affine expression.
Definition: AffineExpr.h:68
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors.
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:489
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
std::unique_ptr< Pass > createLinalgFoldUnitExtentDimsPass()
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:634
Type getType() const
Return the type of this value.
Definition: Value.h:118
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class represents an operand of an operation.
Definition: Value.h:251
static ArrayAttr replaceUnitDims(DenseSet< unsigned > &unitDims, ArrayRef< AffineMap > indexingMaps, MLIRContext *context)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static llvm::Optional< UnitExtentReplacementInfo > replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, MLIRContext *context)
Utility function for replacing operands/results to a linalg generic operation with unit-extent dimens...
bool isa() const
Definition: Types.h:254
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
MLIRContext * getContext() const