MLIR 22.0.0git
Generalization.cpp
Go to the documentation of this file.
1//===- Generalization.cpp - linalg named ops to generic ops --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg generalization pass. It converts named
10// Linalg ops to linalg.generic ops.
11//
12//===----------------------------------------------------------------------===//
13
15
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/Builders.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
26} // namespace mlir
27
28#define DEBUG_TYPE "linalg-generalization"
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
34 // Bailout if `linalgOp` is already a generic.
35 if (isa<GenericOp>(linalgOp))
36 return failure();
37 // Check if the operation has exactly one region.
38 if (linalgOp->getNumRegions() != 1) {
39 assert(linalgOp->getNumRegions() == 0 && "op with multiple regions");
40 // TOD: Otherwise it needs to be built explicitly from the region builder.
41 return failure();
42 }
43 return success();
44}
45
46FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
47 LinalgOp linalgOp) {
48 if (failed(generalizeNamedOpPrecondition(linalgOp)))
49 return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
50
51 SmallVector<Value> inputs = linalgOp.getDpsInputs();
52 ValueRange outputs = linalgOp.getDpsInits();
53 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
54 SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
55 SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
56 ? TypeRange(ValueRange(outputs))
57 : TypeRange{};
58
59 // All named ops have a region attached that can be inlined.
60 assert(linalgOp->getNumRegions() == 1 &&
61 "expect named op to have one region attached");
62 GenericOp genericOp =
63 GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes, inputs,
64 outputs, indexingMaps, iterators);
65 rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
66 genericOp.getRegion().begin());
67 rewriter.replaceOp(linalgOp, genericOp->getResults());
68 return genericOp;
69}
70
71namespace {
72
73struct LinalgGeneralizeNamedOpsPass
75 LinalgGeneralizeNamedOpsPass> {
77 LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
78 void runOnOperation() override;
79};
80
81} // namespace
82
83void LinalgGeneralizeNamedOpsPass::runOnOperation() {
84 RewritePatternSet patterns(&getContext());
86 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
87}
88
return success()
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp)
b getContext())
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns)
Linalg generalization patterns.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Linalg generalization pattern.