13 #include <type_traits> 18 #include "../PassDetail.h" 32 #include "llvm/ADT/TypeSwitch.h" 42 template <
typename TransferOpType>
46 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
48 unsigned offsetsIdx = 0;
49 for (
auto expr : xferOp.getPermutationMap().getResults()) {
50 if (
auto dim = expr.template dyn_cast<AffineDimExpr>()) {
51 Value prevIdx = indices[dim.getPosition()];
53 dims.push_back(prevIdx);
56 b, loc, d0 + offsetMap.
getResult(offsetsIdx++), dims);
65 if (llvm::size(contract.getMasks()) != 0)
71 bindDims(contract.getContext(), m, n, k);
72 auto iteratorTypes = contract.getIteratorTypes().getValue();
81 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
84 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
94 auto memrefType = type.dyn_cast<MemRefType>();
98 if (memrefType.getRank() < 2)
105 int64_t stride = strides[strides.size() - 2];
106 if (stride == ShapedType::kDynamicStrideOrOffset)
114 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
115 readOp.getVectorType().getRank() != 2)
139 if (writeOp.getTransferRank() == 0)
142 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
143 writeOp.getVectorType().getRank() != 2)
148 if (!writeOp.getPermutationMap().isMinorIdentity())
156 auto vecType = constantOp.getType().dyn_cast<VectorType>();
157 if (!vecType || vecType.getRank() != 2)
164 return broadcastOp.getVectorType().getRank() == 2 &&
165 broadcastOp.getSource().getType().isa<
FloatType>();
172 if (isa<arith::AddFOp>(op))
173 return gpu::MMAElementwiseOp::ADDF;
174 if (isa<arith::MulFOp>(op))
175 return gpu::MMAElementwiseOp::MULF;
176 if (isa<arith::MaxFOp>(op))
177 return gpu::MMAElementwiseOp::MAXF;
178 if (isa<arith::MinFOp>(op))
179 return gpu::MMAElementwiseOp::MINF;
180 if (isa<arith::DivFOp>(op))
181 return gpu::MMAElementwiseOp::DIVF;
191 if (isa<scf::ForOp, scf::YieldOp>(op))
193 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op))
195 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
197 if (
auto contract = dyn_cast<vector::ContractionOp>(op))
199 if (
auto constant = dyn_cast<arith::ConstantOp>(op))
201 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
214 unsigned currentIndex = 0;
217 while (currentIndex != slice.size()) {
218 auto *currentOp = (slice)[currentIndex];
220 backwardSlice.clear();
222 slice.insert(backwardSlice.begin(), backwardSlice.end());
225 forwardSlice.clear();
230 if (
auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
231 for (
Value forOpResult : forOp.getResults())
238 slice.insert(forwardSlice.begin(), forwardSlice.end());
250 [](
Type t) {
return t.isa<VectorType>(); });
254 [](
Type t) {
return t.isa<VectorType>(); });
258 if (opToConvert.contains(contract.getOperation()))
265 if (llvm::any_of(dependentOps, [useNvGpu](
Operation *op) {
269 opToConvert.insert(dependentOps.begin(), dependentOps.end());
278 struct PrepareContractToGPUMMA
285 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
292 static constexpr std::array<int64_t, 2>
perm = {1, 0};
293 auto iteratorTypes = op.getIteratorTypes().getValue();
302 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
306 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
307 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs,
perm);
308 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
309 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
310 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
311 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs,
perm);
312 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
313 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
315 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs,
perm);
316 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
317 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
319 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs,
perm);
320 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
322 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
323 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
331 op.getIteratorTypes());
338 struct CombineTransferReadOpTranspose final
344 auto transferReadOp =
345 op.getVector().getDefiningOp<vector::TransferReadOp>();
350 if (transferReadOp.getTransferRank() == 0)
353 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
358 for (int64_t o : perm)
359 permU.push_back(
unsigned(o));
363 permutationMap.
compose(transferReadOp.getPermutationMap());
365 op, op.getType(), transferReadOp.getSource(),
366 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
367 transferReadOp.getPadding(), transferReadOp.getMask(),
368 transferReadOp.getInBoundsAttr());
379 template <
typename OpTy>
382 auto contract = dyn_cast<vector::ContractionOp>(users);
385 if (
contract.getLhs() == op.getResult())
387 if (
contract.getRhs() == op.getResult())
395 assert(op.getTransferRank() > 0 &&
"unexpected 0-d transfer");
409 op.getVectorType().getElementType(), fragType);
411 Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
412 op.getLoc(), type, op.getSource(), op.getIndices(),
413 b.getIndexAttr(*stride));
414 valueMapping[op.getResult()] = load;
424 Value matrix = valueMapping.find(op.getVector())->second;
425 b.
create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(),
437 if (
auto vecType = elType.
dyn_cast<VectorType>())
438 elType = vecType.getElementType();
439 return VectorType::get(shape, elType);
449 if (
failed(warpMatrixInfo))
464 valueMapping[op.getResult()] = result;
475 if (
failed(warpMatrixInfo))
485 !op.getPermutationMap().isMinorIdentity());
487 return op->emitError()
488 <<
"failed to convert vector.transfer_read to ldmatrix; this op " 490 "should not be converted to a nvgpu.ldmatrix call.";
494 auto laneId = builder.
create<gpu::LaneIdOp>(loc);
503 getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
505 nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
507 !op.getPermutationMap().isMinorIdentity(), params->numTiles);
508 valueMapping[op] = newOp->getResult(0);
518 if (
failed(warpMatrixInfo))
523 op->emitError() <<
"Failed to deduce register fragment type during " 524 "conversion to distributed non-ldmatrix compatible load";
532 Type loadedElType = regInfo->registerLLVMType;
536 op.getLoc(), vectorType.getElementType(),
540 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
544 if (!isTransposeLoad) {
545 if (!loadedElType.
isa<VectorType>()) {
546 loadedElType = VectorType::get({1}, loadedElType);
549 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
551 op.getLoc(), builder, *warpMatrixInfo);
554 Value logicalValueId = builder.
create<arith::ConstantOp>(
556 builder.
getIndexAttr(i * regInfo->elementsPerRegister));
558 getXferIndices<vector::TransferReadOp>(
559 builder, op, *coords, {laneId, logicalValueId}, newIndices);
561 Value el = builder.create<vector::LoadOp>(loc, loadedElType,
562 op.getSource(), newIndices);
563 result = builder.create<vector::InsertOp>(loc, el, result,
564 builder.getI64ArrayAttr(i));
567 if (
auto vecType = loadedElType.
dyn_cast<VectorType>()) {
568 loadedElType = vecType.getElementType();
570 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
571 for (
unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
574 Value logicalValueId = builder.
create<arith::ConstantOp>(
576 builder.
getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
578 op.getLoc(), builder, *warpMatrixInfo);
583 getXferIndices<vector::TransferReadOp>(
584 builder, op, *coords, {laneId, logicalValueId}, newIndices);
585 Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
586 op.getSource(), newIndices);
587 result = builder.create<vector::InsertOp>(
588 op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
593 valueMapping[op.getResult()] = result;
607 if (
failed(warpMatrixInfo))
610 bool isLdMatrixCompatible =
614 VectorType vecTy = op.getVectorType();
615 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
620 if (!op.getPermutationMap().isMinorIdentity() &&
621 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
622 vecTy.getDimSize(0) * bitWidth < 128))
623 isLdMatrixCompatible =
false;
625 if (!isLdMatrixCompatible)
636 Value matrix = valueMapping.find(op.getVector())->second;
640 if (
failed(warpMatrixInfo))
650 for (
unsigned i = 0; i < vectorType.getShape()[0]; i++) {
651 Value logicalValueId = b.
create<arith::ConstantOp>(
655 op.getLoc(), b, *warpMatrixInfo);
661 getXferIndices<vector::TransferWriteOp>(
662 b, op, *coords, {laneId, logicalValueId}, newIndices);
663 b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
672 Value opA = valueMapping.find(op.getLhs())->second;
673 Value opB = valueMapping.find(op.getRhs())->second;
674 Value opC = valueMapping.find(op.getAcc())->second;
677 valueMapping[op.getResult()] = matmul;
684 Value opA = valueMapping.find(op.getLhs())->second;
685 Value opB = valueMapping.find(op.getRhs())->second;
686 Value opC = valueMapping.find(op.getAcc())->second;
687 int64_t m = op.getLhs().getType().cast<VectorType>().
getShape()[0];
688 int64_t n = op.getRhs().getType().cast<VectorType>().
getShape()[0];
689 int64_t k = op.getLhs().getType().cast<VectorType>().
getShape()[1];
690 Value matmul = b.
create<nvgpu::MmaSyncOp>(op.getLoc(), opA, opB, opC,
692 valueMapping[op.getResult()] = matmul;
703 auto scalarConstant =
704 b.
create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
706 auto vecType = op.getType().cast<VectorType>();
708 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
709 auto matrix = b.
create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
711 valueMapping[op.getResult()] = matrix;
720 auto vecType = op.getVectorType();
722 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
723 auto matrix = b.
create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
725 valueMapping[op.getResult()] = matrix;
735 auto operands = llvm::to_vector<4>(loop.getIterOperands());
736 operands.append(newIterOperands.begin(), newIterOperands.end());
738 b.
create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(),
739 loop.getUpperBound(), loop.getStep(), operands);
740 newLoop.getBody()->
erase();
741 newLoop.getLoopBody().getBlocks().splice(
742 newLoop.getLoopBody().getBlocks().begin(),
743 loop.getLoopBody().getBlocks());
744 for (
Value operand : newIterOperands)
745 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
747 for (
auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
748 loop.getNumResults())))
749 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
759 auto it = valueMapping.find(operand.value());
760 if (it == valueMapping.end())
762 argMapping.push_back(std::make_pair(
763 operand.index(), op.getNumIterOperands() + newOperands.size()));
764 newOperands.push_back(it->second);
768 Block &loopBody = *newForOp.getBody();
769 for (
auto mapping : argMapping) {
770 valueMapping[newForOp.getResult(mapping.first)] =
771 newForOp.getResult(mapping.second);
773 newForOp.getNumInductionVars())] =
774 loopBody.
getArgument(mapping.second + newForOp.getNumInductionVars());
781 auto loop = cast<scf::ForOp>(op->getParentOp());
782 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
784 auto it = valueMapping.find(operand.value());
785 if (it == valueMapping.end())
789 yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
790 yieldOperands.push_back(it->second);
792 b.
create<scf::YieldOp>(op.getLoc(), yieldOperands);
802 matrixOperands.push_back(valueMapping.find(operand)->second);
803 Value newOp = b.
create<gpu::SubgroupMmaElementwiseOp>(
804 op->
getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
811 patterns.
add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
824 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
826 }
else if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
828 }
else if (
auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
830 }
else if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
832 }
else if (
auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
834 }
else if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
836 }
else if (
auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
849 .Case([&](vector::TransferReadOp transferReadOp) {
852 .Case([&](vector::TransferWriteOp transferWriteOp) {
856 .Case([&](vector::ContractionOp contractionOp) {
859 .Case([&](scf::ForOp forOp) {
863 .Case([&](scf::YieldOp yieldOp) {
867 .Case([&](arith::ConstantOp constOp) {
871 op->
emitError() <<
"unhandled vector to mma type: " << *op;
875 op->
emitError() <<
"Failed to convert op " << *op;
884 struct ConvertVectorToGPUPass
885 :
public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
887 explicit ConvertVectorToGPUPass(
bool useNvGpu_) {
888 useNvGpu.setValue(useNvGpu_);
891 void runOnOperation()
override {
896 return signalPassFailure();
898 if (useNvGpu.getValue()) {
900 return signalPassFailure();
910 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
Include the generated interface declarations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Operation is a basic unit of execution within MLIR.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
bool isParallelIterator(Attribute attr)
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations...
unsigned getNumDims() const
Attribute getZeroAttr(Type type)
operand_range getOperands()
Returns an iterator on the underlying Value's.
static void convertTransferWriteOp(vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
Block represents an ordered list of Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void getBackwardSlice(Operation *op, SetVector< Operation *> *backwardSlice, TransitiveFilter filter=nullptr)
Fills backwardSlice with the computed backward slice (i.e.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, ValueRange newIterOperands)
operand_type_range getOperandTypes()
static llvm::Optional< int64_t > getMemrefConstantHorizontalStride(ShapedType type)
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
static LogicalResult convertConstantOpMmaSync(arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
static SetVector< Operation * > getSliceContract(Operation *op, TransitiveFilter backwardFilter, TransitiveFilter forwardFilter)
Return an unsorted slice handling scf.for region differently than getSlice.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
static void convertBroadcastOp(vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
static LogicalResult convertContractOpToMmaSync(vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
BlockArgument getArgument(unsigned i)
An integer constant appearing in affine expression.
static LogicalResult convertTransferReadToLoads(vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu.ldmatrix operation.
void erase()
Remove this operation from its parent block and delete it.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
AffineExpr getResult(unsigned idx) const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::enable_if< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT >::type walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one)...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static LogicalResult creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, llvm::DenseMap< Value, Value > &valueMapping)
This class provides support for representing a failure result, or a valid value of type T...
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
int64_t elementsPerRegister
static const char * inferFragType(OpTy op)
SetVector< Operation * > topologicalSort(const SetVector< Operation *> &toSort)
Multi-root DAG topological sort.
static void convertForOp(scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void getForwardSlice(Operation *op, SetVector< Operation *> *forwardSlice, TransitiveFilter filter=nullptr)
Fills forwardSlice with the computed forward slice (i.e.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Base type for affine expression.
MLIRContext * getContext() const
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
This class represents an argument of a Block.
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
Given an op that operates on a VectorType representing a warp-level matrix operand, the function returns a struct containing relevant type information.
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isReductionIterator(Attribute attr)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, bool useNvGpu)
static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
RAII guard to reset the insertion point of the builder when destroyed.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
Type getType() const
Return the type of this value.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static LogicalResult createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, llvm::DenseMap< Value, Value > &valueMapping)
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
int64_t numRegistersPerFragment
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void convertContractOp(vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
AffineExpr getAffineDimExpr(unsigned position)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static llvm::Optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void convertVectorToMMAOps(Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
FailureOr< nvgpu::LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
static LogicalResult convertTransferWriteToStores(vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
result_type_range getResultTypes()
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm...
MLIRContext * getContext() const
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
static void getXferIndices(OpBuilder &b, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static void convertYieldOp(scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.