23 if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
44 auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
67 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
68 llvm::SmallDenseMap<std::pair<Value, AffineMap>,
unsigned> dedupedInputs;
69 for (
const auto &en :
llvm::enumerate(genericOp.getDpsInputOperands())) {
73 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
76 droppedOpOperands.push_back(inputOpOperand);
77 if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
79 droppedOpOperands.pop_back();
83 AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
85 dedupedInputs.find(std::make_pair(inputOpOperand->
get(), indexingMap));
86 if (it != dedupedInputs.end()) {
87 origToNewPos[en.index()] = it->second;
88 droppedOpOperands.push_back(inputOpOperand);
93 origToNewPos[en.index()] = newInputOperands.size();
94 dedupedInputs[{inputOpOperand->
get(), indexingMap}] =
95 newInputOperands.size();
96 newInputOperands.push_back(inputOpOperand->
get());
97 newIndexingMaps.push_back(indexingMap);
110 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
111 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>,
unsigned>
115 if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
116 for (
const auto &en :
llvm::enumerate(genericOp.getDpsInitsMutable())) {
117 origToNewPos[en.index()] = newOutputOperands.size();
118 newOutputOperands.push_back(en.value().get());
119 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));
128 auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
129 for (
const auto &outputOpOperand :
131 OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
133 genericOp.getMatchingIndexingMap(&outputOpOperand.value());
134 auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
135 yieldOp->getOperand(outputOpOperand.index()));
141 droppedOpOperands.push_back(&outputOpOperand.value());
142 if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
145 droppedOpOperands.pop_back();
148 if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
154 auto it = dedupedOutpts.find(key);
155 if (it != dedupedOutpts.end()) {
156 origToNewPos[outputOpOperand.index()] = it->second;
157 droppedOpOperands.push_back(&outputOpOperand.value());
162 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
163 dedupedOutpts[key] = newOutputOperands.size();
164 newOutputOperands.push_back(outputOpOperand.value().get());
165 newIndexingMaps.push_back(
166 genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
173 GenericOp genericOp, GenericOp newOp,
174 const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
175 const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
178 Block *newOpBlock = &newOp.getRegion().
front();
179 assert(newOpBlock->
empty() &&
"expected new op to have an empty payload");
180 Block *origOpBlock = &genericOp.getRegion().
front();
185 auto updateReplacements =
188 const llvm::SmallDenseMap<unsigned, unsigned> &map) {
190 auto it = map.find(origOperand.index());
193 OpOperand *newOperand = newOperands[it->second];
194 replacements[origOperand.value()->getOperandNumber()] =
201 updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
204 genericOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
206 newOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
207 updateReplacements(origOutputOperands, newOutputOperands,
208 origOutsToNewOutsPos);
211 if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
213 YieldOp origYieldOp = cast<YieldOp>(origOpBlock->
getTerminator());
217 for (
const auto &yieldOpOperands :
219 auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
220 if (it == origOutsToNewOutsPos.end())
222 newYieldVals[it->second] = yieldOpOperands.value();
227 rewriter.
mergeBlocks(origOpBlock, newOpBlock, replacements);
230 FailureOr<linalg::GenericOp>
232 RewriterBase &rewriter, linalg::GenericOp genericOp,
bool removeOutputs) {
242 llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
247 llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
249 newIndexingMaps, removeOutputs);
252 if (newInputOperands.size() + newOutputOperands.size() ==
253 genericOp->getNumOperands())
259 for (
Value v : newOutputOperands)
260 if (isa<TensorType>(v.getType()))
261 newResultTypes.push_back(v.getType());
262 auto newOp = rewriter.
create<GenericOp>(
263 loc, newResultTypes, newInputOperands, newOutputOperands,
265 genericOp.getIteratorTypes(), genericOp.getDocAttr(),
266 genericOp.getLibraryCallAttr(),
273 if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
274 newOp->setAttr(kv.getName(), kv.getValue());
283 auto it = origOutsToNewOutsPos.find(result.index());
284 if (it == origOutsToNewOutsPos.end())
286 replacementsVals[result.index()] = newOp.getResult(it->second);
288 rewriter.
replaceOp(genericOp, replacementsVals);
294 struct DeduplicateAndRemoveDeadOperandsAndResults
296 DeduplicateAndRemoveDeadOperandsAndResults(
MLIRContext *ctx,
300 LogicalResult matchAndRewrite(GenericOp genericOp,
303 rewriter, genericOp, removeOutputs);
304 if (failed(newOp) || newOp.value() == genericOp) {
306 genericOp,
"failed to dedup operands/remove dead results");
326 LogicalResult matchAndRewrite(GenericOp genericOp,
330 if (!genericOp.hasPureTensorSemantics())
333 bool hasRemovedCycles =
false;
335 for (
const auto &outputOpOperand :
339 Value result = genericOp.getResult(outputOpOperand.index());
345 genericOp.getRegionOutputArgs()[outputOpOperand.index()];
356 if (!isa<linalg::YieldOp>(cycleUserOp))
360 if (cycleUserOp->
getOperand(outputOpOperand.index()) !=
368 hasRemovedCycles =
true;
371 if (hasRemovedCycles) {
393 LogicalResult matchAndRewrite(GenericOp genericOp,
397 for (
int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
399 if (genericOp.getBody()->getArgument(i).getUses().empty())
402 for (
int j = genericOp->getNumOperands() - 1;
j > i; --
j) {
403 if (genericOp->getOperand(i) == genericOp->getOperand(
j) &&
404 genericOp.getIndexingMapsArray()[i] ==
405 genericOp.getIndexingMapsArray()[
j]) {
413 if (replacements.empty())
418 for (
auto [before, after] : replacements) {
419 BlockArgument bbArg = genericOp.getBody()->getArgument(before);
420 BlockArgument replacement = genericOp.getBody()->getArgument(after);
433 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
440 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
static llvm::SmallDenseMap< unsigned, unsigned > deduplicateOutputOperands(GenericOp genericOp, SmallVector< OpOperand * > &droppedOpOperands, SmallVector< Value > &newOutputOperands, SmallVector< AffineMap > &newIndexingMaps, bool removeOutputs)
static llvm::SmallDenseMap< unsigned, unsigned > deduplicateInputOperands(GenericOp genericOp, SmallVector< OpOperand * > &droppedOpOperands, SmallVector< Value > &newInputOperands, SmallVector< AffineMap > &newIndexingMaps)
static void populateOpPayload(GenericOp genericOp, GenericOp newOp, const llvm::SmallDenseMap< unsigned, unsigned > &origInsToNewInsPos, const llvm::SmallDenseMap< unsigned, unsigned > &origOutsToNewOutsPos, RewriterBase &rewriter)
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result)
Return true if the result of an operation genericOp is dead.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
user_iterator user_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< linalg::GenericOp > deduplicateOperandsAndRemoveDeadResults(RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs)
Method to deduplicate operands and remove dead results of linalg.generic operations.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.