MLIR  20.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AffineMap.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "tosa-sharding-impl"
19 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
20 
21 using namespace mlir;
22 using namespace mlir::tosa;
23 using namespace mlir::mesh;
24 
25 namespace {
26 
27 // loop types: [parallel, parallel, parallel, reduction_sum]
28 // indexing maps:
29 // (d0, d1, d2, d3) -> (d0, d1, d3)
30 // (d0, d1, d2, d3) -> (d0, d3, d2)
31 // (d0, d1, d2, d3) -> (d0, d1, d2)
32 struct MatMulOpSharding
33  : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
34  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
35  auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
36  if (!tensorType)
37  return {};
38 
39  SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
40  utils::IteratorType::parallel);
41  types[tensorType.getRank()] = utils::IteratorType::reduction;
42  return types;
43  }
44 
46  getReductionLoopIteratorKinds(Operation *op) const {
47  return SmallVector<ReductionKind>(1, ReductionKind::Sum);
48  }
49 
50  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
51  auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
52  if (!tensorType)
53  return {};
54  MLIRContext *ctx = op->getContext();
56  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
57  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
58  maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
59  return maps;
60  }
61 };
62 
63 template <typename OpType>
64 static void registerElemwiseOne(MLIRContext *ctx) {
65  OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
66 }
67 
68 /// Variadic helper function.
69 template <typename... OpTypes>
70 static void registerElemwiseAll(MLIRContext *ctx) {
71  (registerElemwiseOne<OpTypes>(ctx), ...);
72 }
73 
74 } // namespace
75 
77  DialectRegistry &registry) {
78 
79  registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
80  registerElemwiseAll<
81  ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
82  BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
83  LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
84  MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
85  LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
86  GreaterOp, GreaterEqualOp>(ctx);
87 
88  MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
89  });
90 }
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Definition: AffineMap.cpp:280
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Type getType() const
Return the type of this value.
Definition: Value.h:129
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.