MLIR  19.0.0git
Transforms.cpp
Go to the documentation of this file.
1 //===- Transforms.cpp ---------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TransformsDetail.h"
21 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/OpDefinition.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/Value.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include <iterator>
30 #include <numeric>
31 
32 namespace mlir::mesh {
33 
34 namespace {
35 
36 /// Lower `mesh.process_multi_index` into expression using
37 /// `mesh.process_linear_index` and `mesh.mesh_shape`.
38 struct ProcessMultiIndexOpLowering
39  : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
42 
43  LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
44  PatternRewriter &rewriter) const override {
45  MeshOp mesh = getMesh(op, symbolTableCollection);
46  if (!mesh) {
47  return failure();
48  }
49 
50  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
51  builder.setInsertionPointAfter(op.getOperation());
52  Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
53  ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
54  SmallVector<Value> completeMultiIndex =
55  builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
56  .getMultiIndex();
57  SmallVector<Value> multiIndex;
58  ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
59  SmallVector<MeshAxis> opAxesIota;
60  if (opMeshAxes.empty()) {
61  opAxesIota.resize(mesh.getRank());
62  std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
63  opMeshAxes = opAxesIota;
64  }
65  llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
66  [&completeMultiIndex](MeshAxis meshAxis) {
67  return completeMultiIndex[meshAxis];
68  });
69  rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
70  return success();
71  }
72 };
73 
74 struct AllSliceOpLowering
75  : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
78 
79  LogicalResult matchAndRewrite(AllSliceOp op,
80  PatternRewriter &rewriter) const override {
81  // 1. Compute the process linear index inside the process group from its
82  // multi-index.
83  //
84  // 2. Extract a slice from the input tensor.
85  // All axes except the slicing axis are not interesting and take the full
86  // axis.
87  // The slice axis is split into equisized parts with count
88  // the number of processes in the collective process group induced by
89  // the mesh axes.
90  // The part for each process is determined by the corresponding
91  // linear-index in the process group.
92  //
93  // There are no collectives that require communication.
94  // Each process operates on its local tensor.
95 
96  MeshOp mesh = getMesh(op, symbolTableCollection);
97  if (!mesh) {
98  return failure();
99  }
100 
101  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
102  builder.setInsertionPointAfter(op.getOperation());
103 
104  Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
105 
106  Operation::result_range processInGroupMultiIndex =
107  builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
108  .getResults();
109 
110  Operation::result_range processGroupShape =
111  builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
112  .getResult();
113  Value processGroupSize =
114  createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
115 
116  int64_t sliceAxis = op.getSliceAxis().getSExtValue();
117  Value operandSliceAxisSize =
118  builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
119  Value operandSliceAxisSizeModProcessGroupSize =
120  builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
121  Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
122  arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
123  zero);
124  builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
125  "Slicing a tensor with axis size that is "
126  "not exactly divisible by the "
127  "mesh process group size is not supported.");
128  Value resultSliceAxisSize =
129  builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
130  OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
131  llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
132  llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
133 
134  // insert tensor.extract_slice
135  RankedTensorType operandType =
136  cast<RankedTensorType>(op.getOperand().getType());
137  SmallVector<OpFoldResult> sizes;
138  for (int64_t i = 0; i < operandType.getRank(); ++i) {
139  if (i == sliceAxis) {
140  sizes.emplace_back(resultSliceAxisSize);
141  } else {
142  Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
143  sizes.emplace_back(dimSize);
144  }
145  }
146  SmallVector<OpFoldResult> offsets(
147  operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
148  offsets[sliceAxis] =
149  ArithBuilder(builder, builder.getLoc())
150  .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
151  processInGroupLinearIndex),
152  resultSliceAxisSize);
153  SmallVector<OpFoldResult> strides(
154  operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
155  Value slice = builder.create<tensor::ExtractSliceOp>(
156  op.getOperand(), offsets, sizes, strides);
157  Value newResult =
158  builder.create<tensor::CastOp>(op.getResult().getType(), slice);
159  rewriter.replaceAllUsesWith(op.getResult(), newResult);
160 
161  return success();
162  }
163 };
164 
165 } // namespace
166 
168  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
169  patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
170  patterns.getContext());
171 }
172 
174  registry.insert<affine::AffineDialect, mesh::MeshDialect>();
175 }
176 
178  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
179  patterns.add<AllSliceOpLowering>(symbolTableCollection,
180  patterns.getContext());
181 }
182 
184  registry.insert<affine::AffineDialect, arith::ArithDialect,
185  cf::ControlFlowDialect, mesh::MeshDialect,
186  tensor::TensorDialect>();
187 }
188 
190  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
191  populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection);
192  populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
193 }
194 
198 }
199 
202  ImplicitLocOpBuilder &builder) {
203  Operation::result_range meshShape =
204  builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
205  return cast<TypedValue<IndexType>>(arith::createProduct(
206  builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
207  builder.getIndexType()));
208 }
209 
211  ArrayRef<MeshAxis> meshAxes,
212  ImplicitLocOpBuilder &builder) {
213  ResultRange processInGroupMultiIndex =
214  builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
215  Operation::result_range processGroupShape =
216  builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
217  OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
218  llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
219  llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
220  return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
221 }
222 
223 } // namespace mlir::mesh
IndexType getIndexType()
Definition: Builders.cpp:71
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
ResultRange result_range
Support result iteration.
Definition: Operation.h:405
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:1874
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition: Utils.cpp:269
void registerAllOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:195
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:177
void populateAllOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:189
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:210
void registerAllSliceOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:183
TypedValue< IndexType > createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef< MeshAxis > axes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:201
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:57
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:167
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:173
int16_t MeshAxis
Definition: MeshOps.h:25
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)