23 if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
44 auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
57 struct DeduplicateAndRemoveDeadOperandsAndResults
59 DeduplicateAndRemoveDeadOperandsAndResults(
MLIRContext *ctx,
63 LogicalResult matchAndRewrite(GenericOp genericOp,
74 llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
75 deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
79 llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
80 deduplicateOutputOperands(genericOp, droppedOpOperands,
81 newOutputOperands, newIndexingMaps);
84 if (newInputOperands.size() + newOutputOperands.size() ==
85 genericOp->getNumOperands())
91 for (
Value v : newOutputOperands)
92 if (isa<TensorType>(v.getType()))
93 newResultTypes.push_back(v.getType());
94 auto newOp = rewriter.
create<GenericOp>(
95 loc, newResultTypes, newInputOperands, newOutputOperands,
97 genericOp.getIteratorTypes(), genericOp.getDocAttr(),
98 genericOp.getLibraryCallAttr(),
105 if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
106 newOp->setAttr(kv.getName(), kv.getValue());
109 populateOpPayload(genericOp, newOp, origInsToNewInsPos,
110 origOutsToNewOutsPos, rewriter);
115 auto it = origOutsToNewOutsPos.find(result.index());
116 if (it == origOutsToNewOutsPos.end())
118 replacementsVals[result.index()] = newOp.getResult(it->second);
120 rewriter.
replaceOp(genericOp, replacementsVals);
132 llvm::SmallDenseMap<unsigned, unsigned>
133 deduplicateInputOperands(GenericOp genericOp,
137 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
138 llvm::SmallDenseMap<std::pair<Value, AffineMap>,
unsigned> dedupedInputs;
139 for (
const auto &en :
llvm::enumerate(genericOp.getDpsInputOperands())) {
143 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
146 droppedOpOperands.push_back(inputOpOperand);
147 if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
149 droppedOpOperands.pop_back();
153 AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
154 auto it = dedupedInputs.find(
155 std::make_pair(inputOpOperand->
get(), indexingMap));
156 if (it != dedupedInputs.end()) {
157 origToNewPos[en.index()] = it->second;
158 droppedOpOperands.push_back(inputOpOperand);
163 origToNewPos[en.index()] = newInputOperands.size();
164 dedupedInputs[{inputOpOperand->
get(), indexingMap}] =
165 newInputOperands.size();
166 newInputOperands.push_back(inputOpOperand->
get());
167 newIndexingMaps.push_back(indexingMap);
176 llvm::SmallDenseMap<unsigned, unsigned>
177 deduplicateOutputOperands(GenericOp genericOp,
181 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
182 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>,
unsigned>
186 if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
187 for (
const auto &en :
llvm::enumerate(genericOp.getDpsInitsMutable())) {
188 origToNewPos[en.index()] = newOutputOperands.size();
189 newOutputOperands.push_back(en.value().get());
190 newIndexingMaps.push_back(
191 genericOp.getMatchingIndexingMap(&en.value()));
200 auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
201 for (
const auto &outputOpOperand :
203 OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
205 genericOp.getMatchingIndexingMap(&outputOpOperand.value());
206 auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
207 yieldOp->getOperand(outputOpOperand.index()));
213 droppedOpOperands.push_back(&outputOpOperand.value());
214 if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
217 droppedOpOperands.pop_back();
220 if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
226 auto it = dedupedOutpts.find(key);
227 if (it != dedupedOutpts.end()) {
228 origToNewPos[outputOpOperand.index()] = it->second;
229 droppedOpOperands.push_back(&outputOpOperand.value());
234 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
235 dedupedOutpts[key] = newOutputOperands.size();
236 newOutputOperands.push_back(outputOpOperand.value().get());
237 newIndexingMaps.push_back(
238 genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
244 void populateOpPayload(
245 GenericOp genericOp, GenericOp newOp,
246 const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
247 const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
250 Block *newOpBlock = &newOp.getRegion().
front();
251 assert(newOpBlock->
empty() &&
"expected new op to have an empty payload");
252 Block *origOpBlock = &genericOp.getRegion().
front();
257 auto updateReplacements =
260 const llvm::SmallDenseMap<unsigned, unsigned> &map) {
262 auto it = map.find(origOperand.index());
265 OpOperand *newOperand = newOperands[it->second];
266 replacements[origOperand.value()->getOperandNumber()] =
272 genericOp.getDpsInputOperands();
274 updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
277 llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
280 llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
282 updateReplacements(origOutputOperands, newOutputOperands,
283 origOutsToNewOutsPos);
286 if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
288 YieldOp origYieldOp = cast<YieldOp>(origOpBlock->
getTerminator());
292 for (
const auto &yieldOpOperands :
294 auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
295 if (it == origOutsToNewOutsPos.end())
297 newYieldVals[it->second] = yieldOpOperands.value();
302 rewriter.
mergeBlocks(origOpBlock, newOpBlock, replacements);
316 LogicalResult matchAndRewrite(GenericOp genericOp,
320 if (!genericOp.hasPureTensorSemantics())
323 bool hasRemovedCycles =
false;
325 for (
const auto &outputOpOperand :
329 Value result = genericOp.getResult(outputOpOperand.index());
335 genericOp.getRegionOutputArgs()[outputOpOperand.index()];
346 if (!isa<linalg::YieldOp>(cycleUserOp))
350 if (cycleUserOp->
getOperand(outputOpOperand.index()) !=
358 hasRemovedCycles =
true;
361 if (hasRemovedCycles) {
383 LogicalResult matchAndRewrite(GenericOp genericOp,
387 for (
int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
389 if (genericOp.getBody()->getArgument(i).getUses().empty())
392 for (
int j = genericOp->getNumOperands() - 1;
j > i; --
j) {
393 if (genericOp->getOperand(i) == genericOp->getOperand(
j) &&
394 genericOp.getIndexingMapsArray()[i] ==
395 genericOp.getIndexingMapsArray()[
j]) {
403 if (replacements.empty())
408 for (
auto [before, after] : replacements) {
409 BlockArgument bbArg = genericOp.getBody()->getArgument(before);
410 BlockArgument replacement = genericOp.getBody()->getArgument(after);
423 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
430 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result)
Return true if the result of an operation genericOp is dead.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
user_iterator user_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.