MLIR 22.0.0git
LinalgToStandard.cpp
Go to the documentation of this file.
1//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
18
19namespace mlir {
20#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARDPASS
21#include "mlir/Conversion/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25using namespace mlir::linalg;
26
27static MemRefType makeStridedLayoutDynamic(MemRefType type) {
28 return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get(
29 type.getContext(), ShapedType::kDynamic,
30 SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic)));
31}
32
33/// Helper function to extract the operand types that are passed to the
34/// generated CallOp. MemRefTypes have their layout canonicalized since the
35/// information is not used in signature generation.
36/// Note that static size information is not modified.
39 result.reserve(op->getNumOperands());
40 for (auto type : op->getOperandTypes()) {
41 // The underlying descriptor type (e.g. LLVM) does not have layout
42 // information. Canonicalizing the type at the level of std when going into
43 // a library call avoids needing to introduce DialectCastOp.
44 if (auto memrefType = dyn_cast<MemRefType>(type))
45 result.push_back(makeStridedLayoutDynamic(memrefType));
46 else
47 result.push_back(type);
48 }
49 return result;
50}
51
52// Get a SymbolRefAttr containing the library function name for the LinalgOp.
53// If the library function does not exist, insert a declaration.
54static FailureOr<FlatSymbolRefAttr>
56 auto linalgOp = cast<LinalgOp>(op);
57 auto fnName = linalgOp.getLibraryCallName();
58 if (fnName.empty())
59 return rewriter.notifyMatchFailure(op, "No library call defined for: ");
60
61 // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
62 FlatSymbolRefAttr fnNameAttr =
63 SymbolRefAttr::get(rewriter.getContext(), fnName);
64 auto module = op->getParentOfType<ModuleOp>();
65 if (module.lookupSymbol(fnNameAttr.getAttr()))
66 return fnNameAttr;
67
69 if (op->getNumResults() != 0) {
70 return rewriter.notifyMatchFailure(
71 op,
72 "Library call for linalg operation can be generated only for ops that "
73 "have void return types");
74 }
75 auto libFnType = rewriter.getFunctionType(inputTypes, {});
76
77 OpBuilder::InsertionGuard guard(rewriter);
78 // Insert before module terminator.
79 rewriter.setInsertionPoint(module.getBody(),
80 std::prev(module.getBody()->end()));
81 func::FuncOp funcOp = func::FuncOp::create(rewriter, op->getLoc(),
82 fnNameAttr.getValue(), libFnType);
83 // Insert a function attribute that will trigger the emission of the
84 // corresponding `_mlir_ciface_xxx` interface so that external libraries see
85 // a normalized ABI. This interface is added during std to llvm conversion.
86 funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
87 UnitAttr::get(op->getContext()));
88 funcOp.setPrivate();
89 return fnNameAttr;
90}
91
94 ValueRange operands) {
96 res.reserve(operands.size());
97 for (auto op : operands) {
98 auto memrefType = dyn_cast<MemRefType>(op.getType());
99 if (!memrefType) {
100 res.push_back(op);
101 continue;
102 }
103 Value cast = memref::CastOp::create(
104 b, loc, makeStridedLayoutDynamic(memrefType), op);
105 res.push_back(cast);
106 }
107 return res;
108}
109
111 LinalgOp op, PatternRewriter &rewriter) const {
112 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
113 if (failed(libraryCallName))
114 return failure();
115
116 // TODO: Add support for more complex library call signatures that include
117 // indices or captured values.
118 rewriter.replaceOpWithNewOp<func::CallOp>(
119 op, libraryCallName->getValue(), TypeRange(),
120 createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
121 op->getOperands()));
122 return success();
123}
124
125/// Populate the given list with patterns that convert from Linalg to Standard.
128 // TODO: ConvOp conversion needs to export a descriptor with relevant
129 // attribute values such as kernel striding and dilation.
131}
132
133namespace {
134struct ConvertLinalgToStandardPass
136 ConvertLinalgToStandardPass> {
137 void runOnOperation() override;
138};
139} // namespace
140
141void ConvertLinalgToStandardPass::runOnOperation() {
142 auto module = getOperation();
143 ConversionTarget target(getContext());
144 target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145 func::FuncDialect, memref::MemRefDialect,
146 scf::SCFDialect>();
147 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
148 RewritePatternSet patterns(&getContext());
150 if (failed(applyFullConversion(module, target, std::move(patterns))))
151 signalPassFailure();
152}
return success()
static SmallVector< Value, 4 > createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, ValueRange operands)
static SmallVector< Type, 4 > extractOperandTypes(Operation *op)
Helper function to extract the operand types that are passed to the generated CallOp.
static FailureOr< FlatSymbolRefAttr > getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter)
static MemRefType makeStridedLayoutDynamic(MemRefType type)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
MLIRContext * getContext() const
Definition Builders.h:56
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_type_range getOperandTypes()
Definition Operation.h:397
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Linalg to Standard.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns