27 #include "llvm/ADT/TypeSwitch.h"
30 #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
31 #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
32 #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
33 #include "mlir/Dialect/Linalg/Passes.h.inc"
53 res.push_back(b.
create<affine::AffineApplyOp>(loc, exprMap, operands));
58 template <
typename LoadOpTy,
typename StoreOpTy,
typename OpType>
63 auto &block = op->getRegion(0).front();
65 map.
map(block.getArguments(), indexedValues);
66 for (
auto &op : block.without_terminator()) {
67 auto *newOp = b.
clone(op, map);
68 map.
map(op.getResults(), newOp->getResults());
71 Operation *terminator = block.getTerminator();
74 b.
create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
75 indexing[operand.getOperandNumber()]);
85 template <
typename SingleInputPoolingOp>
88 SingleInputPoolingOp op) {
89 auto mapsRange = op.getIndexingMapsArray();
90 auto maps = llvm::to_vector<8>(
91 llvm::map_range(mapsRange, [](AffineMapAttr a) {
return a.getValue(); }));
127 template <
typename LoadOpTy,
typename StoreOpTy>
131 assert(linalgOp.hasPureBufferSemantics() &&
132 "expected linalg op with buffer semantics");
134 indexedValues.reserve(linalgOp->getNumOperands());
141 for (
OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
142 if (linalgOp.isScalar(inputOperand)) {
143 indexedValues.push_back(inputOperand->get());
147 b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
148 indexedValues.push_back(
149 b.
create<LoadOpTy>(loc, inputOperand->get(), indexing));
152 for (
OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
154 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
156 indexedValues.push_back(
157 b.
create<LoadOpTy>(loc, outputOperand.get(), indexing));
165 for (
OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
166 if (!isa<MemRefType>(outputOperand.get().getType()))
169 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
171 outputBuffers.push_back(outputOperand.get());
173 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
174 indexing, outputBuffers);
186 .Case([&](scf::ParallelOp parallelOp) {
187 allIvs.append(parallelOp.getInductionVars());
189 .Case([&](scf::ForOp forOp) {
190 allIvs.push_back(forOp.getInductionVar());
192 .Case([&](affine::AffineForOp affineForOp) {
193 allIvs.push_back(affineForOp.getInductionVar());
195 .Default([&](
Operation *op) { assert(
false &&
"unexpected op"); });
197 assert(linalgOp.getNumLoops() == allIvs.size() &&
198 "expected the number of loops and induction variables to match");
200 if (!loopOps.empty()) {
201 auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
202 for (
Region *r : loopOp.getLoopRegions())
203 for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
204 rewriter.
replaceOp(indexOp, allIvs[indexOp.getDim()]);
208 template <
typename LoopTy>
212 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
213 affine::AffineLoadOp, memref::LoadOp>;
215 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
216 affine::AffineStoreOp, memref::StoreOp>;
220 assert(linalgOp.hasPureBufferSemantics() &&
221 "expected linalg op with buffer semantics");
223 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
224 auto iteratorTypes = linalgOp.getIteratorTypesArray();
228 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
231 assert(operandValuesToUse == linalgOp->getOperands() &&
232 "expect operands are captured and not passed by loop argument");
233 allIvs.append(ivs.begin(), ivs.end());
234 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
235 return scf::ValueVector{};
240 for (
Value iv : allIvs) {
257 template <
typename LoopType>
263 LogicalResult matchAndRewrite(
Operation *op,
265 auto linalgOp = dyn_cast<LinalgOp>(op);
266 if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
268 op,
"expected linalg op with buffer semantics");
270 if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
289 :
RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
291 LogicalResult matchAndRewrite(
Operation *op,
293 auto affineApplyOp = cast<affine::AffineApplyOp>(op);
294 auto map = affineApplyOp.getAffineMap();
295 if (map.getNumResults() != 1 || map.getNumInputs() > 1)
299 if (map.getNumInputs() == 0) {
300 if (
auto val = dyn_cast<AffineConstantExpr>(expr)) {
306 if (dyn_cast<AffineDimExpr>(expr) || dyn_cast<AffineSymbolExpr>(expr)) {
314 template <
typename LoopType>
315 static void lowerLinalgToLoopsImpl(
Operation *enclosingOp) {
318 patterns.add<LinalgRewritePattern<LoopType>>(context);
319 memref::DimOp::getCanonicalizationPatterns(
patterns, context);
320 tensor::DimOp::getCanonicalizationPatterns(
patterns, context);
321 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
322 patterns.add<FoldAffineOp>(context);
327 struct LowerToAffineLoops
328 :
public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
329 using impl::ConvertLinalgToAffineLoopsPassBase<
330 LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
332 registry.
insert<memref::MemRefDialect>();
334 void runOnOperation()
override {
335 lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
339 struct LowerToLoops :
public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
340 using impl::ConvertLinalgToLoopsPassBase<
341 LowerToLoops>::ConvertLinalgToLoopsPassBase;
343 registry.
insert<memref::MemRefDialect, scf::SCFDialect>();
345 void runOnOperation()
override {
346 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
350 struct LowerToParallelLoops
351 :
public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
352 using impl::ConvertLinalgToParallelLoopsPassBase<
353 LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
354 void runOnOperation()
override {
355 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
362 FailureOr<LinalgLoops>
364 return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
370 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
374 FailureOr<LinalgLoops>
377 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
static SmallVector< Value > makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef< Value > vals)
static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< Operation * > loopOps)
Replace the index operations in the body of the loop nest by the matching induction variables.
static InputAndOutputIndices getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, SingleInputPoolingOp op)
static FailureOr< LinalgLoops > linalgOpToLoopsImpl(RewriterBase &rewriter, LinalgOp linalgOp)
static void emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, LinalgOp linalgOp)
Emits the MLIR for the scalar part of the generic op by:
static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, ArrayRef< Value > indexedValues, ArrayRef< SmallVector< Value >> indexing, ArrayRef< Value > outputBuffers)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
unsigned getNumInputs() const
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
MLIRContext * getContext()
Return the context this operation is associated with.
MutableArrayRef< OpOperand > getOpOperands()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
FailureOr< LinalgLoops > linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.for with the proper body for linalgOp.
FailureOr< LinalgLoops > linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of affine.for with the proper body for linalgOp.
FailureOr< LinalgLoops > linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.parallel with the proper body for linalgOp.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})