MLIR  22.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AffineMap.h"
16 
17 #define DEBUG_TYPE "tosa-sharding-impl"
18 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19 
20 using namespace mlir;
21 using namespace mlir::tosa;
22 using namespace mlir::shard;
23 
24 namespace {
25 
26 // loop types: [parallel, parallel, parallel, reduction_sum]
27 // indexing maps:
28 // (d0, d1, d2, d3) -> (d0, d1, d3)
29 // (d0, d1, d2, d3) -> (d0, d3, d2)
30 // (d0, d1, d2, d3) -> (d0, d1, d2)
31 struct MatMulOpSharding
32  : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
33  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
34  auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
35  if (!tensorType)
36  return {};
37 
38  SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
39  utils::IteratorType::parallel);
40  types[tensorType.getRank()] = utils::IteratorType::reduction;
41  return types;
42  }
43 
45  getReductionLoopIteratorKinds(Operation *op) const {
46  return SmallVector<ReductionKind>(1, ReductionKind::Sum);
47  }
48 
49  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
50  auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
51  if (!tensorType)
52  return {};
53  MLIRContext *ctx = op->getContext();
55  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
56  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
57  maps.push_back(AffineMap::get(0, 0, {}, ctx));
58  maps.push_back(AffineMap::get(0, 0, {}, ctx));
59  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
60  return maps;
61  }
62 };
63 
64 struct NegateOpSharding
65  : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
66  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
67  Value val = op->getOperand(0);
68  auto type = dyn_cast<RankedTensorType>(val.getType());
69  if (!type)
70  return {};
71  SmallVector<utils::IteratorType> types(type.getRank(),
72  utils::IteratorType::parallel);
73  return types;
74  }
75 
76  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
77  MLIRContext *ctx = op->getContext();
78  Value val = op->getOperand(0);
79  auto type = dyn_cast<RankedTensorType>(val.getType());
80  if (!type)
81  return {};
82  int64_t rank = type.getRank();
83  SmallVector<AffineMap> maps = {
85  AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
87  return maps;
88  }
89 
90  LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
91  ArrayRef<Sharding> operandShardings,
92  ArrayRef<Sharding> resultShardings,
93  IRMapping &partitionMap,
94  SymbolTableCollection &symbolTable,
95  OpBuilder &builder) const {
96  partitionTriviallyShardableOperation(*op, partitiondOperands,
97  operandShardings, resultShardings,
98  partitionMap, symbolTable, builder);
99  return success();
100  }
101 };
102 
103 template <typename OpType>
104 static void registerElemwiseOne(MLIRContext *ctx) {
105  OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
106 }
107 
108 /// Variadic helper function.
109 template <typename... OpTypes>
110 static void registerElemwiseAll(MLIRContext *ctx) {
111  (registerElemwiseOne<OpTypes>(ctx), ...);
112 }
113 
114 } // namespace
115 
117  DialectRegistry &registry) {
118 
119  registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
120  registerElemwiseAll<
121  ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
122  BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
123  LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
124  MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
125  LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
126  GreaterOp, GreaterEqualOp>(ctx);
127 
128  MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
129  NegateOp::attachInterface<NegateOpSharding>(*ctx);
130  });
131 }
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Definition: AffineMap.cpp:276
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.