MLIR  22.0.0git
LinalgToStandard.cpp
Go to the documentation of this file.
1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_CONVERTLINALGTOSTANDARDPASS
21 #include "mlir/Conversion/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace mlir;
25 using namespace mlir::linalg;
26 
27 static MemRefType makeStridedLayoutDynamic(MemRefType type) {
29  type.getContext(), ShapedType::kDynamic,
30  SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic)));
31 }
32 
33 /// Helper function to extract the operand types that are passed to the
34 /// generated CallOp. MemRefTypes have their layout canonicalized since the
35 /// information is not used in signature generation.
36 /// Note that static size information is not modified.
38  SmallVector<Type, 4> result;
39  result.reserve(op->getNumOperands());
40  for (auto type : op->getOperandTypes()) {
41  // The underlying descriptor type (e.g. LLVM) does not have layout
42  // information. Canonicalizing the type at the level of std when going into
43  // a library call avoids needing to introduce DialectCastOp.
44  if (auto memrefType = dyn_cast<MemRefType>(type))
45  result.push_back(makeStridedLayoutDynamic(memrefType));
46  else
47  result.push_back(type);
48  }
49  return result;
50 }
51 
52 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
53 // If the library function does not exist, insert a declaration.
54 static FailureOr<FlatSymbolRefAttr>
56  auto linalgOp = cast<LinalgOp>(op);
57  auto fnName = linalgOp.getLibraryCallName();
58  if (fnName.empty())
59  return rewriter.notifyMatchFailure(op, "No library call defined for: ");
60 
61  // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
62  FlatSymbolRefAttr fnNameAttr =
63  SymbolRefAttr::get(rewriter.getContext(), fnName);
64  auto module = op->getParentOfType<ModuleOp>();
65  if (module.lookupSymbol(fnNameAttr.getAttr()))
66  return fnNameAttr;
67 
69  if (op->getNumResults() != 0) {
70  return rewriter.notifyMatchFailure(
71  op,
72  "Library call for linalg operation can be generated only for ops that "
73  "have void return types");
74  }
75  auto libFnType = rewriter.getFunctionType(inputTypes, {});
76 
77  OpBuilder::InsertionGuard guard(rewriter);
78  // Insert before module terminator.
79  rewriter.setInsertionPoint(module.getBody(),
80  std::prev(module.getBody()->end()));
81  func::FuncOp funcOp = func::FuncOp::create(rewriter, op->getLoc(),
82  fnNameAttr.getValue(), libFnType);
83  // Insert a function attribute that will trigger the emission of the
84  // corresponding `_mlir_ciface_xxx` interface so that external libraries see
85  // a normalized ABI. This interface is added during std to llvm conversion.
86  funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
87  UnitAttr::get(op->getContext()));
88  funcOp.setPrivate();
89  return fnNameAttr;
90 }
91 
94  ValueRange operands) {
96  res.reserve(operands.size());
97  for (auto op : operands) {
98  auto memrefType = dyn_cast<MemRefType>(op.getType());
99  if (!memrefType) {
100  res.push_back(op);
101  continue;
102  }
103  Value cast = memref::CastOp::create(
104  b, loc, makeStridedLayoutDynamic(memrefType), op);
105  res.push_back(cast);
106  }
107  return res;
108 }
109 
111  LinalgOp op, PatternRewriter &rewriter) const {
112  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
113  if (failed(libraryCallName))
114  return failure();
115 
116  // TODO: Add support for more complex library call signatures that include
117  // indices or captured values.
118  rewriter.replaceOpWithNewOp<func::CallOp>(
119  op, libraryCallName->getValue(), TypeRange(),
120  createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
121  op->getOperands()));
122  return success();
123 }
124 
125 /// Populate the given list with patterns that convert from Linalg to Standard.
128  // TODO: ConvOp conversion needs to export a descriptor with relevant
129  // attribute values such as kernel striding and dilation.
131 }
132 
133 namespace {
134 struct ConvertLinalgToStandardPass
135  : public impl::ConvertLinalgToStandardPassBase<
136  ConvertLinalgToStandardPass> {
137  void runOnOperation() override;
138 };
139 } // namespace
140 
141 void ConvertLinalgToStandardPass::runOnOperation() {
142  auto module = getOperation();
143  ConversionTarget target(getContext());
144  target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145  func::FuncDialect, memref::MemRefDialect,
146  scf::SCFDialect>();
147  target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
150  if (failed(applyFullConversion(module, target, std::move(patterns))))
151  signalPassFailure();
152 }
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< Value, 4 > createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, ValueRange operands)
static SmallVector< Type, 4 > extractOperandTypes(Operation *op)
Helper function to extract the operand types that are passed to the generated CallOp.
static FailureOr< FlatSymbolRefAttr > getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter)
static MemRefType makeStridedLayoutDynamic(MemRefType type)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
MLIRContext * getContext() const
Definition: Builders.h:55
This class describes a specific conversion target.
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:397
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Linalg to Standard.
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...