MLIR  19.0.0git
ShardingInterface.h
Go to the documentation of this file.
1 //===- ShardingInterface.h --------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
11 
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 
17 namespace mlir {
18 
19 class Operation;
20 class IRMapping;
21 class SymbolTableCollection;
22 
23 namespace mesh {
24 
27 
29  // An array of int array. The sub-array at the i-th position signifies the
30  // mesh axes the i-th loop will be sharded on.
33  // `empty` being true indicates that no sharding information can be inferred
34  // at present. Note that it is different from the case where an operation is
35  // not sharded.
36  bool empty = false;
37  ShardingOption() = default;
39  : shardingArray(std::move(shardingArray)), mesh(mesh) {}
41  auto res = ShardingOption();
42  res.empty = true;
43  return res;
44  }
45 };
46 
47 // This method retrieves the 'MeshShardingAttr' attribute from a given operation
48 // result and includes the 'annotate_for_users' information.
51 
52 // This method retrieves the 'MeshShardingAttr' attribute from a given operation
53 // operand and includes the 'annotate_for_users' information.
55 getMeshShardingAttr(OpOperand &opOperand);
56 
57 namespace detail {
58 
61  ArrayRef<MeshShardingAttr> operandShardings,
62  ArrayRef<MeshShardingAttr> resultShardings);
63 
66  const ShardingOption &shardingOption);
67 
70  const ShardingOption &shardingOption);
71 
72 } // namespace detail
73 
74 // Assumes full replication on all ranked tensor arguments and results.
76  Operation &op, ArrayRef<Value> spmdizedOperands,
77  ArrayRef<MeshShardingAttr> operandShardings,
78  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
79  SymbolTableCollection &symbolTable, OpBuilder &builder);
80 
81 } // namespace mesh
82 } // namespace mlir
83 
84 /// Include the ODS generated interface header files.
85 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
86 
87 #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class helps build Operations.
Definition: Builders.h:209
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
FailureOr< SmallVector< MeshShardingAttr > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Include the generated interface declarations.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static ShardingOption makeEmpty()
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)