MLIR 22.0.0git
VectorPattern.cpp
Go to the documentation of this file.
1//===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
12using namespace mlir;
13
14// For >1-D vector types, extracts the necessary information to iterate over all
15// 1-D subvectors in the underlying llrepresentation of the n-D vector
16// Iterates on the llvm array type until we hit a non-array type (which is
17// asserted to be an llvm vector type).
20 const LLVMTypeConverter &converter) {
21 assert(vectorType.getRank() > 1 && "expected >1D vector type");
23 info.llvmNDVectorTy = converter.convertType(vectorType);
25 info.llvmNDVectorTy = nullptr;
26 return info;
27 }
28 info.arraySizes.reserve(vectorType.getRank() - 1);
29 auto llvmTy = info.llvmNDVectorTy;
30 while (isa<LLVM::LLVMArrayType>(llvmTy)) {
31 info.arraySizes.push_back(
32 cast<LLVM::LLVMArrayType>(llvmTy).getNumElements());
33 llvmTy = cast<LLVM::LLVMArrayType>(llvmTy).getElementType();
34 }
36 return info;
37 info.llvm1DVectorTy = llvmTy;
38 return info;
39}
40
41// Express `linearIndex` in terms of coordinates of `basis`.
42// Returns the empty vector when linearIndex is out of the range [0, P] where
43// P is the product of all the basis coordinates.
44//
45// Prerequisites:
46// Basis is an array of nonnegative integers (signed type inherited from
47// vector shape type).
49 unsigned linearIndex) {
51 res.reserve(basis.size());
52 for (unsigned basisElement : llvm::reverse(basis)) {
53 res.push_back(linearIndex % basisElement);
54 linearIndex = linearIndex / basisElement;
55 }
56 if (linearIndex > 0)
57 return {};
58 std::reverse(res.begin(), res.end());
59 return res;
60}
61
62// Iterate of linear index, convert to coords space and insert splatted 1-D
63// vector in each position.
65 OpBuilder &builder,
67 unsigned ub = 1;
68 for (auto s : info.arraySizes)
69 ub *= s;
70 for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71 auto coords = getCoordinates(info.arraySizes, linearIndex);
72 // Linear index is out of bounds, we are done.
73 if (coords.empty())
74 break;
75 assert(coords.size() == info.arraySizes.size());
76 fun(coords);
77 }
78}
79
81 Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
82 std::function<Value(Type, ValueRange)> createOperand,
83 ConversionPatternRewriter &rewriter) {
84 auto resultNDVectorType = cast<VectorType>(op->getResult(0).getType());
85 auto resultTypeInfo =
86 extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
87 auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
88 auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
89 auto loc = op->getLoc();
90 Value desc = LLVM::PoisonOp::create(rewriter, loc, resultNDVectoryTy);
91 nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
92 // For this unrolled `position` corresponding to the `linearIndex`^th
93 // element, extract operand vectors
94 SmallVector<Value, 4> extractedOperands;
95 for (const auto &operand : llvm::enumerate(operands)) {
96 extractedOperands.push_back(LLVM::ExtractValueOp::create(
97 rewriter, loc, operand.value(), position));
98 }
99 Value newVal = createOperand(result1DVectorTy, extractedOperands);
100 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, newVal, position);
101 });
102 rewriter.replaceOp(op, desc);
103 return success();
104}
105
107 Operation *op, StringRef targetOp, ValueRange operands,
108 ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
109 const LLVMTypeConverter &typeConverter,
110 ConversionPatternRewriter &rewriter) {
111 assert(!operands.empty());
112
113 // Cannot convert ops if their operands are not of LLVM type.
114 if (!llvm::all_of(operands.getTypes(), isCompatibleType))
115 return failure();
116
117 auto llvmNDVectorTy = operands[0].getType();
118 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
119 return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr,
120 typeConverter, rewriter);
121 auto callback = [op, targetOp, targetAttrs, propertiesAttr,
122 &rewriter](Type llvm1DVectorTy, ValueRange operands) {
123 OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp),
124 operands, llvm1DVectorTy, targetAttrs);
125 state.propertiesAttr = propertiesAttr;
126 Operation *newOp = rewriter.create(state);
127 return newOp->getResult(0);
128 };
129
130 return handleMultidimensionalVectors(op, operands, typeConverter, callback,
131 rewriter);
132}
return success()
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Conversion from types to the LLVM IR dialect.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:301
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< int64_t, 4 > arraySizes
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.