MLIR  22.0.0git
Transforms.cpp
Go to the documentation of this file.
1 //===- Transforms.cpp ---------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TransformsDetail.h"
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Value.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include <iterator>
29 #include <numeric>
30 
31 namespace mlir::shard {
32 
33 namespace {
34 
35 /// Lower `shard.process_multi_index` into expression using
36 /// `shard.process_linear_index` and `shard.grid_shape`.
37 struct ProcessMultiIndexOpLowering
38  : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
41 
42  LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
43  PatternRewriter &rewriter) const override {
44  GridOp grid = getGrid(op, symbolTableCollection);
45  if (!grid) {
46  return failure();
47  }
48 
49  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
50  builder.setInsertionPointAfter(op.getOperation());
51  Value linearIndex = ProcessLinearIndexOp::create(builder, grid);
52  ValueRange gridShape = GridShapeOp::create(builder, grid).getResults();
53  SmallVector<Value> completeMultiIndex =
54  affine::AffineDelinearizeIndexOp::create(builder, linearIndex,
55  gridShape)
56  .getMultiIndex();
57  SmallVector<Value> multiIndex;
58  ArrayRef<GridAxis> opGridAxes = op.getAxes();
59  SmallVector<GridAxis> opAxesIota;
60  if (opGridAxes.empty()) {
61  opAxesIota.resize(grid.getRank());
62  std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
63  opGridAxes = opAxesIota;
64  }
65  llvm::transform(opGridAxes, std::back_inserter(multiIndex),
66  [&completeMultiIndex](GridAxis gridAxis) {
67  return completeMultiIndex[gridAxis];
68  });
69  rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
70  return success();
71  }
72 };
73 
74 struct AllSliceOpLowering
75  : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
78 
79  LogicalResult matchAndRewrite(AllSliceOp op,
80  PatternRewriter &rewriter) const override {
81  // 1. Compute the process linear index inside the process group from its
82  // multi-index.
83  //
84  // 2. Extract a slice from the input tensor.
85  // All axes except the slicing axis are not interesting and take the full
86  // axis.
87  // The slice axis is split into equisized parts with count
88  // the number of processes in the collective process group induced by
89  // the grid axes.
90  // The part for each process is determined by the corresponding
91  // linear-index in the process group.
92  //
93  // There are no collectives that require communication.
94  // Each process operates on its local tensor.
95 
96  GridOp grid = getGrid(op, symbolTableCollection);
97  if (!grid) {
98  return failure();
99  }
100 
101  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
102  builder.setInsertionPointAfter(op.getOperation());
103 
104  Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0));
105 
106  Operation::result_range processInGroupMultiIndex =
107  ProcessMultiIndexOp::create(builder, grid.getSymName(),
108  op.getGridAxes())
109  .getResults();
110 
111  Operation::result_range processGroupShape =
112  GridShapeOp::create(builder, grid.getSymName(), op.getGridAxes())
113  .getResult();
114  Value processGroupSize =
115  createCollectiveProcessGroupSize(grid, op.getGridAxes(), builder);
116 
117  int64_t sliceAxis = op.getSliceAxis().getSExtValue();
118  Value operandSliceAxisSize =
119  tensor::DimOp::create(builder, op.getOperand(), sliceAxis);
120  Value operandSliceAxisSizeModProcessGroupSize =
121  arith::RemUIOp::create(builder, operandSliceAxisSize, processGroupSize);
122  Value isTargetShapeExactlyDivisible =
123  arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
124  operandSliceAxisSizeModProcessGroupSize, zero);
125  cf::AssertOp::create(builder, isTargetShapeExactlyDivisible,
126  "Slicing a tensor with axis size that is "
127  "not exactly divisible by the "
128  "grid process group size is not supported.");
129  Value resultSliceAxisSize =
130  arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize);
131  OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
132  llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
133  llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
134 
135  // insert tensor.extract_slice
136  RankedTensorType operandType =
137  cast<RankedTensorType>(op.getOperand().getType());
138  SmallVector<OpFoldResult> sizes;
139  for (int64_t i = 0; i < operandType.getRank(); ++i) {
140  if (i == sliceAxis) {
141  sizes.emplace_back(resultSliceAxisSize);
142  } else {
143  Value dimSize = tensor::DimOp::create(builder, op.getOperand(), i);
144  sizes.emplace_back(dimSize);
145  }
146  }
147  SmallVector<OpFoldResult> offsets(
148  operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
149  offsets[sliceAxis] =
150  ArithBuilder(builder, builder.getLoc())
151  .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
152  processInGroupLinearIndex),
153  resultSliceAxisSize);
154  SmallVector<OpFoldResult> strides(
155  operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
156  Value slice = tensor::ExtractSliceOp::create(builder, op.getOperand(),
157  offsets, sizes, strides);
158  Value newResult =
159  tensor::CastOp::create(builder, op.getResult().getType(), slice);
160  rewriter.replaceAllUsesWith(op.getResult(), newResult);
161 
162  return success();
163  }
164 };
165 
166 } // namespace
167 
169  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
170  patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
171  patterns.getContext());
172 }
173 
175  registry.insert<affine::AffineDialect, shard::ShardDialect>();
176 }
177 
179  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
180  patterns.add<AllSliceOpLowering>(symbolTableCollection,
181  patterns.getContext());
182 }
183 
185  registry.insert<affine::AffineDialect, arith::ArithDialect,
186  cf::ControlFlowDialect, shard::ShardDialect,
187  tensor::TensorDialect>();
188 }
189 
191  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
193  populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
194 }
195 
199 }
200 
203  ImplicitLocOpBuilder &builder) {
204  Operation::result_range gridShape =
205  GridShapeOp::create(builder, grid, axes).getResults();
206  return cast<TypedValue<IndexType>>(arith::createProduct(
207  builder, builder.getLoc(), llvm::to_vector_of<Value>(gridShape),
208  builder.getIndexType()));
209 }
210 
212 createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
213  ArrayRef<GridAxis> gridAxes,
214  ImplicitLocOpBuilder &builder) {
215  Operation::result_range processGroupShape =
216  GridShapeOp::create(builder, grid, gridAxes).getResult();
217  OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
218  llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
219  llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
220  auto res = dyn_cast<Value>(processInGroupLinearIndex);
221  if (!res)
223  builder,
224  cast<IntegerAttr>(cast<Attribute>(processInGroupLinearIndex)).getInt());
225  return cast<TypedValue<IndexType>>(res);
226 }
227 
229  ArrayRef<GridAxis> gridAxes,
230  ImplicitLocOpBuilder &builder) {
232  grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
233  gridAxes, builder);
234 }
235 } // namespace mlir::shard
IndexType getIndexType()
Definition: Builders.cpp:50
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:656
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
ResultRange result_range
Support result iteration.
Definition: Operation.h:410
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:2027
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition: Utils.cpp:345
shard::GridOp GridOp
int16_t GridAxis
Definition: ShardOps.h:26
void populateAllOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:190
void registerAllSliceOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:184
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:178
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:168
void registerAllOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:196
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:228
TypedValue< IndexType > createCollectiveProcessGroupSize(GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:202
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:174
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:121
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)