MLIR 22.0.0git
DistributionUtils.h
Go to the documentation of this file.
1//===- DistributionUtils.h - Distribution Utilities -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
10#define MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
11
16#include "mlir/IR/Value.h"
17
18namespace mlir::gpu {
19struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
22
23 virtual LogicalResult
24 matchAndRewrite(WarpExecuteOnLane0Op op,
25 PatternRewriter &rewriter) const override = 0;
26
27protected:
28 /// Return a value yielded by `warpOp` which statifies the filter lamdba
29 /// condition and is not dead.
30 OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
31 llvm::function_ref<bool(Operation *)> fn) const;
32
33 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
34 WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
35 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
36 ValueRange newYieldedValues, TypeRange newReturnTypes) const;
37
38 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
39 /// `indices` return the index of each new output.
40 WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
41 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
42 ValueRange newYieldedValues, TypeRange newReturnTypes,
44
45 /// Delinearize the given `laneId` into multiple dimensions, where each
46 /// dimension's size is determined by `originalShape` and `distributedShape`
47 /// together. This function expects the total numbers of threads needed for
48 /// distribution is equal to `warpSize`. Returns true and updates
49 /// `delinearizedIds` if so.
50 bool delinearizeLaneId(OpBuilder &builder, Location loc,
51 ArrayRef<int64_t> originalShape,
52 ArrayRef<int64_t> distributedShape, int64_t warpSize,
53 Value laneId,
54 SmallVectorImpl<Value> &delinearizedIds) const;
55};
56
57} // namespace mlir::gpu
58
59#endif // MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.