21 #include "llvm/ADT/STLExtras.h" 
   38   auto [dim, indexValue] = dimAndIndex;
 
   39   assert(dim < sliceParams.size() && 
"slice should be non rank-reducing");
 
   40   return std::make_pair(
 
   43                {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
 
   51   const auto &[dim, indexValue] = dimAndIndex;
 
   53   for (int64_t i : reassociation[dim])
 
   54     basis.push_back(reshapeSourceShape[i]);
 
   56       AffineDelinearizeIndexOp::create(b, loc, indexValue, basis);
 
   57   return delinearized->getResults();
 
   60 FailureOr<ExtractSliceFromCollapseHelper>
 
   62     OpBuilder &b, tensor::CollapseShapeOp collapseOp,
 
   63     tensor::ExtractSliceOp extractOp) {
 
   64   if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
 
   68   ranges.reserve(extractOp.getSourceType().getRank());
 
   69   for (
const auto &[o, s, st] :
 
   70        llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
 
   71                  extractOp.getMixedStrides())) {
 
   72     ranges.push_back({o, s, st});
 
   74   return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
 
   77 FailureOr<ExtractSliceFromCollapseHelper>
 
   79                                                tensor::CollapseShapeOp op,
 
   83   if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
 
   84           op.getSrcType(), op.getReassociationIndices())))
 
   94       op.getReassociationIndices();
 
   98   llvm::SmallBitVector linearizedDimensions =
 
  100   llvm::SmallBitVector slicedDimensions =
 
  103   auto collapseShapeInputShape =
 
  107   for (
unsigned i = 0; i < sliceParams.size(); i++) {
 
  108     if (slicedDimensions[i] && linearizedDimensions[i])
 
  114       op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
 
  115       linearizedDimensions, slicedDimensions, tileSizes);
 
  118 std::pair<Value, SmallVector<Range>>
 
  123       collapseShapeOp.getReassociationIndices();
 
  124   SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
 
  125                                  collapseShapeOutputShape, sliceParams);
 
  131   unsigned loopIdx = 0;
 
  132   for (
unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
 
  133     if (linearizedDimensions[i] && slicedDimensions[i]) {
 
  136                               std::make_tuple(i, tileInductionVars[loopIdx++]));
 
  138           builder, loc, reassociationIndices, collapseShapeInputShape, tb));
 
  143       helper.getExtractSliceParams(builder.
getContext(), multiIndices);
 
  145   Value subTileResult = tensor::ExtractSliceOp::create(
 
  146       builder, loc, collapseShapeOp.getSrc(), extractParams);
 
  149       helper.getInsertSliceParams(builder.
getContext(), tileInductionVars);
 
  152   Value collapsedResult = tensor::CollapseShapeOp::create(
 
  153       builder, loc, subTileResult, reassociationIndices);
 
  154   return std::make_pair(collapsedResult, insertParams);
 
  157 FailureOr<Operation *>
 
  161       op.getReassociationIndices();
 
  162   RankedTensorType sourceType = op.getSrcType();
 
  163   FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
 
  164       getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
 
  165                                                         reassociationIndices);
 
  176   auto sliceOp = tensor::ExtractSliceOp::create(
 
  177       rewriter, op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes,
 
  180   if (!info->newReassociationIndices.has_value()) {
 
  181     rewriter.
replaceOp(op, sliceOp.getResult());
 
  182     return sliceOp.getOperation();
 
  187           op, sliceOp.getResult(), *info->newReassociationIndices)
 
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Operation * > simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op, RewriterBase &rewriter)
Tries to simplify a tensor.collapse_shape operation by inserting a single rank-reducing tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...