20#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARDPASS
21#include "mlir/Conversion/Passes.h.inc"
29 type.getContext(), ShapedType::kDynamic,
44 if (
auto memrefType = dyn_cast<MemRefType>(type))
54static FailureOr<FlatSymbolRefAttr>
56 auto linalgOp = cast<LinalgOp>(op);
57 auto fnName = linalgOp.getLibraryCallName();
63 SymbolRefAttr::get(rewriter.
getContext(), fnName);
64 auto module = op->getParentOfType<ModuleOp>();
65 if (module.lookupSymbol(fnNameAttr.
getAttr()))
72 "Library call for linalg operation can be generated only for ops that "
73 "have void return types");
80 std::prev(module.getBody()->end()));
81 func::FuncOp funcOp = func::FuncOp::create(rewriter, op->
getLoc(),
86 funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
96 res.reserve(operands.size());
97 for (
auto op : operands) {
98 auto memrefType = dyn_cast<MemRefType>(op.getType());
103 Value cast = memref::CastOp::create(
113 if (failed(libraryCallName))
119 op, libraryCallName->getValue(),
TypeRange(),
134struct ConvertLinalgToStandardPass
136 ConvertLinalgToStandardPass> {
137 void runOnOperation()
override;
141void ConvertLinalgToStandardPass::runOnOperation() {
142 auto module = getOperation();
144 target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145 func::FuncDialect, memref::MemRefDialect,
147 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
static SmallVector< Value, 4 > createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, ValueRange operands)
static SmallVector< Type, 4 > extractOperandTypes(Operation *op)
Helper function to extract the operand types that are passed to the generated CallOp.
static FailureOr< FlatSymbolRefAttr > getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter)
static MemRefType makeStridedLayoutDynamic(MemRefType type)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Linalg to Standard.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns