87 LogicalResult matchAndRewrite(GenericOp genericOp,
93 GenericOp createPeeledGenericOp(GenericOp genericOp,
98 GenericOp createResidualGenericOp(GenericOp genericOp,
99 GenericOp peeledGenericOp,
110 auto allShapesSizes =
111 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
112 AffineMap map = op.getShapesToLoopsMap();
123 for (
const auto &position :
125 return cast<AffineDimExpr>(expr).getPosition();
127 permutedValues[position.value()] = values[position.index()];
128 return permutedValues;
134 "expected scalar type while computing zero value");
135 if (isa<IntegerType>(elementType))
136 return b.
create<arith::ConstantIntOp>(loc, 0, elementType);
138 return b.
create<arith::ConstantIndexOp>(loc, 0);
140 auto floatType = cast<FloatType>(elementType);
141 return b.
create<arith::ConstantFloatOp>(
146 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
148 Block *body = genericOp.getBody();
151 genericOp.getIndexingMapsArray();
161 for (
auto scalarOpResult : peeledScalarOperation->
getResults()) {
165 std::optional<unsigned> resultNumber;
166 for (
auto *user : scalarOpResult.getUsers()) {
167 if (
auto yieldOp = dyn_cast<YieldOp>(user)) {
169 for (
OpOperand &yieldOperand : yieldOp->getOpOperands()) {
170 if (yieldOperand.get() == scalarOpResult) {
171 resultNumber = yieldOperand.getOperandNumber();
175 assert(resultNumber &&
"unable to find use of a value in its user");
180 newInitValues.push_back(
181 genericOp.getDpsInitOperand(*resultNumber)->get());
182 OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
183 newResultTypes.push_back(result.
getType());
184 peeledGenericOpIndexingMaps.push_back(
185 genericOp.getIndexingMapMatchingResult(result));
192 rewriter.
create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
193 newInitValues.push_back(emptyTensor);
194 newResultTypes.push_back(emptyTensor.
getType());
195 peeledGenericOpIndexingMaps.push_back(indexingMap);
200 outsOperands.append(newInitValues.begin(), newInitValues.end());
202 resultTypes.append(newResultTypes.begin(), newResultTypes.end());
203 auto indexingMapAttr =
205 return rewriter.
create<GenericOp>(
206 loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
207 genericOp.getIteratorTypes(),
nullptr,
nullptr,
212 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
213 GenericOp peeledGenericOp,
218 unsigned origNumResults = genericOp.getNumResults();
219 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
221 for (
auto resultNum :
222 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
223 extraIns.push_back(peeledGenericOp->getResult(resultNum));
224 residualGenericOpOperands.append(extraIns);
228 auto indexingMaps = llvm::to_vector(
229 llvm::map_range(genericOp.getDpsInputOperands(), [&](
OpOperand *operand) {
230 return genericOp.getMatchingIndexingMap(operand);
232 for (
auto resultNum :
233 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
234 OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
235 indexingMaps.push_back(
236 peeledGenericOp.getIndexingMapMatchingResult(result));
238 for (
OpOperand &outOperand : genericOp.getDpsInitsMutable())
239 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
242 return rewriter.
create<GenericOp>(
243 genericOp->getLoc(), genericOp->getResultTypes(),
244 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245 genericOp.getIteratorTypes(),
nullptr,
nullptr,
250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
253 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
255 "unhandled decomposition of operation "
256 "with non-parallel iterator types");
261 if (!genericOp.hasPureTensorSemantics()) {
263 genericOp,
"only operations with tensor semantics are handled");
266 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](
OpOperand &outOperand) {
267 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
270 genericOp,
"unhandled decomposition of generic op with out operand not "
271 "accessed using a permutation");
275 Block *body = genericOp.getBody();
278 "operation has less than 3 statements");
282 if (llvm::any_of(body->
getOperations().begin()->getResultTypes(),
283 [](
Type t) { return !t.isIntOrIndexOrFloat(); })) {
286 "expected return type to be only int, index or float");
289 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290 GenericOp residualGenericOp =
291 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
295 Block *peeledGenericOpBody = peeledGenericOp.getBody();
296 Block *residualGenericOpBody = residualGenericOp.getBody();
297 assert(peeledGenericOpBody->
empty() && residualGenericOpBody->
empty() &&
298 "expected split generic ops to have empty region");
304 Operation *peeledScalarOperation = &(*peeledGenericOpBody->
begin());
311 for (
auto origYield : yieldOp->getOperands()) {
312 if (origYield.getDefiningOp() == peeledScalarOperation) {
313 yieldedVals.push_back(origYield);
320 yieldedVals.push_back(
321 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
324 yieldedVals.append(llvm::to_vector(
325 llvm::map_range(peeledScalarOperation->
getResults(),
327 rewriter.
create<YieldOp>(genericOp.getLoc(), yieldedVals);
332 unsigned origNumInputs = genericOp.getNumDpsInputs();
333 for (
const auto &inputBlockArg :
335 Value residualOpReplacementArg =
336 residualGenericOpBody->
getArgument(inputBlockArg.index());
338 inputBlockArg.value(), residualOpReplacementArg, [&](
OpOperand &use) {
339 return use.getOwner()->getBlock() == residualGenericOpBody;
342 Value peeledOpReplacementArg =
343 peeledGenericOpBody->
getArgument(inputBlockArg.index());
345 inputBlockArg.value(), peeledOpReplacementArg, [&](
OpOperand &use) {
346 return use.getOwner()->getBlock() == peeledGenericOpBody;
355 for (
const auto &yieldValue :
llvm::enumerate(yieldOp->getOperands())) {
356 OpResult opr = dyn_cast<OpResult>(yieldValue.value());
357 if (!opr || opr.
getOwner() != peeledScalarOperation)
358 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
360 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
367 unsigned peeledScalarOpNumResults = peeledScalarOperation->
getNumResults();
368 scalarReplacements.reserve(peeledScalarOpNumResults);
369 for (
auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
370 scalarReplacements.push_back(
371 residualGenericOpBody->
getArgument(num + origNumInputs));
372 bool allUsesReplaced =
false;
374 residualGenericOpBody, &allUsesReplaced);
375 assert(!allUsesReplaced &&
376 "peeled scalar operation is erased when it wasnt expected to be");
380 rewriter.
replaceOp(genericOp, replacements);
388 if (removeDeadArgsAndResults)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)
Helper method to permute the list of values based on the map.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
Operation is the basic unit of execution within MLIR.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...