MLIR 22.0.0git
ShardingInterface.h
Go to the documentation of this file.
1//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
10#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
11
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16
17namespace mlir {
18
19class Operation;
20class IRMapping;
22
23namespace shard {
24
27
29 // An array of int array. The sub-array at the i-th position signifies the
30 // grid axes the i-th loop will be sharded on.
33 // `empty` being true indicates that no sharding information can be inferred
34 // at present. Note that it is different from the case where an operation is
35 // not sharded.
36 bool empty = false;
37 ShardingOption() = default;
43 auto res = ShardingOption();
44 res.empty = true;
45 return res;
46 }
47};
48
49// This method retrieves the 'Sharding' from a given operation
50// result and includes the 'annotate_for_users' information.
51FailureOr<std::pair<bool, Sharding>> getSharding(OpResult result);
52
53// This method retrieves the 'Sharding' from a given operation
54// operand and includes the 'annotate_for_users' information.
55FailureOr<std::pair<bool, Sharding>> getSharding(OpOperand &opOperand);
56
57namespace detail {
58
59FailureOr<ShardingOption>
61 ArrayRef<Sharding> resultShardings);
62
63FailureOr<std::vector<Sharding>>
65 const ShardingOption &shardingOption);
66
67LogicalResult
69 const ShardingOption &shardingOption);
70
71} // namespace detail
72
73// Assumes full replication on all ranked tensor arguments and results.
75 ArrayRef<Value> partitionedOperands,
76 ArrayRef<Sharding> operandShardings,
77 ArrayRef<Sharding> resultShardings,
78 IRMapping &partitionMap,
79 SymbolTableCollection &symbolTable,
80 OpBuilder &builder);
81
82} // namespace shard
83} // namespace mlir
84
85/// Include the ODS generated interface header files.
86#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h.inc"
87
88#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents a collection of SymbolTables.
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< std::vector< Sharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
ArrayRef< SmallVector< GridAxis > > ShardingArrayRef
FailureOr< std::pair< bool, Sharding > > getSharding(OpResult result)
SmallVector< SmallVector< GridAxis > > ShardingArray
Include the generated interface declarations.
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr grid)
static ShardingOption makeEmpty()