MLIR  22.0.0git
ShardingInterface.h
Go to the documentation of this file.
1 //===- ShardingInterface.h --------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
10 #define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
11 
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 
17 namespace mlir {
18 
19 class Operation;
20 class IRMapping;
21 class SymbolTableCollection;
22 
23 namespace shard {
24 
27 
29  // An array of int array. The sub-array at the i-th position signifies the
30  // grid axes the i-th loop will be sharded on.
33  // `empty` being true indicates that no sharding information can be inferred
34  // at present. Note that it is different from the case where an operation is
35  // not sharded.
36  bool empty = false;
37  ShardingOption() = default;
39  : shardingArray(std::move(shardingArray)), grid(grid) {
40  assert(this->grid);
41  }
43  auto res = ShardingOption();
44  res.empty = true;
45  return res;
46  }
47 };
48 
49 // This method retrieves the 'Sharding' from a given operation
50 // result and includes the 'annotate_for_users' information.
51 FailureOr<std::pair<bool, Sharding>> getSharding(OpResult result);
52 
53 // This method retrieves the 'Sharding' from a given operation
54 // operand and includes the 'annotate_for_users' information.
55 FailureOr<std::pair<bool, Sharding>> getSharding(OpOperand &opOperand);
56 
57 namespace detail {
58 
59 FailureOr<ShardingOption>
61  ArrayRef<Sharding> resultShardings);
62 
63 FailureOr<std::vector<Sharding>>
65  const ShardingOption &shardingOption);
66 
67 LogicalResult
69  const ShardingOption &shardingOption);
70 
71 } // namespace detail
72 
73 // Assumes full replication on all ranked tensor arguments and results.
75  ArrayRef<Value> partitionedOperands,
76  ArrayRef<Sharding> operandShardings,
77  ArrayRef<Sharding> resultShardings,
78  IRMapping &partitionMap,
79  SymbolTableCollection &symbolTable,
80  OpBuilder &builder);
81 
82 } // namespace shard
83 } // namespace mlir
84 
85 /// Include the ODS generated interface header files.
86 #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h.inc"
87 
88 #endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< std::vector< Sharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
FailureOr< std::pair< bool, Sharding > > getSharding(OpResult result)
Include the generated interface declarations.
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr grid)
static ShardingOption makeEmpty()