MLIR  22.0.0git
X86VectorDialect.cpp
Go to the documentation of this file.
1 //===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the X86Vector dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 using namespace mlir;
19 
20 #include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
21 
22 #include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
23 
24 void x86vector::X86VectorDialect::initialize() {
25  addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
28  >();
29 }
30 
31 static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
32  const LLVMTypeConverter &typeConverter,
33  RewriterBase &rewriter) {
34  MemRefDescriptor memRefDescriptor(buffer);
35  return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
36 }
37 
38 LogicalResult x86vector::MaskCompressOp::verify() {
39  if (getSrc() && getConstantSrc())
40  return emitError("cannot use both src and constant_src");
41 
42  if (getSrc() && (getSrc().getType() != getDst().getType()))
43  return emitError("failed to verify that src and dst have same type");
44 
45  if (getConstantSrc() && (getConstantSrc()->getType() != getDst().getType()))
46  return emitError(
47  "failed to verify that constant_src and dst have same type");
48 
49  return success();
50 }
51 
52 SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
53  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
54  RewriterBase &rewriter) {
55  auto loc = getLoc();
56  Adaptor adaptor(operands, *this);
57 
58  auto opType = adaptor.getA().getType();
59  Value src;
60  if (adaptor.getSrc()) {
61  src = adaptor.getSrc();
62  } else if (adaptor.getConstantSrc()) {
63  src = LLVM::ConstantOp::create(rewriter, loc, opType,
64  adaptor.getConstantSrcAttr());
65  } else {
66  auto zeroAttr = rewriter.getZeroAttr(opType);
67  src = LLVM::ConstantOp::create(rewriter, loc, opType, zeroAttr);
68  }
69 
70  return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
71 }
72 
74 x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
75  const LLVMTypeConverter &typeConverter,
76  RewriterBase &rewriter) {
77  SmallVector<Value> intrinsicOperands(operands);
78  // Dot product of all elements, broadcasted to all elements.
79  Value scale =
80  LLVM::ConstantOp::create(rewriter, getLoc(), rewriter.getI8Type(), 0xff);
81  intrinsicOperands.push_back(scale);
82 
83  return intrinsicOperands;
84 }
85 
86 SmallVector<Value> x86vector::DotInt8Op::getIntrinsicOperands(
87  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
88  RewriterBase &rewriter) {
89  SmallVector<Value> intrinsicOprnds;
90  Adaptor adaptor(operands, *this);
91  intrinsicOprnds.push_back(adaptor.getW());
92  // Bitcast `a` and `b` to i32
93  Value bitcast_a = LLVM::BitcastOp::create(
94  rewriter, getLoc(),
95  VectorType::get((getA().getType().getShape()[0] / 4),
96  rewriter.getIntegerType(32)),
97  adaptor.getA());
98  intrinsicOprnds.push_back(bitcast_a);
99  Value bitcast_b = LLVM::BitcastOp::create(
100  rewriter, getLoc(),
101  VectorType::get((getB().getType().getShape()[0] / 4),
102  rewriter.getIntegerType(32)),
103  adaptor.getB());
104  intrinsicOprnds.push_back(bitcast_b);
105 
106  return intrinsicOprnds;
107 }
108 
109 SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
110  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
111  RewriterBase &rewriter) {
112  Adaptor adaptor(operands, *this);
113  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
114  typeConverter, rewriter)};
115 }
116 
117 SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
118  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
119  RewriterBase &rewriter) {
120  Adaptor adaptor(operands, *this);
121  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
122  typeConverter, rewriter)};
123 }
124 
125 SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
126  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
127  RewriterBase &rewriter) {
128  Adaptor adaptor(operands, *this);
129  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
130  typeConverter, rewriter)};
131 }
132 
133 #define GET_OP_CLASSES
134 #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
IntegerType getI8Type()
Definition: Builders.cpp:58
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423