MLIR  19.0.0git
ExtractSliceFromReshapeUtils.cpp
Go to the documentation of this file.
1 //===- ExtractSliceFromReshapeUtils.cpp - Slice reshape rewrites ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites that replace slices of reshape results with
10 // aggregated slices of the reshape source.
11 //
12 //===----------------------------------------------------------------------===//
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "llvm/ADT/STLExtras.h"
24 
25 using namespace mlir;
26 using namespace mlir::affine;
27 using namespace mlir::tensor;
28 
29 /// A tuple that represents (dimension number, dimension value).
30 using DimAndIndex = std::tuple<unsigned, Value>;
31 
32 /// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
33 /// slice described by `sliceParams` into the input index space.
35  ArrayRef<Range> sliceParams,
36  const DimAndIndex &dimAndIndex) {
37  AffineExpr d0, s0, s1;
38  bindDims(b.getContext(), d0);
39  bindSymbols(b.getContext(), s0, s1);
40  auto [dim, indexValue] = dimAndIndex;
41  assert(dim < sliceParams.size() && "slice should be non rank-reducing");
42  return std::make_pair(
44  b, loc, s0 + d0 * s1,
45  {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
46 }
47 
48 /// Transform `dimAndIndex` from the result tensor index space of a
49 /// CollapseShapeOp to the source tensor index space.
51  OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
52  ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
53  const auto &[dim, indexValue] = dimAndIndex;
55  for (int64_t i : reassociation[dim])
56  basis.push_back(reshapeSourceShape[i]);
57  auto delinearized =
58  b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
59  return delinearized->getResults();
60 }
61 
64  OpBuilder &b, tensor::CollapseShapeOp collapseOp,
65  tensor::ExtractSliceOp extractOp) {
66  if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
67  collapseOp)
68  return failure();
69  SmallVector<Range> ranges;
70  ranges.reserve(extractOp.getSourceType().getRank());
71  for (const auto &[o, s, st] :
72  llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
73  extractOp.getMixedStrides())) {
74  ranges.push_back({o, s, st});
75  }
76  return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
77 }
78 
81  tensor::CollapseShapeOp op,
82  ArrayRef<Range> sliceParams) {
83  // Don't perform this pattern if the collapse op can be simplified by
84  // a rank-reducing extract slice.
85  if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
86  op.getSrcType(), op.getReassociationIndices())))
87  return failure();
88 
89  // Materialize the output shape of the collapse_shape operation. This will
90  // create IR describing the output shape in terms of the input shape.
91  ReifiedRankedShapedTypeDims reifiedShapes;
92  if (failed(reifyResultShapes(b, op, reifiedShapes)))
93  return failure();
94  SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0];
95  SmallVector<ReassociationIndices> reassociationIndices =
96  op.getReassociationIndices();
97 
98  // Determine which of the CollapseShapeOp's result dimensions are sliced
99  // and/or linearized.
100  llvm::SmallBitVector linearizedDimensions =
101  getLinearizedDimensions(reassociationIndices);
102  llvm::SmallBitVector slicedDimensions =
103  getSlicedDimensions(collapseShapeOutputShape, sliceParams);
104 
105  auto collapseShapeInputShape =
106  tensor::getMixedSizes(b, op.getLoc(), op.getSrc());
107 
108  SmallVector<Value> tileSizes;
109  for (unsigned i = 0; i < sliceParams.size(); i++) {
110  if (slicedDimensions[i] && linearizedDimensions[i])
111  tileSizes.push_back(
112  getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
113  }
114 
116  op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
117  linearizedDimensions, slicedDimensions, tileSizes);
118 }
119 
120 std::pair<Value, SmallVector<Range>>
122  OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
123  // Create the helper class for forming the slice parameters.
124  const SmallVector<ReassociationIndices> reassociationIndices =
125  collapseShapeOp.getReassociationIndices();
126  SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
127  collapseShapeOutputShape, sliceParams);
128 
129  // Get the indices of the tiled dims (linearized by the collapse_shape
130  // and sliced by the extract_slice) invert the index spaces
131  // transformations.
132  SmallVector<ValueRange> multiIndices;
133  unsigned loopIdx = 0;
134  for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
135  if (linearizedDimensions[i] && slicedDimensions[i]) {
136  DimAndIndex tb =
137  invertSliceIndexing(builder, loc, sliceParams,
138  std::make_tuple(i, tileInductionVars[loopIdx++]));
139  multiIndices.push_back(invertCollapseShapeIndexing(
140  builder, loc, reassociationIndices, collapseShapeInputShape, tb));
141  }
142  }
143 
144  SmallVector<Range> extractParams =
145  helper.getExtractSliceParams(builder.getContext(), multiIndices);
146 
147  Value subTileResult = builder.create<tensor::ExtractSliceOp>(
148  loc, collapseShapeOp.getSrc(), extractParams);
149 
150  SmallVector<Range> insertParams =
151  helper.getInsertSliceParams(builder.getContext(), tileInductionVars);
152 
153  // Collapse the dimensions of the source slice back down.
154  Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
155  loc, subTileResult, reassociationIndices);
156  return std::make_pair(collapsedResult, insertParams);
157 }
158 
161  tensor::CollapseShapeOp op, RewriterBase &rewriter) {
162  SmallVector<ReassociationIndices> reassociationIndices =
163  op.getReassociationIndices();
164  RankedTensorType sourceType = op.getSrcType();
166  getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
167  reassociationIndices);
168  if (failed(info))
169  return failure();
170 
171  // Create the rank-reducing extract slice op.
172  auto zero = rewriter.getIndexAttr(0);
173  auto one = rewriter.getIndexAttr(1);
174  SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero);
176  tensor::getMixedSizes(rewriter, op.getLoc(), op.getSrc());
177  SmallVector<OpFoldResult> strides(sourceType.getRank(), one);
178  auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(
179  op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides);
180 
181  if (!info->newReassociationIndices.has_value()) {
182  rewriter.replaceOp(op, sliceOp.getResult());
183  return sliceOp.getOperation();
184  }
185 
186  return rewriter
187  .replaceOpWithNewOp<tensor::CollapseShapeOp>(
188  op, sliceOp.getResult(), *info->newReassociationIndices)
189  .getOperation();
190 }
static ValueRange invertCollapseShapeIndexing(OpBuilder &b, Location loc, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > reshapeSourceShape, const DimAndIndex &dimAndIndex)
Transform dimAndIndex from the result tensor index space of a CollapseShapeOp to the source tensor in...
static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc, ArrayRef< Range > sliceParams, const DimAndIndex &dimAndIndex)
Transform dimAndIndex from the output index space of a (non-rank-reducing) slice described by slicePa...
std::tuple< unsigned, Value > DimAndIndex
A tuple that represents (dimension number, dimension value).
Base type for affine expression.
Definition: AffineExpr.h:69
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
MLIRContext * getContext() const
Definition: Builders.h:55
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This class assists with generating IR required to materialize an arbitrary-sized slice from the resul...
std::pair< Value, SmallVector< Range > > emitLoopNestBody(OpBuilder &builder, Location loc, ValueRange tileInductionVars)
Generates the IR inside of the caller's loop nest for 1) inverting the index mappings of the ExtractS...
static FailureOr< ExtractSliceFromCollapseHelper > create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef< Range > sliceParams)
Given a CollapseShapeOp and a set of ranges describing the desired slice of its result,...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
FailureOr< Operation * > simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op, RewriterBase &rewriter)
Tries to simplify a tensor.collapse_shape operation by inserting a single rank-reducing tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72