MLIR  19.0.0git
Generalization.cpp
Go to the documentation of this file.
1 //===- Generalization.cpp - linalg named ops to generic ops --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
30 #include "mlir/Dialect/Linalg/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "linalg-generalization"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
39  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
40  // trivially generalize a `linalg.map`, as it does not use the output as
41  // region arguments in the block.
42  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
43  return failure();
44  // Check if the operation has exactly one region.
45  if (linalgOp->getNumRegions() != 1) {
46  assert(linalgOp->getNumRegions() == 0 && "op with multiple regions");
47  // TOD: Otherwise it needs to be built explicitly from the region builder.
48  return failure();
49  }
50  return success();
51 }
52 
54  LinalgOp linalgOp) {
56  return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
57 
58  SmallVector<Value> inputs = linalgOp.getDpsInputs();
59  ValueRange outputs = linalgOp.getDpsInits();
60  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
61  SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
62  SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
63  ? TypeRange(ValueRange(outputs))
64  : TypeRange{};
65 
66  // All named ops have a region attached that can be inlined.
67  assert(linalgOp->getNumRegions() == 1 &&
68  "expect named op to have one region attached");
69  GenericOp genericOp = rewriter.create<GenericOp>(
70  linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
71  rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
72  genericOp.getRegion().begin());
73  rewriter.replaceOp(linalgOp, genericOp->getResults());
74  return genericOp;
75 }
76 
77 namespace {
78 
79 struct LinalgGeneralizeNamedOpsPass
80  : public impl::LinalgGeneralizeNamedOpsPassBase<
81  LinalgGeneralizeNamedOpsPass> {
82  using impl::LinalgGeneralizeNamedOpsPassBase<
83  LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
84  void runOnOperation() override;
85 };
86 
87 } // namespace
88 
89 void LinalgGeneralizeNamedOpsPass::runOnOperation() {
90  RewritePatternSet patterns(&getContext());
92  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
93 }
94 
96  RewritePatternSet &patterns) {
97  patterns.add<LinalgGeneralizationPattern>(patterns.getContext());
98 }
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp)
static MLIRContext * getContext(OpFoldResult val)
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns)
Linalg generalization patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Linalg generalization pattern.
Definition: Transforms.h:1309