MLIR  19.0.0git
LinalgToStandard.cpp
Go to the documentation of this file.
1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/Pass/Pass.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
22 #include "mlir/Conversion/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 static MemRefType makeStridedLayoutDynamic(MemRefType type) {
30  type.getContext(), ShapedType::kDynamic,
31  SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic)));
32 }
33 
34 /// Helper function to extract the operand types that are passed to the
35 /// generated CallOp. MemRefTypes have their layout canonicalized since the
36 /// information is not used in signature generation.
37 /// Note that static size information is not modified.
39  SmallVector<Type, 4> result;
40  result.reserve(op->getNumOperands());
41  for (auto type : op->getOperandTypes()) {
42  // The underlying descriptor type (e.g. LLVM) does not have layout
43  // information. Canonicalizing the type at the level of std when going into
44  // a library call avoids needing to introduce DialectCastOp.
45  if (auto memrefType = dyn_cast<MemRefType>(type))
46  result.push_back(makeStridedLayoutDynamic(memrefType));
47  else
48  result.push_back(type);
49  }
50  return result;
51 }
52 
53 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
54 // If the library function does not exist, insert a declaration.
57  auto linalgOp = cast<LinalgOp>(op);
58  auto fnName = linalgOp.getLibraryCallName();
59  if (fnName.empty())
60  return rewriter.notifyMatchFailure(op, "No library call defined for: ");
61 
62  // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
63  FlatSymbolRefAttr fnNameAttr =
64  SymbolRefAttr::get(rewriter.getContext(), fnName);
65  auto module = op->getParentOfType<ModuleOp>();
66  if (module.lookupSymbol(fnNameAttr.getAttr()))
67  return fnNameAttr;
68 
70  if (op->getNumResults() != 0) {
71  return rewriter.notifyMatchFailure(
72  op,
73  "Library call for linalg operation can be generated only for ops that "
74  "have void return types");
75  }
76  auto libFnType = rewriter.getFunctionType(inputTypes, {});
77 
78  OpBuilder::InsertionGuard guard(rewriter);
79  // Insert before module terminator.
80  rewriter.setInsertionPoint(module.getBody(),
81  std::prev(module.getBody()->end()));
82  func::FuncOp funcOp = rewriter.create<func::FuncOp>(
83  op->getLoc(), fnNameAttr.getValue(), libFnType);
84  // Insert a function attribute that will trigger the emission of the
85  // corresponding `_mlir_ciface_xxx` interface so that external libraries see
86  // a normalized ABI. This interface is added during std to llvm conversion.
87  funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
88  UnitAttr::get(op->getContext()));
89  funcOp.setPrivate();
90  return fnNameAttr;
91 }
92 
95  ValueRange operands) {
97  res.reserve(operands.size());
98  for (auto op : operands) {
99  auto memrefType = dyn_cast<MemRefType>(op.getType());
100  if (!memrefType) {
101  res.push_back(op);
102  continue;
103  }
104  Value cast =
105  b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op);
106  res.push_back(cast);
107  }
108  return res;
109 }
110 
112  LinalgOp op, PatternRewriter &rewriter) const {
113  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
114  if (failed(libraryCallName))
115  return failure();
116 
117  // TODO: Add support for more complex library call signatures that include
118  // indices or captured values.
119  rewriter.replaceOpWithNewOp<func::CallOp>(
120  op, libraryCallName->getValue(), TypeRange(),
122  op->getOperands()));
123  return success();
124 }
125 
126 /// Populate the given list with patterns that convert from Linalg to Standard.
128  RewritePatternSet &patterns) {
129  // TODO: ConvOp conversion needs to export a descriptor with relevant
130  // attribute values such as kernel striding and dilation.
131  patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
132 }
133 
134 namespace {
135 struct ConvertLinalgToStandardPass
136  : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
137  void runOnOperation() override;
138 };
139 } // namespace
140 
141 void ConvertLinalgToStandardPass::runOnOperation() {
142  auto module = getOperation();
143  ConversionTarget target(getContext());
144  target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145  func::FuncDialect, memref::MemRefDialect,
146  scf::SCFDialect>();
147  target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
148  RewritePatternSet patterns(&getContext());
150  if (failed(applyFullConversion(module, target, std::move(patterns))))
151  signalPassFailure();
152 }
153 
154 std::unique_ptr<OperationPass<ModuleOp>>
156  return std::make_unique<ConvertLinalgToStandardPass>();
157 }
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< Value, 4 > createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, ValueRange operands)
static SmallVector< Type, 4 > extractOperandTypes(Operation *op)
Helper function to extract the operand types that are passed to the generated CallOp.
static FailureOr< FlatSymbolRefAttr > getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter)
static MemRefType makeStridedLayoutDynamic(MemRefType type)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:55
This class describes a specific conversion target.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:222
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:537
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Linalg to Standard.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::unique_ptr< OperationPass< ModuleOp > > createConvertLinalgToStandardPass()
Create a pass to convert Linalg operations to the Standard dialect.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26