26 #include "llvm/ADT/TypeSwitch.h"
29 #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
30 #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
31 #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
32 #include "mlir/Dialect/Linalg/Passes.h.inc"
52 res.push_back(b.
create<affine::AffineApplyOp>(loc, exprMap, operands));
57 template <
typename LoadOpTy,
typename StoreOpTy,
typename OpType>
62 auto &block = op->getRegion(0).front();
64 map.
map(block.getArguments(), indexedValues);
65 for (
auto &op : block.without_terminator()) {
66 auto *newOp = b.
clone(op, map);
67 map.
map(op.getResults(), newOp->getResults());
70 Operation *terminator = block.getTerminator();
73 b.
create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
74 indexing[operand.getOperandNumber()]);
84 template <
typename SingleInputPoolingOp>
87 SingleInputPoolingOp op) {
88 auto mapsRange = op.getIndexingMapsArray();
89 auto maps = llvm::to_vector<8>(
90 llvm::map_range(mapsRange, [](AffineMapAttr a) {
return a.getValue(); }));
126 template <
typename LoadOpTy,
typename StoreOpTy>
130 assert(linalgOp.hasPureBufferSemantics() &&
131 "expected linalg op with buffer semantics");
133 indexedValues.reserve(linalgOp->getNumOperands());
140 for (
OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
141 if (linalgOp.isScalar(inputOperand)) {
142 indexedValues.push_back(inputOperand->get());
146 b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
147 indexedValues.push_back(
148 b.
create<LoadOpTy>(loc, inputOperand->get(), indexing));
151 for (
OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
153 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
155 indexedValues.push_back(
156 b.
create<LoadOpTy>(loc, outputOperand.get(), indexing));
164 for (
OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
165 if (!isa<MemRefType>(outputOperand.get().getType()))
168 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
170 outputBuffers.push_back(outputOperand.get());
172 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
173 indexing, outputBuffers);
185 .Case([&](scf::ParallelOp parallelOp) {
186 allIvs.append(parallelOp.getInductionVars());
188 .Case([&](scf::ForOp forOp) {
189 allIvs.push_back(forOp.getInductionVar());
191 .Case([&](affine::AffineForOp affineForOp) {
192 allIvs.push_back(affineForOp.getInductionVar());
194 .Default([&](
Operation *op) { assert(
false &&
"unexpected op"); });
196 assert(linalgOp.getNumLoops() == allIvs.size() &&
197 "expected the number of loops and induction variables to match");
199 if (!loopOps.empty()) {
200 auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
201 for (
Region *r : loopOp.getLoopRegions())
202 for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
203 rewriter.
replaceOp(indexOp, allIvs[indexOp.getDim()]);
207 template <
typename LoopTy>
211 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
212 affine::AffineLoadOp, memref::LoadOp>;
214 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
215 affine::AffineStoreOp, memref::StoreOp>;
219 assert(linalgOp.hasPureBufferSemantics() &&
220 "expected linalg op with buffer semantics");
222 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
223 auto iteratorTypes = linalgOp.getIteratorTypesArray();
227 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
230 assert(operandValuesToUse == linalgOp->getOperands() &&
231 "expect operands are captured and not passed by loop argument");
232 allIvs.append(ivs.begin(), ivs.end());
233 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
234 return scf::ValueVector{};
239 for (
Value iv : allIvs) {
256 template <
typename LoopType>
262 LogicalResult matchAndRewrite(
Operation *op,
264 auto linalgOp = dyn_cast<LinalgOp>(op);
265 if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
267 op,
"expected linalg op with buffer semantics");
269 if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
288 :
RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
290 LogicalResult matchAndRewrite(
Operation *op,
292 auto affineApplyOp = cast<affine::AffineApplyOp>(op);
293 auto map = affineApplyOp.getAffineMap();
294 if (map.getNumResults() != 1 || map.getNumInputs() > 1)
298 if (map.getNumInputs() == 0) {
299 if (
auto val = dyn_cast<AffineConstantExpr>(expr)) {
305 if (isa<AffineDimExpr, AffineSymbolExpr>(expr)) {
313 template <
typename LoopType>
314 static void lowerLinalgToLoopsImpl(
Operation *enclosingOp) {
317 patterns.add<LinalgRewritePattern<LoopType>>(context);
318 memref::DimOp::getCanonicalizationPatterns(
patterns, context);
319 tensor::DimOp::getCanonicalizationPatterns(
patterns, context);
320 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
321 patterns.add<FoldAffineOp>(context);
326 struct LowerToAffineLoops
327 :
public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
328 using impl::ConvertLinalgToAffineLoopsPassBase<
329 LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
331 registry.
insert<memref::MemRefDialect>();
333 void runOnOperation()
override {
334 lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
338 struct LowerToLoops :
public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
339 using impl::ConvertLinalgToLoopsPassBase<
340 LowerToLoops>::ConvertLinalgToLoopsPassBase;
342 registry.
insert<memref::MemRefDialect, scf::SCFDialect>();
344 void runOnOperation()
override {
345 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
349 struct LowerToParallelLoops
350 :
public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
351 using impl::ConvertLinalgToParallelLoopsPassBase<
352 LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
353 void runOnOperation()
override {
354 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
361 FailureOr<LinalgLoops>
363 return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
369 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
373 FailureOr<LinalgLoops>
376 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
static SmallVector< Value > makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef< Value > vals)
static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< Operation * > loopOps)
Replace the index operations in the body of the loop nest by the matching induction variables.
static InputAndOutputIndices getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, SingleInputPoolingOp op)
static FailureOr< LinalgLoops > linalgOpToLoopsImpl(RewriterBase &rewriter, LinalgOp linalgOp)
static void emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, LinalgOp linalgOp)
Emits the MLIR for the scalar part of the generic op by:
static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, ArrayRef< Value > indexedValues, ArrayRef< SmallVector< Value >> indexing, ArrayRef< Value > outputBuffers)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
unsigned getNumInputs() const
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
MLIRContext * getContext()
Return the context this operation is associated with.
MutableArrayRef< OpOperand > getOpOperands()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
FailureOr< LinalgLoops > linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.for with the proper body for linalgOp.
FailureOr< LinalgLoops > linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of affine.for with the proper body for linalgOp.
FailureOr< LinalgLoops > linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.parallel with the proper body for linalgOp.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})