MLIR  21.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AffineMap.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "tosa-sharding-impl"
19 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
20 
21 using namespace mlir;
22 using namespace mlir::tosa;
23 using namespace mlir::mesh;
24 
25 namespace {
26 
27 // loop types: [parallel, parallel, parallel, reduction_sum]
28 // indexing maps:
29 // (d0, d1, d2, d3) -> (d0, d1, d3)
30 // (d0, d1, d2, d3) -> (d0, d3, d2)
31 // (d0, d1, d2, d3) -> (d0, d1, d2)
32 struct MatMulOpSharding
33  : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
34  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
35  auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
36  if (!tensorType)
37  return {};
38 
39  SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
40  utils::IteratorType::parallel);
41  types[tensorType.getRank()] = utils::IteratorType::reduction;
42  return types;
43  }
44 
46  getReductionLoopIteratorKinds(Operation *op) const {
47  return SmallVector<ReductionKind>(1, ReductionKind::Sum);
48  }
49 
50  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
51  auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
52  if (!tensorType)
53  return {};
54  MLIRContext *ctx = op->getContext();
56  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
57  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
58  maps.push_back(AffineMap::get(0, 0, {}, ctx));
59  maps.push_back(AffineMap::get(0, 0, {}, ctx));
60  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
61  return maps;
62  }
63 };
64 
65 struct NegateOpSharding
66  : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
67  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
68  Value val = op->getOperand(0);
69  auto type = dyn_cast<RankedTensorType>(val.getType());
70  if (!type)
71  return {};
72  SmallVector<utils::IteratorType> types(type.getRank(),
73  utils::IteratorType::parallel);
74  return types;
75  }
76 
77  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
78  MLIRContext *ctx = op->getContext();
79  Value val = op->getOperand(0);
80  auto type = dyn_cast<RankedTensorType>(val.getType());
81  if (!type)
82  return {};
83  int64_t rank = type.getRank();
84  SmallVector<AffineMap> maps = {
86  AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
88  return maps;
89  }
90 
91  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
92  ArrayRef<MeshSharding> operandShardings,
93  ArrayRef<MeshSharding> resultShardings,
94  IRMapping &spmdizationMap,
95  SymbolTableCollection &symbolTable,
96  OpBuilder &builder) const {
97  spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
98  resultShardings, spmdizationMap,
99  symbolTable, builder);
100  return success();
101  }
102 };
103 
104 template <typename OpType>
105 static void registerElemwiseOne(MLIRContext *ctx) {
106  OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
107 }
108 
109 /// Variadic helper function.
110 template <typename... OpTypes>
111 static void registerElemwiseAll(MLIRContext *ctx) {
112  (registerElemwiseOne<OpTypes>(ctx), ...);
113 }
114 
115 } // namespace
116 
118  DialectRegistry &registry) {
119 
120  registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
121  registerElemwiseAll<
122  ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
123  BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
124  LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
125  MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
126  LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
127  GreaterOp, GreaterEqualOp>(ctx);
128 
129  MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
130  NegateOp::attachInterface<NegateOpSharding>(*ctx);
131  });
132 }
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Definition: AffineMap.cpp:280
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:204
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.