27 #include "llvm/ADT/TypeSwitch.h"
30 #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
31 #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
32 #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
33 #include "mlir/Dialect/Linalg/Passes.h.inc"
53 res.push_back(b.
create<affine::AffineApplyOp>(loc, exprMap, operands));
58 template <
typename LoadOpTy,
typename StoreOpTy,
typename OpType>
65 map.
map(block.getArguments(), indexedValues);
66 for (
auto &op : block.without_terminator()) {
67 auto *newOp = b.
clone(op, map);
71 Operation *terminator = block.getTerminator();
74 b.
create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
75 indexing[operand.getOperandNumber()]);
85 template <
typename SingleInputPoolingOp>
88 SingleInputPoolingOp op) {
89 auto mapsRange = op.getIndexingMapsArray();
90 auto maps = llvm::to_vector<8>(
91 llvm::map_range(mapsRange, [](AffineMapAttr a) {
return a.getValue(); }));
127 template <
typename LoadOpTy,
typename StoreOpTy>
131 assert(linalgOp.hasPureBufferSemantics() &&
132 "expected linalg op with buffer semantics");
134 indexedValues.reserve(linalgOp->getNumOperands());
141 for (
OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
142 if (linalgOp.isScalar(inputOperand)) {
143 indexedValues.push_back(inputOperand->get());
147 b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
148 indexedValues.push_back(
149 b.
create<LoadOpTy>(loc, inputOperand->get(), indexing));
152 for (
OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
154 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
156 indexedValues.push_back(
157 b.
create<LoadOpTy>(loc, outputOperand.get(), indexing));
165 for (
OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
166 if (!isa<MemRefType>(outputOperand.get().getType()))
169 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
171 outputBuffers.push_back(outputOperand.get());
173 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
174 indexing, outputBuffers);
186 .Case([&](scf::ParallelOp parallelOp) {
187 allIvs.append(parallelOp.getInductionVars().begin(),
188 parallelOp.getInductionVars().end());
190 .Case([&](scf::ForOp forOp) {
191 allIvs.push_back(forOp.getInductionVar());
193 .Case([&](affine::AffineForOp affineForOp) {
194 allIvs.push_back(affineForOp.getInductionVar());
196 .Default([&](
Operation *op) { assert(
false &&
"unexpected op"); });
198 assert(linalgOp.getNumLoops() == allIvs.size() &&
199 "expected the number of loops and induction variables to match");
201 if (!loopOps.empty()) {
202 auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
203 for (
Region *r : loopOp.getLoopRegions())
204 for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
205 rewriter.
replaceOp(indexOp, allIvs[indexOp.getDim()]);
209 template <
typename LoopTy>
213 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
214 affine::AffineLoadOp, memref::LoadOp>;
216 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
217 affine::AffineStoreOp, memref::StoreOp>;
221 assert(linalgOp.hasPureBufferSemantics() &&
222 "expected linalg op with buffer semantics");
224 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
225 auto iteratorTypes = linalgOp.getIteratorTypesArray();
229 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
232 assert(operandValuesToUse == linalgOp->getOperands() &&
233 "expect operands are captured and not passed by loop argument");
234 allIvs.append(ivs.begin(), ivs.end());
235 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
236 return scf::ValueVector{};
241 for (
Value iv : allIvs) {
258 template <
typename LoopType>
266 auto linalgOp = dyn_cast<LinalgOp>(op);
267 if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
269 op,
"expected linalg op with buffer semantics");
271 if (
failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
290 :
RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
294 auto affineApplyOp = cast<affine::AffineApplyOp>(op);
295 auto map = affineApplyOp.getAffineMap();
296 if (map.getNumResults() != 1 || map.getNumInputs() > 1)
300 if (map.getNumInputs() == 0) {
301 if (
auto val = dyn_cast<AffineConstantExpr>(expr)) {
307 if (dyn_cast<AffineDimExpr>(expr) || dyn_cast<AffineSymbolExpr>(expr)) {
315 template <
typename LoopType>
316 static void lowerLinalgToLoopsImpl(
Operation *enclosingOp) {
319 patterns.add<LinalgRewritePattern<LoopType>>(context);
320 memref::DimOp::getCanonicalizationPatterns(patterns, context);
321 tensor::DimOp::getCanonicalizationPatterns(patterns, context);
322 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
323 patterns.add<FoldAffineOp>(context);
328 struct LowerToAffineLoops
329 :
public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
330 using impl::ConvertLinalgToAffineLoopsPassBase<
331 LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
333 registry.
insert<memref::MemRefDialect>();
335 void runOnOperation()
override {
336 lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
340 struct LowerToLoops :
public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
341 using impl::ConvertLinalgToLoopsPassBase<
342 LowerToLoops>::ConvertLinalgToLoopsPassBase;
344 registry.
insert<memref::MemRefDialect, scf::SCFDialect>();
346 void runOnOperation()
override {
347 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
351 struct LowerToParallelLoops
352 :
public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
353 using impl::ConvertLinalgToParallelLoopsPassBase<
354 LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
355 void runOnOperation()
override {
356 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
365 return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
371 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
378 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
static SmallVector< Value > makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef< Value > vals)
static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< Operation * > loopOps)
Replace the index operations in the body of the loop nest by the matching induction variables.
static InputAndOutputIndices getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, SingleInputPoolingOp op)
static FailureOr< LinalgLoops > linalgOpToLoopsImpl(RewriterBase &rewriter, LinalgOp linalgOp)
static void emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, LinalgOp linalgOp)
Emits the MLIR for the scalar part of the generic op by:
static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, ArrayRef< Value > indexedValues, ArrayRef< SmallVector< Value >> indexing, ArrayRef< Value > outputBuffers)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
unsigned getNumInputs() const
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
MLIRContext * getContext()
Return the context this operation is associated with.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
FailureOr< LinalgLoops > linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.for with the proper body for linalgOp.
FailureOr< LinalgLoops > linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of affine.for with the proper body for linalgOp.
FailureOr< LinalgLoops > linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.parallel with the proper body for linalgOp.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})