21 #include "llvm/ADT/STLExtras.h"
38 auto [dim, indexValue] = dimAndIndex;
39 assert(dim < sliceParams.size() &&
"slice should be non rank-reducing");
40 return std::make_pair(
43 {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
51 const auto &[dim, indexValue] = dimAndIndex;
53 for (int64_t i : reassociation[dim])
54 basis.push_back(reshapeSourceShape[i]);
56 AffineDelinearizeIndexOp::create(b, loc, indexValue, basis);
57 return delinearized->getResults();
60 FailureOr<ExtractSliceFromCollapseHelper>
62 OpBuilder &b, tensor::CollapseShapeOp collapseOp,
63 tensor::ExtractSliceOp extractOp) {
64 if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
68 ranges.reserve(extractOp.getSourceType().getRank());
69 for (
const auto &[o, s, st] :
70 llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
71 extractOp.getMixedStrides())) {
72 ranges.push_back({o, s, st});
74 return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
77 FailureOr<ExtractSliceFromCollapseHelper>
79 tensor::CollapseShapeOp op,
83 if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
84 op.getSrcType(), op.getReassociationIndices())))
94 op.getReassociationIndices();
98 llvm::SmallBitVector linearizedDimensions =
100 llvm::SmallBitVector slicedDimensions =
103 auto collapseShapeInputShape =
107 for (
unsigned i = 0; i < sliceParams.size(); i++) {
108 if (slicedDimensions[i] && linearizedDimensions[i])
114 op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
115 linearizedDimensions, slicedDimensions, tileSizes);
118 std::pair<Value, SmallVector<Range>>
123 collapseShapeOp.getReassociationIndices();
124 SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
125 collapseShapeOutputShape, sliceParams);
131 unsigned loopIdx = 0;
132 for (
unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
133 if (linearizedDimensions[i] && slicedDimensions[i]) {
136 std::make_tuple(i, tileInductionVars[loopIdx++]));
138 builder, loc, reassociationIndices, collapseShapeInputShape, tb));
143 helper.getExtractSliceParams(builder.
getContext(), multiIndices);
145 Value subTileResult = tensor::ExtractSliceOp::create(
146 builder, loc, collapseShapeOp.getSrc(), extractParams);
149 helper.getInsertSliceParams(builder.
getContext(), tileInductionVars);
152 Value collapsedResult = tensor::CollapseShapeOp::create(
153 builder, loc, subTileResult, reassociationIndices);
154 return std::make_pair(collapsedResult, insertParams);
157 FailureOr<Operation *>
161 op.getReassociationIndices();
162 RankedTensorType sourceType = op.getSrcType();
163 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
164 getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
165 reassociationIndices);
176 auto sliceOp = tensor::ExtractSliceOp::create(
177 rewriter, op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes,
180 if (!info->newReassociationIndices.has_value()) {
181 rewriter.
replaceOp(op, sliceOp.getResult());
182 return sliceOp.getOperation();
187 op, sliceOp.getResult(), *info->newReassociationIndices)
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Operation * > simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op, RewriterBase &rewriter)
Tries to simplify a tensor.collapse_shape operation by inserting a single rank-reducing tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...