31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
49 .Case([](arith::AddFOp op) {
return ReductionKind::Sum; })
50 .Case([](arith::MulFOp op) {
return ReductionKind::Product; })
52 .Case([](arith::MaximumFOp op) {
return ReductionKind::Max; })
53 .Case([](arith::MinimumFOp op) {
return ReductionKind::Min; })
55 .Case([](arith::AddIOp op) {
return ReductionKind::Sum; })
56 .Case([](arith::OrIOp op) {
return ReductionKind::BitwiseOr; })
57 .Case([](arith::XOrIOp op) {
return ReductionKind::BitwiseXor; })
58 .Case([](arith::AndIOp op) {
return ReductionKind::Sum; })
68 .Case([](arith::MaxUIOp op) {
return ReductionKind::Max; })
69 .Case([](arith::MinUIOp op) {
return ReductionKind::Min; })
70 .Case([](arith::MaxSIOp op) {
return ReductionKind::Max; })
71 .Case([](arith::MinSIOp op) {
return ReductionKind::Min; })
72 .Case([](arith::MulIOp op) {
return ReductionKind::Product; })
73 .Default([](
Operation *op) {
return ReductionKind::Generic; });
79 if (!reducedValue || combinerOps.size() != 1) {
83 return combinerOps[0];
89 return ReductionKind::Generic;
91 [[maybe_unused]]
Type resultElementType =
92 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
96 assert(resultElementType == reductionOp.value()->getResult(0).getType());
103 for (
const Sharding &sharding : operandShardings) {
109 for (
const Sharding &sharding : resultShardings) {
127 LinalgOp op,
int operandNumber,
Value partitionedOperand,
131 gridOp.getSymName(), reductionGridAxes, builder);
133 Value isLeadProcess = arith::CmpIOp::create(
134 builder, builder.
getI1Type(), arith::CmpIPredicate::eq,
135 processLinearIndexInReductionGroup, zero);
136 scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.
getType(),
137 isLeadProcess,
true,
true);
142 scf::YieldOp::create(builder, partitionedOperand);
153 matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
154 assert(combinerOps.size() == 1);
155 std::optional<TypedAttr> neutralEl =
158 Value init = tensor::EmptyOp::create(builder, op.
getLoc(), shape,
159 neutralEl.value().getType());
161 arith::ConstantOp::create(builder, op.
getLoc(), neutralEl.value());
162 Value fill = linalg::FillOp::create(builder, op.
getLoc(), constant, init)
165 scf::YieldOp::create(builder, fill);
167 return ifOp.getResult(0);
178 assert(op.getNumDpsInits() == 1 &&
"Multiple initial values not supported.");
180 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
181 Value partitionedInitOperand =
182 partitionMap.
lookup(op->getOperands()[operandIdx]);
184 op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder);
193 for (
auto [unshardedLinalgOpResult, resultSharding] :
194 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
195 Value partitionedLinalgOpResult =
196 partitionMap.
lookup(unshardedLinalgOpResult);
197 Value reducedValue = shard::AllReduceOp::create(
198 builder, partitionedLinalgOpResult, resultSharding.getGrid(),
199 opReductionGridAxes, reductionKind);
200 partitionMap.
map(unshardedLinalgOpResult, reducedValue);
211 GridOp grid =
getGrid(op, operandShardings, resultShardings, symbolTable);
213 loopIteratorTypes, gridAxisAssignmentForLoopIterators);
216 reductionGridAxes, partitionMap,
222 for (
auto [unshardedOperand, partitionedOperand] :
223 llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) {
224 internalPartitionMap.
map(unshardedOperand, partitionedOperand);
227 *op, partitionedLinalgOpOperands, operandShardings, resultShardings,
228 internalPartitionMap, symbolTable, builder);
229 for (
Value result : op->getResults()) {
230 partitionMap.
map(result, internalPartitionMap.
lookup(result));
235 op, reductionGridAxes, resultShardings, partitionMap, builder);
243 template <
typename Op>
244 struct StructuredOpShardingInterface
245 :
public shard::ShardingInterface::ExternalModel<
246 StructuredOpShardingInterface<Op>, Op> {
248 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
251 SmallVector<AffineMap> getIndexingMaps(Operation *op)
const {
252 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
253 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
257 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
258 res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
264 SmallVector<ReductionKind>
265 getReductionLoopIteratorKinds(Operation *op)
const {
266 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
267 SmallVector<utils::IteratorType> iteratorTypes =
268 linalgOp.getIteratorTypesArray();
269 unsigned reductionItersCount = std::accumulate(
270 iteratorTypes.begin(), iteratorTypes.end(), 0,
271 [](
unsigned count, utils::IteratorType iter) {
272 return count + (iter == utils::IteratorType::reduction);
275 return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
278 LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
279 ArrayRef<Sharding> operandShardings,
280 ArrayRef<Sharding> resultShardings,
281 IRMapping &partitionMap,
282 SymbolTableCollection &symbolTable,
283 OpBuilder &builder)
const {
284 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
286 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
287 bool allIndexingMapsAreProjectedPermutation =
288 llvm::all_of(indexingMaps, [](AffineMap map) {
289 return map.isProjectedPermutation();
291 if (!allIndexingMapsAreProjectedPermutation) {
293 return op->emitOpError()
294 <<
"supports indexing maps that are only projected permutation.";
297 SmallVector<utils::IteratorType> loopIteratorTypes =
298 linalgOp.getIteratorTypesArray();
301 loopIteratorTypes, indexingMaps);
303 loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
304 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
306 linalgOp, partitionedOperands, operandShardings, resultShardings,
307 loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap,
308 symbolTable, implicitLocBuilder);
311 operandShardings, resultShardings,
312 partitionMap, symbolTable, builder);
321 template <
typename OpType>
323 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
327 template <
typename... OpTypes>
329 (registerOne<OpTypes>(ctx), ...);
335 registry.
insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
336 tensor::TensorDialect>();
341 registerOne<linalg::GenericOp>(ctx);
344 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, int operandNumber, Value partitionedOperand, ArrayRef< GridAxis > reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder)
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< GridAxis > opReductionGridAxes, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
shard::ShardingArray ShardingArray
shard::ReductionKind ReductionKind
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, GridOp gridOp, ArrayRef< Value > partitionedOperands, ArrayRef< GridAxis > reductionGridAxes, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
static void partitionLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators, IRMapping &partitionMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
static GridOp getGrid(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, SymbolTableCollection &symbolTable)
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
SmallVector< SmallVector< GridAxis > > ShardingArray
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...