MLIR  14.0.0git
DropUnitDims.cpp
Go to the documentation of this file.
1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "linalg-drop-unit-dims"
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
36 /// broadcasting. For example,
37 ///
38 /// ```mlir
39 /// #accesses = [
40 /// affine_map<(d0, d1) -> (0, d1)>,
41 /// affine_map<(d0, d1) -> (d0, 0)>,
42 /// affine_map<(d0, d1) -> (d0, d1)>
43 /// ]
44 ///
45 /// #trait = {
46 /// args_in = 2,
47 /// args_out = 1,
48 /// indexing_maps = #accesses,
49 /// iterator_types = ["parallel", "parallel"],
50 /// library_call = "some_external_fn"
51 /// }
52 ///
53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
54 /// tensor<5x5xf32>
55 /// {
56 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
57 /// tensor<5xf32> into tensor<1x5xf32>
58 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
59 /// tensor<5xf32> into tensor<5x1xf32>
60 /// %2 = linalg.generic #trait %0, %1 {
61 /// ^bb0(%arg2: f32, %arg3: f32):
62 /// %3 = arith.addf %arg2, %arg3 : f32
63 /// linalg.yield %3 : f32
64 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
65 /// return %2 : tensor<5x5xf32>
66 /// }
67 ///
68 /// would canonicalize to
69 ///
70 /// ```mlir
71 /// #accesses = [
72 /// affine_map<(d0, d1) -> (d1)>,
73 /// affine_map<(d0, d1) -> (d0)>,
74 /// affine_map<(d0, d1) -> (d0, d1)>
75 /// ]
76 ///
77 /// #trait = {
78 /// args_in = 2,
79 /// args_out = 1,
80 /// indexing_maps = #accesses,
81 /// iterator_types = ["parallel", "parallel"],
82 /// library_call = "some_external_fn"
83 /// }
84 ///
85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
86 /// tensor<5x5xf32>
87 /// {
88 /// %0 = linalg.generic #trait %arg0, %arg1 {
89 /// ^bb0(%arg2: f32, %arg3: f32):
90 /// %3 = arith.addf %arg2, %arg3 : f32
91 /// linalg.yield %3 : f32
92 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
93 /// return %0 : tensor<5x5xf32>
94 /// }
95 
96 /// Given dims of the iteration space of a structured op that are known to be
97 /// single trip count (`unitDims`), return the indexing maps to use in the
98 /// canonicalized op with these dims removed, given the original `indexingMaps`.
99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
100  ArrayRef<AffineMap> indexingMaps,
101  MLIRContext *context) {
102  if (indexingMaps.empty())
103  return nullptr;
104  unsigned numIterationDims = indexingMaps.front().getNumDims();
105  unsigned numSymbols = indexingMaps.front().getNumSymbols();
106 
107  // Compute the replacement for each dim expr.
108  SmallVector<AffineExpr, 4> dimReplacements;
109  dimReplacements.reserve(numIterationDims);
110  unsigned numKeptDims = 0;
111  for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
112  if (unitDims.count(dim))
113  dimReplacements.push_back(getAffineConstantExpr(0, context));
114  else
115  dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
116  }
117 
118  // Symbols remain the same.
119  SmallVector<AffineExpr, 4> symReplacements;
120  symReplacements.reserve(numSymbols);
121  for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
122  symReplacements.push_back(getAffineSymbolExpr(symbol, context));
123 
124  SmallVector<AffineMap, 4> newIndexingMaps;
125  newIndexingMaps.reserve(indexingMaps.size());
126  for (AffineMap operandMap : indexingMaps) {
127  // Expected indexing maps to have no symbols.
128  if (operandMap.getNumSymbols())
129  return nullptr;
130  newIndexingMaps.push_back(simplifyAffineMap(
131  operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
132  numIterationDims - unitDims.size(),
133  numSymbols)));
134  }
135 
136  // Check that the new index maps are invertible. If not, something went
137  // wrong, so abort.
138  if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
139  return nullptr;
140  return ArrayAttr::get(context,
141  llvm::to_vector<4>(llvm::map_range(
142  newIndexingMaps, [](AffineMap map) -> Attribute {
143  return AffineMapAttr::get(map);
144  })));
145 }
146 
147 /// Update the index accesses of linalg operations having index semantics.
148 static void replaceUnitDimIndexOps(GenericOp genericOp,
149  const DenseSet<unsigned> &unitDims,
150  PatternRewriter &rewriter) {
151  for (IndexOp indexOp :
152  llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
153  OpBuilder::InsertionGuard guard(rewriter);
154  rewriter.setInsertionPoint(indexOp);
155  if (unitDims.count(indexOp.dim()) != 0) {
156  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
157  } else {
158  // Update the dimension of the index operation if needed.
159  unsigned droppedDims = llvm::count_if(
160  unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
161  if (droppedDims != 0)
162  rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
163  indexOp.dim() - droppedDims);
164  }
165  }
166 }
167 
168 namespace {
169 /// Pattern to fold unit-trip count loops in GenericOps.
170 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
172  LogicalResult matchAndRewrite(GenericOp genericOp,
173  PatternRewriter &rewriter) const override {
174  SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
175  if (indexingMaps.empty())
176  return failure();
177 
178  // Check if any of the iteration dimensions are unit-trip count. They will
179  // end up being unit-trip count if they are used to index into a unit-dim
180  // tensor/memref.
181  AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
182  if (!invertedMap)
183  return failure();
184  SmallVector<int64_t> dims = genericOp.getStaticShape();
185 
186  DenseSet<unsigned> unitDims;
187  SmallVector<unsigned, 4> unitDimsReductionLoops;
188  ArrayAttr iteratorTypes = genericOp.iterator_types();
189  for (const auto &expr : enumerate(invertedMap.getResults())) {
190  if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
191  if (dims[dimExpr.getPosition()] == 1)
192  unitDims.insert(expr.index());
193  }
194 
195  if (unitDims.empty())
196  return failure();
197 
198  // Compute the modified indexing maps.
199  MLIRContext *context = rewriter.getContext();
200  ArrayAttr newIndexingMapAttr =
201  replaceUnitDims(unitDims, indexingMaps, context);
202  if (!newIndexingMapAttr)
203  return genericOp.emitError("unable to compute modified indexing_maps");
204 
205  // Compute the iterator types of the modified op by dropping the one-trip
206  // count loops.
207  SmallVector<Attribute, 4> newIteratorTypes;
208  for (const auto &attr : llvm::enumerate(iteratorTypes)) {
209  if (!unitDims.count(attr.index()))
210  newIteratorTypes.push_back(attr.value());
211  }
212 
213  rewriter.startRootUpdate(genericOp);
214  genericOp.indexing_mapsAttr(newIndexingMapAttr);
215  genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
216  replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
217  rewriter.finalizeRootUpdate(genericOp);
218  return success();
219  }
220 };
221 
222 struct UnitExtentReplacementInfo {
223  Type type;
224  AffineMap indexMap;
225  ArrayAttr reassociation;
226 };
227 } // namespace
228 
229 /// Utility function for replacing operands/results to a linalg generic
230 /// operation with unit-extent dimensions. These can be replaced with
231 /// an operand/result with the unit-extent dimension removed. This is only done
232 /// if the indexing map used to access that didimensionmension has a
233 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
234 /// Linalg op, and its `indexMap` the utility function returns:
235 /// - the new type with dimensions of size 1 removed.
236 /// - modified index map that can be used to access the replaced result/operand
237 /// - the reassociation that converts from the original tensor type to the
238 /// modified tensor type.
240 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
241  MLIRContext *context) {
242  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
243  ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
244  ArrayRef<AffineExpr> exprs = indexingMap.getResults();
245  SmallVector<AffineExpr> reassociations;
246  SmallVector<Attribute> reassociationMaps;
247  SmallVector<AffineExpr> newIndexExprs;
248  SmallVector<int64_t> newShape;
249 
250  int64_t origRank = genericOp.getRank(opOperand);
251  AffineExpr zeroExpr = getAffineConstantExpr(0, context);
252  auto isUnitExtent = [&](int64_t dim) -> bool {
253  return shape[dim] == 1 && exprs[dim] == zeroExpr;
254  };
255 
256  // Early return for memrefs with affine maps to represent that we will always
257  // leave them unchanged.
258  Type actualType = opOperand->get().getType();
259  if (auto memref = actualType.dyn_cast<MemRefType>()) {
260  if (!memref.getLayout().isIdentity())
261  return llvm::None;
262  }
263 
264  int64_t dim = 0;
265  // Fold dimensions that are unit-extent at the beginning of the tensor.
266  while (dim < origRank && isUnitExtent(dim))
267  reassociations.push_back(getAffineDimExpr(dim++, context));
268  while (dim < origRank) {
269  reassociations.push_back(getAffineDimExpr(dim, context));
270  newIndexExprs.push_back(exprs[dim]);
271  newShape.push_back(shape[dim]);
272  // Fold all following dimensions that are unit-extent.
273  while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
274  ++dim;
275  reassociations.push_back(getAffineDimExpr(dim, context));
276  }
277  reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
278  origRank, /*symbolCount = */ 0, reassociations, context)));
279  reassociations.clear();
280  ++dim;
281  }
282 
283  // Compute the tensor or scalar replacement type.
284  Type elementType = getElementTypeOrSelf(opOperand->get());
285  Type replacementType;
286  if (elementType == opOperand->get().getType()) {
287  replacementType = elementType;
288  } else if (actualType.isa<RankedTensorType>()) {
289  replacementType = RankedTensorType::get(newShape, elementType);
290  } else if (actualType.isa<MemRefType>()) {
291  replacementType = MemRefType::get(newShape, elementType);
292  }
293  assert(replacementType && "unsupported shaped type");
294  UnitExtentReplacementInfo info = {replacementType,
295  AffineMap::get(indexingMap.getNumDims(),
296  indexingMap.getNumSymbols(),
297  newIndexExprs, context),
298  ArrayAttr::get(context, reassociationMaps)};
299  return info;
300 }
301 
302 namespace {
303 
305 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
306  SmallVector<ReassociationExprs, 2> reassociationExprs;
307  for (auto attr : affineMapArrayAttr)
308  reassociationExprs.push_back(
309  llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
310  return reassociationExprs;
311 }
312 
313 /// Pattern to replace tensor/buffer operands/results that are unit extents.
314 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
316 
317  // Return the original value if the type is unchanged, or reshape it. Return a
318  // nullptr if this is an unsupported type.
319  Value maybeExpand(Value result, Type origResultType,
320  ArrayAttr reassociationMap, Location loc,
321  PatternRewriter &rewriter) const {
322  if (origResultType == result.getType())
323  return result;
324  if (origResultType.isa<RankedTensorType>()) {
325  return rewriter.create<tensor::ExpandShapeOp>(
326  loc, origResultType, result,
327  convertAffineMapArrayToExprs(reassociationMap));
328  }
329  if (origResultType.isa<MemRefType>()) {
330  return rewriter.create<memref::ExpandShapeOp>(
331  loc, origResultType, result,
332  convertAffineMapArrayToExprs(reassociationMap));
333  }
334  return nullptr;
335  };
336 
337  // Return the original value if the type is unchanged, or reshape it. Return a
338  // nullptr if this is an unsupported type.
339  Value maybeCollapse(Value operand, Type newInputOutputType,
340  ArrayAttr reassociationMap, Location loc,
341  PatternRewriter &rewriter) const {
342  auto operandType = operand.getType();
343  if (operandType == newInputOutputType)
344  return operand;
345  if (operandType.isa<MemRefType>()) {
346  return rewriter.create<memref::CollapseShapeOp>(
347  loc, newInputOutputType, operand,
348  convertAffineMapArrayToExprs(reassociationMap));
349  }
350  if (operandType.isa<RankedTensorType>()) {
351  return rewriter.create<tensor::CollapseShapeOp>(
352  loc, newInputOutputType, operand,
353  convertAffineMapArrayToExprs(reassociationMap));
354  }
355  return nullptr;
356  };
357 
358  LogicalResult matchAndRewrite(GenericOp genericOp,
359  PatternRewriter &rewriter) const override {
360  // Skip the pattern if the op has any tensor with special encoding.
361  if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
362  auto tensorType = type.dyn_cast<RankedTensorType>();
363  return tensorType && tensorType.getEncoding() != nullptr;
364  }))
365  return failure();
366  MLIRContext *context = rewriter.getContext();
367  Location loc = genericOp.getLoc();
368 
369  SmallVector<AffineMap> newIndexingMaps;
370  SmallVector<ArrayAttr> reassociationMaps;
371  SmallVector<Type> newInputOutputTypes;
372  bool doCanonicalization = false;
373  for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
374  auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
375  if (replacementInfo) {
376  reassociationMaps.push_back(replacementInfo->reassociation);
377  newIndexingMaps.push_back(replacementInfo->indexMap);
378  newInputOutputTypes.push_back(replacementInfo->type);
379  doCanonicalization |=
380  replacementInfo->type != opOperand->get().getType();
381  } else {
382  // If replaceUnitExtents cannot handle this case, maintain the same
383  // type, indexing map, and create a set of mappings representing an
384  // identity matrix.
385  newInputOutputTypes.push_back(opOperand->get().getType());
386  newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
387  int64_t origRank = genericOp.getRank(opOperand);
388  auto maps = llvm::to_vector<8>(llvm::map_range(
389  llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
390  return AffineMapAttr::get(
391  AffineMap::get(origRank, /*symbolCount = */ 0,
392  getAffineDimExpr(dim, context), context));
393  }));
394  reassociationMaps.push_back(ArrayAttr::get(context, maps));
395  }
396  }
397 
398  // If the indexing maps of the result operation are not invertible (i.e. not
399  // legal), abort.
400  if (!doCanonicalization ||
401  !inversePermutation(concatAffineMaps(newIndexingMaps)))
402  return failure();
403 
404  // If any operand type change, insert a reshape to convert from the original
405  // type to the new type.
406  // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
407  unsigned flattenedIdx = 0;
408  auto insertReshapes = [&](ValueRange values) {
410  res.reserve(values.size());
411  for (auto operand : values) {
412  auto reshapedValue =
413  maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
414  reassociationMaps[flattenedIdx], loc, rewriter);
415  assert(reshapedValue &&
416  "expected ranked MemRef or Tensor operand type");
417  res.push_back(reshapedValue);
418  ++flattenedIdx;
419  }
420  return res;
421  };
422 
423  SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
424  SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
425 
426  // If any result type changes, insert a reshape to convert from the original
427  // type to the new type.
428  SmallVector<Type, 4> resultTypes;
429  resultTypes.reserve(genericOp.getNumResults());
430  for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
431  resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
432  GenericOp replacementOp = rewriter.create<GenericOp>(
433  loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
434  llvm::to_vector<4>(
435  genericOp.iterator_types().template getAsValueRange<StringAttr>()));
436  rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
437  replacementOp.region().begin());
438 
439  // If any result tensor has a modified shape, then add reshape to recover
440  // the original shape.
441  SmallVector<Value, 4> resultReplacements;
442  for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
443  unsigned index = result.index() + replacementOp.getNumInputs();
444  auto origResultType = genericOp.getResult(result.index()).getType();
445 
446  auto newResult = maybeExpand(result.value(), origResultType,
447  reassociationMaps[index], loc, rewriter);
448  assert(newResult &&
449  "unexpected output type other than ranked MemRef or Tensor");
450  resultReplacements.push_back(newResult);
451  }
452  rewriter.replaceOp(genericOp, resultReplacements);
453  return success();
454  }
455 };
456 } // namespace
457 
458 /// Get the reassociation maps to fold the result of a extract_slice (or source
459 /// of a insert_slice) operation with given offsets, and sizes to its
460 /// rank-reduced version. This is only done for the cases where the size is 1
461 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
462 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
463 /// potentially cannot be handled).
466  SmallVector<ReassociationIndices> reassociation;
468  for (const auto &it : llvm::enumerate(mixedSizes)) {
469  auto dim = it.index();
470  auto size = it.value();
471  curr.push_back(dim);
472  auto attr = size.dyn_cast<Attribute>();
473  if (attr && attr.cast<IntegerAttr>().getInt() == 1)
474  continue;
475  reassociation.emplace_back(ReassociationIndices{});
476  std::swap(reassociation.back(), curr);
477  }
478  // When the reassociations are not empty, then fold the remaining
479  // unit-dimensions into the last dimension. If the reassociations so far is
480  // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
481  if (!curr.empty() && !reassociation.empty())
482  reassociation.back().append(curr.begin(), curr.end());
483  return reassociation;
484 }
485 
486 namespace {
487 /// Convert `extract_slice` operations to rank-reduced versions.
488 struct UseRankReducedExtractSliceOp
489  : public OpRewritePattern<tensor::ExtractSliceOp> {
491 
492  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
493  PatternRewriter &rewriter) const override {
494  RankedTensorType resultType = sliceOp.getType();
495  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
496  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
497  SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
498  auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
499  if (!reassociation ||
500  reassociation->size() == static_cast<size_t>(resultType.getRank()))
501  return failure();
502  auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType(
503  reassociation->size(), sliceOp.getSourceType(),
504  offsets, sizes, strides)
505  .cast<RankedTensorType>();
506 
507  Location loc = sliceOp.getLoc();
508  Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
509  loc, rankReducedType, sliceOp.source(), offsets, sizes, strides);
510  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
511  sliceOp, resultType, newSlice, *reassociation);
512  return success();
513  }
514 };
515 
516 /// Convert `insert_slice` operations to rank-reduced versions.
517 struct UseRankReducedInsertSliceOp
518  : public OpRewritePattern<tensor::InsertSliceOp> {
520 
521  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
522  PatternRewriter &rewriter) const override {
523  RankedTensorType sourceType = insertOp.getSourceType();
524  SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
525  SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
526  SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
527  auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
528  if (!reassociation ||
529  reassociation->size() == static_cast<size_t>(sourceType.getRank()))
530  return failure();
531  Location loc = insertOp.getLoc();
532  auto reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
533  loc, insertOp.source(), *reassociation);
534  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
535  insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
536  insertOp.getMixedSizes(), insertOp.getMixedStrides());
537  return success();
538  }
539 };
540 } // namespace
541 
542 /// Patterns that are used to canonicalize the use of unit-extent dims for
543 /// broadcasting.
545  RewritePatternSet &patterns) {
546  auto *context = patterns.getContext();
547  patterns.add<FoldUnitDimLoops, ReplaceUnitExtents,
548  UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>(
549  context);
550  linalg::FillOp::getCanonicalizationPatterns(patterns, context);
551  linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context);
552  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
553  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
554 }
555 
556 namespace {
557 /// Pass that removes unit-extent dims within generic ops.
558 struct LinalgFoldUnitExtentDimsPass
559  : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
560  void runOnOperation() override {
561  Operation *op = getOperation();
562  MLIRContext *context = op->getContext();
563  RewritePatternSet patterns(context);
564  if (foldOneTripLoopsOnly)
565  patterns.add<FoldUnitDimLoops>(context);
566  else
568  (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
569  }
570 };
571 } // namespace
572 
574  return std::make_unique<LinalgFoldUnitExtentDimsPass>();
575 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:673
MLIRContext * getContext() const
Definition: Builders.h:54
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:516
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:774
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
static void replaceUnitDimIndexOps(GenericOp genericOp, const DenseSet< unsigned > &unitDims, PatternRewriter &rewriter)
Update the index accesses of linalg operations having index semantics.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
U dyn_cast() const
Definition: AffineExpr.h:281
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
U dyn_cast() const
Definition: Types.h:244
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:501
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps...
Definition: AffineMap.cpp:718
Base type for affine expression.
Definition: AffineExpr.h:68
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:133
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors.
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:311
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:779
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
std::unique_ptr< Pass > createLinalgFoldUnitExtentDimsPass()
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:654
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class represents an operand of an operation.
Definition: Value.h:249
static ArrayAttr replaceUnitDims(DenseSet< unsigned > &unitDims, ArrayRef< AffineMap > indexingMaps, MLIRContext *context)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static llvm::Optional< UnitExtentReplacementInfo > replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, MLIRContext *context)
Utility function for replacing operands/results to a linalg generic operation with unit-extent dimens...
bool isa() const
Definition: Types.h:234
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class provides an abstraction over the different types of ranges over Values.
static Optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
MLIRContext * getContext() const
Definition: PatternMatch.h:906