13#include "llvm/ADT/SmallVectorExtras.h"
86 using OpRewritePattern<GenericOp>::OpRewritePattern;
88 LogicalResult matchAndRewrite(GenericOp genericOp,
89 PatternRewriter &rewriter)
const override;
94 GenericOp createPeeledGenericOp(GenericOp genericOp,
95 PatternRewriter &rewriter)
const;
99 GenericOp createResidualGenericOp(GenericOp genericOp,
100 GenericOp peeledGenericOp,
101 PatternRewriter &rewriter)
const;
109 b.setInsertionPoint(op);
111 auto allShapesSizes =
112 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(
b, loc);
113 AffineMap map = op.getShapesToLoopsMap();
124 for (
const auto &position :
126 return cast<AffineDimExpr>(expr).getPosition();
128 permutedValues[position.value()] = values[position.index()];
129 return permutedValues;
135 "expected scalar type while computing zero value");
136 if (isa<IntegerType>(elementType))
141 auto floatType = cast<FloatType>(elementType);
143 b, loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
147DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
149 Block *body = genericOp.getBody();
150 Operation *peeledScalarOperation = &(*body->
begin());
151 SmallVector<AffineMap> peeledGenericOpIndexingMaps =
152 genericOp.getIndexingMapsArray();
156 Location loc = genericOp.getLoc();
158 SmallVector<Value> newInitValues;
159 SmallVector<Type> newResultTypes;
162 for (
auto scalarOpResult : peeledScalarOperation->
getResults()) {
166 std::optional<unsigned> resultNumber;
167 for (
auto *user : scalarOpResult.getUsers()) {
168 if (
auto yieldOp = dyn_cast<YieldOp>(user)) {
170 for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
171 if (yieldOperand.get() == scalarOpResult) {
172 resultNumber = yieldOperand.getOperandNumber();
176 assert(resultNumber &&
"unable to find use of a value in its user");
181 newInitValues.push_back(
182 genericOp.getDpsInitOperand(*resultNumber)->get());
183 OpResult
result = cast<OpResult>(genericOp.getResult(*resultNumber));
184 newResultTypes.push_back(
result.getType());
185 peeledGenericOpIndexingMaps.push_back(
186 genericOp.getIndexingMapMatchingResult(
result));
192 Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, domain,
193 scalarOpResult.getType());
194 newInitValues.push_back(emptyTensor);
195 newResultTypes.push_back(emptyTensor.
getType());
196 peeledGenericOpIndexingMaps.push_back(indexingMap);
200 SmallVector<Value> outsOperands = genericOp.getOutputs();
201 outsOperands.append(newInitValues.begin(), newInitValues.end());
202 SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
203 resultTypes.append(newResultTypes.begin(), newResultTypes.end());
204 auto indexingMapAttr =
206 return GenericOp::create(
207 rewriter, loc, resultTypes, genericOp.getInputs(), outsOperands,
208 indexingMapAttr, genericOp.getIteratorTypes(),
nullptr,
209 nullptr, [](OpBuilder, Location,
ValueRange) {});
213DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
214 GenericOp peeledGenericOp,
215 PatternRewriter &rewriter)
const {
218 SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
219 unsigned origNumResults = genericOp.getNumResults();
220 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
221 SmallVector<Value> extraIns;
222 for (
auto resultNum :
223 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
224 extraIns.push_back(peeledGenericOp->getResult(resultNum));
225 residualGenericOpOperands.append(extraIns);
229 auto indexingMaps = llvm::map_to_vector(
230 genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
231 return genericOp.getMatchingIndexingMap(operand);
233 for (
auto resultNum :
234 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
235 OpResult
result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
236 indexingMaps.push_back(
237 peeledGenericOp.getIndexingMapMatchingResult(
result));
239 for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
240 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
243 return GenericOp::create(
244 rewriter, genericOp->getLoc(), genericOp->getResultTypes(),
245 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
246 genericOp.getIteratorTypes(),
nullptr,
nullptr,
251DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
252 PatternRewriter &rewriter)
const {
254 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
256 "unhandled decomposition of operation "
257 "with non-parallel iterator types");
262 if (!genericOp.hasPureTensorSemantics()) {
264 genericOp,
"only operations with tensor semantics are handled");
267 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
268 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
271 genericOp,
"unhandled decomposition of generic op with out operand not "
272 "accessed using a permutation");
276 Block *body = genericOp.getBody();
279 "operation has less than 3 statements");
283 if (llvm::any_of(body->
getOperations().begin()->getResultTypes(),
284 [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
287 "expected return type to be only int, index or float");
290 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
291 GenericOp residualGenericOp =
292 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
296 Block *peeledGenericOpBody = peeledGenericOp.getBody();
297 Block *residualGenericOpBody = residualGenericOp.getBody();
298 assert(peeledGenericOpBody->
empty() && residualGenericOpBody->
empty() &&
299 "expected split generic ops to have empty region");
305 Operation *peeledScalarOperation = &(*peeledGenericOpBody->
begin());
309 OpBuilder::InsertionGuard g(rewriter);
311 SmallVector<Value> yieldedVals;
312 for (
auto origYield : yieldOp->getOperands()) {
313 if (origYield.getDefiningOp() == peeledScalarOperation) {
314 yieldedVals.push_back(origYield);
319 OpBuilder::InsertionGuard g(rewriter);
321 yieldedVals.push_back(
322 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
326 llvm::map_to_vector(peeledScalarOperation->
getResults(),
327 [](OpResult opr) -> Value { return opr; }));
328 YieldOp::create(rewriter, genericOp.getLoc(), yieldedVals);
333 unsigned origNumInputs = genericOp.getNumDpsInputs();
334 for (
const auto &inputBlockArg :
335 llvm::enumerate(genericOp.getBody()->getArguments())) {
336 Value residualOpReplacementArg =
337 residualGenericOpBody->
getArgument(inputBlockArg.index());
339 inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
340 return use.getOwner()->getBlock() == residualGenericOpBody;
343 Value peeledOpReplacementArg =
344 peeledGenericOpBody->
getArgument(inputBlockArg.index());
346 inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
347 return use.getOwner()->getBlock() == peeledGenericOpBody;
355 SmallVector<Value> replacements;
356 for (
const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
357 OpResult opr = dyn_cast<OpResult>(yieldValue.value());
358 if (!opr || opr.
getOwner() != peeledScalarOperation)
359 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
361 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
367 SmallVector<Value> scalarReplacements;
368 unsigned peeledScalarOpNumResults = peeledScalarOperation->
getNumResults();
369 scalarReplacements.reserve(peeledScalarOpNumResults);
370 for (
auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
371 scalarReplacements.push_back(
372 residualGenericOpBody->
getArgument(num + origNumInputs));
373 bool allUsesReplaced =
false;
375 residualGenericOpBody, &allUsesReplaced);
376 assert(!allUsesReplaced &&
377 "peeled scalar operation is erased when it wasnt expected to be");
381 rewriter.
replaceOp(genericOp, replacements);
389 if (removeDeadArgsAndResults)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)
Helper method to permute the list of values based on the map.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
BlockArgument getArgument(unsigned i)
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
AffineMap getMultiDimIdentityMap(unsigned rank)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * getOwner() const
Returns the operation that owns this result.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...