21 #define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
22 #include "mlir/Conversion/Passes.h.inc"
30 type.getContext(), ShapedType::kDynamic,
45 if (
auto memrefType = dyn_cast<MemRefType>(type))
48 result.push_back(type);
55 static FailureOr<FlatSymbolRefAttr>
57 auto linalgOp = cast<LinalgOp>(op);
58 auto fnName = linalgOp.getLibraryCallName();
66 if (module.lookupSymbol(fnNameAttr.
getAttr()))
73 "Library call for linalg operation can be generated only for ops that "
74 "have void return types");
81 std::prev(module.getBody()->end()));
82 func::FuncOp funcOp = rewriter.
create<func::FuncOp>(
87 funcOp->
setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
97 res.reserve(operands.size());
98 for (
auto op : operands) {
99 auto memrefType = dyn_cast<MemRefType>(op.getType());
114 if (failed(libraryCallName))
120 op, libraryCallName->getValue(),
TypeRange(),
135 struct ConvertLinalgToStandardPass
136 :
public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
137 void runOnOperation()
override;
141 void ConvertLinalgToStandardPass::runOnOperation() {
142 auto module = getOperation();
144 target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145 func::FuncDialect, memref::MemRefDialect,
147 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
154 std::unique_ptr<OperationPass<ModuleOp>>
156 return std::make_unique<ConvertLinalgToStandardPass>();
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< Value, 4 > createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, ValueRange operands)
static SmallVector< Type, 4 > extractOperandTypes(Operation *op)
Helper function to extract the operand types that are passed to the generated CallOp.
static FailureOr< FlatSymbolRefAttr > getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter)
static MemRefType makeStridedLayoutDynamic(MemRefType type)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
MLIRContext * getContext() const
This class describes a specific conversion target.
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Linalg to Standard.
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createConvertLinalgToStandardPass()
Create a pass to convert Linalg operations to the Standard dialect.