MLIR 22.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineMap.h"
16
17#define DEBUG_TYPE "tosa-sharding-impl"
18#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19
20using namespace mlir;
21using namespace mlir::tosa;
22using namespace mlir::shard;
23
24namespace {
25
26// loop types: [parallel, parallel, parallel, reduction_sum]
27// indexing maps:
28// (d0, d1, d2, d3) -> (d0, d1, d3)
29// (d0, d1, d2, d3) -> (d0, d3, d2)
30// (d0, d1, d2, d3) -> (d0, d1, d2)
31struct MatMulOpSharding
32 : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
33 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
34 auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
35 if (!tensorType)
36 return {};
37
38 SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
39 utils::IteratorType::parallel);
40 types[tensorType.getRank()] = utils::IteratorType::reduction;
41 return types;
42 }
43
45 getReductionLoopIteratorKinds(Operation *op) const {
46 return SmallVector<ReductionKind>(1, ReductionKind::Sum);
47 }
48
49 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
50 auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
51 if (!tensorType)
52 return {};
53 MLIRContext *ctx = op->getContext();
55 maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
56 maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
57 maps.push_back(AffineMap::get(0, 0, {}, ctx));
58 maps.push_back(AffineMap::get(0, 0, {}, ctx));
59 maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
60 return maps;
61 }
62};
63
64struct NegateOpSharding
65 : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
66 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
67 Value val = op->getOperand(0);
68 auto type = dyn_cast<RankedTensorType>(val.getType());
69 if (!type)
70 return {};
71 SmallVector<utils::IteratorType> types(type.getRank(),
72 utils::IteratorType::parallel);
73 return types;
74 }
75
76 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
77 MLIRContext *ctx = op->getContext();
78 Value val = op->getOperand(0);
79 auto type = dyn_cast<RankedTensorType>(val.getType());
80 if (!type)
81 return {};
82 int64_t rank = type.getRank();
85 AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
87 return maps;
88 }
89
90 LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
91 ArrayRef<Sharding> operandShardings,
92 ArrayRef<Sharding> resultShardings,
93 IRMapping &partitionMap,
94 SymbolTableCollection &symbolTable,
95 OpBuilder &builder) const {
96 partitionTriviallyShardableOperation(*op, partitiondOperands,
97 operandShardings, resultShardings,
98 partitionMap, symbolTable, builder);
99 return success();
100 }
101};
102
103template <typename OpType>
104static void registerElemwiseOne(MLIRContext *ctx) {
105 OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
106}
107
108/// Variadic helper function.
109template <typename... OpTypes>
110static void registerElemwiseAll(MLIRContext *ctx) {
111 (registerElemwiseOne<OpTypes>(ctx), ...);
112}
113
114} // namespace
115
117 DialectRegistry &registry) {
118
119 registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
120 registerElemwiseAll<
121 ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
122 BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
123 LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
124 MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
125 LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
126 GreaterOp, GreaterEqualOp>(ctx);
127
128 MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
129 NegateOp::attachInterface<NegateOpSharding>(*ctx);
130 });
131}
return success()
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents a collection of SymbolTables.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.