33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/TypeSwitch.h"
54 .Case([](arith::AddFOp op) {
return ReductionKind::Sum; })
55 .Case([](arith::MulFOp op) {
return ReductionKind::Product; })
57 .Case([](arith::MaximumFOp op) {
return ReductionKind::Max; })
58 .Case([](arith::MinimumFOp op) {
return ReductionKind::Min; })
60 .Case([](arith::AddIOp op) {
return ReductionKind::Sum; })
61 .Case([](arith::OrIOp op) {
return ReductionKind::BitwiseOr; })
62 .Case([](arith::XOrIOp op) {
return ReductionKind::BitwiseXor; })
63 .Case([](arith::AndIOp op) {
return ReductionKind::Sum; })
73 .Case([](arith::MaxUIOp op) {
return ReductionKind::Max; })
74 .Case([](arith::MinUIOp op) {
return ReductionKind::Min; })
75 .Case([](arith::MaxSIOp op) {
return ReductionKind::Max; })
76 .Case([](arith::MinSIOp op) {
return ReductionKind::Min; })
77 .Case([](arith::MulIOp op) {
return ReductionKind::Product; })
78 .Default([](
Operation *op) {
return ReductionKind::Generic; });
84 if (!reducedValue || combinerOps.size() != 1) {
88 return combinerOps[0];
94 return ReductionKind::Generic;
96 [[maybe_unused]]
Type resultElementType =
97 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
101 assert(resultElementType == reductionOp.value()->getResult(0).getType());
136 meshOp.getSymName(), reductionMeshAxes, builder);
138 Value isLeadProcess = builder.
create<arith::CmpIOp>(
139 builder.
getI1Type(), arith::CmpIPredicate::eq,
140 processLinearIndexInReductionGroup, zero);
141 scf::IfOp ifOp = builder.
create<scf::IfOp>(spmdizedOperand.
getType(),
142 isLeadProcess,
true,
true);
147 builder.
create<scf::YieldOp>(spmdizedOperand);
156 PartialReductionOpInterface partialReductionIface =
157 llvm::cast<PartialReductionOpInterface>(op.getOperation());
158 assert(op->getNumResults() == 1 &&
"Multiple results not supported.");
159 FailureOr<SmallVector<Value>> reductionNeutralTensor =
160 partialReductionIface.generateInitialTensorForPartialReduction(
161 builder, builder.
getLoc(), shape, {});
162 assert(succeeded(reductionNeutralTensor));
163 builder.
create<scf::YieldOp>(reductionNeutralTensor.value());
165 return ifOp.getResult(0);
176 assert(op.getNumDpsInits() == 1 &&
"Multiple initial values not supported.");
178 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
179 Value spmdizedInitOperand =
180 spmdizationMap.
lookup(op->getOperands()[operandIdx]);
182 op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
191 llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
193 return !llvm::is_contained(resultSharding.getPartialAxes(),
196 if (allReduceMeshAxes.empty()) {
200 Value spmdizedLinalgOpResult = spmdizationMap.
lookup(unshardedLinalgOpResult);
201 Value reducedValue = builder.
create<mesh::AllReduceOp>(
202 spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
203 allReduceMeshAxes, reductionKind);
204 spmdizationMap.
map(unshardedLinalgOpResult, reducedValue);
212 for (
auto [unshardedLinalgOpResult, resultSharding] :
213 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
215 unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
216 reductionKind, spmdizationMap, builder);
228 MeshOp mesh =
getMesh(op, operandShardings, resultShardings, symbolTable);
230 loopIteratorTypes, meshAxisAssignmentForLoopIterators);
234 spmdizationMap, builder);
239 for (
auto [unshardedOperand, spmdizedOperand] :
240 llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
241 internalSpmdizationMap.
map(unshardedOperand, spmdizedOperand);
244 *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
245 internalSpmdizationMap, symbolTable, builder);
246 for (
Value result : op->getResults()) {
247 spmdizationMap.
map(result, internalSpmdizationMap.
lookup(result));
252 op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
260 template <
typename Op>
261 struct StructuredOpShardingInterface
262 :
public mesh::ShardingInterface::ExternalModel<
263 StructuredOpShardingInterface<Op>, Op> {
265 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
268 SmallVector<AffineMap> getIndexingMaps(Operation *op)
const {
269 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
270 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
274 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
275 res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
281 SmallVector<ReductionKind>
282 getReductionLoopIteratorKinds(Operation *op)
const {
283 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
284 SmallVector<utils::IteratorType> iteratorTypes =
285 linalgOp.getIteratorTypesArray();
286 unsigned reductionItersCount = std::accumulate(
287 iteratorTypes.begin(), iteratorTypes.end(), 0,
288 [](
unsigned count, utils::IteratorType iter) {
289 return count + (iter == utils::IteratorType::reduction);
292 return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
295 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
296 ArrayRef<MeshShardingAttr> operandShardings,
297 ArrayRef<MeshShardingAttr> resultShardings,
298 IRMapping &spmdizationMap,
299 SymbolTableCollection &symbolTable,
300 OpBuilder &builder)
const {
301 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
303 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
304 bool allIndexingMapsAreProjectedPermutation =
305 llvm::all_of(indexingMaps, [](AffineMap map) {
306 return map.isProjectedPermutation();
308 if (!allIndexingMapsAreProjectedPermutation) {
310 return op->emitOpError()
311 <<
"supports indexing maps that are only projected permutation.";
314 SmallVector<utils::IteratorType> loopIteratorTypes =
315 linalgOp.getIteratorTypesArray();
318 loopIteratorTypes, indexingMaps);
320 loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
321 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
323 linalgOp, spmdizedOperands, operandShardings, resultShardings,
324 loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
325 symbolTable, implicitLocBuilder);
328 operandShardings, resultShardings,
329 spmdizationMap, symbolTable, builder);
338 template <
typename OpType>
340 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
344 template <
typename... OpTypes>
346 (registerOne<OpTypes>(ctx), ...);
352 registry.
insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
353 tensor::TensorDialect>();
358 registerOne<linalg::GenericOp>(ctx);
361 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
static MeshOp getMesh(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, SymbolTableCollection &symbolTable)
mesh::ReductionKind ReductionKind
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< MeshAxis > opReductionMeshAxes, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
mesh::MeshShardingAttr MeshShardingAttr
static void spmdizeLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, Value spmdizedOperand, ArrayRef< MeshAxis > reductionMeshAxes, MeshOp meshOp, ImplicitLocOpBuilder &builder)
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, MeshOp meshOp, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshAxis > reductionMeshAxes, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
mesh::ShardingArray ShardingArray
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
static void createAllReduceForResultWithoutPartialSharding(Value unshardedLinalgOpResult, ArrayRef< MeshAxis > opReductionMeshAxes, MeshShardingAttr resultSharding, ReductionKind reductionKind, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...