MLIR 22.0.0git
ShardingInterface.cpp
Go to the documentation of this file.
1//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
13#include "mlir/IR/AffineMap.h"
14#include "mlir/IR/IRMapping.h"
15#include "mlir/Support/LLVM.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallSet.h"
19#include "llvm/Support/Debug.h"
20
21#include <utility>
22
23#define DEBUG_TYPE "sharding-interface"
24#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25
26using namespace mlir;
27using namespace mlir::shard;
28
29#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc"
30
31//===----------------------------------------------------------------------===//
32// common util functions
33//===----------------------------------------------------------------------===//
34
35static LogicalResult
37 SmallVectorImpl<bool> &seenIds) {
38 switch (expr.getKind()) {
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41 AffineExpr lhs = binOpExpr.getLHS();
42 AffineExpr rhs = binOpExpr.getRHS();
43 if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44 return failure();
45 if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46 return failure();
47 return success();
48 }
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51 AffineExpr lhs = binOpExpr.getLHS();
52 AffineExpr rhs = binOpExpr.getRHS();
53 AffineExpr dimExpr;
54 if (lhs.getKind() == AffineExprKind::DimId &&
55 rhs.getKind() == AffineExprKind::Constant) {
56 dimExpr = lhs;
57 } else if (rhs.getKind() == AffineExprKind::DimId &&
58 lhs.getKind() == AffineExprKind::Constant) {
59 dimExpr = rhs;
60 } else {
61 return failure();
62 }
63 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
64 if ((size_t)position >= seenIds.size() || seenIds[position])
65 return failure();
66 seenIds[position] = true;
67 return success();
68 }
70 unsigned position = cast<AffineDimExpr>(expr).getPosition();
71 if ((size_t)position >= seenIds.size() || seenIds[position])
72 return failure();
73 seenIds[position] = true;
74 return success();
75 }
76 default:
77 return failure();
78 }
79}
80
81static FailureOr<llvm::SmallSet<unsigned, 2>>
82checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
83 SmallVector<bool> seenIds(numDims, false);
84 if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
85 return failure();
86
87 llvm::SmallSet<unsigned, 2> positions;
88 for (auto it : llvm::enumerate(seenIds)) {
89 if (it.value())
90 positions.insert((unsigned)it.index());
91 }
92 return positions;
93}
94
95template <typename T>
99 for (const auto &v : vec) {
100 res.emplace_back(GridAxesAttr::get(ctxt, v));
101 }
102 return res;
103}
104
105//===----------------------------------------------------------------------===//
106// shard::getSharding
107//===----------------------------------------------------------------------===//
108
109FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpResult result) {
110 Value val = cast<Value>(result);
111 bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
112 auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
113 if (!shardOp)
114 return false;
115 return !shardOp.getAnnotateForUsers();
116 });
117
118 if (anyShardedForDef) {
119 // expected to have exact one use if it has a use of `shard.shard` without
120 // unit attr annotate_for_users
121 if (!val.hasOneUse())
122 return failure();
123 auto shardOp = llvm::cast<shard::ShardOp>(*val.getUsers().begin());
124 return std::make_pair(false, Sharding(shardOp.getSharding()));
125 }
126
127 bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
128 auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
129 if (!shardOp)
130 return false;
131 return shardOp.getAnnotateForUsers();
132 });
133 if (anyShardedForUsers) {
134 SmallVector<ShardOp> shardOps;
135 for (Operation *user : val.getUsers()) {
136 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
137 if (shardOp)
138 shardOps.push_back(shardOp);
139 }
140 Sharding shardForDef = shardOps[0].getSharding();
141 for (size_t i = 1; i < shardOps.size(); ++i) {
142 // TODO: Deduce a reasonable grid sharding attr for def when they are
143 // different
144 assert(shardForDef == shardOps[i].getSharding() &&
145 "only support all shard ops have the same grid sharding attr");
146 }
147 return std::make_pair(true, shardForDef);
148 }
149 return failure();
150}
151
152FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpOperand &opOperand) {
153 Value val = opOperand.get();
154 if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
155 return std::make_pair(shardOp.getAnnotateForUsers(),
156 Sharding(shardOp.getSharding()));
157
158 return failure();
159}
160
161//===----------------------------------------------------------------------===//
162// ShardingInterface::verifyShardingInterfaceImpl
163//===----------------------------------------------------------------------===//
164
165LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() {
166 Operation *op = getOperation();
167
168 // check operands and results type
169 for (Type type : op->getOperandTypes())
170 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
171 return failure();
172 for (Type type : op->getResultTypes())
173 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
174 return failure();
175
176 // check maps
177 SmallVector<AffineMap> maps = getIndexingMaps();
178 if (maps.empty())
179 return failure();
180 unsigned numOperands = op->getNumOperands();
181 unsigned numResults = op->getNumResults();
182 if (numOperands + numResults != maps.size())
183 return failure();
184
185 for (OpResult result : op->getResults()) {
186 auto resultType = dyn_cast<RankedTensorType>(result.getType());
187 if (!resultType)
188 return failure();
189 AffineMap map = maps[numOperands + result.getResultNumber()];
190 if (!map.isProjectedPermutation()) {
191 return failure();
192 }
193 }
194
195 return success();
196}
197
198//===----------------------------------------------------------------------===//
199// ShardingInterface::printLoopTypesAndIndexingMaps
200//===----------------------------------------------------------------------===//
201
202void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
203 os << "print loop types and indexing maps for: \n";
204 getOperation()->print(os);
205 os << "\n";
206 os << "loop types: [";
207 for (utils::IteratorType type : getLoopIteratorTypes()) {
208 os << stringifyEnum(type) << " ";
209 }
210 os << "]\n";
211 os << "indexing maps: \n";
212 for (AffineMap map : getIndexingMaps())
213 os << map << "\n";
214 os << "\n";
215}
216
217//===----------------------------------------------------------------------===//
218// detail::defaultGetShardingOption
219//===----------------------------------------------------------------------===//
220
221namespace {
222
223// Update the given `shardingOption` according to `gridAxes` and `loopIdx`
224static LogicalResult fillShardingOption(Operation *op,
225 ShardingOption &shardingOption,
227 ArrayRef<GridAxis> gridAxes,
228 unsigned loopIdx) {
229 if ((shardingOption.grid && grid && shardingOption.grid != grid) ||
230 (!shardingOption.shardingArray[loopIdx].empty() &&
231 shardingOption.shardingArray[loopIdx] != gridAxes)) {
232 LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
233 << loopIdx << "\n");
234 return failure();
235 }
236 for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
237 if (i == loopIdx)
238 continue;
239
240 for (GridAxis axis : gridAxes) {
241 if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
242 LLVM_DEBUG(DBGS() << "sharding option conflicts because grid axes "
243 << axis << " duplicate");
244 return failure();
245 }
246 }
247 }
248 if (grid)
249 shardingOption.grid = grid;
250 if (shardingOption.shardingArray[loopIdx].empty())
251 shardingOption.shardingArray[loopIdx].append(gridAxes.begin(),
252 gridAxes.end());
253 return success();
254}
255
256} // namespace
257
258FailureOr<ShardingOption>
260 ArrayRef<Sharding> operandShardings,
261 ArrayRef<Sharding> resultShardings) {
262 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
263 ShardingOption shardingOption;
264
265 if (failed(shardingOp.verifyShardingInterfaceImpl()))
266 return op->emitOpError() << "invalid sharding interface implementation";
268 shardingOp.getLoopIteratorTypes();
269 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
270 unsigned numOperands = op->getNumOperands();
271 shardingOption.shardingArray.resize(loopTypes.size());
272 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
273 bool anyShardingInResultsOrOperands = false;
274
275 // 1. Fill sharding option based on op results
276 for (auto shardingIt : llvm::enumerate(resultShardings)) {
277 const Sharding &shardAttr = shardingIt.value();
278 if (!shardAttr)
279 continue;
280 AffineMap map = maps[numOperands + shardingIt.index()];
281 anyShardingInResultsOrOperands = true;
282 if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
283 shardingOption.grid = shardAttr.getGridAttr();
284 } else {
285 // Handle the split axes: calculate the corresponding loop index for each
286 // split axes sub-array, and then store the sub-array to
287 // shardingOption[index]
288 for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
289 AffineExpr expr = std::get<0>(it);
290 ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
291 auto dim = cast<AffineDimExpr>(expr);
292 unsigned index = dim.getPosition();
293 visitedLoopIndices.insert(index);
294 if (failed(fillShardingOption(op, shardingOption,
295 shardAttr.getGridAttr(), axes, index)))
296 return failure();
297 }
298 }
299 }
300
301 // 2. Fill sharding option based on operands
302 for (auto shardingIt : llvm::enumerate(operandShardings)) {
303 const Sharding &shardAttr = shardingIt.value();
304 if (!shardAttr)
305 continue;
306
307 anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
308 AffineMap map = maps[shardingIt.index()];
309 unsigned numDims = map.getNumDims();
310
311 // Handle the split axes.
312 //
313 // TODO: Change to process the operands with single loop index first and
314 // then the operands with multiple loop indices.
315 for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
316 AffineExpr expr = std::get<0>(it);
317 ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
318 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
319 checkOperandAffineExpr(expr, numDims);
320 if (failed(loopIndices))
321 return op->emitOpError()
322 << "operand's affine expression is restricted to const_i * "
323 "dim_i + const_j + dim_j + ...";
324 if (loopIndices->empty())
325 continue;
326 if (loopIndices->size() == 1) {
327 unsigned loopIdx = *loopIndices->begin();
328 visitedLoopIndices.insert(loopIdx);
329 if (failed(fillShardingOption(op, shardingOption,
330 shardAttr.getGridAttr(), axes, loopIdx)))
331 return failure();
332 }
333 // If multiple loop indices correspond to a dimension of an operand, it is
334 // difficult to infer which loop indices are responsible for sharding.
335 // Therefore, the exact loop index must be specified by others.
336 if (loopIndices->size() > 1) {
337 bool seenLoopIndices = false;
338 for (unsigned loopIdx : *loopIndices) {
339 if (visitedLoopIndices.contains(loopIdx)) {
340 seenLoopIndices = true;
341 break;
342 }
343 }
344 if (!seenLoopIndices)
345 return op->emitOpError()
346 << "the operand " << shardingIt.index()
347 << " has multiple loop indices in a dimension, but none of "
348 "them could be found in the exactly specified annotation "
349 "of op results or operands.";
350 }
351 }
352 }
353
354 // 3. Finalize sharding option
356 if (!anyShardingInResultsOrOperands)
357 shardingOption.empty = true;
358 return shardingOption;
359}
360
361// Get the sharding attributed for the given result and sharding option.
363 const ShardingOption &shardingOption, AffineMap map,
365 auto resultType = cast<RankedTensorType>(result.getType());
366 SmallVector<SmallVector<GridAxis>> splitAxes(resultType.getRank());
367
368 // process the split axes
369 for (auto it : llvm::enumerate(map.getResults())) {
370 AffineExpr expr = it.value();
371 // `expr` must be an `AffineDimExpr` because `map` is verified by
372 // isProjectedPermutation
373 auto dim = cast<AffineDimExpr>(expr);
374 unsigned loopIdx = dim.getPosition();
375 if (loopIdx < shardingOption.shardingArray.size())
376 splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
377 }
378
380 return Sharding::get(shardingOption.grid,
381 fromArrayOfVector(result.getContext(), splitAxes));
382}
383
384static FailureOr<Sharding> getSharding(OpOperand &opOperand,
385 const ShardingOption &shardingOption,
386 AffineMap map) {
387 Value operandValue = opOperand.get();
388 auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
389 if (!operandType) {
390 if (operandValue.getType().isIntOrIndexOrFloat())
391 return Sharding();
392 return failure();
393 }
394 // 0d tensors cannot be sharded and must get replicated
395 if (operandType.getRank() == 0) {
396 return Sharding(shardingOption.grid);
397 }
398 SmallVector<SmallVector<GridAxis>> splitAxes(operandType.getRank());
399 unsigned numDims = map.getNumDims();
400 for (auto it : llvm::enumerate(map.getResults())) {
401 int64_t idx = it.index();
402 AffineExpr expr = it.value();
403 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
404 checkOperandAffineExpr(expr, numDims);
405 if (failed(loopIndices))
406 return failure();
407 SmallVector<unsigned> shardedLoopIndices;
408 for (unsigned loopIdx : *loopIndices) {
409 if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
410 !shardingOption.shardingArray[loopIdx].empty())
411 shardedLoopIndices.push_back(loopIdx);
412 }
413 // mostly one sharded loop index is accepted
414 if (shardedLoopIndices.size() > 1)
415 return failure();
416 if (shardedLoopIndices.size() == 1) {
417 splitAxes[idx].append(
418 shardingOption.shardingArray[shardedLoopIndices[0]]);
419 }
420 }
421
423 return Sharding::get(
424 shardingOption.grid,
425 fromArrayOfVector(opOperand.get().getContext(), splitAxes));
426}
427
428FailureOr<std::vector<Sharding>> shard::detail::defaultGetShardingAnnotations(
429 Operation *op, const ShardingOption &shardingOption) {
430 std::vector<Sharding> res;
431
432 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
434 shardingOp.getLoopIteratorTypes();
435 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
436 unsigned numOperands = op->getNumOperands();
437
438 for (OpOperand &opOperand : op->getOpOperands()) {
439 FailureOr<Sharding> shardingAttr = ::getSharding(
440 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
441 if (failed(shardingAttr))
442 return failure();
443 res.push_back(*shardingAttr);
444 }
445
446 for (OpResult result : op->getResults()) {
447 res.push_back(::getSharding(result, shardingOption,
448 maps[numOperands + result.getResultNumber()],
449 loopTypes));
450 }
451
452 return res;
453}
454
455//===----------------------------------------------------------------------===//
456// detail::defaultAddShardingAnnotations
457//===----------------------------------------------------------------------===//
458
459// To add a `shard.shard` op for the given result, based on the details provided
460// in `shardingOption`, `map`, and `loopTypes`.
461static LogicalResult addShardOp(OpBuilder &b, OpResult result,
462 const ShardingOption &shardingOption,
463 AffineMap map,
465 Sharding sharding = getSharding(result, shardingOption, map, loopTypes);
467
468 return success();
469}
470
471// To add a `shard.shard` op for the given operand, based on the details
472// provided in `shardingOption`, `map`, and `loopTypes`.
473static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
474 const ShardingOption &shardingOption,
475 AffineMap map) {
476
477 FailureOr<Sharding> sharding = getSharding(opOperand, shardingOption, map);
478 if (failed(sharding)) {
479 return failure();
480 }
482 maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
483
484 return success();
485}
486
488 Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
489 assert(!shardingOption.empty && shardingOption.grid);
490
491 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
493 shardingOp.getLoopIteratorTypes();
494 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
495 unsigned numOperands = op->getNumOperands();
496
497 // 1. add shard.shard ops for all op results
498 for (OpResult result : op->getResults()) {
499 if (failed(addShardOp(b, result, shardingOption,
500 maps[numOperands + result.getResultNumber()],
501 loopTypes)))
502 return failure();
503 }
504
505 // 2. add shard.shard ops for all operands
506 for (OpOperand &opOperand : op->getOpOperands()) {
507 if (failed(addShardOp(b, opOperand, shardingOption,
508 maps[opOperand.getOperandNumber()])))
509 return failure();
510 }
511
512 return success();
513}
514
515#ifndef NDEBUG
516static bool
518 const Sharding &sharding) {
519 if (isa<RankedTensorType>(value.getType())) {
520 return isFullReplication(sharding);
521 }
522
523 return !sharding;
524}
525
526template <typename ValueRange, typename ShardingRage>
527static bool
529 ShardingRage &&shardings) {
530 if (std::size(values) != std::size(shardings)) {
531 return false;
532 }
533 return llvm::all_of(llvm::zip_equal(std::forward<ValueRange>(values),
534 std::forward<ShardingRage>(shardings)),
535 [](auto valueAndSharding) {
537 std::get<0>(valueAndSharding),
538 std::get<1>(valueAndSharding));
539 });
540}
541#endif // NDEBUG
542
544 Operation &op, ArrayRef<Value> partitionedOperands,
545 ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
546 IRMapping &partitionMap, SymbolTableCollection &symbolTable,
547 OpBuilder &builder) {
548 assert(partitionedOperands.size() == operandShardings.size());
550 operandShardings));
552 resultShardings));
553 // `clone` will populate the mapping of old to new results.
554 builder.clone(op, partitionMap);
555}
556
558 ArrayRef<GridAxis> gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
559 SmallVector<std::optional<SmallVector<GridAxis>>>
560 &gridAxesAssignmentForLoopIterators) {
561 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
562 unsigned loopIteratorIdx = affineDimExpr.getPosition();
563 if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) {
564 assert(llvm::equal(gridAxesAssignmentForTensorAxis,
565 *gridAxesAssignmentForLoopIterators[loopIteratorIdx]));
566 } else {
567 gridAxesAssignmentForLoopIterators[loopIteratorIdx] =
568 llvm::to_vector(gridAxesAssignmentForTensorAxis);
569 }
570}
571
573 ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
574 ArrayRef<utils::IteratorType> loopIteratorTypes,
575 ArrayRef<AffineMap> indexingMaps) {
577 gridAxisAssignmentForLoopIterators(loopIteratorTypes.size());
578 std::vector<Sharding> operatorAndResultShardings;
579 operatorAndResultShardings.reserve(operandShardings.size() +
580 resultShardings.size());
581 llvm::append_range(operatorAndResultShardings, operandShardings);
582 for (auto [sharding, affineMap] :
583 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
584 if (!sharding) {
585 continue;
586 }
587 for (auto [gridAxesAssignmentForTensorAxis, indexingExpr] :
588 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
590 gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
591 gridAxisAssignmentForLoopIterators);
592 }
593 // Missing trailing split axes means replication on those tensor dimensions.
594 for (unsigned i = sharding.getSplitAxes().size();
595 i < affineMap.getNumResults(); ++i) {
597 {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators);
598 }
599 }
600
601 ShardingArray res;
602 llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res),
603 [](std::optional<SmallVector<GridAxis>> &axes) {
604 if (!axes) {
605 return SmallVector<GridAxis>();
606 };
607 return std::move(*axes);
608 });
609 return res;
610}
611
613 ArrayRef<utils::IteratorType> loopIteratorTypes,
614 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
615 for (auto [loopIteratorType, gridAxisAssignment] :
616 llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
617 if (loopIteratorType == utils::IteratorType::reduction &&
618 !gridAxisAssignment.empty()) {
619 return true;
620 }
621 }
622 return false;
623}
624
626 ArrayRef<utils::IteratorType> loopIteratorTypes,
627 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
628 SmallVector<GridAxis> gridAxes;
629 for (auto [loopIteratorType, gridAxisAssignment] :
630 llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
631 if (loopIteratorType == utils::IteratorType::reduction) {
632 llvm::append_range(gridAxes, gridAxisAssignment);
633 }
634 }
635 return gridAxes;
636}
637
639 Operation &op, ArrayRef<Value> partitionedOperands,
640 ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
641 IRMapping &partitionMap, SymbolTableCollection &symbolTable,
642 OpBuilder &builder) {
643 // `clone` will populate the mapping of old to new results.
644 Operation *newOp = builder.clone(op, partitionMap);
645 // Set the result types to the sharded counterparts.
646 for (auto [oldResult, newResult, sharding] :
647 llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
648 newResult.setType(shardType(
649 newResult.getType(),
650 getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding));
651 }
652}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static void updateGridAxisAssignmentForLoopIterators(ArrayRef< GridAxis > gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< GridAxis > > > &gridAxesAssignmentForLoopIterators)
static bool isValueCompatibleWithFullReplicationSharding(Value value, const Sharding &sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes)
SmallVector< GridAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T > > &vec)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, ShardingRage &&shardings)
#define DBGS()
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:770
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:60
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:62
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< std::vector< Sharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition ShardOps.cpp:338
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T > > &array)
Definition ShardOps.h:100
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:113
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:106
int16_t GridAxis
Definition ShardOps.h:26
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition ShardOps.cpp:352
FailureOr< std::pair< bool, Sharding > > getSharding(OpResult result)
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
SmallVector< SmallVector< GridAxis > > ShardingArray
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ Constant
Constant integer.
Definition AffineExpr.h:57