MLIR  16.0.0git
Generalization.cpp
Go to the documentation of this file.
1 //===- Generalization.cpp - linalg named ops to generic ops --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_LINALGGENERALIZATION
30 #include "mlir/Dialect/Linalg/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "linalg-generalization"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
39  // Check if the operation is a LinalgOp but not a GenericOp.
40  if (isa<GenericOp>(linalgOp))
41  return failure();
42  // Check if the operation has a region builder.
43  if (!linalgOp.getRegionBuilder())
44  return failure();
45  return success();
46 }
47 
49  LinalgOp linalgOp) {
51  return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
52 
53  SmallVector<Value> inputOperands = linalgOp.getInputOperands();
54  SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
55  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
56  SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
57  SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
58  SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
59 
60  // All named ops have a region attached that can be inlined.
61  assert(linalgOp->getNumRegions() == 1 &&
62  "expect named op to have one region attached");
63  GenericOp genericOp =
64  rewriter.create<GenericOp>(linalgOp.getLoc(), types, inputOperands,
65  outputOperands, indexingMaps, iterators);
66  rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
67  genericOp.getRegion().begin());
68  rewriter.replaceOp(linalgOp, genericOp->getResults());
69  return genericOp;
70 }
71 
72 namespace {
73 
74 struct LinalgGeneralizationPass
75  : public impl::LinalgGeneralizationBase<LinalgGeneralizationPass> {
76  void runOnOperation() override;
77 };
78 
79 } // namespace
80 
81 void LinalgGeneralizationPass::runOnOperation() {
82  func::FuncOp func = getOperation();
83  RewritePatternSet patterns(&getContext());
85  (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
86 }
87 
89  RewritePatternSet &patterns, const LinalgTransformationFilter &marker) {
90  patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
91 }
92 
93 std::unique_ptr<OperationPass<func::FuncOp>>
95  return std::make_unique<LinalgGeneralizationPass>();
96 }
Include the generated interface declarations.
Helper class to control application of linalg transformation patterns.
Definition: Transforms.h:362
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:418
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
std::unique_ptr< OperationPass< func::FuncOp > > createLinalgGeneralizationPass()
Create a pass to convert named Linalg operations to Linalg generic operations.
Linalg generalization pattern.
Definition: Transforms.h:855
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns, const LinalgTransformationFilter &filter=LinalgTransformationFilter())
Linalg generalization patterns.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398