29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/ADT/ScopeExit.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 36 #define DEBUG_TYPE "linalg-fusion" 69 bool fromSubViewOpOnly =
false) {
72 for (
OpOperand *opOperand : op.getInputAndOutputOperands()) {
79 if (fromSubViewOpOnly &&
80 !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
81 opOperand->get().getDefiningOp()))
84 AffineMap map = op.getTiedIndexingMap(opOperand);
85 LLVM_DEBUG(llvm::dbgs() <<
"getShapeDefiningLoopRange I/O idx: " 86 << opOperand->getOperandNumber() <<
"\n");
87 LLVM_DEBUG(llvm::dbgs()
88 <<
"getShapeDefiningLoopRange map: " << map <<
"\n");
94 if (loopDepth == en.value().cast<
AffineDimExpr>().getPosition()) {
95 LLVM_DEBUG(llvm::dbgs() <<
"getShapeDefiningLoopRange loopDepth: " 96 << loopDepth <<
"\n");
97 LLVM_DEBUG(llvm::dbgs() <<
"getShapeDefiningLoopRange shape: " 98 << opOperand->get() <<
"\n");
100 static_cast<unsigned>(en.index())};
104 llvm_unreachable(
"Expect to be able to extract a shape defining loop range");
108 return producer.getInputAndOutputOperands();
120 for (
unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
124 sizeBounds.push_back(dim);
125 auto it = fusedLoopsAndRanges.find(i);
126 if (it != fusedLoopsAndRanges.end()) {
127 ivs.push_back(it->second.offset);
128 tileSizes.push_back(it->second.size);
129 loopRanges.push_back(it->second);
130 LLVM_DEBUG(llvm::dbgs() <<
"tiled loop#" << i <<
" with LoopRange " 131 << loopRanges.back() <<
"\n");
135 LLVM_DEBUG(llvm::dbgs() <<
"full loop#" << i <<
" with LoopRange " 136 << loopRanges.back() <<
"\n");
141 clonedShapes.reserve(producer.getNumInputsAndOutputs());
153 resultTypes.reserve(producer->getNumResults());
154 for (RankedTensorType t : producer.getOutputTensorTypes()) {
155 unsigned rank = t.getRank();
157 rank, ShapedType::kDynamicStrideOrOffset);
160 rank, ShapedType::kDynamicStrideOrOffset);
161 resultTypes.push_back(tensor::ExtractSliceOp::inferResultType(
162 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
163 staticStridesVector));
166 Operation *clonedOp = producer.
clone(b, loc, resultTypes, clonedShapes);
170 llvm::map_range(loopRanges, [&](
Range range) {
return range.
offset; }));
179 Value shapedOperand,
unsigned dim) {
181 if (
auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
182 return subViewOp.getOrCreateRanges(b, loc)[dim];
183 if (
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp))
184 return sliceOp.getOrCreateRanges(b, loc)[dim];
185 llvm_unreachable(
"SubviewOp or ExtractSliceOp expected");
193 LLVM_DEBUG(llvm::dbgs() <<
"Producer map: " << producerMap <<
"\n");
195 Value shapedOperand = consumerOpOperand.
get();
199 b, consumerOpOperand.
getOwner()->
getLoc(), shapedOperand, en.index());
201 return fuse(b, producerOp, fusedLoopsAndRanges);
208 assert(producer.hasBufferSemantics() &&
209 "expected linalg op with buffer semantics");
210 assert(consumer.hasBufferSemantics() &&
211 "expected linalg op with buffer semantics");
212 if (producer.getNumOutputs() != 1) {
213 LLVM_DEBUG(llvm::dbgs() <<
"\nNot structurally fusable (multi-output)");
218 if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
221 <<
"\nNot structurally fusable (producer block does not dominate)");
231 assert(producer.hasBufferSemantics() &&
232 "expected linalg op with buffer semantics");
233 assert(consumer.hasBufferSemantics() &&
234 "expected linalg op with buffer semantics");
238 LLVM_DEBUG(llvm::dbgs() <<
"\n***Not static last write due to structure:\t" 239 << *producer.getOperation());
244 LLVM_DEBUG(llvm::dbgs() <<
"\n***Not fusable due to interleaved write:\t" 245 << *producer.getOperation());
252 LinalgOp consumer,
Value consumedView,
254 assert(producer.hasBufferSemantics() &&
255 "expected linalg op with buffer semantics");
256 assert(consumer.hasBufferSemantics() &&
257 "expected linalg op with buffer semantics");
263 LLVM_DEBUG(llvm::dbgs()
264 <<
"\n***Not fusable due to an interleaved dependence:\t" 265 << *producer.getOperation());
277 LLVM_DEBUG(llvm::dbgs() <<
"findFusableProducer for: " 278 << consumerOpOperand.
get() <<
" @" 280 << *consumerOpOperand.
getOwner() <<
"\n");
281 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.
getOwner());
286 for (
auto depType : {
287 LinalgDependenceGraph::DependenceType::RAW,
288 LinalgDependenceGraph::DependenceType::WAW,
290 LLVM_DEBUG(llvm::dbgs()
291 <<
"Dependencies into: " << *consumerOp.getOperation() <<
"\n");
292 for (
auto dependence : llvm::make_filter_range(
295 LLVM_DEBUG(llvm::dbgs() <<
"Inspect dependence btw: " 296 << elem.getIndexingValue() <<
" and " 297 << elem.getDependentValue() <<
"\n");
298 Value v = elem.getIndexingValue();
300 elem.getIndexingOpViewOperandNum();
301 return isa<LinalgOp>(elem.getDependentOp()) &&
302 v == consumerOpOperand.
get() && operandNum &&
307 auto producer = cast<LinalgOp>(dependence.getDependentOp());
308 LLVM_DEBUG(llvm::dbgs()
311 <<
"producer: " << *dependence.getDependentOp()
312 <<
" view: " << dependence.getDependentValue() <<
"\n");
317 if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
321 if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
322 assert(dependence.dependenceType ==
323 LinalgDependenceGraph::DependenceType::RAW);
336 if (!fusableDependence)
339 LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
345 fusableDependence->getDependentValue().getParentBlock())
349 fusableDependence->getDependentOpViewIndexingMap();
357 LLVM_DEBUG(llvm::dbgs() <<
"\nNot fusable (not a subview)");
364 LLVM_DEBUG(llvm::dbgs() <<
"Fuse into consumer: " 365 << *consumerOpOperand.
getOwner() <<
"\n");
367 auto fusedProducer =
fuse(b, producerOp, *producerMap, consumerOpOperand);
378 if (!tensor.
getType().
isa<RankedTensorType>())
382 LLVM_DEBUG(llvm::dbgs() <<
"\ngetProducerOfTensor: " << tensor);
387 if (
auto sliceOp = tensor.
getDefiningOp<tensor::ExtractSliceOp>()) {
388 tensor = sliceOp.getSource();
392 if (
auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
393 tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
403 Value inputTensor = consumerOpOperand.
get();
406 if (!producerOpResult) {
407 LLVM_DEBUG(llvm::dbgs() <<
"\nUnable to find producer");
416 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.
getOwner());
420 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.
getOwner());
424 Value inputTensor = consumerOpOperand.
get();
427 auto sliceOp = inputTensor.
getDefiningOp<tensor::ExtractSliceOp>();
429 LLVM_DEBUG(llvm::dbgs()
430 <<
"\nNot fusable, not an extract_slice op: " << inputTensor);
442 LLVM_DEBUG(llvm::dbgs() <<
"Fuse into consumer: " << *consumerOp <<
"\n");
445 LinalgOp fusedProducer =
446 fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
456 if (consumerType != def.getType())
457 def = b.
create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
458 consumerOpOperand.
set(def);
Include the generated interface declarations.
dependence_range getDependencesInto(Operation *dst, DependenceType dt) const
Returns the X such that X -> op is a dependence of type dt.
Operation is a basic unit of execution within MLIR.
bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer)
Checks whether the specific producer is the last write to exactly the whole consumedView.
This is a value defined by a result of an operation.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
static Range getRangeFromOperandShape(OpBuilder &b, Location loc, Value shapedOperand, unsigned dim)
Get the loop range for a dimension dim based on the shapedOperand.
This class represents a single result from folding an operation.
A class for computing basic dominance information.
SmallVector< Operation *, 8 > findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const
Returns the operations that are interleaved between srcLinalgOp and dstLinalgOp and that are involved...
FailureOr< FusionInfo > fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, const LinalgDependenceGraph &graph)
Fuses producer into consumer if the producer is structurally feasible and the fusion would not violat...
Implements a simple high-level fusion pass on linalg structured operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
Operation * getOwner() const
Returns the operation that owns this result.
static void getProducerOfTensor(Value tensor, OpResult &opResult)
Walk back use-def chain through scf::For yields.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer)
Checks whether fusing the specific producer of the consumedView is feasible.
This class provides support for representing a failure result, or a valid value of type T...
static FailureOr< LinalgDependenceGraph::LinalgDependenceGraphElem > findFusableProducer(OpOperand &consumerOpOperand, const LinalgDependenceGraph &dependenceGraph)
For consumer with buffer semantics, find the Linalg operation on buffers that is the last writer of c...
void set(IRValueT newValue)
Set the current value being used by this operand.
Operation * clone(BlockAndValueMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
unsigned getResultNumber() const
Returns the number of this result.
Block * getParentBlock()
Return the Block in which this Value is defined.
Data structure for holding a dependence graph that operates on LinalgOp and views as SSA values...
FailureOr< FusionInfo > fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand)
Tensor counterpart of fuseProducerOfBuffer.
unsigned getNumResults() const
Location getLoc()
The source location the operation was defined or derived from.
IRValueT get() const
Return the current value being used by this operand.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
This class represents an argument of a Block.
ArrayRef< AffineExpr > getResults() const
SmallVector< Operation *, 8 > findCoveringWrites(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const
Returns the operations that are interleaved between srcLinalgOp and dstLinalgOp and that are involved...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, const DenseMap< unsigned, Range > &fusedLoopsAndRanges)
Fuses the producer by cloning the producer.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, LinalgOp consumer)
RAII guard to reset the insertion point of the builder when destroyed.
Type getType() const
Return the type of this value.
A dimensional identifier appearing in an affine expression.
static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, bool fromSubViewOpOnly=false)
Operation * getOwner() const
Return the owner of this operand.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class represents an operand of an operation.
A struct containing the Linalg producer before and after fusion.
static SmallVector< Value > getTiledOperands(LinalgOp producer)
static StringRef getDependenceTypeStr(DependenceType depType)
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
This class helps build Operations.
IntegerAttr getIndexAttr(int64_t value)