MLIR  14.0.0git
LinalgToStandard.cpp
Go to the documentation of this file.
1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../PassDetail.h"
16 #include "mlir/Dialect/SCF/SCF.h"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 
22 /// Helper function to extract the operand types that are passed to the
23 /// generated CallOp. MemRefTypes have their layout canonicalized since the
24 /// information is not used in signature generation.
25 /// Note that static size information is not modified.
27  SmallVector<Type, 4> result;
28  result.reserve(op->getNumOperands());
29  for (auto type : op->getOperandTypes()) {
30  // The underlying descriptor type (e.g. LLVM) does not have layout
31  // information. Canonicalizing the type at the level of std when going into
32  // a library call avoids needing to introduce DialectCastOp.
33  if (auto memrefType = type.dyn_cast<MemRefType>())
34  result.push_back(eraseStridedLayout(memrefType));
35  else
36  result.push_back(type);
37  }
38  return result;
39 }
40 
41 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
42 // If the library function does not exist, insert a declaration.
44  PatternRewriter &rewriter) {
45  auto linalgOp = cast<LinalgOp>(op);
46  auto fnName = linalgOp.getLibraryCallName();
47  if (fnName.empty()) {
48  op->emitWarning("No library call defined for: ") << *op;
49  return {};
50  }
51 
52  // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
53  FlatSymbolRefAttr fnNameAttr =
54  SymbolRefAttr::get(rewriter.getContext(), fnName);
55  auto module = op->getParentOfType<ModuleOp>();
56  if (module.lookupSymbol(fnNameAttr.getAttr()))
57  return fnNameAttr;
58 
60  assert(op->getNumResults() == 0 &&
61  "Library call for linalg operation can be generated only for ops that "
62  "have void return types");
63  auto libFnType = rewriter.getFunctionType(inputTypes, {});
64 
65  OpBuilder::InsertionGuard guard(rewriter);
66  // Insert before module terminator.
67  rewriter.setInsertionPoint(module.getBody(),
68  std::prev(module.getBody()->end()));
69  FuncOp funcOp =
70  rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
71  // Insert a function attribute that will trigger the emission of the
72  // corresponding `_mlir_ciface_xxx` interface so that external libraries see
73  // a normalized ABI. This interface is added during std to llvm conversion.
74  funcOp->setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
75  funcOp.setPrivate();
76  return fnNameAttr;
77 }
78 
81  ValueRange operands) {
83  res.reserve(operands.size());
84  for (auto op : operands) {
85  auto memrefType = op.getType().dyn_cast<MemRefType>();
86  if (!memrefType) {
87  res.push_back(op);
88  continue;
89  }
90  Value cast =
91  b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op);
92  res.push_back(cast);
93  }
94  return res;
95 }
96 
98  LinalgOp op, PatternRewriter &rewriter) const {
99  // Only LinalgOp for which there is no specialized pattern go through this.
100  if (isa<CopyOp>(op))
101  return failure();
102 
103  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
104  if (!libraryCallName)
105  return failure();
106 
107  // TODO: Add support for more complex library call signatures that include
108  // indices or captured values.
109  rewriter.replaceOpWithNewOp<mlir::CallOp>(
110  op, libraryCallName.getValue(), TypeRange(),
111  createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
112  op->getOperands()));
113  return success();
114 }
115 
117  CopyOp op, PatternRewriter &rewriter) const {
118  auto inputPerm = op.inputPermutation();
119  if (inputPerm.hasValue() && !inputPerm->isIdentity())
120  return failure();
121  auto outputPerm = op.outputPermutation();
122  if (outputPerm.hasValue() && !outputPerm->isIdentity())
123  return failure();
124 
125  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
126  if (!libraryCallName)
127  return failure();
128 
129  rewriter.replaceOpWithNewOp<mlir::CallOp>(
130  op, libraryCallName.getValue(), TypeRange(),
131  createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
132  op.getOperands()));
133  return success();
134 }
135 
137  CopyOp op, PatternRewriter &rewriter) const {
138  Value in = op.input(), out = op.output();
139 
140  // If either inputPerm or outputPerm are non-identities, insert transposes.
141  auto inputPerm = op.inputPermutation();
142  if (inputPerm.hasValue() && !inputPerm->isIdentity())
143  in = rewriter.create<memref::TransposeOp>(op.getLoc(), in,
144  AffineMapAttr::get(*inputPerm));
145  auto outputPerm = op.outputPermutation();
146  if (outputPerm.hasValue() && !outputPerm->isIdentity())
147  out = rewriter.create<memref::TransposeOp>(op.getLoc(), out,
148  AffineMapAttr::get(*outputPerm));
149 
150  // If nothing was transposed, fail and let the conversion kick in.
151  if (in == op.input() && out == op.output())
152  return failure();
153 
154  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
155  if (!libraryCallName)
156  return failure();
157 
158  rewriter.replaceOpWithNewOp<mlir::CallOp>(
159  op, libraryCallName.getValue(), TypeRange(),
160  createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
161  return success();
162 }
163 
164 /// Populate the given list with patterns that convert from Linalg to Standard.
166  RewritePatternSet &patterns) {
167  // TODO: ConvOp conversion needs to export a descriptor with relevant
168  // attribute values such as kernel striding and dilation.
169  // clang-format off
170  patterns.add<
174  // clang-format on
175 }
176 
177 namespace {
178 struct ConvertLinalgToStandardPass
179  : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
180  void runOnOperation() override;
181 };
182 } // namespace
183 
184 void ConvertLinalgToStandardPass::runOnOperation() {
185  auto module = getOperation();
186  ConversionTarget target(getContext());
187  target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
188  memref::MemRefDialect, scf::SCFDialect,
189  StandardOpsDialect>();
190  target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
191  RewritePatternSet patterns(&getContext());
193  if (failed(applyFullConversion(module, target, std::move(patterns))))
194  signalPassFailure();
195 }
196 
197 std::unique_ptr<OperationPass<ModuleOp>>
199  return std::make_unique<ConvertLinalgToStandardPass>();
200 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
LogicalResult matchAndRewrite(CopyOp op, PatternRewriter &rewriter) const override
static SmallVector< Type, 4 > extractOperandTypes(Operation *op)
Helper function to extract the operand types that are passed to the generated CallOp.
LogicalResult matchAndRewrite(CopyOp op, PatternRewriter &rewriter) const override
MLIRContext * getContext() const
Definition: Builders.h:54
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MemRefType eraseStridedLayout(MemRefType t)
Return a version of t with a layout that has all dynamic offset and strides.
A symbol reference with a reference path containing a single element.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
static SmallVector< Value, 4 > createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, ValueRange operands)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:215
operand_type_range getOperandTypes()
Definition: Operation.h:266
static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter)
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:120
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
LogicalResult applyFullConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
std::unique_ptr< OperationPass< ModuleOp > > createConvertLinalgToStandardPass()
Create a pass to convert Linalg operations to the Standard dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
StringRef getValue() const
Returns the name of the held symbol reference.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
Rewrite CopyOp with permutations into a sequence of TransposeOp and permutation-free CopyOp...
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening...
Definition: Operation.cpp:243
Rewrite pattern specialization for CopyOp, kicks in when both input and output permutations are left ...
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
void addLegalOp()
Register the given operations as legal.
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Linalg to Standard.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
This class describes a specific conversion target.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:67
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
MLIRContext * getContext() const
Definition: PatternMatch.h:906