MLIR  22.0.0git
X86VectorDialect.cpp
Go to the documentation of this file.
1 //===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the X86Vector dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 using namespace mlir;
19 
20 #include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
21 
22 #include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
23 
24 void x86vector::X86VectorDialect::initialize() {
25  addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
28  >();
29 }
30 
31 static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
32  const LLVMTypeConverter &typeConverter,
33  RewriterBase &rewriter) {
34  MemRefDescriptor memRefDescriptor(buffer);
35  return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
36 }
37 
38 LogicalResult x86vector::MaskCompressOp::verify() {
39  if (getSrc() && getConstantSrc())
40  return emitError("cannot use both src and constant_src");
41 
42  if (getSrc() && (getSrc().getType() != getDst().getType()))
43  return emitError("failed to verify that src and dst have same type");
44 
45  if (getConstantSrc() && (getConstantSrc()->getType() != getDst().getType()))
46  return emitError(
47  "failed to verify that constant_src and dst have same type");
48 
49  return success();
50 }
51 
52 SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
53  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
54  RewriterBase &rewriter) {
55  auto loc = getLoc();
56  Adaptor adaptor(operands, *this);
57 
58  auto opType = adaptor.getA().getType();
59  Value src;
60  if (adaptor.getSrc()) {
61  src = adaptor.getSrc();
62  } else if (adaptor.getConstantSrc()) {
63  src = LLVM::ConstantOp::create(rewriter, loc, opType,
64  adaptor.getConstantSrcAttr());
65  } else {
66  auto zeroAttr = rewriter.getZeroAttr(opType);
67  src = LLVM::ConstantOp::create(rewriter, loc, opType, zeroAttr);
68  }
69 
70  return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
71 }
72 
74 x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
75  const LLVMTypeConverter &typeConverter,
76  RewriterBase &rewriter) {
77  SmallVector<Value> intrinsicOperands(operands);
78  // Dot product of all elements, broadcasted to all elements.
79  Value scale =
80  LLVM::ConstantOp::create(rewriter, getLoc(), rewriter.getI8Type(), 0xff);
81  intrinsicOperands.push_back(scale);
82 
83  return intrinsicOperands;
84 }
85 
86 SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
87  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
88  RewriterBase &rewriter) {
89  Adaptor adaptor(operands, *this);
90  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
91  typeConverter, rewriter)};
92 }
93 
94 SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
95  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
96  RewriterBase &rewriter) {
97  Adaptor adaptor(operands, *this);
98  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
99  typeConverter, rewriter)};
100 }
101 
102 SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
103  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
104  RewriterBase &rewriter) {
105  Adaptor adaptor(operands, *this);
106  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
107  typeConverter, rewriter)};
108 }
109 
110 #define GET_OP_CLASSES
111 #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
IntegerType getI8Type()
Definition: Builders.cpp:59
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423