13#include "llvm/ADT/SmallVectorExtras.h"
86 using OpRewritePattern<GenericOp>::OpRewritePattern;
88 LogicalResult matchAndRewrite(GenericOp genericOp,
89 PatternRewriter &rewriter)
const override;
94 GenericOp createPeeledGenericOp(GenericOp genericOp,
95 PatternRewriter &rewriter)
const;
99 GenericOp createResidualGenericOp(GenericOp genericOp,
100 GenericOp peeledGenericOp,
101 PatternRewriter &rewriter)
const;
109 b.setInsertionPoint(op);
111 auto allShapesSizes =
112 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(
b, loc);
113 AffineMap map = op.getShapesToLoopsMap();
122 "expected scalar type while computing zero value");
123 if (isa<IntegerType>(elementType))
128 auto floatType = cast<FloatType>(elementType);
130 b, loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
134DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
136 Block *body = genericOp.getBody();
137 Operation *peeledScalarOperation = &(*body->
begin());
138 SmallVector<AffineMap> peeledGenericOpIndexingMaps =
139 genericOp.getIndexingMapsArray();
143 Location loc = genericOp.getLoc();
145 SmallVector<Value> newInitValues;
146 SmallVector<Type> newResultTypes;
149 for (
auto scalarOpResult : peeledScalarOperation->
getResults()) {
153 std::optional<unsigned> resultNumber;
154 for (
auto *user : scalarOpResult.getUsers()) {
155 if (
auto yieldOp = dyn_cast<YieldOp>(user)) {
157 for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
158 if (yieldOperand.get() == scalarOpResult) {
159 resultNumber = yieldOperand.getOperandNumber();
163 assert(resultNumber &&
"unable to find use of a value in its user");
168 newInitValues.push_back(
169 genericOp.getDpsInitOperand(*resultNumber)->get());
170 OpResult
result = cast<OpResult>(genericOp.getResult(*resultNumber));
171 newResultTypes.push_back(
result.getType());
172 peeledGenericOpIndexingMaps.push_back(
173 genericOp.getIndexingMapMatchingResult(
result));
179 Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, domain,
180 scalarOpResult.getType());
181 newInitValues.push_back(emptyTensor);
182 newResultTypes.push_back(emptyTensor.
getType());
183 peeledGenericOpIndexingMaps.push_back(indexingMap);
187 SmallVector<Value> outsOperands = genericOp.getOutputs();
188 outsOperands.append(newInitValues.begin(), newInitValues.end());
189 SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
190 resultTypes.append(newResultTypes.begin(), newResultTypes.end());
191 auto indexingMapAttr =
193 return GenericOp::create(
194 rewriter, loc, resultTypes, genericOp.getInputs(), outsOperands,
195 indexingMapAttr, genericOp.getIteratorTypes(),
nullptr,
196 nullptr, [](OpBuilder, Location,
ValueRange) {});
200DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
201 GenericOp peeledGenericOp,
202 PatternRewriter &rewriter)
const {
205 SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
206 unsigned origNumResults = genericOp.getNumResults();
207 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
208 SmallVector<Value> extraIns;
209 for (
auto resultNum :
210 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
211 extraIns.push_back(peeledGenericOp->getResult(resultNum));
212 residualGenericOpOperands.append(extraIns);
216 auto indexingMaps = llvm::map_to_vector(
217 genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
218 return genericOp.getMatchingIndexingMap(operand);
220 for (
auto resultNum :
221 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
222 OpResult
result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
223 indexingMaps.push_back(
224 peeledGenericOp.getIndexingMapMatchingResult(
result));
226 for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
227 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
230 return GenericOp::create(
231 rewriter, genericOp->getLoc(), genericOp->getResultTypes(),
232 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
233 genericOp.getIteratorTypes(),
nullptr,
nullptr,
238DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
239 PatternRewriter &rewriter)
const {
241 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
243 "unhandled decomposition of operation "
244 "with non-parallel iterator types");
249 if (!genericOp.hasPureTensorSemantics()) {
251 genericOp,
"only operations with tensor semantics are handled");
254 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
255 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
258 genericOp,
"unhandled decomposition of generic op with out operand not "
259 "accessed using a permutation");
263 Block *body = genericOp.getBody();
266 "operation has less than 3 statements");
270 if (llvm::any_of(body->
getOperations().begin()->getResultTypes(),
271 [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
274 "expected return type to be only int, index or float");
277 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
278 GenericOp residualGenericOp =
279 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
283 Block *peeledGenericOpBody = peeledGenericOp.getBody();
284 Block *residualGenericOpBody = residualGenericOp.getBody();
285 assert(peeledGenericOpBody->
empty() && residualGenericOpBody->
empty() &&
286 "expected split generic ops to have empty region");
292 Operation *peeledScalarOperation = &(*peeledGenericOpBody->
begin());
296 OpBuilder::InsertionGuard g(rewriter);
298 SmallVector<Value> yieldedVals;
299 for (
auto origYield : yieldOp->getOperands()) {
300 if (origYield.getDefiningOp() == peeledScalarOperation) {
301 yieldedVals.push_back(origYield);
306 OpBuilder::InsertionGuard g(rewriter);
308 yieldedVals.push_back(
309 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
313 llvm::map_to_vector(peeledScalarOperation->
getResults(),
314 [](OpResult opr) -> Value { return opr; }));
315 YieldOp::create(rewriter, genericOp.getLoc(), yieldedVals);
320 unsigned origNumInputs = genericOp.getNumDpsInputs();
321 for (
const auto &inputBlockArg :
322 llvm::enumerate(genericOp.getBody()->getArguments())) {
323 Value residualOpReplacementArg =
324 residualGenericOpBody->
getArgument(inputBlockArg.index());
326 inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
327 return use.getOwner()->getBlock() == residualGenericOpBody;
330 Value peeledOpReplacementArg =
331 peeledGenericOpBody->
getArgument(inputBlockArg.index());
333 inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
334 return use.getOwner()->getBlock() == peeledGenericOpBody;
342 SmallVector<Value> replacements;
343 for (
const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
344 OpResult opr = dyn_cast<OpResult>(yieldValue.value());
345 if (!opr || opr.
getOwner() != peeledScalarOperation)
346 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
348 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
354 SmallVector<Value> scalarReplacements;
355 unsigned peeledScalarOpNumResults = peeledScalarOperation->
getNumResults();
356 scalarReplacements.reserve(peeledScalarOpNumResults);
357 for (
auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
358 scalarReplacements.push_back(
359 residualGenericOpBody->
getArgument(num + origNumInputs));
360 bool allUsesReplaced =
false;
362 residualGenericOpBody, &allUsesReplaced);
363 assert(!allUsesReplaced &&
364 "peeled scalar operation is erased when it wasnt expected to be");
368 rewriter.
replaceOp(genericOp, replacements);
376 if (removeDeadArgsAndResults)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
BlockArgument getArgument(unsigned i)
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
AffineMap getMultiDimIdentityMap(unsigned rank)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * getOwner() const
Returns the operation that owns this result.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...