MLIR 22.0.0git
Transforms.cpp
Go to the documentation of this file.
1//===- Transforms.cpp ---------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TransformsDetail.h"
25#include "mlir/IR/Value.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include <iterator>
29#include <numeric>
30
31namespace mlir::shard {
32
33namespace {
34
35/// Lower `shard.process_multi_index` into expression using
36/// `shard.process_linear_index` and `shard.grid_shape`.
37struct ProcessMultiIndexOpLowering
38 : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
39 using OpRewritePatternWithSymbolTableCollection::
40 OpRewritePatternWithSymbolTableCollection;
41
42 LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
43 PatternRewriter &rewriter) const override {
44 GridOp grid = getGrid(op, symbolTableCollection);
45 if (!grid) {
46 return failure();
47 }
48
49 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
50 builder.setInsertionPointAfter(op.getOperation());
51 Value linearIndex = ProcessLinearIndexOp::create(builder, grid);
52 ValueRange gridShape = GridShapeOp::create(builder, grid).getResults();
53 SmallVector<Value> completeMultiIndex =
54 affine::AffineDelinearizeIndexOp::create(builder, linearIndex,
55 gridShape)
56 .getMultiIndex();
57 SmallVector<Value> multiIndex;
58 ArrayRef<GridAxis> opGridAxes = op.getAxes();
59 SmallVector<GridAxis> opAxesIota;
60 if (opGridAxes.empty()) {
61 opAxesIota.resize(grid.getRank());
62 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
63 opGridAxes = opAxesIota;
64 }
65 llvm::transform(opGridAxes, std::back_inserter(multiIndex),
66 [&completeMultiIndex](GridAxis gridAxis) {
67 return completeMultiIndex[gridAxis];
68 });
69 rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
70 return success();
71 }
72};
73
74struct AllSliceOpLowering
76 using OpRewritePatternWithSymbolTableCollection::
77 OpRewritePatternWithSymbolTableCollection;
78
79 LogicalResult matchAndRewrite(AllSliceOp op,
80 PatternRewriter &rewriter) const override {
81 // 1. Compute the process linear index inside the process group from its
82 // multi-index.
83 //
84 // 2. Extract a slice from the input tensor.
85 // All axes except the slicing axis are not interesting and take the full
86 // axis.
87 // The slice axis is split into equisized parts with count
88 // the number of processes in the collective process group induced by
89 // the grid axes.
90 // The part for each process is determined by the corresponding
91 // linear-index in the process group.
92 //
93 // There are no collectives that require communication.
94 // Each process operates on its local tensor.
95
96 GridOp grid = getGrid(op, symbolTableCollection);
97 if (!grid) {
98 return failure();
99 }
100
101 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
102 builder.setInsertionPointAfter(op.getOperation());
103
104 Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0));
105
106 Operation::result_range processInGroupMultiIndex =
107 ProcessMultiIndexOp::create(builder, grid.getSymName(),
108 op.getGridAxes())
109 .getResults();
110
111 Operation::result_range processGroupShape =
112 GridShapeOp::create(builder, grid.getSymName(), op.getGridAxes())
113 .getResult();
114 Value processGroupSize =
115 createCollectiveProcessGroupSize(grid, op.getGridAxes(), builder);
116
117 int64_t sliceAxis = op.getSliceAxis().getSExtValue();
118 Value operandSliceAxisSize =
119 tensor::DimOp::create(builder, op.getOperand(), sliceAxis);
120 Value operandSliceAxisSizeModProcessGroupSize =
121 arith::RemUIOp::create(builder, operandSliceAxisSize, processGroupSize);
122 Value isTargetShapeExactlyDivisible =
123 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
124 operandSliceAxisSizeModProcessGroupSize, zero);
125 cf::AssertOp::create(builder, isTargetShapeExactlyDivisible,
126 "Slicing a tensor with axis size that is "
127 "not exactly divisible by the "
128 "grid process group size is not supported.");
129 Value resultSliceAxisSize =
130 arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize);
131 OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
132 llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
133 llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
134
135 // insert tensor.extract_slice
136 RankedTensorType operandType =
137 cast<RankedTensorType>(op.getOperand().getType());
138 SmallVector<OpFoldResult> sizes;
139 for (int64_t i = 0; i < operandType.getRank(); ++i) {
140 if (i == sliceAxis) {
141 sizes.emplace_back(resultSliceAxisSize);
142 } else {
143 Value dimSize = tensor::DimOp::create(builder, op.getOperand(), i);
144 sizes.emplace_back(dimSize);
145 }
146 }
147 SmallVector<OpFoldResult> offsets(
148 operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
149 offsets[sliceAxis] =
150 ArithBuilder(builder, builder.getLoc())
151 .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
152 processInGroupLinearIndex),
153 resultSliceAxisSize);
154 SmallVector<OpFoldResult> strides(
155 operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
156 Value slice = tensor::ExtractSliceOp::create(builder, op.getOperand(),
157 offsets, sizes, strides);
158 Value newResult =
159 tensor::CastOp::create(builder, op.getResult().getType(), slice);
160 rewriter.replaceAllUsesWith(op.getResult(), newResult);
161
162 return success();
163 }
164};
165
166} // namespace
167
169 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
170 patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
171 patterns.getContext());
172}
173
175 registry.insert<affine::AffineDialect, shard::ShardDialect>();
176}
177
179 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
180 patterns.add<AllSliceOpLowering>(symbolTableCollection,
181 patterns.getContext());
182}
183
185 registry.insert<affine::AffineDialect, arith::ArithDialect,
186 cf::ControlFlowDialect, shard::ShardDialect,
187 tensor::TensorDialect>();
188}
189
195
200
203 ImplicitLocOpBuilder &builder) {
204 Operation::result_range gridShape =
205 GridShapeOp::create(builder, grid, axes).getResults();
206 return cast<TypedValue<IndexType>>(arith::createProduct(
207 builder, builder.getLoc(), llvm::to_vector_of<Value>(gridShape),
208 builder.getIndexType()));
209}
210
212createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
213 ArrayRef<GridAxis> gridAxes,
214 ImplicitLocOpBuilder &builder) {
215 Operation::result_range processGroupShape =
216 GridShapeOp::create(builder, grid, gridAxes).getResult();
217 OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
218 llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
219 llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
220 auto res = dyn_cast<Value>(processInGroupLinearIndex);
221 if (!res)
223 builder,
224 cast<IntegerAttr>(cast<Attribute>(processInGroupLinearIndex)).getInt());
225 return cast<TypedValue<IndexType>>(res);
226}
227
229 ArrayRef<GridAxis> gridAxes,
230 ImplicitLocOpBuilder &builder) {
232 grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
233 gridAxes, builder);
234}
235} // namespace mlir::shard
return success()
IndexType getIndexType()
Definition Builders.cpp:51
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
This class represents a single result from folding an operation.
ResultRange result_range
Support result iteration.
Definition Operation.h:410
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition Utils.cpp:345
void populateAllOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerAllSliceOpLoweringDialects(DialectRegistry &registry)
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
int16_t GridAxis
Definition ShardOps.h:26
void registerAllOpLoweringDialects(DialectRegistry &registry)
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
TypedValue< IndexType > createCollectiveProcessGroupSize(GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder)
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111