MLIR 23.0.0git
ShardToMPI.cpp
Go to the documentation of this file.
1//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation of Shard communication ops to MPI ops.
10//
11//===----------------------------------------------------------------------===//
12
15
33#include "mlir/IR/AffineMap.h"
34#include "mlir/IR/Builders.h"
38#include "mlir/IR/SymbolTable.h"
41
42#define DEBUG_TYPE "shard-to-mpi"
43
44namespace mlir {
45#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
46#include "mlir/Conversion/Passes.h.inc"
47} // namespace mlir
48
49using namespace mlir;
50using namespace shard;
51
52namespace {
53/// Converts a vector of OpFoldResults (ints) into vector of Values of the
54/// provided type.
57 ValueRange dynamics,
58 Type type = Type()) {
59 SmallVector<Value> values;
60 auto dyn = dynamics.begin();
61 Type i64 = b.getI64Type();
62 if (!type)
63 type = i64;
64 assert((i64 == type || b.getIndexType() == type) &&
65 "expected an i64 or an intex type");
66 for (auto s : statics) {
67 if (s == ShapedType::kDynamic) {
68 values.emplace_back(*(dyn++));
69 } else {
70 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
71 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
72 }
73 }
74 return values;
75}
76
77/// Create operations converting a linear index to a multi-dimensional index.
78[[maybe_unused]] static SmallVector<Value>
79linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex,
80 ValueRange dimensions) {
81 int n = dimensions.size();
82 SmallVector<Value> multiIndex(n);
83
84 for (int i = n - 1; i >= 0; --i) {
85 multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
86 if (i > 0)
87 linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
88 }
89
90 return multiIndex;
91}
92
93/// Create operations converting a multi-dimensional index to a linear index.
94Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
95 ValueRange dimensions) {
96
97 Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
98 Value stride = arith::ConstantIndexOp::create(b, loc, 1);
99
100 for (int i = multiIndex.size() - 1; i >= 0; --i) {
101 Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
102 linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
103 stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
104 }
105
106 return linearIndex;
107}
108
109/// Replace GetShardingOp with related/dependent ShardingOp.
110struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
111 using OpConversionPattern::OpConversionPattern;
112
113 LogicalResult
114 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter) const override {
116 auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
117 if (!shardOp)
118 return failure();
119 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
120 if (!shardingOp)
121 return failure();
122
123 rewriter.replaceOp(op, shardingOp.getResult());
124 return success();
125 }
126};
127
128/// Convert a sharding op to a tuple of tensors of its components
129/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
130/// as defined by type converter.
131struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
132 using OpConversionPattern::OpConversionPattern;
133
134 LogicalResult
135 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const override {
137 auto splitAxes = op.getSplitAxes().getAxes();
138 int64_t maxNAxes = 0;
139 for (auto axes : splitAxes)
140 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
141
142 // To hold the split axes, create empty 2d tensor with shape
143 // {splitAxes.size(), max-size-of-split-groups}.
144 // Set trailing elements for smaller split-groups to -1.
145 Location loc = op.getLoc();
146 auto i16 = rewriter.getI16Type();
147 auto i64 = rewriter.getI64Type();
148 std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
149 maxNAxes};
150 Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
151 auto attr = IntegerAttr::get(i16, -1);
152 Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
153 resSplitAxes =
154 linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
155 .getResult(0);
156
157 // explicitly write values into tensor row by row
158 std::array<int64_t, 2> strides = {1, 1};
159 int64_t nSplits = 0;
160 ValueRange empty = {};
161 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
162 int64_t size = axes.size();
163 if (size > 0)
164 ++nSplits;
165 std::array<int64_t, 2> offs = {(int64_t)i, 0};
166 std::array<int64_t, 2> sizes = {1, size};
167 auto tensorType = RankedTensorType::get({size}, i16);
168 auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
169 auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
170 resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
171 resSplitAxes, empty, empty,
172 empty, offs, sizes, strides);
173 }
174
175 // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
176 // Store the halo sizes in the tensor.
177 SmallVector<Value> haloSizes =
178 getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
179 adaptor.getDynamicHaloSizes());
180 auto type = RankedTensorType::get({nSplits, 2}, i64);
181 Value resHaloSizes =
182 haloSizes.empty()
183 ? tensor::EmptyOp::create(rewriter, loc,
184 std::array<int64_t, 2>{0, 0}, i64)
185 .getResult()
186 : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
187 .getResult();
188
189 // To hold sharded dims offsets, create Tensor with shape {nSplits,
190 // maxSplitSize+1}. Store the offsets in the tensor but set trailing
191 // elements for smaller split-groups to -1. Computing the max size of the
192 // split groups needs using collectiveProcessGroupSize (which needs the
193 // GridOp)
194 Value resOffsets;
195 if (adaptor.getStaticShardedDimsOffsets().empty()) {
196 resOffsets = tensor::EmptyOp::create(rewriter, loc,
197 std::array<int64_t, 2>{0, 0}, i64);
198 } else {
199 SymbolTableCollection symbolTableCollection;
200 auto gridOp = getGrid(op, symbolTableCollection);
201 int64_t maxSplitSize = 0;
202 for (auto axes : splitAxes) {
203 int64_t splitSize =
204 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
205 assert(splitSize != ShapedType::kDynamic);
206 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
207 }
208 assert(maxSplitSize);
209 ++maxSplitSize; // add one for the total size
210
211 resOffsets = tensor::EmptyOp::create(
212 rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
213 Value zero = arith::ConstantOp::create(
214 rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
215 resOffsets =
216 linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
217 SmallVector<Value> offsets =
218 getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
219 adaptor.getDynamicShardedDimsOffsets());
220 int64_t curr = 0;
221 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
222 int64_t splitSize =
223 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
224 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
225 ++splitSize; // add one for the total size
226 ArrayRef<Value> values(&offsets[curr], splitSize);
227 Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
228 std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
229 std::array<int64_t, 2> sizes = {1, splitSize};
230 resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
231 resOffsets, empty, empty,
232 empty, offs, sizes, strides);
233 curr += splitSize;
234 }
235 }
236
237 // return a tuple of tensors as defined by type converter
238 SmallVector<Type> resTypes;
239 if (failed(getTypeConverter()->convertType(op.getResult().getType(),
240 resTypes)))
241 return failure();
242
243 resSplitAxes =
244 tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
245 resHaloSizes =
246 tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
247 resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
248
249 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
250 op, TupleType::get(op.getContext(), resTypes),
251 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
252
253 return success();
254 }
255};
256
257class ConvertProcessLinearIndexOp
258 : public OpConversionPattern<ProcessLinearIndexOp> {
259
260public:
261 using OpConversionPattern::OpConversionPattern;
262
263 LogicalResult
264 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override {
266 // Create mpi::CommRankOp
267 Location loc = op.getLoc();
268 auto *ctx = op.getContext();
269 Value commWorld =
270 mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
271 auto rank = mpi::CommRankOp::create(
272 rewriter, loc,
273 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
274 commWorld)
275 .getRank();
276 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
277 rank);
278 return success();
279 }
280};
281
282struct ConvertNeighborsLinearIndicesOp
283 : public OpConversionPattern<NeighborsLinearIndicesOp> {
284 using OpConversionPattern::OpConversionPattern;
285
286 LogicalResult
287 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289
290 // Computes the neighbors indices along a split axis by simply
291 // adding/subtracting 1 to the current index in that dimension.
292 // Assigns -1 if neighbor is out of bounds.
293
294 auto axes = adaptor.getSplitAxes();
295 // For now only single axis sharding is supported
296 if (axes.size() != 1)
297 return failure();
298
299 Location loc = op.getLoc();
300 SymbolTableCollection symbolTableCollection;
301 auto gridOp = getGrid(op, symbolTableCollection);
302 auto mIdx = adaptor.getDevice();
303 auto orgIdx = mIdx[axes[0]];
304 SmallVector<Value> dims;
305 llvm::transform(
306 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
307 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
308 });
309 Value dimSz = dims[axes[0]];
310 Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
311 Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
312 Value atBorder =
313 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
314 arith::ConstantIndexOp::create(rewriter, loc, 0));
315 auto down = scf::IfOp::create(
316 rewriter, loc, atBorder,
317 [&](OpBuilder &builder, Location loc) {
318 scf::YieldOp::create(builder, loc, minus1);
319 },
320 [&](OpBuilder &builder, Location loc) {
321 SmallVector<Value> tmp = mIdx;
322 tmp[axes[0]] =
323 arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
324 .getResult();
325 scf::YieldOp::create(builder, loc,
326 multiToLinearIndex(loc, rewriter, tmp, dims));
327 });
328 atBorder = arith::CmpIOp::create(
329 rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
330 arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
331 auto up = scf::IfOp::create(
332 rewriter, loc, atBorder,
333 [&](OpBuilder &builder, Location loc) {
334 scf::YieldOp::create(builder, loc, minus1);
335 },
336 [&](OpBuilder &builder, Location loc) {
337 SmallVector<Value> tmp = mIdx;
338 tmp[axes[0]] =
339 arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
340 scf::YieldOp::create(builder, loc,
341 multiToLinearIndex(loc, rewriter, tmp, dims));
342 });
343 rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
344 return success();
345 }
346};
347
348struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
349 using OpConversionPattern::OpConversionPattern;
350
351 LogicalResult
352 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override {
354 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
355 if (!sharding) {
356 return op->emitError()
357 << "Expected ShardingOp as defining op for sharding"
358 << " but found " << adaptor.getSharding()[0].getDefiningOp();
359 }
360
361 // Compute the sharded shape by applying the sharding to the input shape.
362 // If shardedDimsOffsets is not defined in the sharding, the shard shape is
363 // computed by dividing the dimension size by the number of shards in that
364 // dimension (which is given by the size of the grid axes provided in
365 // split-axes). Odd elements get distributed to trailing shards. If a
366 // shardedDimsOffsets is provided, the shard shape is computed by
367 // subtracting the offset of the current shard from the offset of the next
368 // shard.
369
370 Location loc = op.getLoc();
371 Type index = rewriter.getIndexType();
372
373 // This is a 1:N conversion because the sharding op is a 1:3 conversion.
374 // The operands in the adaptor are a vector<ValeRange>. For dims and device
375 // we have a 1:1 conversion.
376 // For simpler access fill a vector with the dynamic dims.
377 SmallVector<Value> dynDims, dynDevice;
378 for (auto dim : adaptor.getDimsDynamic()) {
379 // type conversion should be 1:1 for ints
380 dynDims.emplace_back(llvm::getSingleElement(dim));
381 }
382 // same for device
383 for (auto device : adaptor.getDeviceDynamic()) {
384 dynDevice.emplace_back(llvm::getSingleElement(device));
385 }
386
387 // To keep the code simple, convert dims/device to values when they are
388 // attributes. Count on canonicalization to fold static values.
389 SmallVector<Value> shape =
390 getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
391 SmallVector<Value> multiIdx =
392 getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
393
394 // Get the GridOp, the grid shape is needed to compute the sharded shape.
395 SymbolTableCollection symbolTableCollection;
396 auto gridOp = getGrid(sharding, symbolTableCollection);
397 // For now we only support static grid shapes
398 if (ShapedType::isDynamicShape(gridOp.getShape()))
399 return failure();
400
401 auto splitAxes = sharding.getSplitAxes().getAxes();
402 // shardedDimsOffsets are optional and might be Values (not attributes).
403 // Also, the shardId might be dynamic which means the position in the
404 // shardedDimsOffsets is not statically known. Create a tensor of the
405 // shardedDimsOffsets and later extract the offsets for computing the
406 // local shard-size.
407 Value shardedDimsOffs;
408 {
409 SmallVector<Value> tmp = getMixedAsValues(
410 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
411 sharding.getDynamicShardedDimsOffsets(), index);
412 if (!tmp.empty())
413 shardedDimsOffs = tensor::FromElementsOp::create(
414 rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
415 tmp);
416 }
417
418 // With static grid shape the sizes of the split axes are known.
419 // Hence the start/pos for each split axes in shardDimsOffsets can be
420 // computed statically.
421 int64_t pos = 0;
422 SmallVector<Value> shardShape;
423 Value zero =
424 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
425 Value one =
426 arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
427
428 // Iterate over the dimensions of the tensor shape, get their split Axes,
429 // and compute the sharded shape.
430 for (auto [i, dim] : llvm::enumerate(shape)) {
431 // Trailing dimensions might not be annotated.
432 if (i < splitAxes.size() && !splitAxes[i].empty()) {
433 auto axes = splitAxes[i];
434 // The current dimension might not be sharded.
435 // Create a value from the static position in shardDimsOffsets.
436 Value posVal = arith::ConstantOp::create(rewriter, loc,
437 rewriter.getIndexAttr(pos));
438 // Get the index of the local shard in the grid axis.
439 Value idx = multiIdx[axes[0]];
440 auto numShards =
441 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
442 if (shardedDimsOffs) {
443 // If sharded dims offsets are provided, use them to compute the
444 // sharded shape.
445 if (axes.size() > 1) {
446 return op->emitError() << "Only single axis sharding is "
447 << "supported for each dimension.";
448 }
449 idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
450 // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
451 Value off =
452 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
453 idx = arith::AddIOp::create(rewriter, loc, idx, one);
454 Value nextOff =
455 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
456 Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
457 shardShape.emplace_back(sz);
458 } else {
459 Value numShardsVal = arith::ConstantOp::create(
460 rewriter, loc, rewriter.getIndexAttr(numShards));
461 // Compute shard dim size by distributing odd elements to trailing
462 // shards:
463 // sz = dim / numShards
464 // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
465 Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
466 Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
467 sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
468 auto cond = arith::CmpIOp::create(
469 rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
470 Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
471 sz = arith::AddIOp::create(rewriter, loc, sz, odd);
472 shardShape.emplace_back(sz);
473 }
474 pos += numShards + 1; // add one for the total size.
475 } // else no sharding if split axis is empty or no split axis
476 // If no size was added -> no sharding in this dimension.
477 if (shardShape.size() <= i)
478 shardShape.emplace_back(dim);
479 }
480 assert(shardShape.size() == shape.size());
481 rewriter.replaceOp(op, shardShape);
482 return success();
483 }
484};
485
486static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
487 auto *ctx = kind.getContext();
488 auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
489 return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
490 };
491
492 switch (kind.getValue()) {
493 case ReductionKind::Sum:
494 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
495 case ReductionKind::Product:
496 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
497 case ReductionKind::Min:
498 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
499 case ReductionKind::Max:
500 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
501 case ReductionKind::BitwiseAnd:
502 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
503 case ReductionKind::BitwiseOr:
504 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
505 case ReductionKind::BitwiseXor:
506 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
507 default:
508 llvm_unreachable("Unknown/unsupported reduction kind");
509 }
510}
511
512template <typename CommOp>
513struct CommOpPattern : public OpConversionPattern<CommOp> {
514 using OpConversionPattern<CommOp>::OpConversionPattern;
515
516 MemRefType getMemrefType(ShapedType tensorType) const {
517 return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
518 }
519
520 Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
521 auto itype = input.getType();
522 // If the source is a memref, cast it to a tensor.
523 if (isa<RankedTensorType>(itype)) {
524 auto memrefType = getMemrefType(cast<ShapedType>(itype));
525 input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
526 } else {
527 assert(isa<MemRefType>(itype) &&
528 "expected input to be of MemRefType or TensorType");
529 }
530 return input;
531 }
532
533 FailureOr<GridOp> checkGrid(CommOp op,
534 SymbolTableCollection &symbolTableCollection,
535 bool allowDynamic = false) const {
536 GridOp gridOp = getGrid(op, symbolTableCollection);
537 if (!gridOp)
538 return op->emitError() << "Missing grid symbol.";
539 if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
540 return op->emitError() << "Dynamic grid shape not supported.";
541 return gridOp;
542 }
543
544 // Get an MPI_Comm_split for a given grid and axes.
545 // The color is the linear index of the process in the grid along the
546 // non-'grid-axes'. The key is the linear index of the process in the grid
547 // along the grid-axes.
548 Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
549 ImplicitLocOpBuilder &iBuilder) const {
550 size_t gridDims = gridOp.getShape().size();
551 auto commType = mpi::CommType::get(gridOp->getContext());
552 Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
553
554 if (gridAxes.empty() || gridAxes.size() >= gridDims) {
555 return commWorld;
556 }
557
558 SmallVector<GridAxis> otherAxes;
559 for (GridAxis i = 0; i < static_cast<GridAxis>(gridDims); ++i) {
560 if (!llvm::is_contained(gridAxes, i))
561 otherAxes.emplace_back(i);
562 }
563
564 SmallVector<Type> indexResultTypes(otherAxes.size(),
565 iBuilder.getIndexType());
566
567 Value color =
568 createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
569 color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
570
571 Value key =
572 createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
573 key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
574
575 // Finally split the communicator
576 return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
577 .getNewcomm();
578 }
579};
580
581struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
582 using CommOpPattern::CommOpPattern;
583
584 LogicalResult
585 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter) const override {
587 SymbolTableCollection symbolTableCollection;
588 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
589 if (failed(gridOp))
590 return failure();
591 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
592 Value input = getAsMemref(adaptor.getInput(), iBuilder);
593 MemRefType inType = cast<MemRefType>(input.getType());
595 return op.emitError(
596 "Expected static shaped memref in contiguous row-major layout.");
597 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
599 return op.emitError(
600 "Expected static shaped memref in contiguous row-major layout.");
601
602 // Allocate buffer and copy input to buffer.
603 Value buffer = memref::AllocOp::create(iBuilder, outType);
604 linalg::CopyOp::create(iBuilder, input, buffer);
605 // Get the right communicator
606 Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
607 // Create the MPI AllReduce operation.
608 mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
609 getMPIReductionOp(adaptor.getReductionAttr()),
610 comm);
611
612 // If the destination is a tensor, cast it to a tensor
613 if (isa<RankedTensorType>(op.getType()))
614 buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
615 true);
616 rewriter.replaceOp(op, buffer);
617 return success();
618 }
619};
620
621struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
622 using CommOpPattern::CommOpPattern;
623
624 // shard.allgather concatenates along a specified gather-axis.
625 // mpi.allgather always concatenates along the first dimension and
626 // there is no MPI operation that allows gathering along an arbitrary axis.
627 // Hence, if gather-axis != 0, we need to permute the output buffer
628 // accordingly.
629
630 LogicalResult
631 matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
632 ConversionPatternRewriter &rewriter) const override {
633 SymbolTableCollection symbolTableCollection;
634 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
635 if (failed(gridOp))
636 return failure();
637
638 ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
639 Value input = getAsMemref(adaptor.getInput(), ib);
640 MemRefType inType = cast<MemRefType>(input.getType());
641 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
642 auto inputShape = inType.getShape();
643 auto outputShape = outType.getShape();
644 int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
645 int64_t inputDimOnAxis = inputShape[gatherAxis];
646 int64_t outputDimOnAxis = outputShape[gatherAxis];
647
648 for (size_t i = 0; i < outputShape.size(); ++i)
649 if (outputShape[i] != inputShape[i] && i != (size_t)gatherAxis)
650 return op.emitError(
651 "Result and input shapes must match along non-gather axes.");
652 if (inputDimOnAxis == 0)
653 return op.emitError("Input size along the gather axis must be non-zero.");
654 if (inputDimOnAxis == 1) {
655 assert(outputDimOnAxis == inputDimOnAxis);
656 rewriter.replaceOp(op, adaptor.getInput());
657 return success();
658 }
659 if (outputDimOnAxis % inputDimOnAxis != 0)
660 return op.emitError("Result size along the gather axis must be an exact "
661 "multiple of the input size along the gather axis.");
662
665 return op.emitError("Input/result must be statically shaped memrefs in "
666 "contiguous row-major layout.");
667
668 // Get the right communicator.
669 Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib);
670 Value nRanksV =
671 mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize();
672 nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV);
673 int64_t nRanks = outputDimOnAxis / inputDimOnAxis;
674 Value nRanksC = arith::ConstantIndexOp::create(ib, nRanks);
675 Value notError =
676 arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC);
677 cf::AssertOp::create(ib, notError,
678 "Expected number of ranks in the communicator to "
679 "match the output size along the gather axis divided "
680 "by the input size along the gather axis.");
681
682 // mpi.allgather always concatenates along the first dimension, so
683 // get a output buffer of shape {nRanks, dim0, ...}.
684 SmallVector<int64_t> gatherShape;
685 gatherShape.emplace_back(nRanks);
686 gatherShape.append(inputShape.begin(), inputShape.end());
687 auto gatherType = MemRefType::get(gatherShape, outType.getElementType());
688 Value finalOutput = memref::AllocOp::create(ib, gatherType);
689 // Create the MPI AllGather operation.
690 mpi::AllGatherOp::create(ib, TypeRange(), input, finalOutput, comm);
691
692 if (gatherAxis == 0) {
693 // If gather axis == 0, simply collapse the first 2 dims from {nRanks,
694 // dim0, ...} to {nRanks*dim0, ...}.
695 SmallVector<ReassociationIndices> reassociation;
696 reassociation.push_back({0, 1});
697 int64_t numGatherDims = gatherShape.size();
698 for (int64_t i = 2; i < numGatherDims; ++i)
699 reassociation.push_back({i});
700 finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput,
701 reassociation);
702
703 // If the op's result is a tensor, cast it to a tensor.
704 if (isa<RankedTensorType>(op.getType()))
705 finalOutput = bufferization::ToTensorOp::create(ib, op.getType(),
706 finalOutput, true);
707 } else {
708 // 1. Enter tensor-land.
709 auto inType =
710 RankedTensorType::get(gatherShape, outType.getElementType());
711 finalOutput =
712 bufferization::ToTensorOp::create(ib, inType, finalOutput, true);
713
714 // 2. Permute the output buffer from {nRanks, dim0, ..., gatherAxis, ...}
715 // to {dim0, ..., nRanks, dim1,...}.
716 SmallVector<int64_t> outShapePermuted, permutation;
717 for (int i = 1; i <= gatherAxis; ++i) {
718 outShapePermuted.emplace_back(gatherShape[i]);
719 permutation.emplace_back(i);
720 }
721 outShapePermuted.emplace_back(gatherShape[0]);
722 permutation.emplace_back(0);
723 for (size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) {
724 outShapePermuted.emplace_back(gatherShape[i]);
725 permutation.emplace_back(i);
726 }
727 Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted,
728 outType.getElementType());
729 finalOutput =
730 linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation)
731 ->getResult(0);
732
733 // 3. Collapse the output buffer from {dim0, ..., nRanks, gatherAxis, ...}
734 // to {dim0, ..., nRanks*gatherAxis, ...}.
735 SmallVector<ReassociationIndices> reassociation;
736 for (int64_t i = 0; i < gatherAxis; ++i) {
737 reassociation.push_back({i});
738 }
739 reassociation.push_back({gatherAxis, gatherAxis + 1});
740 for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size();
741 ++i) {
742 reassociation.push_back({i});
743 }
744 auto outTType =
745 RankedTensorType::get(outputShape, outType.getElementType());
746 finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput,
747 reassociation);
748
749 // 4. Cast back to memref if needed.
750 if (isa<MemRefType>(op.getType()))
751 finalOutput =
752 bufferization::ToBufferOp::create(ib, outType, finalOutput);
753 }
754
755 rewriter.replaceOp(op, finalOutput);
756 return success();
757 }
758};
759
760struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
761 using OpConversionPattern::OpConversionPattern;
762
763 LogicalResult
764 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
765 ConversionPatternRewriter &rewriter) const override {
766
767 // The input/output memref is assumed to be in C memory order.
768 // Halos are exchanged as 2 blocks per dimension (one for each side: down
769 // and up). For each haloed dimension `d`, the exchanged blocks are
770 // expressed as multi-dimensional subviews. The subviews include potential
771 // halos of higher dimensions `dh > d`, no halos for the lower dimensions
772 // `dl < d` and for dimension `d` the currently exchanged halo only.
773 // By iterating form higher to lower dimensions this also updates the halos
774 // in the 'corners'.
775 // memref.subview is used to read and write the halo data from and to the
776 // local data. Because subviews and halos can have mixed dynamic and static
777 // shapes, OpFoldResults are used whenever possible.
778
779 auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
780 adaptor.getHaloSizes(), rewriter);
781 if (haloSizes.empty()) {
782 // no halos -> nothing to do
783 rewriter.replaceOp(op, adaptor.getDestination());
784 return success();
785 }
786
787 SymbolTableCollection symbolTableCollection;
788 Location loc = op.getLoc();
789
790 // convert a OpFoldResult into a Value
791 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
792 if (auto value = dyn_cast<Value>(v))
793 return value;
794 return arith::ConstantOp::create(
795 rewriter, loc,
796 rewriter.getIndexAttr(
797 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
798 };
799
800 auto dest = adaptor.getDestination();
801 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
802 Value array = dest;
803 if (isa<RankedTensorType>(array.getType())) {
804 // If the destination is a memref, we need to cast it to a tensor
805 auto mmemrefType = MemRefType::get(
806 dstShape, cast<ShapedType>(array.getType()).getElementType());
807 array =
808 bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
809 }
810 auto rank = cast<ShapedType>(array.getType()).getRank();
811 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
812 auto grid = adaptor.getGrid();
813 auto gridOp = getGrid(op, symbolTableCollection);
814 // subviews need Index values
815 for (auto &sz : haloSizes) {
816 if (auto value = dyn_cast<Value>(sz))
817 sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
818 value)
819 .getResult();
820 }
821
822 // most of the offset/size/stride data is the same for all dims
823 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
824 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
825 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
826 auto currHaloDim = -1; // halo sizes are provided for split dimensions only
827 // we need the actual shape to compute offsets and sizes
828 for (auto i = 0; i < rank; ++i) {
829 auto s = dstShape[i];
830 if (ShapedType::isDynamic(s))
831 shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
832 else
833 shape[i] = rewriter.getIndexAttr(s);
834
835 if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
836 ++currHaloDim;
837 // the offsets for lower dim sstarts after their down halo
838 offsets[i] = haloSizes[currHaloDim * 2];
839
840 // prepare shape and offsets of highest dim's halo exchange
841 Value _haloSz = arith::AddIOp::create(
842 rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
843 toValue(haloSizes[currHaloDim * 2 + 1]));
844 // the halo shape of lower dims exlude the halos
845 dimSizes[i] =
846 arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
847 .getResult();
848 } else {
849 dimSizes[i] = shape[i];
850 }
851 }
852
853 auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
854 auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
855 auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
856 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
857
858 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
859 rewriter.getIndexType());
860 auto myMultiIndex =
861 ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
862 .getResult();
863 // traverse all split axes from high to low dim
864 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
865 auto splitAxes = opSplitAxes[dim];
866 if (splitAxes.empty())
867 continue;
868 assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
869 // Get the linearized ids of the neighbors (down and up) for the
870 // given split
871 auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
872 myMultiIndex, splitAxes)
873 .getResults();
874 // MPI operates on i32...
875 Value neighbourIDs[2] = {
876 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
877 tmp[0]),
878 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
879 tmp[1])};
880
881 auto lowerRecvOffset = rewriter.getIndexAttr(0);
882 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
883 auto upperRecvOffset =
884 arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
885 toValue(haloSizes[currHaloDim * 2 + 1]));
886 auto upperSendOffset = arith::SubIOp::create(
887 rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
888
889 Value commWorld = mpi::CommWorldOp::create(
890 rewriter, loc, mpi::CommType::get(op->getContext()));
891
892 // Make sure we send/recv in a way that does not lead to a dead-lock.
893 // The current approach is by far not optimal, this should be at least
894 // be a red-black pattern or using MPI_sendrecv.
895 // Also, buffers should be re-used.
896 // Still using temporary contiguous buffers for MPI communication...
897 // Still yielding a "serialized" communication pattern...
898 auto genSendRecv = [&](bool upperHalo) {
899 auto orgOffset = offsets[dim];
900 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
901 : haloSizes[currHaloDim * 2];
902 // Check if we need to send and/or receive
903 // Processes on the grid borders have only one neighbor
904 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
905 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
906 auto hasFrom = arith::CmpIOp::create(
907 rewriter, loc, arith::CmpIPredicate::sge, from, zero);
908 auto hasTo = arith::CmpIOp::create(rewriter, loc,
909 arith::CmpIPredicate::sge, to, zero);
910 auto buffer = memref::AllocOp::create(
911 rewriter, loc, dimSizes,
912 cast<ShapedType>(array.getType()).getElementType());
913 // if has neighbor: copy halo data from array to buffer and send
914 scf::IfOp::create(
915 rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
916 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
917 : OpFoldResult(upperSendOffset);
918 auto subview = memref::SubViewOp::create(
919 builder, loc, array, offsets, dimSizes, strides);
920 memref::CopyOp::create(builder, loc, subview, buffer);
921 mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
922 commWorld);
923 scf::YieldOp::create(builder, loc);
924 });
925 // if has neighbor: receive halo data into buffer and copy to array
926 scf::IfOp::create(
927 rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
928 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
929 : OpFoldResult(lowerRecvOffset);
930 mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
931 commWorld);
932 auto subview = memref::SubViewOp::create(
933 builder, loc, array, offsets, dimSizes, strides);
934 memref::CopyOp::create(builder, loc, buffer, subview);
935 scf::YieldOp::create(builder, loc);
936 });
937 memref::DeallocOp::create(rewriter, loc, buffer);
938 offsets[dim] = orgOffset;
939 };
940
941 auto doSendRecv = [&](int upOrDown) {
942 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
943 Value haloSz = dyn_cast<Value>(v);
944 if (!haloSz)
945 haloSz = arith::ConstantOp::create(
946 rewriter, loc,
947 rewriter.getI32IntegerAttr(
948 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
949 auto hasSize = arith::CmpIOp::create(
950 rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
951 scf::IfOp::create(rewriter, loc, hasSize,
952 [&](OpBuilder &builder, Location loc) {
953 genSendRecv(upOrDown > 0);
954 scf::YieldOp::create(builder, loc);
955 });
956 };
957
958 doSendRecv(0);
959 doSendRecv(1);
960
961 // the shape for lower dims include higher dims' halos
962 dimSizes[dim] = shape[dim];
963 // -> the offset for higher dims is always 0
964 offsets[dim] = rewriter.getIndexAttr(0);
965 // on to next halo
966 --currHaloDim;
967 }
968
969 if (isa<MemRefType>(op.getResult().getType())) {
970 rewriter.replaceOp(op, array);
971 } else {
972 assert(isa<RankedTensorType>(op.getResult().getType()));
973 rewriter.replaceOp(op, bufferization::ToTensorOp::create(
974 rewriter, loc, op.getResult().getType(), array,
975 /*restrict=*/true, /*writable=*/true));
976 }
977 return success();
978 }
979};
980
981struct ConvertShardToMPIPass
982 : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
983 using Base::Base;
984
985 /// Run the dialect converter on the module.
986 void runOnOperation() override {
987 auto *ctxt = &getContext();
988 RewritePatternSet patterns(ctxt);
989 ConversionTarget target(getContext());
990
991 // Define a type converter to convert shard::ShardingType,
992 // mostly for use in return operations.
993 TypeConverter typeConverter;
994 typeConverter.addConversion([](Type type) { return type; });
995
996 // convert shard::ShardingType to a tuple of RankedTensorTypes
997 typeConverter.addConversion(
998 [](ShardingType type,
999 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1000 auto i16 = IntegerType::get(type.getContext(), 16);
1001 auto i64 = IntegerType::get(type.getContext(), 64);
1002 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
1003 ShapedType::kDynamic};
1004 results.emplace_back(RankedTensorType::get(shp, i16));
1005 results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
1006 results.emplace_back(RankedTensorType::get(shp, i64));
1007 return success();
1008 });
1009
1010 // To 'extract' components, a UnrealizedConversionCastOp is expected
1011 // to define the input
1012 typeConverter.addTargetMaterialization(
1013 [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
1014 Location loc) {
1015 // Expecting a single input.
1016 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
1017 return SmallVector<Value>();
1018 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
1019 // Expecting an UnrealizedConversionCastOp.
1020 if (!castOp)
1021 return SmallVector<Value>();
1022 // Fill a vector with elements of the tuple/castOp.
1023 SmallVector<Value> results;
1024 for (auto oprnd : castOp.getInputs()) {
1025 if (!isa<RankedTensorType>(oprnd.getType()))
1026 return SmallVector<Value>();
1027 results.emplace_back(oprnd);
1028 }
1029 return results;
1030 });
1031
1032 // No shard dialect should left after conversion...
1033 target.addIllegalDialect<shard::ShardDialect>();
1034 // ...except the global GridOp. GridShapeOp which will get folded later.
1035 target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
1036 // Allow all the stuff that our patterns will convert to
1037 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
1038 arith::ArithDialect, tensor::TensorDialect,
1039 bufferization::BufferizationDialect,
1040 linalg::LinalgDialect, memref::MemRefDialect,
1041 affine::AffineDialect, cf::ControlFlowDialect>();
1042 // Make sure the function signature, calls etc. are legal
1043 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1044 return typeConverter.isSignatureLegal(op.getFunctionType());
1045 });
1046 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
1047 [&](Operation *op) { return typeConverter.isLegal(op); });
1048
1049 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
1050 ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
1051 ConvertAllGatherOp, ConvertAllReduceOp,
1052 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
1053 SymbolTableCollection stc;
1056
1057 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1058 patterns, typeConverter);
1061
1062 (void)applyPartialConversion(getOperation(), target, std::move(patterns));
1063
1064 // Folding patterns cannot be mixed with conversion patterns -> extra pass.
1065 patterns.clear();
1066 SymbolTableCollection symbolTableCollection;
1067 mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
1068 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1069 }
1070};
1071
1072} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition ShardOps.cpp:214
IntegerType getI32Type()
Definition Builders.cpp:67
IndexType getIndexType()
Definition Builders.cpp:55
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
This class helps build Operations.
Definition Builders.h:209
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
int16_t GridAxis
Definition ShardOps.h:27
TypedValue< IndexType > createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={})
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:156
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:131
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns