MLIR  16.0.0git
VectorPattern.cpp
Go to the documentation of this file.
1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 using namespace mlir;
13 
14 // For >1-D vector types, extracts the necessary information to iterate over all
15 // 1-D subvectors in the underlying llrepresentation of the n-D vector
16 // Iterates on the llvm array type until we hit a non-array type (which is
17 // asserted to be an llvm vector type).
20  LLVMTypeConverter &converter) {
21  assert(vectorType.getRank() > 1 && "expected >1D vector type");
22  NDVectorTypeInfo info;
23  info.llvmNDVectorTy = converter.convertType(vectorType);
24  if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
25  info.llvmNDVectorTy = nullptr;
26  return info;
27  }
28  info.arraySizes.reserve(vectorType.getRank() - 1);
29  auto llvmTy = info.llvmNDVectorTy;
30  while (llvmTy.isa<LLVM::LLVMArrayType>()) {
31  info.arraySizes.push_back(
32  llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
33  llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
34  }
35  if (!LLVM::isCompatibleVectorType(llvmTy))
36  return info;
37  info.llvm1DVectorTy = llvmTy;
38  return info;
39 }
40 
41 // Express `linearIndex` in terms of coordinates of `basis`.
42 // Returns the empty vector when linearIndex is out of the range [0, P] where
43 // P is the product of all the basis coordinates.
44 //
45 // Prerequisites:
46 // Basis is an array of nonnegative integers (signed type inherited from
47 // vector shape type).
49  unsigned linearIndex) {
51  res.reserve(basis.size());
52  for (unsigned basisElement : llvm::reverse(basis)) {
53  res.push_back(linearIndex % basisElement);
54  linearIndex = linearIndex / basisElement;
55  }
56  if (linearIndex > 0)
57  return {};
58  std::reverse(res.begin(), res.end());
59  return res;
60 }
61 
62 // Iterate of linear index, convert to coords space and insert splatted 1-D
63 // vector in each position.
65  OpBuilder &builder,
66  function_ref<void(ArrayRef<int64_t>)> fun) {
67  unsigned ub = 1;
68  for (auto s : info.arraySizes)
69  ub *= s;
70  for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71  auto coords = getCoordinates(info.arraySizes, linearIndex);
72  // Linear index is out of bounds, we are done.
73  if (coords.empty())
74  break;
75  assert(coords.size() == info.arraySizes.size());
76  fun(coords);
77  }
78 }
79 
81  Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
82  std::function<Value(Type, ValueRange)> createOperand,
83  ConversionPatternRewriter &rewriter) {
84  auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
85 
86  SmallVector<Type> operand1DVectorTypes;
87  for (Value operand : op->getOperands()) {
88  auto operandNDVectorType = operand.getType().cast<VectorType>();
89  auto operandTypeInfo =
90  extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
91  operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
92  }
93  auto resultTypeInfo =
94  extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
95  auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
96  auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
97  auto loc = op->getLoc();
98  Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
99  nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
100  // For this unrolled `position` corresponding to the `linearIndex`^th
101  // element, extract operand vectors
102  SmallVector<Value, 4> extractedOperands;
103  for (const auto &operand : llvm::enumerate(operands)) {
104  extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
105  loc, operand.value(), position));
106  }
107  Value newVal = createOperand(result1DVectorTy, extractedOperands);
108  desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
109  });
110  rewriter.replaceOp(op, desc);
111  return success();
112 }
113 
115  Operation *op, StringRef targetOp, ValueRange operands,
116  LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
117  assert(!operands.empty());
118 
119  // Cannot convert ops if their operands are not of LLVM type.
120  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
121  return failure();
122 
123  auto llvmNDVectorTy = operands[0].getType();
124  if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
125  return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
126 
127  auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
128  ValueRange operands) {
129  return rewriter
130  .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
131  llvm1DVectorTy, op->getAttrs())
132  ->getResult(0);
133  };
134 
135  return handleMultidimensionalVectors(op, operands, typeConverter, callback,
136  rewriter);
137 }
Include the generated interface declarations.
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter)
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
SmallVector< int64_t, 4 > arraySizes
Definition: VectorPattern.h:27
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:854
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
type_range getTypes() const
Definition: ValueRange.cpp:44
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LLVM dialect array type.
Definition: LLVMTypes.h:75
unsigned getNumElements() const
Returns the number of elements in the array type.
Definition: LLVMTypes.cpp:54
Type getType() const
Return the type of this value.
Definition: Value.h:118
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect. ...
Definition: LLVMTypes.cpp:872
type_range getType() const
Definition: ValueRange.cpp:46
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:309
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:225
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
U cast() const
Definition: Types.h:278