15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/LogicalResult.h"
23 struct FoldExpandOfRankReducingExtract
27 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
29 RankedTensorType resultType = expandShapeOp.getResultType();
31 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
34 RankedTensorType srcType = extractSliceOp.getSourceType();
39 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40 srcType, extractSliceOp.getStaticOffsets(),
41 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
42 if (nonReducingExtractType != resultType)
49 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
57 struct FoldUnPaddingCollapseIntoExtract
61 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
64 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
68 if (!extractSliceOp || !extractSliceOp->hasOneUse())
74 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
77 "expected unpadding collapse");
79 Value unPaddedExtractSlice = rewriter.
create<tensor::ExtractSliceOp>(
80 extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
89 template <
typename OpTy>
93 LogicalResult matchAndRewrite(OpTy insertSliceOp,
95 auto collapseShapeOp =
96 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
99 RankedTensorType srcType = collapseShapeOp.getSrcType();
104 RankedTensorType nonReducingInsertType =
106 insertSliceOp.getDestType().getElementType());
107 if (nonReducingInsertType != srcType)
114 insertSliceOp.getDest(), mixedOffsets,
115 mixedSizes, mixedStrides);
122 template <
typename OpTy>
126 LogicalResult matchAndRewrite(OpTy insertSliceOp,
128 auto expandShapeOp = insertSliceOp.getSource()
129 .template getDefiningOp<tensor::ExpandShapeOp>();
136 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
139 "expected rank increasing expansion");
142 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
150 struct BubbleUpExpandThroughParallelCollapse
154 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
157 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
160 auto expandReInds = expandOp.getReassociationIndices();
161 auto collapseReInds = collapseOp.getReassociationIndices();
165 if (expandReInds.size() == 0) {
171 for (
auto [expandReassociation, collapseReassociation] :
172 llvm::zip_equal(expandReInds, collapseReInds)) {
173 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
183 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
185 int64_t index = 0, expandIndex = 0, collapseIndex = 0;
186 for (
auto [idx, collapseReassociation] :
llvm::enumerate(collapseReInds)) {
187 if (collapseReassociation.size() != 1) {
189 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
190 newCollapseReassociation.push_back(index);
191 newExpandReInds.push_back({index++});
192 newExpandSizes.push_back(collapseSizes[collapseIndex++]);
194 newCollapseReInds.push_back(newCollapseReassociation);
199 auto expandReassociation = expandReInds[idx];
200 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
201 newExpandReassociation.push_back(index);
202 newCollapseReInds.push_back({index++});
203 newExpandSizes.push_back(expandSizes[expandIndex++]);
205 newExpandReInds.push_back(newExpandReassociation);
213 auto expandResultType = expandOp.getResultType().clone(staticSizes);
214 auto newExpand = rewriter.
create<tensor::ExpandShapeOp>(
215 loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
218 expandOp, newExpand.getResult(), newCollapseReInds);
254 struct BubbleUpExpandShapeThroughExtractSlice
258 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
261 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
263 if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
277 expandShapeOp.getOutputShape(), rewriter);
280 Location loc = expandShapeOp->getLoc();
297 expandShapeOp.getReassociationIndices()) {
306 for (
long expandedDim : indices) {
310 reassocGroupSizes.push_back(expandedShape[expandedDim]);
311 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
312 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
316 llvm::map_to_vector(reassocGroupOffsets, [&](
OpFoldResult ofr) {
321 .
create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
325 collapsedOffsets.push_back(collapsedOffset);
326 collapsedSizes.push_back(collapsedSize);
337 shape, expandShapeOp.getResultType().getElementType());
340 Value newSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
341 loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
344 sliceOp, resultType, newSliceOp,
345 expandShapeOp.getReassociationIndices(), expandedSizes);
353 checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
354 tensor::ExpandShapeOp expandShapeOp,
357 if (!expandShapeOp) {
359 sliceOp,
"tensor.extract_slice source not produced by expand_shape");
362 if (!sliceOp.hasUnitStride()) {
364 sliceOp,
"unsupported: non-unit stride. Only contiguous slices can "
365 "be supported in this transformation.");
371 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
374 "unimplemented: rank reducing slice");
379 expandShapeOp.getOutputShape(), rewriter);
382 isZeroOffsetAndFullSize =
386 FailureOr<bool> maybeEqual =
388 return llvm::succeeded(maybeEqual) && maybeEqual.value();
402 expandShapeOp.getReassociationIndices()) {
404 int64_t e = indices.size();
418 int64_t expandedDim = indices[i];
419 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
420 outputShape[expandedDim])) {
422 sliceOp,
"Not a contiguous slice of the expanded tensor.");
436 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
437 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
438 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
439 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
440 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
451 patterns.add<BubbleUpExpandShapeThroughExtractSlice>(
patterns.getContext());
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...