27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
38 struct ProcessMultiIndexOpLowering
39 : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
43 LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
44 PatternRewriter &rewriter)
const override {
50 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
51 builder.setInsertionPointAfter(op.getOperation());
52 Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
53 ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
54 SmallVector<Value> completeMultiIndex =
55 builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
57 SmallVector<Value> multiIndex;
58 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
59 SmallVector<MeshAxis> opAxesIota;
60 if (opMeshAxes.empty()) {
61 opAxesIota.resize(mesh.getRank());
62 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
63 opMeshAxes = opAxesIota;
65 llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
66 [&completeMultiIndex](
MeshAxis meshAxis) {
67 return completeMultiIndex[meshAxis];
69 rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
74 struct AllSliceOpLowering
75 : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
79 LogicalResult matchAndRewrite(AllSliceOp op,
80 PatternRewriter &rewriter)
const override {
101 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
102 builder.setInsertionPointAfter(op.getOperation());
104 Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
107 builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
111 builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
113 Value processGroupSize =
116 int64_t sliceAxis = op.getSliceAxis().getSExtValue();
117 Value operandSliceAxisSize =
118 builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
119 Value operandSliceAxisSizeModProcessGroupSize =
120 builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
121 Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
122 arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
124 builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
125 "Slicing a tensor with axis size that is "
126 "not exactly divisible by the "
127 "mesh process group size is not supported.");
128 Value resultSliceAxisSize =
129 builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
131 llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
132 llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
135 RankedTensorType operandType =
136 cast<RankedTensorType>(op.getOperand().getType());
137 SmallVector<OpFoldResult> sizes;
138 for (int64_t i = 0; i < operandType.getRank(); ++i) {
139 if (i == sliceAxis) {
140 sizes.emplace_back(resultSliceAxisSize);
142 Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
143 sizes.emplace_back(dimSize);
146 SmallVector<OpFoldResult> offsets(
149 ArithBuilder(builder, builder.getLoc())
151 processInGroupLinearIndex),
152 resultSliceAxisSize);
153 SmallVector<OpFoldResult> strides(
155 Value slice = builder.create<tensor::ExtractSliceOp>(
156 op.getOperand(), offsets, sizes, strides);
158 builder.create<tensor::CastOp>(op.getResult().getType(), slice);
159 rewriter.replaceAllUsesWith(op.getResult(), newResult);
169 patterns.
add<ProcessMultiIndexOpLowering>(symbolTableCollection,
174 registry.
insert<affine::AffineDialect, mesh::MeshDialect>();
179 patterns.
add<AllSliceOpLowering>(symbolTableCollection,
184 registry.
insert<affine::AffineDialect, arith::ArithDialect,
185 cf::ControlFlowDialect, mesh::MeshDialect,
186 tensor::TensorDialect>();
204 builder.
create<mesh::MeshShapeOp>(mesh, axes).getResults();
206 builder, builder.
getLoc(), llvm::to_vector_of<Value>(meshShape),
214 builder.
create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
216 builder.
create<MeshShapeOp>(mesh, meshAxes).getResult();
218 llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
219 llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
220 return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<
Value>());
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class represents a single result from folding an operation.
ResultRange result_range
Support result iteration.
This class implements the result iterators for the Operation class.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents a collection of SymbolTables.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
void registerAllOpLoweringDialects(DialectRegistry ®istry)
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
void registerAllSliceOpLoweringDialects(DialectRegistry ®istry)
TypedValue< IndexType > createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef< MeshAxis > axes, ImplicitLocOpBuilder &builder)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry)
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)