MLIR  22.0.0git
Transforms.h
Go to the documentation of this file.
1 //===- Transforms.h - Shard Transforms --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
10 #define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
11 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 
18 namespace mlir {
19 class RewritePatternSet;
20 class SymbolTableCollection;
21 class DialectRegistry;
22 class ImplicitLocOpBuilder;
23 namespace shard {
24 
26  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
27 void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry);
28 
30  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
31 void registerAllSliceOpLoweringDialects(DialectRegistry &registry);
32 
34  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
35 void registerAllOpLoweringDialects(DialectRegistry &registry);
36 
37 TypedValue<IndexType>
38 createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
39  ImplicitLocOpBuilder &builder);
40 
41 // Get process linear index along the given grid axes.
42 TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
43  ArrayRef<GridAxis> gridAxes,
44  ImplicitLocOpBuilder &builder);
45 // Get process linear index from a multi-index along the given grid axes .
46 TypedValue<IndexType>
47 createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
48  ArrayRef<GridAxis> gridAxes,
49  ImplicitLocOpBuilder &builder);
50 
51 } // namespace shard
52 } // namespace mlir
53 
54 #endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
shard::GridOp GridOp
void populateAllOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:190
void registerAllSliceOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:184
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:178
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition: Transforms.cpp:168
void registerAllOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:196
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:228
TypedValue< IndexType > createCollectiveProcessGroupSize(GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:202
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry)
Definition: Transforms.cpp:174
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns