MLIR  20.0.0git
ShardingInterface.h
Go to the documentation of this file.
1 //===- ShardingInterface.h --------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
11 
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 
17 namespace mlir {
18 
19 class Operation;
20 class IRMapping;
21 class SymbolTableCollection;
22 
23 namespace mesh {
24 
27 
29  // An array of int array. The sub-array at the i-th position signifies the
30  // mesh axes the i-th loop will be sharded on.
33  // `empty` being true indicates that no sharding information can be inferred
34  // at present. Note that it is different from the case where an operation is
35  // not sharded.
36  bool empty = false;
37  ShardingOption() = default;
39  : shardingArray(std::move(shardingArray)), mesh(mesh) {}
41  auto res = ShardingOption();
42  res.empty = true;
43  return res;
44  }
45 };
46 
47 // This method retrieves the 'MeshSharding' from a given operation
48 // result and includes the 'annotate_for_users' information.
49 FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpResult result);
50 
51 // This method retrieves the 'MeshSharding' from a given operation
52 // operand and includes the 'annotate_for_users' information.
53 FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpOperand &opOperand);
54 
55 namespace detail {
56 
57 FailureOr<ShardingOption>
59  ArrayRef<MeshSharding> resultShardings);
60 
61 FailureOr<std::vector<MeshSharding>>
63  const ShardingOption &shardingOption);
64 
65 LogicalResult
67  const ShardingOption &shardingOption);
68 
69 } // namespace detail
70 
71 // Assumes full replication on all ranked tensor arguments and results.
73  ArrayRef<Value> spmdizedOperands,
74  ArrayRef<MeshSharding> operandShardings,
75  ArrayRef<MeshSharding> resultShardings,
76  IRMapping &spmdizationMap,
77  SymbolTableCollection &symbolTable,
78  OpBuilder &builder);
79 
80 } // namespace mesh
81 } // namespace mlir
82 
83 /// Include the ODS generated interface header files.
84 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
85 
86 #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class helps build Operations.
Definition: Builders.h:212
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
Include the generated interface declarations.
static ShardingOption makeEmpty()
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)