26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
37 struct ProcessMultiIndexOpLowering
38 : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
42 LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
43 PatternRewriter &rewriter)
const override {
49 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
50 builder.setInsertionPointAfter(op.getOperation());
51 Value linearIndex = ProcessLinearIndexOp::create(builder, grid);
52 ValueRange gridShape = GridShapeOp::create(builder, grid).getResults();
53 SmallVector<Value> completeMultiIndex =
54 affine::AffineDelinearizeIndexOp::create(builder, linearIndex,
57 SmallVector<Value> multiIndex;
58 ArrayRef<GridAxis> opGridAxes = op.getAxes();
59 SmallVector<GridAxis> opAxesIota;
60 if (opGridAxes.empty()) {
61 opAxesIota.resize(grid.getRank());
62 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
63 opGridAxes = opAxesIota;
65 llvm::transform(opGridAxes, std::back_inserter(multiIndex),
66 [&completeMultiIndex](
GridAxis gridAxis) {
67 return completeMultiIndex[gridAxis];
69 rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
74 struct AllSliceOpLowering
75 : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
79 LogicalResult matchAndRewrite(AllSliceOp op,
80 PatternRewriter &rewriter)
const override {
101 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
102 builder.setInsertionPointAfter(op.getOperation());
104 Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0));
107 ProcessMultiIndexOp::create(builder, grid.getSymName(),
112 GridShapeOp::create(builder, grid.getSymName(), op.getGridAxes())
114 Value processGroupSize =
117 int64_t sliceAxis = op.getSliceAxis().getSExtValue();
118 Value operandSliceAxisSize =
119 tensor::DimOp::create(builder, op.getOperand(), sliceAxis);
120 Value operandSliceAxisSizeModProcessGroupSize =
121 arith::RemUIOp::create(builder, operandSliceAxisSize, processGroupSize);
122 Value isTargetShapeExactlyDivisible =
123 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
124 operandSliceAxisSizeModProcessGroupSize, zero);
125 cf::AssertOp::create(builder, isTargetShapeExactlyDivisible,
126 "Slicing a tensor with axis size that is "
127 "not exactly divisible by the "
128 "grid process group size is not supported.");
129 Value resultSliceAxisSize =
130 arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize);
132 llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
133 llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
136 RankedTensorType operandType =
137 cast<RankedTensorType>(op.getOperand().getType());
138 SmallVector<OpFoldResult> sizes;
139 for (int64_t i = 0; i < operandType.getRank(); ++i) {
140 if (i == sliceAxis) {
141 sizes.emplace_back(resultSliceAxisSize);
143 Value dimSize = tensor::DimOp::create(builder, op.getOperand(), i);
144 sizes.emplace_back(dimSize);
147 SmallVector<OpFoldResult> offsets(
150 ArithBuilder(builder, builder.getLoc())
152 processInGroupLinearIndex),
153 resultSliceAxisSize);
154 SmallVector<OpFoldResult> strides(
156 Value slice = tensor::ExtractSliceOp::create(builder, op.getOperand(),
157 offsets, sizes, strides);
159 tensor::CastOp::create(builder, op.getResult().getType(), slice);
160 rewriter.replaceAllUsesWith(op.getResult(), newResult);
170 patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
175 registry.
insert<affine::AffineDialect, shard::ShardDialect>();
180 patterns.add<AllSliceOpLowering>(symbolTableCollection,
185 registry.
insert<affine::AffineDialect, arith::ArithDialect,
186 cf::ControlFlowDialect, shard::ShardDialect,
187 tensor::TensorDialect>();
205 GridShapeOp::create(builder, grid, axes).getResults();
207 builder, builder.
getLoc(), llvm::to_vector_of<Value>(gridShape),
216 GridShapeOp::create(builder, grid, gridAxes).getResult();
218 llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
219 llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
220 auto res = dyn_cast<Value>(processInGroupLinearIndex);
224 cast<IntegerAttr>(cast<Attribute>(processInGroupLinearIndex)).getInt());
225 return cast<TypedValue<IndexType>>(res);
232 grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class represents a single result from folding an operation.
ResultRange result_range
Support result iteration.
This class implements the result iterators for the Operation class.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
void populateAllOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerAllSliceOpLoweringDialects(DialectRegistry ®istry)
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerAllOpLoweringDialects(DialectRegistry ®istry)
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
TypedValue< IndexType > createCollectiveProcessGroupSize(GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder)
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)