1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
18 #include "mlir/Pass/Pass.h"
20 namespace mlir {
22 #include "mlir/Conversion/"
23 } // namespace mlir
25 using namespace mlir;
26 using namespace mlir::linalg;
28 static MemRefType makeStridedLayoutDynamic(MemRefType type) {
30  type.getContext(), ShapedType::kDynamic,
31  SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic)));
32 }
34 /// Helper function to extract the operand types that are passed to the
35 /// generated CallOp. MemRefTypes have their layout canonicalized since the
36 /// information is not used in signature generation.
37 /// Note that static size information is not modified.
39  SmallVector<Type, 4> result;
40  result.reserve(op->getNumOperands());
41  for (auto type : op->getOperandTypes()) {
42  // The underlying descriptor type (e.g. LLVM) does not have layout
43  // information. Canonicalizing the type at the level of std when going into
44  // a library call avoids needing to introduce DialectCastOp.
45  if (auto memrefType = dyn_cast<MemRefType>(type))
46  result.push_back(makeStridedLayoutDynamic(memrefType));
47  else
48  result.push_back(type);
49  }
50  return result;
51 }
53 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
54 // If the library function does not exist, insert a declaration.
57  auto linalgOp = cast<LinalgOp>(op);
58  auto fnName = linalgOp.getLibraryCallName();
59  if (fnName.empty())
60  return rewriter.notifyMatchFailure(op, "No library call defined for: ");
62  // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
63  FlatSymbolRefAttr fnNameAttr =
64  SymbolRefAttr::get(rewriter.getContext(), fnName);
65  auto module = op->getParentOfType<ModuleOp>();
66  if (module.lookupSymbol(fnNameAttr.getAttr()))
67  return fnNameAttr;
70  if (op->getNumResults() != 0) {
71  return rewriter.notifyMatchFailure(
72  op,
73  "Library call for linalg operation can be generated only for ops that "
74  "have void return types");
75  }
76  auto libFnType = rewriter.getFunctionType(inputTypes, {});
78  OpBuilder::InsertionGuard guard(rewriter);
79  // Insert before module terminator.
80  rewriter.setInsertionPoint(module.getBody(),
81  std::prev(module.getBody()->end()));
82  func::FuncOp funcOp = rewriter.create<func::FuncOp>(
83  op->getLoc(), fnNameAttr.getValue(), libFnType);
84  // Insert a function attribute that will trigger the emission of the
85  // corresponding `_mlir_ciface_xxx` interface so that external libraries see
86  // a normalized ABI. This interface is added during std to llvm conversion.
87  funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
88  UnitAttr::get(op->getContext()));
89  funcOp.setPrivate();
90  return fnNameAttr;
91 }
95  ValueRange operands) {
97  res.reserve(operands.size());
98  for (auto op : operands) {
99  auto memrefType = dyn_cast<MemRefType>(op.getType());
100  if (!memrefType) {
101  res.push_back(op);
102  continue;
103  }
104  Value cast =
105  b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op);
106  res.push_back(cast);
107  }
108  return res;
109 }
112  LinalgOp op, PatternRewriter &rewriter) const {
113  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
114  if (failed(libraryCallName))
115  return failure();
117  // TODO: Add support for more complex library call signatures that include
118  // indices or captured values.
119  rewriter.replaceOpWithNewOp<func::CallOp>(
120  op, libraryCallName->getValue(), TypeRange(),
122  op->getOperands()));
123  return success();
124 }
126 /// Populate the given list with patterns that convert from Linalg to Standard.
128  RewritePatternSet &patterns) {
129  // TODO: ConvOp conversion needs to export a descriptor with relevant
130  // attribute values such as kernel striding and dilation.
131  patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
132 }
134 namespace {
135 struct ConvertLinalgToStandardPass
136  : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
137  void runOnOperation() override;
138 };
139 } // namespace
141 void ConvertLinalgToStandardPass::runOnOperation() {
142  auto module = getOperation();
143  ConversionTarget target(getContext());
144  target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145  func::FuncDialect, memref::MemRefDialect,
146  scf::SCFDialect>();
147  target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
148  RewritePatternSet patterns(&getContext());
150  if (failed(applyFullConversion(module, target, std::move(patterns))))
151  signalPassFailure();
152 }
154 std::unique_ptr<OperationPass<ModuleOp>>
156  return std::make_unique<ConvertLinalgToStandardPass>();
157 }
