30 #define DEBUG_TYPE "mesh-to-mpi"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
34 #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
35 #include "mlir/Conversion/Passes.h.inc"
46 int n = dimensions.size();
49 for (
int i = n - 1; i >= 0; --i) {
50 multiIndex[i] = b.
create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
52 linearIndex = b.
create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
63 auto linearIndex = b.
create<arith::ConstantIndexOp>(loc, 0).getResult();
64 auto stride = b.
create<arith::ConstantIndexOp>(loc, 1).getResult();
66 for (
int i = multiIndex.size() - 1; i >= 0; --i) {
67 auto off = b.
create<arith::MulIOp>(loc, multiIndex[i], stride);
68 linearIndex = b.
create<arith::AddIOp>(loc, linearIndex, off);
69 stride = b.
create<arith::MulIOp>(loc, stride, dimensions[i]);
75 struct ConvertProcessMultiIndexOp
80 matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
86 auto loc = op.getLoc();
87 auto meshOp =
getMesh(op, symbolTableCollection);
89 if (ShapedType::isDynamicShape(meshOp.getShape())) {
90 return mlir::failure();
95 meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
96 return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
99 rewriter.
create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
100 auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
103 auto axes = op.getAxes();
106 for (
auto axis : axes) {
107 subIndex.push_back(mIdx[axis]);
113 return mlir::success();
117 struct ConvertProcessLinearIndexOp
122 matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
128 auto loc = op.getLoc();
129 auto rankOpName =
StringAttr::get(op->getContext(),
"static_mpi_rank");
130 if (
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
132 if (
auto initTnsr = globalOp.getInitialValueAttr()) {
133 auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
135 rewriter.
create<arith::ConstantIndexOp>(loc, val));
136 return mlir::success();
147 return mlir::success();
151 struct ConvertNeighborsLinearIndicesOp
156 matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
163 auto axes = op.getSplitAxes();
165 if (axes.size() != 1) {
166 return mlir::failure();
169 auto loc = op.getLoc();
171 auto meshOp =
getMesh(op, symbolTableCollection);
172 auto mIdx = op.getDevice();
173 auto orgIdx = mIdx[axes[0]];
176 meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
177 return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
179 auto dimSz = dims[axes[0]];
180 auto one = rewriter.
create<arith::ConstantIndexOp>(loc, 1).getResult();
181 auto minus1 = rewriter.
create<arith::ConstantIndexOp>(loc, -1).getResult();
182 auto atBorder = rewriter.
create<arith::CmpIOp>(
183 loc, arith::CmpIPredicate::sle, orgIdx,
184 rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult());
185 auto down = rewriter.
create<scf::IfOp>(
188 builder.create<scf::YieldOp>(loc, minus1);
193 rewriter.
create<arith::SubIOp>(op.getLoc(), orgIdx, one)
195 builder.create<scf::YieldOp>(
196 loc, multiToLinearIndex(loc, rewriter, tmp, dims));
198 atBorder = rewriter.
create<arith::CmpIOp>(
199 loc, arith::CmpIPredicate::sge, orgIdx,
200 rewriter.
create<arith::SubIOp>(loc, dimSz, one).getResult());
201 auto up = rewriter.
create<scf::IfOp>(
204 builder.create<scf::YieldOp>(loc, minus1);
209 rewriter.
create<arith::AddIOp>(op.getLoc(), orgIdx, one)
211 builder.create<scf::YieldOp>(
212 loc, multiToLinearIndex(loc, rewriter, tmp, dims));
215 return mlir::success();
219 struct ConvertUpdateHaloOp
224 matchAndRewrite(mlir::mesh::UpdateHaloOp op,
240 auto loc = op.getLoc();
246 : rewriter.
create<::mlir::arith::ConstantOp>(
249 cast<IntegerAttr>(v.get<
Attribute>()).getInt()));
252 auto dest = op.getDestination();
253 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
255 if (isa<RankedTensorType>(array.
getType())) {
258 dstShape, cast<ShapedType>(array.
getType()).getElementType());
259 array = rewriter.
create<bufferization::ToMemrefOp>(loc, tensorType, array)
262 auto rank = cast<ShapedType>(array.
getType()).getRank();
263 auto opSplitAxes = op.getSplitAxes().getAxes();
264 auto mesh = op.getMesh();
265 auto meshOp =
getMesh(op, symbolTableCollection);
267 getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
269 for (
auto &sz : haloSizes) {
270 if (sz.is<
Value>()) {
282 auto currHaloDim = -1;
284 for (
auto i = 0; i < rank; ++i) {
285 auto s = dstShape[i];
286 if (ShapedType::isDynamic(s)) {
287 shape[i] = rewriter.
create<memref::DimOp>(loc, array, s).getResult();
292 if ((
size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
295 offsets[i] = haloSizes[currHaloDim * 2];
300 .
create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
301 toValue(haloSizes[currHaloDim * 2 + 1]))
305 rewriter.
create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
308 dimSizes[i] = shape[i];
313 auto tag = rewriter.
create<::mlir::arith::ConstantOp>(loc, tagAttr);
315 auto zero = rewriter.
create<::mlir::arith::ConstantOp>(loc, zeroAttr);
320 rewriter.
create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
323 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
324 auto splitAxes = opSplitAxes[dim];
325 if (splitAxes.empty()) {
328 assert(currHaloDim >= 0 && (
size_t)currHaloDim < haloSizes.size() / 2);
332 .
create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
336 Value neighbourIDs[2] = {rewriter.
create<arith::IndexCastOp>(
338 rewriter.
create<arith::IndexCastOp>(
342 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
343 auto upperRecvOffset = rewriter.
create<arith::SubIOp>(
344 loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
345 auto upperSendOffset = rewriter.
create<arith::SubIOp>(
346 loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
354 auto genSendRecv = [&](
bool upperHalo) {
355 auto orgOffset = offsets[dim];
356 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
357 : haloSizes[currHaloDim * 2];
360 auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
361 auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
362 auto hasFrom = rewriter.
create<arith::CmpIOp>(
363 loc, arith::CmpIPredicate::sge, from, zero);
364 auto hasTo = rewriter.
create<arith::CmpIOp>(
365 loc, arith::CmpIPredicate::sge, to, zero);
366 auto buffer = rewriter.
create<memref::AllocOp>(
367 loc, dimSizes, cast<ShapedType>(array.
getType()).getElementType());
369 rewriter.
create<scf::IfOp>(
371 offsets[dim] = upperHalo ?
OpFoldResult(lowerSendOffset)
373 auto subview = builder.create<memref::SubViewOp>(
374 loc, array, offsets, dimSizes, strides);
375 builder.create<memref::CopyOp>(loc, subview, buffer);
376 builder.create<mpi::SendOp>(loc,
TypeRange{}, buffer, tag, to);
377 builder.create<scf::YieldOp>(loc);
380 rewriter.
create<scf::IfOp>(
382 offsets[dim] = upperHalo ?
OpFoldResult(upperRecvOffset)
384 builder.create<mpi::RecvOp>(loc,
TypeRange{}, buffer, tag, from);
385 auto subview = builder.create<memref::SubViewOp>(
386 loc, array, offsets, dimSizes, strides);
387 builder.create<memref::CopyOp>(loc, buffer, subview);
388 builder.create<scf::YieldOp>(loc);
390 rewriter.
create<memref::DeallocOp>(loc, buffer);
391 offsets[dim] = orgOffset;
398 dimSizes[dim] = shape[dim];
405 if (isa<MemRefType>(op.getResult().getType())) {
408 assert(isa<RankedTensorType>(op.getResult().getType()));
410 loc, op.getResult().getType(), array,
413 return mlir::success();
417 struct ConvertMeshToMPIPass
418 :
public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
422 void runOnOperation()
override {
426 patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
427 ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
438 return std::make_unique<ConvertMeshToMPIPass>();
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::unique_ptr<::mlir::Pass > createConvertMeshToMPIPass()
Lowers Mesh communication operations (updateHalo, AllGater, ...) to MPI primitives.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...