23 #include "llvm/ADT/STLExtras.h"
40 auto [dim, indexValue] = dimAndIndex;
41 assert(dim < sliceParams.size() &&
"slice should be non rank-reducing");
42 return std::make_pair(
45 {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
53 const auto &[dim, indexValue] = dimAndIndex;
55 for (int64_t i : reassociation[dim])
56 basis.push_back(reshapeSourceShape[i]);
58 b.
create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
62 FailureOr<ExtractSliceFromCollapseHelper>
64 OpBuilder &b, tensor::CollapseShapeOp collapseOp,
65 tensor::ExtractSliceOp extractOp) {
66 if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
70 ranges.reserve(extractOp.getSourceType().getRank());
71 for (
const auto &[o, s, st] :
72 llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
73 extractOp.getMixedStrides())) {
74 ranges.push_back({o, s, st});
76 return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
79 FailureOr<ExtractSliceFromCollapseHelper>
81 tensor::CollapseShapeOp op,
85 if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
86 op.getSrcType(), op.getReassociationIndices())))
96 op.getReassociationIndices();
100 llvm::SmallBitVector linearizedDimensions =
102 llvm::SmallBitVector slicedDimensions =
105 auto collapseShapeInputShape =
109 for (
unsigned i = 0; i < sliceParams.size(); i++) {
110 if (slicedDimensions[i] && linearizedDimensions[i])
116 op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
117 linearizedDimensions, slicedDimensions, tileSizes);
120 std::pair<Value, SmallVector<Range>>
125 collapseShapeOp.getReassociationIndices();
126 SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
127 collapseShapeOutputShape, sliceParams);
133 unsigned loopIdx = 0;
134 for (
unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
135 if (linearizedDimensions[i] && slicedDimensions[i]) {
138 std::make_tuple(i, tileInductionVars[loopIdx++]));
140 builder, loc, reassociationIndices, collapseShapeInputShape, tb));
145 helper.getExtractSliceParams(builder.
getContext(), multiIndices);
147 Value subTileResult = builder.
create<tensor::ExtractSliceOp>(
148 loc, collapseShapeOp.getSrc(), extractParams);
151 helper.getInsertSliceParams(builder.
getContext(), tileInductionVars);
154 Value collapsedResult = builder.
create<tensor::CollapseShapeOp>(
155 loc, subTileResult, reassociationIndices);
156 return std::make_pair(collapsedResult, insertParams);
159 FailureOr<Operation *>
163 op.getReassociationIndices();
164 RankedTensorType sourceType = op.getSrcType();
165 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
166 getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
167 reassociationIndices);
178 auto sliceOp = rewriter.
create<tensor::ExtractSliceOp>(
179 op.
getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides);
181 if (!info->newReassociationIndices.has_value()) {
183 return sliceOp.getOperation();
188 op, sliceOp.
getResult(), *info->newReassociationIndices)
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Operation * > simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op, RewriterBase &rewriter)
Tries to simplify a tensor.collapse_shape operation by inserting a single rank-reducing tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...