22 genericOp.getDpsInitOperand(
result.getResultNumber());
23 if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
34 genericOp.getRegionOutputArgs()[
result.getResultNumber()];
44 auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
49 if (yieldOp.getOperand(
result.getResultNumber()) != outputArg)
67 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
68 llvm::SmallDenseMap<std::pair<Value, AffineMap>,
unsigned> dedupedInputs;
69 for (
const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
73 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
76 droppedOpOperands.push_back(inputOpOperand);
77 if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
79 droppedOpOperands.pop_back();
83 AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
85 dedupedInputs.find(std::make_pair(inputOpOperand->
get(), indexingMap));
86 if (it != dedupedInputs.end()) {
87 origToNewPos[en.index()] = it->second;
88 droppedOpOperands.push_back(inputOpOperand);
93 origToNewPos[en.index()] = newInputOperands.size();
94 dedupedInputs[{inputOpOperand->
get(), indexingMap}] =
95 newInputOperands.size();
96 newInputOperands.push_back(inputOpOperand->
get());
97 newIndexingMaps.push_back(indexingMap);
110 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
111 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>,
unsigned>
115 if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
116 for (
const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
117 origToNewPos[en.index()] = newOutputOperands.size();
118 newOutputOperands.push_back(en.value().get());
119 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));
128 auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
129 for (
const auto &outputOpOperand :
130 llvm::enumerate(genericOp.getDpsInitsMutable())) {
131 OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
133 genericOp.getMatchingIndexingMap(&outputOpOperand.value());
134 auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
135 yieldOp->getOperand(outputOpOperand.index()));
141 droppedOpOperands.push_back(&outputOpOperand.value());
142 if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
145 droppedOpOperands.pop_back();
148 if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
154 auto it = dedupedOutpts.find(key);
155 if (it != dedupedOutpts.end()) {
156 origToNewPos[outputOpOperand.index()] = it->second;
157 droppedOpOperands.push_back(&outputOpOperand.value());
162 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
163 dedupedOutpts[key] = newOutputOperands.size();
164 newOutputOperands.push_back(outputOpOperand.value().get());
165 newIndexingMaps.push_back(
166 genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
173 GenericOp genericOp, GenericOp newOp,
174 const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
175 const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
178 Block *newOpBlock = &newOp.getRegion().front();
179 assert(newOpBlock->
empty() &&
"expected new op to have an empty payload");
180 Block *origOpBlock = &genericOp.getRegion().front();
185 auto updateReplacements =
188 const llvm::SmallDenseMap<unsigned, unsigned> &map) {
189 for (
const auto &origOperand : llvm::enumerate(origOperands)) {
190 auto it = map.find(origOperand.index());
193 OpOperand *newOperand = newOperands[it->second];
194 replacements[origOperand.value()->getOperandNumber()] =
201 updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
204 genericOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
206 newOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
207 updateReplacements(origOutputOperands, newOutputOperands,
208 origOutsToNewOutsPos);
211 if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
213 YieldOp origYieldOp = cast<YieldOp>(origOpBlock->
getTerminator());
217 for (
const auto &yieldOpOperands :
218 llvm::enumerate(origYieldOp.getValues())) {
219 auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
220 if (it == origOutsToNewOutsPos.end())
222 newYieldVals[it->second] = yieldOpOperands.value();
227 rewriter.
mergeBlocks(origOpBlock, newOpBlock, replacements);
230FailureOr<linalg::GenericOp>
232 RewriterBase &rewriter, linalg::GenericOp genericOp,
bool removeOutputs) {
242 llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
247 llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
249 newIndexingMaps, removeOutputs);
252 if (newInputOperands.size() + newOutputOperands.size() ==
253 genericOp->getNumOperands())
259 for (
Value v : newOutputOperands)
260 if (isa<TensorType>(v.getType()))
261 newResultTypes.push_back(v.getType());
262 auto newOp = GenericOp::create(
263 rewriter, loc, newResultTypes, newInputOperands, newOutputOperands,
265 genericOp.getIteratorTypes(), genericOp.getDocAttr(),
266 genericOp.getLibraryCallAttr(),
273 if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
274 newOp->setAttr(kv.getName(), kv.getValue());
282 for (
const auto &
result : llvm::enumerate(genericOp.getResults())) {
283 auto it = origOutsToNewOutsPos.find(
result.index());
284 if (it == origOutsToNewOutsPos.end())
286 replacementsVals[
result.index()] = newOp.getResult(it->second);
288 rewriter.
replaceOp(genericOp, replacementsVals);
294struct DeduplicateAndRemoveDeadOperandsAndResults
296 DeduplicateAndRemoveDeadOperandsAndResults(
MLIRContext *ctx,
300 LogicalResult matchAndRewrite(GenericOp genericOp,
303 rewriter, genericOp, removeOutputs);
304 if (failed(newOp) || newOp.value() == genericOp) {
306 genericOp,
"failed to dedup operands/remove dead results");
324 using OpRewritePattern<GenericOp>::OpRewritePattern;
326 LogicalResult matchAndRewrite(GenericOp genericOp,
327 PatternRewriter &rewriter)
const override {
330 if (!genericOp.hasPureTensorSemantics())
333 bool hasRemovedCycles =
false;
335 for (
const auto &outputOpOperand :
336 llvm::enumerate(genericOp.getDpsInits())) {
339 Value
result = genericOp.getResult(outputOpOperand.index());
344 BlockArgument outputArg =
345 genericOp.getRegionOutputArgs()[outputOpOperand.index()];
355 Operation *cycleUserOp = *cycleOp->
user_begin();
356 if (!isa<linalg::YieldOp>(cycleUserOp))
360 if (cycleUserOp->
getOperand(outputOpOperand.index()) !=
368 hasRemovedCycles =
true;
371 if (hasRemovedCycles) {
391 using OpRewritePattern<GenericOp>::OpRewritePattern;
393 LogicalResult matchAndRewrite(GenericOp genericOp,
394 PatternRewriter &rewriter)
const override {
397 for (
int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
399 if (genericOp.getBody()->getArgument(i).getUses().empty())
402 for (
int j = genericOp->getNumOperands() - 1; j > i; --j) {
403 if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
404 genericOp.getIndexingMapsArray()[i] ==
405 genericOp.getIndexingMapsArray()[j]) {
413 if (replacements.empty())
418 for (
auto [before, after] : replacements) {
419 BlockArgument bbArg = genericOp.getBody()->getArgument(before);
420 BlockArgument
replacement = genericOp.getBody()->getArgument(after);
433 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
440 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
static llvm::SmallDenseMap< unsigned, unsigned > deduplicateOutputOperands(GenericOp genericOp, SmallVector< OpOperand * > &droppedOpOperands, SmallVector< Value > &newOutputOperands, SmallVector< AffineMap > &newIndexingMaps, bool removeOutputs)
static llvm::SmallDenseMap< unsigned, unsigned > deduplicateInputOperands(GenericOp genericOp, SmallVector< OpOperand * > &droppedOpOperands, SmallVector< Value > &newInputOperands, SmallVector< AffineMap > &newIndexingMaps)
static void populateOpPayload(GenericOp genericOp, GenericOp newOp, const llvm::SmallDenseMap< unsigned, unsigned > &origInsToNewInsPos, const llvm::SmallDenseMap< unsigned, unsigned > &origOutsToNewOutsPos, RewriterBase &rewriter)
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result)
Return true if the result of an operation genericOp is dead.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
user_iterator user_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
FailureOr< linalg::GenericOp > deduplicateOperandsAndRemoveDeadResults(RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs)
Method to deduplicate operands and remove dead results of linalg.generic operations.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...